In [1]:
import sys

sys.path.append('../')

import os
import random
import pandas as pd
import lightning as L
import numpy as np
import torch
from chemprop import data, featurizers, models
from chemprop import nn as chem_nn
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from pytorch_lightning.utilities import move_data_to_device
import pandas as pd
import rdkit.Chem as Chem
from rdkit.rdBase import BlockLogs
from sklearn.model_selection import GroupShuffleSplit
from commons.utils import get_scaffold, standardize
from typing import NamedTuple
from itertools import chain

import wandb
# from commons.data import load_and_split_gsk_dataset

RANDOM_SEED = 42

def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seeds(RANDOM_SEED)

# load_dotenv('.env.secret')
wandb.login(key='cf344975eb80edf6f0d52af80528cc6094234caf')

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/rahul/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mrahul-e-dev[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
def mol_to_inchi(mol):
    with BlockLogs():
        return Chem.MolToInchi(mol)


def load_and_split_gsk_dataset(path, RANDOM_SEED):
    df = pd.read_csv(path)
    df = df.iloc[:, 1:]
    df.columns = ["smiles", "per_inhibition"]

    # standardize and convert to inchi
    df["mol"] = df["smiles"].map(standardize)
    df = df.dropna(subset=["mol"])
    df["inchi"] = df["mol"].map(mol_to_inchi)
    df = df.groupby(["inchi"]).filter(lambda x: len(x) == 1).reset_index(drop=True)

    clusters, _ = pd.factorize(
        df["mol"]
        .map(Chem.MolToSmiles)  # type: ignore
        .map(get_scaffold)
    )
    clusters = pd.Series(clusters)

    df = df.drop(["smiles", "inchi"], axis=1)

    splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED)
    train_idxs, val_test_idxs = next(splitter.split(df, groups=clusters))
    df_train = df.loc[train_idxs].reset_index(drop=True)
    df_val_test = df.loc[val_test_idxs].reset_index(drop=True)
    clusters_val_test = clusters.iloc[val_test_idxs].reset_index(drop=True)

    splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED, test_size=0.5)
    val_idxs, test_idxs = next(splitter.split(df_val_test, groups=clusters_val_test))
    df_val = df_val_test.loc[val_idxs].reset_index(drop=True)
    df_test = df_val_test.loc[test_idxs].reset_index(drop=True)

    return df_train, df_val, df_test

In [3]:
df_train, df_val, df_test = load_and_split_gsk_dataset("../GSK_HepG2.csv", RANDOM_SEED)

In [4]:
def mol_to_molecule_datapoint(x):
    return data.MoleculeDatapoint(x['mol'], x['per_inhibition'])

featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_mol_dataset = data.MoleculeDataset(df_train.apply(mol_to_molecule_datapoint, axis=1), featurizer=featurizer)
val_mol_dataset = data.MoleculeDataset(df_val.apply(mol_to_molecule_datapoint, axis=1), featurizer=featurizer)
test_mol_dataset = data.MoleculeDataset(df_test.apply(mol_to_molecule_datapoint, axis=1), featurizer=featurizer)

train_mol_dataset.cache = True
val_mol_dataset.cache = True
test_mol_dataset.cache = True

In [5]:
class RandomPairDataPoint(NamedTuple):
    anchor: data.datasets.Datum
    candidates: list[data.datasets.Datum]


class RandomPairTrainBatch(NamedTuple):
    anchor: data.collate.TrainingBatch
    candidates: data.collate.TrainingBatch
    B: int
    C: int
    

