In [1]:
import sys

sys.path.append('../')

import os
import random

import lightning as L
import numpy as np
from pathlib import Path
import torch
from chemprop import data, featurizers, models, nn
from data import ConstrastiveDataModule, ExemplarDataset
from dotenv import load_dotenv
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader

import wandb
from commons.data import load_and_split_gsk_dataset

RANDOM_SEED = 42

def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

set_seeds(RANDOM_SEED)

load_dotenv('.env.secret')
wandb.login(key='cf344975eb80edf6f0d52af80528cc6094234caf')

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/rahul_e_dev/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mrahul-e-dev[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
df_train, df_val, df_test = load_and_split_gsk_dataset("../GSK_HepG2.csv", RANDOM_SEED)

In [3]:
from chemprop.conf import DEFAULT_HIDDEN_DIM
from chemprop.nn.ffn import MLP
from chemprop.nn.metrics import BCELoss, BinaryAUPRC, ChempropMetric
from chemprop.nn.predictors import Predictor, PredictorRegistry
from chemprop.nn.transforms import UnscaleTransform
from chemprop.utils import Factory
from lightning.pytorch.core.mixins import HyperparametersMixin
from torch import Tensor


@PredictorRegistry.register("ranking")
class RankNetPredictor(Predictor, HyperparametersMixin):

    n_targets = 1
    _T_default_criterion = BCELoss
    _T_default_metric = BinaryAUPRC

    def __init__(
        self,
        n_tasks: int = 1,
        input_dim: int = DEFAULT_HIDDEN_DIM,
        hidden_dim: int = 300,
        n_layers: int = 1,
        dropout: float = 0.0,
        activation: str | torch.nn.Module = "relu",
        criterion: ChempropMetric | None = None,
        task_weights: Tensor | None = None,
        threshold: float | None = None,
        output_transform: UnscaleTransform | None = None,
    ):
        super().__init__()
        # manually add criterion and output_transform to hparams to suppress lightning's warning
        # about double saving their state_dict values.
        ignore_list = ["criterion", "output_transform", "activation"]
        self.save_hyperparameters(ignore=ignore_list)
        self.hparams["criterion"] = criterion
        self.hparams["output_transform"] = output_transform
        self.hparams["activation"] = activation
        self.hparams["cls"] = self.__class__

        self.ffn = MLP.build(
            input_dim, n_tasks * self.n_targets, hidden_dim, n_layers, dropout, activation
        )
        task_weights = torch.ones(n_tasks) if task_weights is None else task_weights
        self.criterion = criterion or Factory.build(
            self._T_default_criterion, task_weights=task_weights, threshold=threshold
        )
        self.output_transform = output_transform if output_transform is not None else torch.nn.Identity()

    @property
    def input_dim(self) -> int:
        return self.ffn.input_dim

    @property
    def output_dim(self) -> int:
        return self.ffn.output_dim

    @property
    def n_tasks(self) -> int:
        return self.output_dim // self.n_targets

    def forward(self, Z: Tensor) -> Tensor:
        # print(Z.shape)
        A, B = torch.split(Z, self.input_dim, dim=-1)
        logit_A = self.ffn(A)
        logit_B = self.ffn(B)
        return (logit_A - logit_B).sigmoid()

    def encode(self, Z: Tensor, i: int) -> Tensor:
        A, B = torch.split(Z, self.input_dim, dim=-1)
        enc_A = self.ffn[:i](A)
        enc_B = self.ffn[:i](B)
        return torch.cat([enc_A, enc_B], dim=-1)
    
    def train_step(self, Z: Tensor) -> Tensor:
        # print(Z.shape)
        A, B = torch.split(Z, self.input_dim, dim=-1)
        logit_A = self.ffn(A)
        logit_B = self.ffn(B)
        return logit_A - logit_B

In [4]:
fdims = featurizers.SimpleMoleculeMolGraphFeaturizer().shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).
mcmp = nn.MulticomponentMessagePassing(
    blocks=[nn.BondMessagePassing(*fdims)],
    n_components=2,
    shared=True
)
agg = nn.NormAggregation()
ffn = RankNetPredictor(n_tasks=1, input_dim=mcmp.output_dim // 2)
batch_norm = True
metric_list = [nn.metrics.BinaryF1Score(), nn.metrics.BinaryAUPRC(), nn.metrics.BinaryAUROC()]
mpnn = models.multi.MulticomponentMPNN(mcmp, agg, ffn, batch_norm, metric_list)
mpnn.max_lr = 0.01

In [None]:
wandb.finish()
wandb_logger = WandbLogger(project="chemprop_delta_clf", log_model="all", save_code=True)
wandb_logger.experiment.mark_preempting()

trainer = L.Trainer(
    logger=wandb_logger,
    enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=50,  # number of epochs to train for
    reload_dataloaders_every_n_epochs=1,
    log_every_n_steps=50,
    callbacks=[
        EarlyStopping(monitor="val_loss", mode="min", verbose=True, patience=10),
        ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=2)
    ]
)

