In [1]:
import pandas as pd
import rdkit.Chem as Chem
from rdkit.rdBase import BlockLogs
from sklearn.model_selection import GroupShuffleSplit
import numpy as np
from utils import standardize, get_scaffold

from torch.utils.data import Dataset, DataLoader
import random

import lightning as L

from chemprop import data, featurizers, models, nn

import wandb
import os
from dotenv import load_dotenv
import torch

RANDOM_SEED = 42

def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

set_seeds(RANDOM_SEED)

load_dotenv('.env.secret')
wandb.login(key=os.environ['WANDB_API_KEY'])

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/rahul/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mrahul-e-dev[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
def mol_to_inchi(mol):
    with BlockLogs():
        return Chem.MolToInchi(mol)

df = pd.read_csv("./GSK_HepG2.csv")
df = df.iloc[:, 1:]
df.columns = ['smiles', 'per_inhibition']
df['per_inhibition'] = -df['per_inhibition']


# standardize and convert to inchi
df['mol'] = df['smiles'].map(standardize)
df = df.dropna(subset=['mol'])
df['inchi'] = df['mol'].map(mol_to_inchi)
df = df.groupby(["inchi"]).filter(lambda x: len(x) == 1).reset_index(drop=True)

clusters, _ = pd.factorize(
    df['mol']
        .map(Chem.MolToSmiles) # type: ignore
        .map(get_scaffold)
)
clusters = pd.Series(clusters)


df = df.drop(['smiles', 'inchi'], axis=1)

In [3]:
splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED)
train_idxs, val_test_idxs = next(splitter.split(df, groups=clusters))
df_train = df.loc[train_idxs].reset_index(drop=True)
df_val_test = df.loc[val_test_idxs].reset_index(drop=True)
clusters_val_test = clusters.iloc[val_test_idxs].reset_index(drop=True)


splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED, test_size=0.5)
val_idxs, test_idxs = next(splitter.split(df_val_test, groups=clusters_val_test))
df_val = df_val_test.loc[val_idxs].reset_index(drop=True)
df_test = df_val_test.loc[test_idxs].reset_index(drop=True)

In [None]:
class ShuffledPairsDataset(Dataset):
    def __init__(self, df, sample_ratio=5):
        super().__init__()
        self.df = df
        self.sample_ratio = sample_ratio
        self.featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
        self.pairs: list = []
        self.mg_cache: list = []

        self.build_mg_cache()
        self.update_pairs()

    def build_mg_cache(self):
        self.mg_cache = self.df['mol'].map(self.featurizer).tolist()

    def update_pairs(self):
        N = len(self.df)

        weights = self.df['per_inhibition'].to_numpy()
        weights = np.where(weights > -15, 8.0, 1.0)
        weights = weights / weights.sum()

        pairs = [
            (i, random.randint(0, N-1))
            for i in range(N)
            for _ in np.random.choice(
                len(df), 
                size=(5,), 
                p=weights, 
                replace=False
            )
        ]

        pairs += [(j, i) for i,j in pairs]
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        left_idx, right_idx = self.pairs[idx]
        left_mg, right_mg = self.mg_cache[left_idx], self.mg_cache[right_idx]
        delta = (
            self.df['per_inhibition'][left_idx] > self.df['per_inhibition'][right_idx]
        ).astype(float)

        left_datum = data.datasets.Datum(
            left_mg, None, None, np.array([delta]), 1.0, None, None
        )

        right_datum = data.datasets.Datum(
            right_mg, None, None, None, 1.0, None, None
        )

        return [left_datum, right_datum]

In [None]:
class ExemplarDataset(Dataset):
    def __init__(self, df_regular, df_exemplars) -> None:
        super().__init__()
        self.df_exemplars = df_exemplars
        self.df_regular = df_regular
        self.featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
        self.pairs = []
        self.exemplar_mg_cache: list = []
        self.regular_mg_cache: list = []

        self.build_pairs()
        self.build_mg_cache()

    def build_mg_cache(self):
        self.exemplar_mg_cache = self.df_exemplars['mol'].map(self.featurizer).tolist()
        self.regular_mg_cache = self.df_regular['mol'].map(self.featurizer).tolist()

    def build_pairs(self):
        self.pairs = [
            (i, j)
            for i in range(len(self.df_regular))
            for j in range(len(self.df_exemplars))
        ]

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        regular_idx, exemplar_idx = self.pairs[idx]
        regular_mol = self.regular_mg_cache[regular_idx]
        exemplar_mol = self.exemplar_mg_cache[exemplar_idx]
        delta = (
            self.df_regular['per_inhibition'][regular_idx] > 
            self.df_exemplars['per_inhibition'][exemplar_idx]
        ).astype(float)

        regular_datum = data.datasets.Datum(
            regular_mol, None, None, np.array([delta]), 1.0, None, None
        )

        exemplar_datum = data.datasets.Datum(
            exemplar_mol, None, None, None, 1.0, None, None
        )

        return [regular_datum, exemplar_datum]

In [None]:
# see https://docs.pytorch.org/docs/stable/notes/randomness.html#dataloader
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