class RandomPairDataset(Dataset):
    def __init__(self, mol_dataset, candidate_size):
        super().__init__()
        self.mol_dataset: data.datasets.MoleculeDataset = mol_dataset
        self.candidate_size: int = candidate_size

    def __len__(self):
        return len(self.mol_dataset)
    
    def get_random_candidates(self):
        targets = self.mol_dataset.Y.squeeze()
        mask = targets > 50
        probs = np.where(mask, 4.0, 1.0)
        probs = probs / probs.sum()
        candidate_idxs = np.random.choice(targets.shape[0], size=(self.candidate_size,), p=probs, replace=False)
        return [self.mol_dataset[idx] for idx in candidate_idxs]

    def __getitem__(self, idx) -> RandomPairDataPoint:
        return RandomPairDataPoint(
            self.mol_dataset[idx], 
            self.get_random_candidates()
        )
    
    @staticmethod
    def collate_function(batch):
        batch_anchors, batch_candidates = zip(*batch)
        B = len(batch)
        C = len(batch_candidates[0])
        batch_anchors = data.dataloader.collate_batch(batch_anchors)
        batch_candidates = data.dataloader.collate_batch(chain.from_iterable(batch_candidates))
        return RandomPairTrainBatch(batch_anchors, batch_candidates, B, C)


    # @staticmethod
    # def collate_function(batch):
    #     batch_left, batch_right = zip(*batch)
    #     batch_left = data.dataloader.collate_batch(batch_left)
    #     batch_right = data.dataloader.collate_batch(batch_right)
    #     return RandomPairTrainBatch(batch_left, batch_right)

In [6]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

class RandomPairDataModule(L.LightningDataModule):
    def __init__(self, mol_ds_train, mol_ds_val) -> None:
        super().__init__()
        self.mol_ds_train: data.MoleculeDataset = mol_ds_train
        self.mol_ds_val: data.MoleculeDataset = mol_ds_val
        self.batch_size=32
        self.candidate_size=16

        self.ds_train = None
        self.ds_val = None

    def setup(self, stage=None):
        self.ds_train = RandomPairDataset(self.mol_ds_train, self.candidate_size)
        self.ds_val = RandomPairDataset(self.mol_ds_val, self.candidate_size)

    def train_dataloader(self):
        assert self.ds_train is not None
        return DataLoader(
            self.ds_train,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=RandomPairDataset.collate_function,
            worker_init_fn=seed_worker,
            num_workers=8,
        )

    def val_dataloader(self):
        assert self.ds_val is not None
        return DataLoader(
            self.ds_val,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=RandomPairDataset.collate_function,
            num_workers=8,
        )

In [7]:
from typing import Any, Iterable
from chemprop.nn import Aggregation, ChempropMetric, MessagePassing, Predictor
from chemprop.nn.transforms import ScaleTransform
import pytorch_lightning as pl


