# Compare molecule embeddings between models

In [None]:
import inspect
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy.stats import pearsonr, spearmanr

from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs

from jointformer.configs.dataset import DatasetConfig
from jointformer.configs.tokenizer import TokenizerConfig
from jointformer.utils.datasets.auto import AutoDataset
from jointformer.utils.tokenizers.auto import AutoTokenizer
from jointformer.configs.model import ModelConfig
from jointformer.models.auto import AutoModel
from jointformer.utils.properties.smiles.physchem import PhysChem

## Parameters setup

In [2]:
# Provide parameters for running the script

# Directory to save the outputs
OUTPUT_DIR = 'embeddings_analysis_data/embeddings_analysis_output/'

DATA_DIR = "../../../../data"

# Paths to dataset and tokenizer configs
PATH_TO_DATASET_CONFIG   = '../../configs/datasets/guacamol/physchem/'

# Set list of properties to consider as labels
PROPERTIES = ['MolLogP', 'TPSA', 'QED', 'MolWT']

# If to take a sample of molecules for inference, can be None
NUM_SAMPLES = 20000

# Type of dimensionality reduction
DIM_REDUCTION = 'pca'

# Specify which dimensionalities of reduced embeddings to use for 2D plot
# if "first_two" then first two dimensions are used. If "top_correlated", search for most correlated
# dimensions with each property
PERFORM_DIM_REDUCTION = True
DIMENSIONS_FOR_VISUALIZATION = "first_two"
CORR_TYPE = "pearson"

FIGSIZE = (18, 8)
SCATTERPLOT_KWARGS = {
    'cmap': 'viridis',
    'alpha': 0.6,
}

SEED = 42

In [3]:
# Jointformer parameters
PATH_TO_JOINTFORMER_TOKENIZER_CONFIG = '../../configs/tokenizers/smiles'

# Path to vocabulary file
PATH_TO_JOINTFORMER_VOCAB = "../../data/vocabularies/deepchem.txt"

# Path to model config
PATH_TO_JOINTFORMER_MODEL_CONFIG = '../../configs/models/jointformer/'

# Path to the pre-trained model checkpoint
PATH_TO_JOINTFORMER_PRETRAINED_MODEL_CKPT = "../../../../checkpoints/jointformer/no_separate_task_token/cls_embedding/05082024/ckpt.pt"

In [4]:
# Chemberta parameters
PATH_TO_CHEMBERTA_MODEL_CONFIG = '../../configs//models/chemberta'
PATH_TO_CHEMBERTA_TOKENIZER_CONFIG = "../../configs/tokenizers/chemberta"
CHEMBERTA_CHECKPOINT = "DeepChem/ChemBERTa-77M-MTR"

## Utils

In [5]:
def compute_embeddings(inputs, model, embedding_func, tokenizer, batch_size=32, **tokenizer_call_kwargs):
    """Compute embeddings in batches."""
    embeddings = []
    for i in range(0, len(inputs), batch_size):
        inputs_batch = tokenizer(inputs[i:i + batch_size], **tokenizer_call_kwargs)
        embeddings_batch = embedding_func(model, inputs_batch).detach()
        embeddings.append(embeddings_batch)
    return torch.cat(embeddings)

def two_D_reduction(X, reducer="pca", **reducer_kwargs):
    """
    Performs dimensionality reduction on the input data.

    Args:
        X (array-like): Input data.
        reducer (str, optional): The dimensionality reduction method to use. Options are 'pca' and 'tsne'. Defaults to 'pca'.
        **reducer_kwargs: Additional keyword arguments to pass to the dimensionality reduction method.

    Returns:
        array-like: The reduced data.
    """
    if reducer == "pca":
        reducer = PCA(**reducer_kwargs)
    elif reducer == "tsne":
        reducer = TSNE(**reducer_kwargs)
    else:
        raise ValueError(f"Unknown reducer: {reducer}")

    X_reduced = reducer.fit_transform(X)
    return X_reduced

