# Analyze molecule embeddings

In [None]:
import inspect
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy.stats import pearsonr, spearmanr

from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs

from jointformer.configs.task import TaskConfig
from jointformer.utils.datasets.auto import AutoDataset
from jointformer.utils.tokenizers.auto import AutoTokenizer
from jointformer.configs.model import ModelConfig
from jointformer.models.auto import AutoModel
from jointformer.utils.properties.smiles.physchem import PhysChem

## Parameters setup

Only modify below cell and run rest of the notebook. 

Alternatively, if you have `papermill` installed, you can run the notebook with:

`papermill 01-embeddings-analysis.ipynb out_notebook.ipynb -f embedings_analysis_config.yml`

from the command line, where `embedings_analysis_config.yml` stores the parameters below.

In [None]:
# Provide parameters for running the script

# Directory to save the outputs
OUTPUT_DIR = 'embeddings_analysis_data/embeddings_analysis_output/'

# Path to the task configuration file
PATH_TO_TASK_CONFIG = '../../configs/tasks/guacamol/physchem/'

# Path to vocabulary file
PATH_TO_VOCAB = "../../data/vocabularies/deepchem.txt"

# Path to model config
PATH_TO_MODEL_CONFIG = '../../configs/models/jointformer/'

# Path to the pre-trained model checkpoint
PATH_TO_PRETRAINED_MODEL_CKPT = "../../../../checkpoints/lm_physchem/ckpt.pt"

# Set list of properties to consider as labels
PROPERTIES = ['MolLogP', 'TPSA', 'QED', 'MolWT']

# If to take a sample of molecules for inference, can be None
NUM_SAMPLES = 1000

# Type of dimensionality reduction
DIM_REDUCTION = 'pca'

# Specify which dimensionalities of reduced embeddings to use for 2D plot
# if "first_two" then first two dimensions are used. If "top_correlated", search for most correlated
# dimensions with each property
PERFORM_DIM_REDUCTION = True
DIMENSIONS_FOR_VISUALIZATION = "top_correlated"
CORR_TYPE = "pearson"

FIGSIZE = (10, 7)
SCATTERPLOT_KWARGS = {
    'cmap': 'viridis'
}

SEED = 42

## Utils

In [None]:
def compute_embeddings(inputs, model, embedding_func, tokenizer, batch_size=32, **tokenizer_call_kwargs):
    """Compute embeddings in batches."""
    embeddings = []
    for i in range(0, len(inputs), batch_size):
        inputs_batch = tokenizer(inputs[i:i + batch_size], **tokenizer_call_kwargs)
        embeddings_batch = embedding_func(model, inputs_batch).detach()
        embeddings.append(embeddings_batch)
    return torch.cat(embeddings)

def two_D_reduction(X, reducer="pca", **reducer_kwargs):
    """
    Performs dimensionality reduction on the input data.

    Args:
        X (array-like): Input data.
        reducer (str, optional): The dimensionality reduction method to use. Options are 'pca' and 'tsne'. Defaults to 'pca'.
        **reducer_kwargs: Additional keyword arguments to pass to the dimensionality reduction method.

    Returns:
        array-like: The reduced data.
    """
    if reducer == "pca":
        reducer = PCA(**reducer_kwargs)
    elif reducer == "tsne":
        reducer = TSNE(**reducer_kwargs)
    else:
        raise ValueError(f"Unknown reducer: {reducer}")

    X_reduced = reducer.fit_transform(X)
    return X_reduced

def plot_2D_data_matplotlib(X_2d, ax=None, axis_titles=None, title=None, **scatter_kwargs):
    """
    Plots the reduced data using matplotlib on the provided or a new axis.

    Args:
        X_2d (array-like): Reduced data.
        ax (matplotlib.axes.Axes, optional): An existing axis to plot on. If None, a new figure and axis are created.
        axis_aliases (list of str, optional): The aliases for the axes. Defaults to None.
        **scatter_kwargs: Additional keyword arguments to pass to plt.scatter.

    Returns:
        matplotlib.axes.Axes: The matplotlib axes containing the plot.
    """
    # If no axis is provided, create a new figure and axis
    p = ax.scatter(X_2d[:, 0], X_2d[:, 1], **scatter_kwargs)
    if "c" in scatter_kwargs:
        plt.colorbar(p, ax=ax)
    ax.set_xlabel(axis_titles[0] if axis_titles is not None else "")
    ax.set_ylabel(axis_titles[1] if axis_titles is not None else "")
    if title is not None:
        ax.set_title(title)

    return ax
    
def get_most_correlated_dimensions(X, y, method="pearson", absolute_vals=True):
    """
    Get the two most correlated dimensions of X with a reference vector y w.r.t. Pearson or Spearman correlation.

    Args:
        X (array-like): Input data.
        y (array-like): Reference vector.
        method (str, optional): The correlation method to use. Options are 'pearson' and 'spearman'. Defaults to 'pearson'.
        absolute_vals (bool, optional): Whether to consider the absolute values of the correlations. Defaults to True.

    Returns:
        tuple: The indices of the two most correlated dimensions.
    """
    # Compute the correlation between each dimension of X and y
    if method == "pearson":
        correlations = np.array([pearsonr(X[:, i], y)[0] for i in range(X.shape[1])])
    elif method == "spearman":
        correlations = np.array([spearmanr(X[:, i], y)[0] for i in range(X.shape[1])])
    else:
        raise ValueError(f"Unknown correlation method: {method}")

    # Get the indices of the two most correlated dimensions
    if absolute_vals:
        most_correlated_dims = np.argsort(np.abs(correlations))[::-1][:2]
    else:
        most_correlated_dims = np.argsort(correlations)[::-1][:2]

    return most_correlated_dims

