In [1]:
import sys

sys.path.append("../")

import random
from itertools import chain
from typing import NamedTuple

import lightning as L
import numpy as np
import pandas as pd
import rdkit.Chem as Chem
import torch
import wandb
from chemprop import data, featurizers, models
from chemprop import nn as chem_nn
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from pytorch_lightning.utilities import move_data_to_device
from rdkit.Chem.Descriptors import CalcMolDescriptors
from rdkit.rdBase import BlockLogs
from sklearn.model_selection import GroupShuffleSplit
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

from commons.utils import get_scaffold, standardize

# from commons.data import load_and_split_gsk_dataset

RANDOM_SEED = 42


def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seeds(RANDOM_SEED)

# load_dotenv('.env.secret')
wandb.login(key='cf344975eb80edf6f0d52af80528cc6094234caf')

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/rahul/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mrahul-e-dev[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
def mol_to_inchi(mol):
    with BlockLogs():
        return Chem.MolToInchi(mol)
    

def generate_features(df):
    with BlockLogs():
        feats = pd.DataFrame.from_records(df["mol"].map(CalcMolDescriptors).tolist())
        feats.columns = [f"feat_{f}" for f in feats.columns]
        df = pd.concat(
            [
                df.reset_index(drop=True),
                feats,
            ],
            axis=1,
        )

    return df


def load_and_split_gsk_dataset(path, RANDOM_SEED):
    df = pd.read_csv(path)
    df = df.iloc[:, 1:]
    df.columns = ["smiles", "per_inhibition"]

    # standardize and convert to inchi
    df["mol"] = df["smiles"].map(standardize)
    df = df.dropna(subset=["mol"])
    df["inchi"] = df["mol"].map(mol_to_inchi)
    df = df.groupby(["inchi"]).filter(lambda x: len(x) == 1).reset_index(drop=True)

    df["is_cytotoxic"] = df["per_inhibition"] > 50.0

    df = generate_features(df)

    clusters, _ = pd.factorize(
        df["mol"]
        .map(Chem.MolToSmiles)  # type: ignore
        .map(get_scaffold)
    )
    clusters = pd.Series(clusters)

    df = df.drop(["smiles", "inchi"], axis=1)

    splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED)
    train_idxs, val_test_idxs = next(splitter.split(df, groups=clusters))
    df_train = df.loc[train_idxs].reset_index(drop=True)
    df_val_test = df.loc[val_test_idxs].reset_index(drop=True)
    clusters_val_test = clusters.iloc[val_test_idxs].reset_index(drop=True)

    splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED, test_size=0.5)
    val_idxs, test_idxs = next(splitter.split(df_val_test, groups=clusters_val_test))
    df_val = df_val_test.loc[val_idxs].reset_index(drop=True)
    df_test = df_val_test.loc[test_idxs].reset_index(drop=True)

    return df_train, df_val, df_test

In [3]:
df_train, df_val, df_test = load_and_split_gsk_dataset("../GSK_HepG2.csv", RANDOM_SEED)

In [4]:
def get_molecule_datapoint(row):
    feat_entry_names = [f for f in row.index if f.startswith('feat')]
    feat_array = pd.to_numeric(row[feat_entry_names], errors="coerce")
    return data.MoleculeDatapoint(
        mol=row['mol'], 
        y=np.array([row['is_cytotoxic']]),
        x_d=feat_array.to_numpy()
    )

featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_mol_dataset = data.MoleculeDataset(df_train.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)
val_mol_dataset = data.MoleculeDataset(df_val.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)
test_mol_dataset = data.MoleculeDataset(df_test.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)

x_d_scaler = train_mol_dataset.normalize_inputs("X_d")
val_mol_dataset.normalize_inputs("X_d", x_d_scaler)
test_mol_dataset.normalize_inputs("X_d", x_d_scaler)

train_mol_dataset.cache = True
val_mol_dataset.cache = True
test_mol_dataset.cache = True

In [None]:
train_loader = data.build_dataloader(train_mol_dataset, num_workers=8)
val_loader = data.build_dataloader(val_mol_dataset, num_workers=8, shuffle=False)
test_loader = data.build_dataloader(test_mol_dataset, num_workers=8, shuffle=False)

In [20]:
fdims = featurizers.SimpleMoleculeMolGraphFeaturizer().shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).
mp = chem_nn.BondMessagePassing()
agg = chem_nn.NormAggregation()
ffn_dims = mp.output_dim + len([f for f in df_train.columns if f.startswith("feat")])
ffn = chem_nn.BinaryClassificationFFN(n_tasks=1, input_dim=ffn_dims)
batch_norm = True
metric_list = [chem_nn.metrics.BinaryF1Score(), chem_nn.metrics.BinaryAUPRC(), chem_nn.metrics.BinaryAUROC()]
X_d_transform = chem_nn.ScaleTransform.from_standard_scaler(x_d_scaler)
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list, X_d_transform=X_d_transform)
# mpnn.max_lr = 0.01