def plot_2D_data_matplotlib(X_2d, ax=None, axis_titles=None, title=None, **scatter_kwargs):
    """
    Plots the reduced data using matplotlib on the provided or a new axis.

    Args:
        X_2d (array-like): Reduced data.
        ax (matplotlib.axes.Axes, optional): An existing axis to plot on. If None, a new figure and axis are created.
        axis_aliases (list of str, optional): The aliases for the axes. Defaults to None.
        **scatter_kwargs: Additional keyword arguments to pass to plt.scatter.

    Returns:
        matplotlib.axes.Axes: The matplotlib axes containing the plot.
    """
    # If no axis is provided, create a new figure and axis
    p = ax.scatter(X_2d[:, 0], X_2d[:, 1], **scatter_kwargs)
    if "c" in scatter_kwargs:
        plt.colorbar(p, ax=ax)
    ax.set_xlabel(axis_titles[0] if axis_titles is not None else "")
    ax.set_ylabel(axis_titles[1] if axis_titles is not None else "")
    if title is not None:
        ax.set_title(title)

    return ax
    
def get_most_correlated_dimensions(X, y, method="pearson", absolute_vals=True):
    """
    Get the two most correlated dimensions of X with a reference vector y w.r.t. Pearson or Spearman correlation.

    Args:
        X (array-like): Input data.
        y (array-like): Reference vector.
        method (str, optional): The correlation method to use. Options are 'pearson' and 'spearman'. Defaults to 'pearson'.
        absolute_vals (bool, optional): Whether to consider the absolute values of the correlations. Defaults to True.

    Returns:
        tuple: The indices of the two most correlated dimensions.
    """
    # Compute the correlation between each dimension of X and y
    if method == "pearson":
        correlations = np.array([pearsonr(X[:, i], y)[0] for i in range(X.shape[1])])
    elif method == "spearman":
        correlations = np.array([spearmanr(X[:, i], y)[0] for i in range(X.shape[1])])
    else:
        raise ValueError(f"Unknown correlation method: {method}")

    # Get the indices of the two most correlated dimensions
    if absolute_vals:
        most_correlated_dims = np.argsort(np.abs(correlations))[::-1][:2]
    else:
        most_correlated_dims = np.argsort(correlations)[::-1][:2]

    return most_correlated_dims

## Load the data for inference

In [None]:
# Get dataset to infer on
dataset_config = DatasetConfig.from_config_file(PATH_TO_DATASET_CONFIG)
dataset = AutoDataset.from_config(dataset_config, data_dir=DATA_DIR, split='test')

print(f"Dataset size: {len(dataset)}")

In [None]:
# Get a list of property names
phys_chem = PhysChem()
property_names = phys_chem.descriptor_list

# Get indexes of the properties to consider
property_idx_dict = {prop: list(map(lambda x: x.lower(), property_names)).index(prop.lower()) for prop in PROPERTIES}

# Get indexes of the properties to consider
props, idxs = [], []
for prop in PROPERTIES:
    if prop.lower() not in list(map(lambda x: x.lower(), property_names)):
        raise ValueError(f"Property {prop} not found in the list of available properties.")
    idx = list(map(lambda x: x.lower(), property_names)).index(prop.lower()) 
    props.append(prop)
    idxs.append(idx)
    print(f"User provided property name {prop} mapped to property name {property_names[idx]} with index {idx}.")

# Extract SMILES
molecules_list = dataset.data

# Extract proper labels corresponding to properties of choice
labels = dataset.target[:, idxs]
labels_df = pd.DataFrame(labels, columns=props)

# Make sure you have correct labels data
for prop in PROPERTIES:
    df_values = labels_df[prop].values
    idx = list(map(lambda x: x.lower(), property_names)).index(prop.lower())
    assert property_names[idx].lower() == prop.lower(), f"Property {prop} not found in the list of available properties."
    assert np.allclose(df_values, dataset.target[:, idx]), f"Property {prop} values do not match."

# Optionally, take sample of the data
if NUM_SAMPLES is not None:
    # Sample indices
    np.random.seed(SEED)
    sample_indices = np.random.choice(len(molecules_list), NUM_SAMPLES, replace=False)
    molecules_list = [molecules_list[i] for i in sample_indices]
    labels_df = labels_df.iloc[sample_indices]

