In [1]:
import pandas as pd
import rdkit.Chem as Chem
from rdkit.rdBase import BlockLogs
from sklearn.model_selection import GroupShuffleSplit
import numpy as np
from utils import standardize, get_scaffold

from torch.utils.data import Dataset, DataLoader
import random
from chemprop.featurizers.molgraph.reaction import CondensedGraphOfReactionFeaturizer
from chemprop.data.datapoints import ReactionDatapoint
from chemprop.data.datasets import Datum

import lightning as L
from chemprop.data.collate import collate_batch
from chemprop.data.dataloader import build_dataloader

import wandb
import os
from dotenv import load_dotenv
import torch
import networkx as nx

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)


load_dotenv('.env.secret')
wandb.login(key=os.environ['WANDB_API_KEY'])

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/rahul/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mrahul-e-dev[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
def mol_to_inchi(mol):
    with BlockLogs():
        return Chem.MolToInchi(mol)

df = pd.read_csv("./GSK_HepG2.csv")
df = df.iloc[:, 1:]
df.columns = ['smiles', 'per_inhibition']
df['per_inhibition'] = -df['per_inhibition']


# standardize and convert to inchi
df['mol'] = df['smiles'].map(standardize)
df = df.dropna(subset=['mol'])
df['inchi'] = df['mol'].map(mol_to_inchi)
df = df.groupby(["inchi"]).filter(lambda x: len(x) == 1).reset_index(drop=True)

clusters, _ = pd.factorize(
    df['mol']
        .map(Chem.MolToSmiles) # type: ignore
        .map(get_scaffold)
)
clusters = pd.Series(clusters)


df = df.drop(['smiles', 'inchi'], axis=1)

In [3]:
splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED)
train_idxs, val_test_idxs = next(splitter.split(df, groups=clusters))
df_train = df.loc[train_idxs].reset_index(drop=True)
df_val_test = df.loc[val_test_idxs].reset_index(drop=True)
clusters_val_test = clusters.iloc[val_test_idxs].reset_index(drop=True)


splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED, test_size=0.5)
val_idxs, test_idxs = next(splitter.split(df_val_test, groups=clusters_val_test))
df_val = df_val_test.loc[val_idxs].reset_index(drop=True)
df_test = df_val_test.loc[test_idxs].reset_index(drop=True)

In [4]:
class ExemplarDataset(Dataset):
    def __init__(self, df_train_exemplars, df_val):
        self.df_train_exemplars = df_train_exemplars.reset_index(drop=True)
        self.df_val = df_val
        self.featurizer = CondensedGraphOfReactionFeaturizer()
        self.pairs = [
            (i, j)
            for i in range(len(self.df_val))
            for j in range(len(self.df_train_exemplars))
        ]

    def prepare_datum(self, lidx, ridx):
        left_mol = self.df_val['mol'][lidx]
        right_mol = self.df_train_exemplars['mol'][ridx]
        delta = (
            self.df_val['per_inhibition'][lidx] > self.df_train_exemplars['per_inhibition'][ridx]
        ).astype(float)

        mg = self.featurizer((left_mol, right_mol), None, None)
        rxn_dp = ReactionDatapoint(left_mol, right_mol, np.array([delta]))
        return Datum(mg, None, None, np.array([delta]), rxn_dp.weight, rxn_dp.lt_mask, rxn_dp.gt_mask)
        
    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        i, j = self.pairs[idx]
        return self.prepare_datum(i, j)

In [5]:
from lightning.pytorch.loggers import WandbLogger
from chemprop import data, featurizers, models, nn
from pathlib import Path

checkpoint_reference = 'rahul-e-dev/chemprop_delta_clf/model-t8zqdbql:v4'
wandb_logger = WandbLogger(project="chemprop_delta_clf")
artifact_dir = wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")
mpnn = models.MPNN.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")

trainer = L.Trainer(
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
)

[34m[1mwandb[0m:   1 of 1 files downloaded.  
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [34]:
exemplar_ds = ExemplarDataset(
    df_train[df_train['per_inhibition'] > 0].sample(50).reset_index(drop=True),
    df_val
)

exemplar_dl = build_dataloader(
    exemplar_ds,
    batch_size=2048,
    num_workers=4,
    shuffle=False
)

val_ds_preds = trainer.predict(model=mpnn, dataloaders=exemplar_dl)
val_ds_preds = torch.cat(val_ds_preds)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

In [35]:
from collections import defaultdict


asd = defaultdict(list)
for (i, j), delta in zip(exemplar_ds.pairs, val_ds_preds.squeeze()):
    exemplar_val = exemplar_ds.df_train_exemplars['per_inhibition'][j]
    asd[i].append(float(delta.item()))


df_val['diffs'] = asd
df_val['pred'] = df_val['diffs'].map(np.median) > 0.5

In [36]:
df_val['true'] = df_val['per_inhibition'] > -15

In [37]:
from sklearn.metrics import f1_score

f1_score(df_val['true'], df_val['pred'])

0.29036004645760743