class ContrastiveMPNN(models.MPNN):
    def __init__(
        self,
        message_passing: MessagePassing,
        agg: Aggregation,
        predictor: Predictor,
        batch_norm: bool = False,
        metrics: Iterable[ChempropMetric] | None = None,
        warmup_epochs: int = 2,
        init_lr: float = 0.0001,
        max_lr: float = 0.001,
        final_lr: float = 0.0001,
        X_d_transform: ScaleTransform | None = None,
    ):
        super().__init__(
            message_passing,
            agg,
            predictor,
            batch_norm,
            metrics,
            warmup_epochs,
            init_lr,
            max_lr,
            final_lr,
            X_d_transform,
        )

        # self.global_embs = torch.nn.Embedding(4, 300)
        # self.dh_emb_idx = torch.nn.Parameter(torch.LongTensor([0]), requires_grad=False)
        # self.dr_emb_idx = torch.nn.Parameter(torch.LongTensor([1]), requires_grad=False)
        # self.b_emb_idx = torch.nn.Parameter(torch.LongTensor([2]), requires_grad=False)
        # self.r_emb_idx = torch.nn.Parameter(torch.LongTensor([3]), requires_grad=False)
        # self.global_b_emb = torch.nn.Parameter(torch.rand(1))

        self.R = torch.nn.Linear(300, 300, bias=False)
        
        # self.clf = torch.nn.Sequential(
        #     torch.nn.Linear(600, 300),
        #     torch.nn.ReLU(),
        #     torch.nn.Linear(300, 1),
        #     torch.nn.Sigmoid()
        # )
        self.loss_fn = torch.nn.MSELoss()

    def clf(self, head_emb, tail_emb):
        # print((head_emb @ self.R.weight.unsqueeze(0)).shape)
        z = (head_emb @ self.R.weight.view(1, 300, 300)) @ tail_emb.transpose(-2, -1)
        return z.squeeze()


    def embed_simple_batch(self, batch: data.collate.TrainingBatch):
        bmg, V_d, X_d, target, _, _, _ = batch
        Z = self.encoding(bmg, V_d, X_d)
        return dict(embeds=Z, targets=target)
    

    def get_loss(self, batch: RandomPairTrainBatch, stage: str):
        B, C = batch.B, batch.C

        bmg, V_d, X_d, target_anchor, _, _, _ = batch.anchor
        Z_anchor = self.encoding(bmg, V_d, X_d)

        bmg, V_d, X_d, target_candidates, _, _, _ = batch.candidates
        Z_candidates = self.encoding(bmg, V_d, X_d)

        Z_anchor = Z_anchor.view(B, 1, -1) #.expand(B, C, -1)     # (B, d) -> (B, 1, d) -> (B, C, d)
        Z_candidates = Z_candidates.view(B, C, -1)              # (B*C, d) -> (B, C, d)

        # left to right loss
        preds = self.clf(Z_anchor, Z_candidates).squeeze()
        labels = target_anchor.view(-1, 1) - target_candidates.view(B, C) # type: ignore
        lr_batch_loss = self.loss_fn(preds, labels)
        self.log(f"{stage}_lr_batch_loss", lr_batch_loss, batch_size=batch.B, on_epoch=True, enable_graph=True)

        # right to left loss
        # Z_combined = torch.cat([Z_candidates, Z_anchor], dim=-1)    # (B, C, 2*d)
        preds = self.clf(Z_candidates, Z_anchor).squeeze()
        labels = target_candidates.view(B, C) - target_anchor.view(-1, 1) # type: ignore
        rl_batch_loss = self.loss_fn(preds, labels)
        self.log(f"{stage}_rl_batch_loss", rl_batch_loss, batch_size=batch.B, on_epoch=True, enable_graph=True)


        return lr_batch_loss + rl_batch_loss


    def training_step(self, batch: RandomPairTrainBatch, batch_idx):  # type: ignore
        loss = self.get_loss(batch, 'train')
        self.log("train_loss", loss, batch_size=batch.B, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch: RandomPairTrainBatch, batch_idx):  # type: ignore
        loss = self.get_loss(batch, 'val')
        self.log("val_loss", loss, batch_size=batch.B)
        return loss

In [8]:
fdims = featurizers.SimpleMoleculeMolGraphFeaturizer().shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).
mp = chem_nn.BondMessagePassing()
agg = chem_nn.NormAggregation()
ffn = chem_nn.BinaryClassificationFFN(n_tasks=1)
batch_norm = True
metric_list = [chem_nn.metrics.BinaryF1Score(), chem_nn.metrics.BinaryAUPRC(), chem_nn.metrics.BinaryAUROC()]
contrastive_mpnn = ContrastiveMPNN(mp, agg, ffn, batch_norm, metric_list)

In [9]:
wandb.finish()
wandb_logger = WandbLogger(project="chemprop_delta_reg", log_model="all", save_code=True)
wandb_logger.experiment.mark_preempting()

trainer = L.Trainer(
    logger=wandb_logger,
    enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=50,  # number of epochs to train for
    # reload_dataloaders_every_n_epochs=1,
    log_every_n_steps=50,
    callbacks=[
        EarlyStopping(monitor="val_loss", mode="min", verbose=True, patience=5),
        ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=2)
    ]
)