class ConstrastiveDataModule(L.LightningDataModule):
    def __init__(self, df_train: pd.DataFrame, df_val: pd.DataFrame):
        super().__init__()
        self.df_train = df_train
        self.df_val = df_val

    def setup(self, stage=None):
        pass

    def train_dataloader(self):
        train_dataset = ShuffledPairsDataset(self.df_train, sample_ratio=10)
        return DataLoader(
            dataset=train_dataset,
            batch_size=1024,
            shuffle=True,
            collate_fn=data.dataloader.collate_multicomponent,
            worker_init_fn=seed_worker,
            num_workers=12,
        )
    
    def val_dataloader(self):
        exemplar_df = pd.concat([
            df_train[df_train['per_inhibition'] > 0].sample(50),
            df_train[df_train['per_inhibition'] < 0].sample(50)
        ]).reset_index(drop=True)

        val_dataset = ExemplarDataset(self.df_val,exemplar_df)
        return DataLoader(
            dataset=val_dataset,
            batch_size=2048,
            shuffle=False,
            collate_fn=data.dataloader.collate_multicomponent,
            worker_init_fn=seed_worker,
            num_workers=12,
        )

In [None]:
fdims = featurizers.SimpleMoleculeMolGraphFeaturizer().shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).
mcmp = nn.MulticomponentMessagePassing(
    blocks=[nn.BondMessagePassing(*fdims), nn.BondMessagePassing(*fdims)],
    n_components=2,
)
agg = nn.NormAggregation()
ffn = nn.BinaryClassificationFFN(n_tasks=1, input_dim=mcmp.output_dim)
batch_norm = True
metric_list = [nn.metrics.BinaryF1Score(), nn.metrics.BinaryAUPRC(), nn.metrics.BinaryAUROC()]
mpnn = models.multi.MulticomponentMPNN(mcmp, agg, ffn, batch_norm, metric_list)
mpnn.max_lr = 0.01

In [None]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

wandb.finish()
wandb_logger = WandbLogger(project="chemprop_delta_clf", log_model="all")
wandb_logger.experiment.mark_preempting()

trainer = L.Trainer(
    logger=wandb_logger,
    enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=50,  # number of epochs to train for
    reload_dataloaders_every_n_epochs=1,
    log_every_n_steps=50,
    callbacks=[
        EarlyStopping(monitor="val/prc", mode="max", verbose=True, patience=10),
        ModelCheckpoint(monitor="val/prc", mode="max", save_top_k=2)
    ]
)

contrastive_data_module = ConstrastiveDataModule(df_train, df_val)
trainer.fit(mpnn, datamodule=contrastive_data_module)

In [None]:
from pathlib import Path

run_id = wandb_logger.experiment.id
checkpoint_reference = f"rahul-e-dev/chemprop_delta_clf/model-{run_id}:best"
artifact_dir = wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")


ckpt = torch.load(Path(artifact_dir) / "model.ckpt", map_location='cpu', weights_only=False)
hparams = ckpt.get('hyper_parameters', ckpt.get('hparams', {}))
mpnn.load_state_dict(ckpt['state_dict'])

trainer = L.Trainer(
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
)

exemplar_df = df_train[df_train['per_inhibition'] > -15].sample(100).reset_index(drop=True)

exemplar_ds = ExemplarDataset(
    df_test,
    exemplar_df
)

exemplar_dl = DataLoader(
    dataset=exemplar_ds,
    batch_size=2048,
    shuffle=False,
    collate_fn=data.dataloader.collate_multicomponent,
    num_workers=12,
)

test_ds_preds = trainer.predict(model=mpnn, dataloaders=exemplar_dl)
test_ds_preds = torch.cat(test_ds_preds)

In [None]:
from collections import defaultdict

def calc(x):
    x = np.array(x)
    return (x>=0.5).sum()


deltas = defaultdict(list)
for (i, j), delta in zip(exemplar_ds.pairs, test_ds_preds.squeeze()):
    exemplar_val = exemplar_ds.df_exemplars['per_inhibition'][j]
    deltas[i].append(float(delta.item()))


df_test['deltas'] = deltas
df_test['pred_probs'] = df_test['deltas'].map(calc)
df_test['asd'] = df_test['deltas'].map(np.mean)
df_test['preds'] = df_test['pred_probs'] > 3
df_test['true'] = df_val['per_inhibition'] > -15

In [None]:
df_test

In [None]:
[
        f1_score(df_test['true'], df_test['preds']),
        precision_score(df_test['true'], df_test['preds']),
        recall_score(df_test['true'], df_test['preds']),
        accuracy_score(df_test['true'], df_test['preds'])
    ]

In [None]:
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

wandb_logger.log_table(
    'final_metrics', 
    ['f1', 'precision', 'recall', 'accuracy'],
    [[
        f1_score(df_val['true'], df_val['preds']),
        precision_score(df_val['true'], df_val['preds']),
        recall_score(df_val['true'], df_val['preds']),
        accuracy_score(df_val['true'], df_val['preds'])
    ]]
)

In [None]:
wandb.finish()