In [None]:
import pandas as pd
import rdkit.Chem as Chem
from rdkit.rdBase import BlockLogs
from sklearn.model_selection import GroupShuffleSplit
import numpy as np
from utils import standardize, get_scaffold

from torch.utils.data import Dataset, DataLoader, IterableDataset
import random
from chemprop.featurizers.molgraph.reaction import CondensedGraphOfReactionFeaturizer
from chemprop.data.datapoints import ReactionDatapoint
from chemprop.data.datasets import Datum

import lightning as L
from chemprop.data.collate import collate_batch
from chemprop.data.dataloader import build_dataloader

import wandb
import os
from dotenv import load_dotenv
import torch
import networkx as nx

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)


load_dotenv('.env.secret')
wandb.login(key=os.environ['WANDB_API_KEY'])

In [None]:
def mol_to_inchi(mol):
    with BlockLogs():
        return Chem.MolToInchi(mol)

df = pd.read_csv("./GSK_HepG2.csv")
df = df.iloc[:, 1:]
df.columns = ['smiles', 'per_inhibition']
df['per_inhibition'] = -df['per_inhibition']


# standardize and convert to inchi
df['mol'] = df['smiles'].map(standardize)
df = df.dropna(subset=['mol'])
df['inchi'] = df['mol'].map(mol_to_inchi)
df = df.groupby(["inchi"]).filter(lambda x: len(x) == 1).reset_index(drop=True)

clusters, _ = pd.factorize(
    df['mol']
        .map(Chem.MolToSmiles) # type: ignore
        .map(get_scaffold)
)
clusters = pd.Series(clusters)


df = df.drop(['smiles', 'inchi'], axis=1)

In [None]:
splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED)
train_idxs, val_test_idxs = next(splitter.split(df, groups=clusters))
df_train = df.loc[train_idxs].reset_index(drop=True)
df_val_test = df.loc[val_test_idxs].reset_index(drop=True)
clusters_val_test = clusters.iloc[val_test_idxs].reset_index(drop=True)


splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED, test_size=0.5)
val_idxs, test_idxs = next(splitter.split(df_val_test, groups=clusters_val_test))
df_val = df_val_test.loc[val_idxs].reset_index(drop=True)
df_test = df_val_test.loc[test_idxs].reset_index(drop=True)

In [None]:
class ShuffledPairsDataset(Dataset):
    def __init__(self, data, sample_ratio=5):
        self.data = data  # raw data
        self.featurizer = CondensedGraphOfReactionFeaturizer()
        self.pairs = []
        self.sample_ratio = sample_ratio
        self.update_pairs()  # list of (i, j, label)

    def update_pairs(self):
        N = len(self.data)
        pairs = [
            (i, random.randint(0, N-1))
            for i in range(N)
            for _ in range(self.sample_ratio)
        ]

        pairs += [(j, i) for i,j in pairs]
        self.pairs = pairs

    def prepare_datum(self, lidx, ridx):
        left_mol = self.data['mol'][lidx]
        right_mol = self.data['mol'][ridx]
        delta = (
            self.data['per_inhibition'][lidx] > self.data['per_inhibition'][ridx]
        ).astype(float)

        mg = self.featurizer((left_mol, right_mol), None, None)
        rxn_dp = ReactionDatapoint(left_mol, right_mol, np.array([delta]))
        return Datum(mg, None, None, np.array([delta]), rxn_dp.weight, rxn_dp.lt_mask, rxn_dp.gt_mask)
        
    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        i, j = self.pairs[idx]
        return self.prepare_datum(i, j)
    

class ExemplarDataset(Dataset):
    def __init__(self, df_train_exemplars, df_val):
        self.df_train_exemplars = df_train_exemplars.reset_index(drop=True)
        self.df_val = df_val
        self.featurizer = CondensedGraphOfReactionFeaturizer()
        self.pairs = [
            (i, j)
            for i in range(len(self.df_val))
            for j in range(len(self.df_train_exemplars))
        ]

    def prepare_datum(self, lidx, ridx):
        left_mol = self.df_val['mol'][lidx]
        right_mol = self.df_train_exemplars['mol'][ridx]
        delta = (
            self.df_val['per_inhibition'][lidx] > self.df_train_exemplars['per_inhibition'][ridx]
        ).astype(float)

        mg = self.featurizer((left_mol, right_mol), None, None)
        rxn_dp = ReactionDatapoint(left_mol, right_mol, np.array([delta]))
        return Datum(mg, None, None, np.array([delta]), rxn_dp.weight, rxn_dp.lt_mask, rxn_dp.gt_mask)
        
    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        i, j = self.pairs[idx]
        return self.prepare_datum(i, j)