trainer.fit(contrastive_mpnn, datamodule=RandomPairDataModule(train_mol_dataset, val_mol_dataset))

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                    | Params | Mode 
--------------------------------------------------------------------
0 | message_passing | BondMessagePassing      | 227 K  | train
1 | agg             | NormAggregation         | 0      | train
2 | bn              | BatchNorm1d             | 600    | train
3 | predictor       | BinaryClassificationFFN | 90.6 K | train
4 | X_d_transform   | Identity                | 0      | train
5 | metrics         | ModuleList              | 0      | train
6 | R               | Linear                  | 90.0 K | train
7 | loss_fn         | MSELoss                 | 0      | train
--------------------------------------------------------------------
408 K     Trainable params
0         Non-trainable params
408 K     T

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 2892.801


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 323.244 >= min_delta = 0.0. New best score: 2569.557


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 35.106 >= min_delta = 0.0. New best score: 2534.451


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 118.168 >= min_delta = 0.0. New best score: 2416.283


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 20.165 >= min_delta = 0.0. New best score: 2396.118


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 35.587 >= min_delta = 0.0. New best score: 2360.531


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 33.005 >= min_delta = 0.0. New best score: 2327.526


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 2.545 >= min_delta = 0.0. New best score: 2324.981


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val_loss did not improve in the last 5 records. Best score: 2324.981. Signaling Trainer to stop.


In [11]:
from pathlib import Path

run_id = wandb_logger.experiment.id
checkpoint_reference = f"rahul-e-dev/chemprop_delta_reg/model-{run_id}:best"
artifact_dir = wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")


ckpt = torch.load(Path(artifact_dir) / "model.ckpt", map_location='cpu', weights_only=False)
hparams = ckpt.get('hyper_parameters', ckpt.get('hparams', {}))
contrastive_mpnn.load_state_dict(ckpt['state_dict'])

# trainer = L.Trainer(
#     enable_progress_bar=True,
#     accelerator="auto",
#     devices=1,
# )

