In [1]:
import sys

sys.path.append('../')

import os
import random
import pandas as pd
import lightning as L
import numpy as np
import torch
from chemprop import data, featurizers, models
from chemprop import nn as chem_nn
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from pytorch_lightning.utilities import move_data_to_device
import pandas as pd
import rdkit.Chem as Chem
from rdkit.rdBase import BlockLogs
from sklearn.model_selection import GroupShuffleSplit
from commons.utils import get_scaffold, standardize
from typing import NamedTuple

import wandb
# from commons.data import load_and_split_gsk_dataset

RANDOM_SEED = 42

def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seeds(RANDOM_SEED)

# load_dotenv('.env.secret')
wandb.login(key='cf344975eb80edf6f0d52af80528cc6094234caf')

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/rahul_e_dev/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mrahul-e-dev[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
def mol_to_inchi(mol):
    with BlockLogs():
        return Chem.MolToInchi(mol)


def load_and_split_gsk_dataset(path, RANDOM_SEED):
    df = pd.read_csv(path)
    df = df.iloc[:, 1:]
    df.columns = ["smiles", "per_inhibition"]

    # standardize and convert to inchi
    df["mol"] = df["smiles"].map(standardize)
    df = df.dropna(subset=["mol"])
    df["inchi"] = df["mol"].map(mol_to_inchi)
    df = df.groupby(["inchi"]).filter(lambda x: len(x) == 1).reset_index(drop=True)

    clusters, _ = pd.factorize(
        df["mol"]
        .map(Chem.MolToSmiles)  # type: ignore
        .map(get_scaffold)
    )
    clusters = pd.Series(clusters)

    df = df.drop(["smiles", "inchi"], axis=1)

    splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED)
    train_idxs, val_test_idxs = next(splitter.split(df, groups=clusters))
    df_train = df.loc[train_idxs].reset_index(drop=True)
    df_val_test = df.loc[val_test_idxs].reset_index(drop=True)
    clusters_val_test = clusters.iloc[val_test_idxs].reset_index(drop=True)

    splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED, test_size=0.5)
    val_idxs, test_idxs = next(splitter.split(df_val_test, groups=clusters_val_test))
    df_val = df_val_test.loc[val_idxs].reset_index(drop=True)
    df_test = df_val_test.loc[test_idxs].reset_index(drop=True)

    return df_train, df_val, df_test

In [None]:
df_train, df_val, df_test = load_and_split_gsk_dataset("../GSK_HepG2.csv", RANDOM_SEED)

In [None]:
def mol_to_molecule_datapoint(x):
    return data.MoleculeDatapoint(x['mol'], x['per_inhibition'])

featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_mol_dataset = data.MoleculeDataset(df_train.apply(mol_to_molecule_datapoint, axis=1), featurizer=featurizer)
val_mol_dataset = data.MoleculeDataset(df_val.apply(mol_to_molecule_datapoint, axis=1), featurizer=featurizer)
test_mol_dataset = data.MoleculeDataset(df_test.apply(mol_to_molecule_datapoint, axis=1), featurizer=featurizer)

train_mol_dataset.cache = True
val_mol_dataset.cache = True
test_mol_dataset.cache = True

In [None]:
class RandomPairDataPoint(NamedTuple):
    left: data.datasets.Datum
    right: data.datasets.Datum


class RandomPairTrainBatch(NamedTuple):
    left: data.collate.TrainingBatch
    right: data.collate.TrainingBatch

    def batch_size(self):
        return self.left.w.shape[0]
    

class RandomPairDataset(Dataset):
    def __init__(self, mol_dataset):
        super().__init__()
        self.mol_dataset = mol_dataset
        self.pairs: list = []

        self.update_pairs()

    def update_pairs(self):
        N = len(self.mol_dataset)

        pairs = [
            (i, random.randint(0, N-1))
            for i in range(N)
        ]
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx) -> RandomPairDataPoint:
        left_idx, right_idx = self.pairs[idx]
        return RandomPairDataPoint(
            self.mol_dataset[left_idx], self.mol_dataset[right_idx]
        )

    @staticmethod
    def collate_function(batch):
        batch_left, batch_right = zip(*batch)
        batch_left = data.dataloader.collate_batch(batch_left)
        batch_right = data.dataloader.collate_batch(batch_right)
        return RandomPairTrainBatch(batch_left, batch_right)

In [None]:
class RandomPairDataModule(L.LightningDataModule):
    def __init__(self, mol_ds_train, mol_ds_val) -> None:
        super().__init__()
        self.mol_ds_train: data.MoleculeDataset = mol_ds_train
        self.mol_ds_val: data.MoleculeDataset = mol_ds_val
        self.batch_size=1024

        self.ds_train = None
        self.ds_val = None

    def setup(self, stage=None):
        self.ds_train = RandomPairDataset(self.mol_ds_train)
        self.ds_val = RandomPairDataset(self.mol_ds_val)

    def train_dataloader(self):
        assert self.ds_train is not None
        return DataLoader(
            self.ds_train,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=RandomPairDataset.collate_function,
            num_workers=12,
        )

    def val_dataloader(self):
        assert self.ds_val is not None
        return DataLoader(
            self.ds_val,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=RandomPairDataset.collate_function,
            num_workers=12,
        )