In [None]:
train_ds = ShuffledPairsDataset(df_train, sample_ratio=5)
val_ds = ExemplarDataset(
    df_train.sample(10).reset_index(drop=True),
    df_val
)

In [None]:
class ConstrastiveDataModule(L.LightningDataModule):
    def __init__(self, train_dataset: ShuffledPairsDataset, val_dataset: ExemplarDataset):
        super().__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset

    def setup(self, stage=None):
        pass

    def train_dataloader(self):
        self.train_dataset.update_pairs()
        return build_dataloader(
            self.train_dataset,   # type: ignore
            batch_size=512,
            num_workers=4,
        )
    
    def val_dataloader(self):
        self.val_dataset = ExemplarDataset(
            self.train_dataset.data.sample(10).reset_index(drop=True),
            self.val_dataset.df_val
        )
        return build_dataloader(
            self.val_dataset,   # type: ignore
            batch_size=2048,
            num_workers=4,
            shuffle=False
        )
    
    def predict_dataloader(self):
        # for some reason, pytorch lightning does not like using build_dataloader function here
        # manually creating the dataloader for now
        return DataLoader(
            self.val_dataset,   # type: ignore
            collate_fn=collate_batch,
            batch_size=512,
            num_workers=4,
            shuffle=False,
            drop_last=False
        )

In [None]:
from chemprop import data, featurizers, models, nn

fdims = train_ds.featurizer.shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).
mp = nn.BondMessagePassing(*fdims)
agg = nn.NormAggregation()
ffn = nn.BinaryClassificationFFN(n_tasks=1)
batch_norm = True
metric_list = [nn.metrics.BinaryF1Score(), nn.metrics.BinaryAUPRC(), nn.metrics.BinaryAUROC()]
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)
mpnn.max_lr = 0.01

In [None]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="chemprop_delta_clf", log_model="all")
trainer = L.Trainer(
    logger=wandb_logger,
    enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20,  # number of epochs to train for
    reload_dataloaders_every_n_epochs=1,
    log_every_n_steps=50,
    callbacks=[
        EarlyStopping(monitor="val/f1", mode="max", verbose=True, patience=5),
        ModelCheckpoint(monitor="val/f1", mode="max", save_top_k=2)
    ]
)

contrastive_data_module = ConstrastiveDataModule(train_ds, val_ds)
trainer.fit(mpnn, datamodule=contrastive_data_module)

In [None]:
exemplar_ds = ExemplarDataset(
    df_train.sample(25).reset_index(drop=True),
    df_val
)

exemplar_dl = build_dataloader(
    exemplar_ds,
    batch_size=2048,
    num_workers=4,
    shuffle=False
)

In [None]:
val_ds_preds = trainer.predict(dataloaders=exemplar_dl)
val_ds_preds = torch.cat(val_ds_preds)

In [None]:
from collections import defaultdict

exemplar_ds = contrastive_data_module.val_dataset
asd = defaultdict(list)
for (i, j), delta in zip(exemplar_ds.pairs, val_ds_preds.squeeze()):
    exemplar_val = exemplar_ds.df_train_exemplars['per_inhibition'][j]
    asd[i].append(float(delta.item()))

In [None]:
import statistics

qqq = [
    (k, statistics.mean(v))
    for k, v in asd.items()
]

In [None]:
df_val['pred'] = [x for _, x in qqq]

In [None]:
df_val.to_csv('qqq.csv')

In [None]:
df_val.sort_values(by='per_inhibition')[::-1].to_csv('qqq.csv')

In [None]:
df

In [None]:
node_list = [
    (f"{i}_exemplar", {"type": "exemplar"})
    for i in range(len(exemplar_ds.df_train_exemplars))
]

node_list += [
    (f"{i}_regular", {"type": "regular"})
    for i in range(len(exemplar_ds.df_val))
]

edge_list = []
for (i, j), delta in zip(exemplar_ds.pairs, val_ds_preds.squeeze()):
    if delta.item() > 0.5:
        edge_list.append((
            f"{i}_regular", 
            f"{j}_exemplar", 
            delta.item()
        ))

G = nx.DiGraph()
G.add_nodes_from(node_list)
G.add_weighted_edges_from(edge_list)
nx.write_gexf(G, 'asd.gexf')