[34m[1mwandb[0m:   1 of 1 files downloaded.  


<All keys matched successfully>

In [12]:
contrastive_mpnn = contrastive_mpnn.eval()

In [13]:
from tqdm.auto import tqdm

@torch.no_grad()
def embed_all(mol_dataset: data.datasets.MoleculeDataset, contrastive_mpnn):
    dl = DataLoader(mol_dataset, batch_size=64, shuffle=False, collate_fn=data.dataloader.collate_batch)
    all_embeds = []
    for batch in tqdm(dl, total=len(dl)):
        batch = move_data_to_device(batch, contrastive_mpnn.device)
        res = contrastive_mpnn.embed_simple_batch(batch)
        all_embeds.append(res['embeds'])

    all_embeds = torch.cat(all_embeds)
    return all_embeds

In [14]:
train_embeds = embed_all(train_mol_dataset, contrastive_mpnn)
test_embeds = embed_all(test_mol_dataset, contrastive_mpnn)

  0%|          | 0/162 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

In [15]:
exemplar_idxs = np.argwhere(train_mol_dataset.Y > 50)
exemplar_embeds = train_embeds[exemplar_idxs].squeeze()
exemplar_targets = train_mol_dataset.Y[exemplar_idxs].squeeze()

In [16]:
with torch.no_grad():
    all_preds = contrastive_mpnn.clf(test_embeds, exemplar_embeds).detach().numpy().squeeze()
    # all_preds = all_preds.mean(axis=-1)
    # all_true = test_mol_dataset.Y

In [19]:
all_preds.mean(axis=-1)

array([ -1.1050335, -33.932346 , -22.105888 , ..., -57.95315  ,
       -48.40741  , -51.42194  ], shape=(1282,), dtype=float32)

In [21]:
test_mol_dataset.Y

array([100., 100., 100., ..., -19., -21., -23.], shape=(1282,))

In [None]:
# all_preds = []
# all_true = []


# with torch.no_grad():
#     for idx in range(test_embeds.shape[0]):
#         Z_anchor = test_embeds[idx].view(1, -1).expand_as(exemplar_embeds)
#         # comb = torch.cat([Z_anchor, exemplar_embeds], dim=-1)
#         preds = contrastive_mpnn.clf(Z_anchor, exemplar_embeds).mean()

#         all_preds.append(preds)
#         all_true.append(test_mol_dataset.Y[idx] > 50)

# all_preds = np.array(all_preds)
# all_true = np.array(all_true)

In [None]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, average_precision_score, roc_auc_score, balanced_accuracy_score, recall_score

wandb_logger.log_table(
    'final_metrics', 
    ['accuracy', 'balanced_accuracy', 'f1', 'precision', 'recall', 'AUCROC', 'PRAUC'],
    [[
        accuracy_score(all_true, all_preds > 0.5),
        balanced_accuracy_score(all_true, all_preds > 0.5),
        f1_score(all_true, all_preds > 0.5),
        precision_score(all_true, all_preds > 0.5),
        recall_score(all_true, all_preds > 0.5),
        roc_auc_score(all_true, all_preds),
        average_precision_score(all_true, all_preds)
    ]]
)

In [None]:
wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
train_loss_epoch,█▇▆▆▅▅▄▄▃▃▃▂▂▂▂▁▁▁
train_loss_step,▇█▆▇█▆▆▇▆▆▅▆▃▃▃▄▄▃▁▂▃▃▄▃▄▂▃▃▁▂▄▁▂▄▃▄▂▁▂▁
train_lr_batch_loss_epoch,█▇▆▆▅▅▄▄▃▃▃▂▂▂▂▁▁▁
train_lr_batch_loss_step,█▆▅▆▆▅▆▆▅▆▅▄▇▆▅▅▄▃▅▅▄▃▅▄▄▄▃▄▃▂▅▃▄▁▂▁▂▂▂▃
train_rl_batch_loss_epoch,█▇▆▆▅▅▄▄▃▃▃▂▂▂▂▁▁▁
train_rl_batch_loss_step,█▇▆▆▇▆▆▇▅▄▅▄▅▄▅▅▃▄▃▄▄▄▄▃▄▃▄▅▄▂▅▃▂▂▃▁▂▂▃▃
trainer/global_step,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇█
val_loss,██▅▃▃▃▂▃▃▁▂▁▁▃▃▃▄▅
val_lr_batch_loss,▆█▅▃▃▂▂▃▃▁▂▁▁▃▃▃▄▅

0,1
epoch,17.0
train_loss_epoch,0.83079
train_loss_step,0.96847
train_lr_batch_loss_epoch,0.4143
train_lr_batch_loss_step,0.48113
train_rl_batch_loss_epoch,0.4165
train_rl_batch_loss_step,0.48733
trainer/global_step,5831.0
val_loss,1.13031
val_lr_batch_loss,0.56294


In [None]:
from tqdm.auto import tqdm

@torch.no_grad()
def embed_all(mol_dataset: data.datasets.MoleculeDataset, contrastive_mpnn):
    dl = DataLoader(mol_dataset, batch_size=64, shuffle=False, collate_fn=data.dataloader.collate_batch)
    all_embeds = []
    for batch in tqdm(dl, total=len(dl)):
        batch = move_data_to_device(batch, contrastive_mpnn.device)
        res = contrastive_mpnn.embed_simple_batch(batch)
        all_embeds.append(res['embeds'])

    all_embeds = torch.cat(all_embeds)
    return all_embeds

@torch.no_grad()
def mine_hard_and_rand_negatives_for_anchor(anchor_idx, all_embeds: torch.Tensor, all_targets: torch.Tensor, clf: torch.nn.Linear):
    anchor = all_embeds[anchor_idx]
    anchor = anchor.view(1, -1).expand_as(all_embeds)
    logits = clf(torch.cat([anchor, all_embeds], dim=-1)).squeeze()
    anchor_target = all_targets[anchor_idx]

    negative_mask = torch.logical_xor(logits > 0.5, anchor_target > all_targets)
    logits[~negative_mask] = float("-inf")

    _, hard_neg_idxs = torch.topk(logits, 3, largest=True)

    negative_mask[hard_neg_idxs] = False

    rand_neg_slection_prob = torch.where(negative_mask, 1.0, float("-inf")).softmax(dim=-1)
    rand_neg_idxs = torch.multinomial(rand_neg_slection_prob, 4)