In [None]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

wandb.finish()
wandb_logger = WandbLogger(project="chemprop_baseline", log_model="all", save_code=True)
wandb_logger.experiment.mark_preempting()

trainer = L.Trainer(
    logger=wandb_logger,
    enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=50,  # number of epochs to train for
    log_every_n_steps=50,
    callbacks=[
        EarlyStopping(monitor="val_loss", mode="min", verbose=True, patience=10),
        ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=2)
    ]
)

trainer.fit(mpnn, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                    | Params | Mode 
--------------------------------------------------------------------
0 | message_passing | BondMessagePassing      | 227 K  | train
1 | agg             | NormAggregation         | 0      | train
2 | bn              | BatchNorm1d             | 600    | train
3 | predictor       | BinaryClassificationFFN | 155 K  | train
4 | X_d_transform   | ScaleTransform          | 0      | train
5 | metrics         | ModuleList              | 0      | train
--------------------------------------------------------------------
384 K     Trainable params
0         Non-trainable params
384 K     Total params
1.536     Total estimated model params size (MB)
26        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/prc improved. New best score: 0.542


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/prc improved by 0.034 >= min_delta = 0.0. New best score: 0.576


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/prc improved by 0.071 >= min_delta = 0.0. New best score: 0.647


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/prc improved by 0.033 >= min_delta = 0.0. New best score: 0.680


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/prc improved by 0.009 >= min_delta = 0.0. New best score: 0.690


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val/prc did not improve in the last 10 records. Best score: 0.690. Signaling Trainer to stop.


In [22]:
from pathlib import Path

run_id = wandb_logger.experiment.id
checkpoint_reference = f"rahul-e-dev/chemprop_baseline/model-{run_id}:best"
artifact_dir = wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")


ckpt = torch.load(Path(artifact_dir) / "model.ckpt", map_location='cpu', weights_only=False)
hparams = ckpt.get('hyper_parameters', ckpt.get('hparams', {}))
mpnn.load_state_dict(ckpt['state_dict'])

trainer = L.Trainer(
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
)

test_ds_preds = trainer.predict(model=mpnn, dataloaders=test_loader)
test_ds_preds = torch.cat(test_ds_preds)

[34m[1mwandb[0m:   1 of 1 files downloaded.  
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/rahul/delta/.venv/lib/python3.12/site-packages/lightning/pytorch/core/saving.py:363: Skipping 'metrics' parameter because it is not possible to safely dump to YAML.


Predicting: |          | 0/? [00:00<?, ?it/s]

In [23]:
df_test['preds'] = (test_ds_preds.squeeze().numpy() >= 0.5).astype(float)
df_test['pred_probs'] = test_ds_preds.squeeze().numpy()

In [24]:
from sklearn.metrics import (
    accuracy_score,
    average_precision_score,
    balanced_accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)

wandb_logger.log_table(
    'final_metrics', 
    ['accuracy', 'balanced_accuracy', 'f1', 'precision', 'recall', 'AUCROC', 'PRAUC'],
    [[
        accuracy_score(df_test["is_cytotoxic"], df_test["preds"]),
        balanced_accuracy_score(df_test["is_cytotoxic"], df_test["preds"]),
        f1_score(df_test["is_cytotoxic"], df_test["preds"]),
        precision_score(df_test["is_cytotoxic"], df_test["preds"]),
        recall_score(df_test["is_cytotoxic"], df_test["preds"]),
        roc_auc_score(df_test["is_cytotoxic"], df_test["pred_probs"]),
        average_precision_score(df_test["is_cytotoxic"], df_test["pred_probs"])
    ]]
)

In [25]:
wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇█████
train_loss_epoch,█▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train_loss_step,▇▅▆█▄▅▄▄▅▄▃▄▅▄▁▅▆▄▄▆▅▄▃▄▂▄▄▃▂▂▄▂▃▄▂▃▄▂▁▃
trainer/global_step,▁▁▂▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇██
val/f1,▄▁▆▂▇▆▆▇▆▆▇▄██▇▅▆▇█▆█▇█
val/prc,▁▃▆▄█▆▆▅▇▆▇▅█▇▆▆▆▆█▆▆▆▇
val/roc,▁▆▆▇▇▇▇▇█▇█▇████▇▇█████
val_loss,▂▁▁█▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,22.0
train_loss_epoch,0.16469
train_loss_step,0.23787
trainer/global_step,3725.0
val/f1,0.58873
val/prc,0.67029
val/roc,0.85098
val_loss,0.36246