assert len(molecules_list) == len(labels_df), "Number of molecules and labels do not match."

print()
print(f"Number of molecules to infer: {len(molecules_list)}")
print(f"Number of properties: {len(labels_df.columns)}")

## Get embeddings from models

In [8]:
# Initialize dictionary to store embeddings
embeddings_dict = {}

### Jointformer

In [None]:
# Get tokenizer
tokenizer_config = TokenizerConfig.from_config_file(PATH_TO_JOINTFORMER_TOKENIZER_CONFIG)
tokenizer_config.path_to_vocabulary = PATH_TO_JOINTFORMER_VOCAB
tokenizer = AutoTokenizer.from_config(tokenizer_config)


model_config = ModelConfig.from_config_file(PATH_TO_JOINTFORMER_MODEL_CONFIG)
model = AutoModel.from_config(model_config)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load weights
model.load_pretrained(PATH_TO_JOINTFORMER_PRETRAINED_MODEL_CKPT)

model.eval()
model.to(device)

# Compute embeddings
smiles_encoder = model.to_smiles_encoder(tokenizer, batch_size=4, device="cpu")

embeddings = smiles_encoder.encode(molecules_list)

# Store embeddings
embeddings_dict["Jointformer"] = embeddings

### Chemberta

In [None]:
tokenizer_config = TokenizerConfig.from_config_file(PATH_TO_CHEMBERTA_TOKENIZER_CONFIG)
tokenizer = AutoTokenizer.from_config(tokenizer_config)

# Get model
model_config = ModelConfig.from_config_file(PATH_TO_CHEMBERTA_MODEL_CONFIG)
model_config.pretrained_filepath = CHEMBERTA_CHECKPOINT
model = AutoModel.from_config(model_config)

smiles_encoder = model.to_smiles_encoder(tokenizer, batch_size=4, device='cpu')

# Compute embeddings
embeddings = smiles_encoder.encode(molecules_list)

# Reduce dimensionality
if PERFORM_DIM_REDUCTION:
    reduced_embeddings = two_D_reduction(embeddings, reducer=DIM_REDUCTION, n_components=embeddings.shape[1],
                                         random_state=SEED)
else:
    reduced_embeddings = embeddings

embeddings_dict["ChemBERTa"] = embeddings

In [None]:
for k in embeddings_dict.keys():
    print(f"Embeddings shape for {k}: {embeddings_dict[k].shape}")

## Visualize embeddings

In [None]:
# Visualize embeddings - multiple targets
if OUTPUT_DIR is not None and not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

NUM_COLS = len(PROPERTIES)
NUM_ROWS = 2

fig, axes = plt.subplots(NUM_ROWS, NUM_COLS, figsize=FIGSIZE)

if DIM_REDUCTION == "pca":
    axis_alias = 'PCA'
elif DIM_REDUCTION == "tsne":
    axis_alias = 'tSNE'