In [None]:
from typing import Any, Iterable
from chemprop.nn import Aggregation, ChempropMetric, MessagePassing, Predictor
from chemprop.nn.transforms import ScaleTransform
import pytorch_lightning as pl


class ContrastiveMPNN(models.MPNN):
    def __init__(
        self,
        message_passing: MessagePassing,
        agg: Aggregation,
        predictor: Predictor,
        batch_norm: bool = False,
        metrics: Iterable[ChempropMetric] | None = None,
        warmup_epochs: int = 2,
        init_lr: float = 0.0001,
        max_lr: float = 0.001,
        final_lr: float = 0.0001,
        X_d_transform: ScaleTransform | None = None,
    ):
        super().__init__(
            message_passing,
            agg,
            predictor,
            batch_norm,
            metrics,
            warmup_epochs,
            init_lr,
            max_lr,
            final_lr,
            X_d_transform,
        )
        
        self.clf = torch.nn.Sequential(
            torch.nn.Linear(600, 300),
            torch.nn.ReLU(),
            torch.nn.Linear(300, 1),
            torch.nn.Sigmoid()
        )
        self.loss_fn = torch.nn.BCELoss()


    def embed_simple_batch(self, batch: data.collate.TrainingBatch):
        bmg, V_d, X_d, target, _, _, _ = batch
        Z_anchor = self.encoding(bmg, V_d, X_d)
        return dict(embeds=Z_anchor, targets=target)
    

    def get_loss(self, batch: RandomPairTrainBatch):
        bmg, V_d, X_d, target_left, _, _, _ = batch.left
        Z_left = self.encoding(bmg, V_d, X_d)

        bmg, V_d, X_d, target_right, _, _, _ = batch.right
        Z_right = self.encoding(bmg, V_d, X_d)

        Z_combined = torch.cat([Z_left, Z_right], axis=-1) # type: ignore
        logits = self.clf(Z_combined).squeeze()

        mask = target_left > target_right # type: ignore
        labels = torch.where(mask, 1.0, 0.0)
        loss = self.loss_fn(logits, labels)

        inv_labels = torch.where(~mask, 1.0, 0.0)
        inv_loss = self.loss_fn(logits, inv_labels)
        return loss + inv_loss


    def training_step(self, batch: RandomPairTrainBatch, batch_idx):  # type: ignore
        loss = self.get_loss(batch)
        self.log("train_loss", loss, batch_size=batch.batch_size(), prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch: RandomPairTrainBatch, batch_idx):  # type: ignore
        loss = self.get_loss(batch)
        self.log("val_loss", loss, batch_size=batch.batch_size())
        return loss

In [None]:
fdims = featurizers.SimpleMoleculeMolGraphFeaturizer().shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).
mp = chem_nn.BondMessagePassing()
agg = chem_nn.NormAggregation()
ffn = chem_nn.BinaryClassificationFFN(n_tasks=1)
batch_norm = True
metric_list = [chem_nn.metrics.BinaryF1Score(), chem_nn.metrics.BinaryAUPRC(), chem_nn.metrics.BinaryAUROC()]
contrastive_mpnn = ContrastiveMPNN(mp, agg, ffn, batch_norm, metric_list)

In [None]:
wandb.finish()
wandb_logger = WandbLogger(project="chemprop_delta_clf", log_model="all", save_code=True)
wandb_logger.experiment.mark_preempting()

trainer = L.Trainer(
    logger=wandb_logger,
    enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=50,  # number of epochs to train for
    reload_dataloaders_every_n_epochs=1,
    log_every_n_steps=50,
    callbacks=[
        EarlyStopping(monitor="val_loss", mode="min", verbose=True, patience=10),
        ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=2)
    ]
)

trainer.fit(contrastive_mpnn, datamodule=RandomPairDataModule(train_mol_dataset, val_mol_dataset))

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                    | Params | Mode 
--------------------------------------------------------------------
0 | message_passing | BondMessagePassing      | 227 K  | train
1 | agg             | NormAggregation         | 0      | train
2 | bn              | BatchNorm1d             | 600    | train
3 | predictor       | BinaryClassificationFFN | 90.6 K | train
4 | X_d_transform   | Identity                | 0      | train
5 | metrics         | ModuleList              | 0      | train
6 | clf             | Sequential              | 180 K  | train
7 | loss_fn         | BCELoss                 | 0      | train
--------------------------------------------------------------------
499 K     Trainable params
0         Non-trainable params
499 K     T

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 0.683


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.055 >= min_delta = 0.0. New best score: 0.629


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.022 >= min_delta = 0.0. New best score: 0.606


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.016 >= min_delta = 0.0. New best score: 0.590


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.581


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.574


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 0.569


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
wandb.finish()