contrastive_data_module = ConstrastiveDataModule(df_train, df_val)
trainer.fit(mpnn, datamodule=contrastive_data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA L4') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                         | Params | Mode 
-------------------------------------------------------------------------
0 | message_passing | MulticomponentMessagePassing | 227 K  | train
1 | agg             | NormAggregation              | 0      | train
2 | bn              | BatchNorm1d                  | 1.2 K  | train
3 | predictor       | RankNetPredictor             | 90.6 K | train
4 | X_d_tr

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 0.558


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
run_id = wandb_logger.experiment.id
checkpoint_reference = f"rahul-e-dev/chemprop_delta_clf/model-{run_id}:best"
artifact_dir = wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")


ckpt = torch.load(Path(artifact_dir) / "model.ckpt", map_location='cpu', weights_only=False)
hparams = ckpt.get('hyper_parameters', ckpt.get('hparams', {}))
mpnn.load_state_dict(ckpt['state_dict'])

trainer = L.Trainer(
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
)

[34m[1mwandb[0m:   1 of 1 files downloaded.  
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
exemplar_df = df_train[df_train['per_inhibition'] >= 50].sample(100).reset_index(drop=True)

exemplar_ds = ExemplarDataset(
    df_test,
    exemplar_df
)

exemplar_dl = DataLoader(
    dataset=exemplar_ds,
    batch_size=2048,
    shuffle=False,
    collate_fn=data.dataloader.collate_multicomponent,
    num_workers=12,
)

test_ds_preds = trainer.predict(model=mpnn, dataloaders=exemplar_dl)
test_ds_preds = torch.cat(test_ds_preds)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/rahul_e_dev/delta/.venv/lib/python3.12/site-packages/lightning/pytorch/core/saving.py:363: Skipping 'metrics' parameter because it is not possible to safely dump to YAML.


Predicting: |          | 0/? [00:00<?, ?it/s]

In [None]:
from collections import defaultdict


def calc(x):
    x = np.array(x)
    return (x>=0.5).sum()


deltas = defaultdict(list)
for (i, j), delta in zip(exemplar_ds.pairs, test_ds_preds.squeeze()):
    deltas[i].append(float(delta.item()))


df_test['deltas'] = deltas
df_test['pred_probs'] = df_test['deltas'].map(calc)
df_test['means'] = df_test['deltas'].map(np.mean)
df_test['std'] = df_test['deltas'].map(np.std)
df_test['range'] = df_test['deltas'].map(lambda x: max(x) - min(x))
df_test['preds'] = df_test['pred_probs'] >= 10
df_test['true'] = df_test['per_inhibition'] >= 50

In [None]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

wandb_logger.log_table(
    'final_metrics', 
    ['f1', 'precision', 'recall', 'accuracy'],
    [[
        f1_score(df_test['true'], df_test['preds']),
        precision_score(df_test['true'], df_test['preds']),
        recall_score(df_test['true'], df_test['preds']),
        accuracy_score(df_test['true'], df_test['preds'])
    ]]
)

In [None]:
wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▂▂▂▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇██
train_loss_epoch,█▆▅▄▃▂▂▂▁▁▁
train_loss_step,██▇▆▆▆▅▅▅▅▄▃▄▃▃▃▃▃▂▃▃▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇█████
val/f1,▂▁▆▆▃▂██▆▇█
val/prc,▄▁▄▆▃▁▆▆▅▅█
val/roc,▄▁▅▆▃▁▇▇▅▆█
val_loss,▁▁▂▃▄▅▅▇▇██

0,1
epoch,10.0
train_loss_epoch,0.12198
train_loss_step,0.11191
trainer/global_step,2232.0
val/f1,0.7388
val/prc,0.81523
val/roc,0.8142
val_loss,1.2989


In [None]:
[[
        f1_score(df_test['true'], df_test['preds']),
        precision_score(df_test['true'], df_test['preds']),
        recall_score(df_test['true'], df_test['preds']),
        accuracy_score(df_test['true'], df_test['preds'])
    ]]

[[0.402088772845953, 0.27208480565371024, 0.77, 0.6427457098283932]]