# Iterate over properties
for i, prop in enumerate(PROPERTIES):
    labels = labels_df[prop].values

    # Plot embeddings
    # Jointformer
    # Establish which dimensions to use for visualization
    reduced_embeddings = embeddings_dict["Jointformer"]
    if DIMENSIONS_FOR_VISUALIZATION == "first_two":
        current_2d_data = reduced_embeddings[:, :2]
        ax1_alias = f"{axis_alias} 1"
        ax2_alias = f"{axis_alias} 2"
    elif DIMENSIONS_FOR_VISUALIZATION == "top_correlated":
        most_correlated_dims = get_most_correlated_dimensions(reduced_embeddings, labels_df[prop].values, method=CORR_TYPE)
        current_2d_data = reduced_embeddings[:, most_correlated_dims]
        ax1_alias = f"{axis_alias} {most_correlated_dims[0] + 1}"
        ax2_alias = f"{axis_alias} {most_correlated_dims[1] + 1}"
    else:
        raise ValueError(f"Unknown value for DIMENSIONS_FOR_VISUALIZATION: {DIMENSIONS_FOR_VISUALIZATION}")
    ax = axes[0, i]
    plot_2D_data_matplotlib(current_2d_data, ax=ax, c=labels, 
                            title=f"Jointformer",
                            **SCATTERPLOT_KWARGS)
    ax.set_xlabel(ax1_alias)
    ax.set_ylabel(ax2_alias)
   
    # Add colorbar title
    cbar = ax.collections[0].colorbar
    cbar.set_label(prop)
    
    
    # ChemBERTa
    reduced_embeddings = embeddings_dict["ChemBERTa"]
    if DIMENSIONS_FOR_VISUALIZATION == "first_two":
        current_2d_data = reduced_embeddings[:, :2]
        ax1_alias = f"{axis_alias} 1"
        ax2_alias = f"{axis_alias} 2"
    elif DIMENSIONS_FOR_VISUALIZATION == "top_correlated":
        most_correlated_dims = get_most_correlated_dimensions(reduced_embeddings, labels_df[prop].values, method=CORR_TYPE)
        current_2d_data = reduced_embeddings[:, most_correlated_dims]
        ax1_alias = f"{axis_alias} {most_correlated_dims[0] + 1}"
        ax2_alias = f"{axis_alias} {most_correlated_dims[1] + 1}"
    else:
        raise ValueError(f"Unknown value for DIMENSIONS_FOR_VISUALIZATION: {DIMENSIONS_FOR_VISUALIZATION}")
    ax = axes[1, i]
    plot_2D_data_matplotlib(current_2d_data, ax=ax, c=labels, 
                            title=f"ChemBERTa",
                            **SCATTERPLOT_KWARGS)
    ax.set_xlabel(ax1_alias)
    ax.set_ylabel(ax2_alias)
   
    # Add colorbar title
    cbar = ax.collections[0].colorbar
    cbar.set_label(prop)
    
plt.tight_layout()

plt.show()

## Distance plots

In [13]:
from scipy.spatial.distance import cosine, euclidean

In [14]:
NUM_PAIRS = 20000

In [15]:
# Get random pairs of molecules
np.random.seed(SEED)
pairs_indices = np.random.choice(len(molecules_list), NUM_PAIRS, replace=False)
# Get tuples of indices
pairs = [(pairs_indices[i], pairs_indices[i + 1]) for i in range(0, len(pairs_indices), 2)]

jointformer_distances = []
chemberta_distances = []
properties_differences = {prop: [] for prop in PROPERTIES}

for i, (idx1, idx2) in enumerate(pairs):
    # Get embeddings
    jointformer_emb1 = embeddings_dict["Jointformer"][idx1]
    jointformer_emb2 = embeddings_dict["Jointformer"][idx2]
    chemberta_emb1 = embeddings_dict["ChemBERTa"][idx1]
    chemberta_emb2 = embeddings_dict["ChemBERTa"][idx2]
    
    # Compute cosine similarity
    jointformer_cosine = cosine(jointformer_emb1, jointformer_emb2)
    chemberta_cosine = cosine(chemberta_emb1, chemberta_emb2)

    jointformer_distances.append(jointformer_cosine)
    chemberta_distances.append(chemberta_cosine)

for prop in PROPERTIES:
    for i, (idx1, idx2) in enumerate(pairs):
        # Get property values
        prop_val1 = labels_df[prop].values[idx1]
        prop_val2 = labels_df[prop].values[idx2]
        prop_diff = np.abs(prop_val1 - prop_val2)
        properties_differences[prop].append(prop_diff)

In [None]:
# Plot the results

fig, axes = plt.subplots(1, len(PROPERTIES), figsize=(15, 5))

for i, prop in enumerate(PROPERTIES):
    ax = axes[i]
    ax.scatter(jointformer_distances, properties_differences[prop], label="Jointformer", alpha=0.6)
    ax.scatter(chemberta_distances, properties_differences[prop], label="ChemBERTa", alpha=0.6)
    ax.set_xlabel("Cosine distance")
    ax.set_ylabel(f"Absolute difference in {prop}")
    ax.set_title(f"{prop}")
    ax.legend()

plt.tight_layout() 
plt.show()