## Load the data for inference

In [None]:
task_config = TaskConfig.from_config_file(PATH_TO_TASK_CONFIG)
task_config.path_to_vocabulary = PATH_TO_VOCAB

dataset = AutoDataset.from_config(task_config, split='test')
tokenizer = AutoTokenizer.from_config(task_config)

In [None]:
# Get a list of property names
phys_chem = PhysChem()
property_names = phys_chem.descriptor_list

# Get indexes of the properties to consider
property_idx_dict = {prop: list(map(lambda x: x.lower(), property_names)).index(prop.lower()) for prop in PROPERTIES}

# Get indexes of the properties to consider
props, idxs = [], []
for prop in PROPERTIES:
    if prop.lower() not in list(map(lambda x: x.lower(), property_names)):
        raise ValueError(f"Property {prop} not found in the list of available properties.")
    idx = list(map(lambda x: x.lower(), property_names)).index(prop.lower()) 
    props.append(prop)
    idxs.append(idx)
    print(f"User provided property name {prop} mapped to property name {property_names[idx]} with index {idx}.")

# Extract SMILES
molecules_list = dataset.data

# Extract proper labels corresponding to properties of choice
labels = dataset.target[:, 0, idxs]
labels_df = pd.DataFrame(labels, columns=props)

# Make sure you have correct labels data
for prop in PROPERTIES:
    df_values = labels_df[prop].values
    idx = list(map(lambda x: x.lower(), property_names)).index(prop.lower())
    assert property_names[idx].lower() == prop.lower(), f"Property {prop} not found in the list of available properties."
    assert np.allclose(df_values, dataset.target[:, 0, idx]), f"Property {prop} values do not match."

# Optionally, take sample of the data
if NUM_SAMPLES is not None:
    # Sample indices
    np.random.seed(SEED)
    sample_indices = np.random.choice(len(molecules_list), NUM_SAMPLES, replace=False)
    molecules_list = [molecules_list[i] for i in sample_indices]
    labels_df = labels_df.iloc[sample_indices]

assert len(molecules_list) == len(labels_df), "Number of molecules and labels do not match."

print()
print(f"Number of molecules to infer: {len(molecules_list)}")
print(f"Number of properties: {len(labels_df.columns)}")

## Load the model

In [None]:
model_config = ModelConfig.from_config_file(PATH_TO_MODEL_CONFIG)
model = AutoModel.from_config(model_config)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load weights
model.load_pretrained(PATH_TO_PRETRAINED_MODEL_CKPT)

model.eval()
model.to(device)

## 2D visualizations of embeddings for a set of molecules

In [None]:
# Define an embeddings function
embeddings_func = lambda model, inputs: model(**inputs)["embeddings"].mean(1)

# Compute embeddings
embeddings = compute_embeddings(molecules_list, model, embeddings_func, tokenizer, batch_size=2)
embeddings = embeddings.cpu().numpy()

# Reduce dimensionality
if PERFORM_DIM_REDUCTION:
    reduced_embeddings = two_D_reduction(embeddings, reducer=DIM_REDUCTION, n_components=embeddings.shape[1],
                                         random_state=SEED)
else:
    reduced_embeddings = embeddings

In [None]:
# Visualize embeddings - multiple targets
if OUTPUT_DIR is not None and not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

NUM_COLS = 2
NUM_ROWS = int(np.ceil(len(PROPERTIES) / NUM_COLS))

fig, axes = plt.subplots(NUM_ROWS, NUM_COLS, figsize=FIGSIZE)

if DIM_REDUCTION == "pca":
    axis_alias = 'PCA'
elif DIM_REDUCTION == "tsne":
    axis_alias = 'tSNE'

for i, prop in enumerate(PROPERTIES):
    # Establish which dimensions to use for visualization
    if DIMENSIONS_FOR_VISUALIZATION == "first_two":
        current_2d_data = reduced_embeddings[:, :2]
        ax1_alias = f"{axis_alias} 1"
        ax2_alias = f"{axis_alias} 2"
    elif DIMENSIONS_FOR_VISUALIZATION == "top_correlated":
        most_correlated_dims = get_most_correlated_dimensions(reduced_embeddings, labels_df[prop].values, method=CORR_TYPE)
        current_2d_data = reduced_embeddings[:, most_correlated_dims]
        ax1_alias = f"{axis_alias} {most_correlated_dims[0] + 1}"
        ax2_alias = f"{axis_alias} {most_correlated_dims[1] + 1}"
    else:
        raise ValueError(f"Unknown value for DIMENSIONS_FOR_VISUALIZATION: {DIMENSIONS_FOR_VISUALIZATION}")

    # Plot the data
    ax = axes[i // NUM_COLS, i % NUM_COLS]
    plot_2D_data_matplotlib(current_2d_data, ax=ax, c=labels_df[prop].values, 
            title=f"Property: {prop}",
            **SCATTERPLOT_KWARGS)
    ax.set_xlabel(ax1_alias)
    ax.set_ylabel(ax2_alias)
   
    # Add colorbar title
    cbar = ax.collections[0].colorbar
    cbar.set_label(prop)

plt.tight_layout()

# Save the plot
if OUTPUT_DIR is not None:
    plt.savefig(os.path.join(OUTPUT_DIR, f"embeddings_{DIM_REDUCTION}_{DIMENSIONS_FOR_VISUALIZATION}.pdf"))

plt.show()