In [1]:
import sys

sys.path.append('../')

import os
import random
import pandas as pd
import lightning as L
import numpy as np
import torch
from chemprop import data, featurizers, models
from chemprop import nn as chem_nn
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from pytorch_lightning.utilities import move_data_to_device
import pandas as pd
import rdkit.Chem as Chem
from rdkit.Chem.Descriptors import CalcMolDescriptors
from rdkit.rdBase import BlockLogs
from sklearn.model_selection import GroupShuffleSplit
from commons.utils import get_scaffold, standardize
from typing import NamedTuple
from itertools import chain
from delta_data import RandomPairDataModule
import chemprop as cp
from delta_model import DeltaProp, Encoder, Interaction
from ray import tune, train
import ray

import wandb
# from commons.data import load_and_split_gsk_dataset

RANDOM_SEED = 42

def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seeds(RANDOM_SEED)

# load_dotenv('.env.secret')
wandb.login(key='cf344975eb80edf6f0d52af80528cc6094234caf')

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mrahul-e-dev[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
def mol_to_inchi(mol):
    with BlockLogs():
        return Chem.MolToInchi(mol)
    

def generate_features(df):
    with BlockLogs():
        feats = pd.DataFrame.from_records(df["mol"].map(CalcMolDescriptors).tolist())
        feats.columns = [f"feat_{f}" for f in feats.columns]
        df = pd.concat(
            [
                df.reset_index(drop=True),
                feats,
            ],
            axis=1,
        )

    return df


def load_and_split_gsk_dataset(path, RANDOM_SEED):
    df = pd.read_csv(path)
    df = df.iloc[:, 1:]
    df.columns = ["smiles", "per_inhibition"]

    # standardize and convert to inchi
    df["mol"] = df["smiles"].map(standardize)
    df = df.dropna(subset=["mol"])
    df["inchi"] = df["mol"].map(mol_to_inchi)
    df = df.groupby(["inchi"]).filter(lambda x: len(x) == 1).reset_index(drop=True)

    df["is_cytotoxic"] = df["per_inhibition"] > 50.0

    df = generate_features(df)

    clusters, _ = pd.factorize(
        df["mol"]
        .map(Chem.MolToSmiles)  # type: ignore
        .map(get_scaffold)
    )
    clusters = pd.Series(clusters)

    df = df.drop(["smiles", "inchi"], axis=1)

    splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED)
    train_idxs, val_test_idxs = next(splitter.split(df, groups=clusters))
    df_train = df.loc[train_idxs].reset_index(drop=True)
    df_val_test = df.loc[val_test_idxs].reset_index(drop=True)
    clusters_val_test = clusters.iloc[val_test_idxs].reset_index(drop=True)

    splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED, test_size=0.5)
    val_idxs, test_idxs = next(splitter.split(df_val_test, groups=clusters_val_test))
    df_val = df_val_test.loc[val_idxs].reset_index(drop=True)
    df_test = df_val_test.loc[test_idxs].reset_index(drop=True)

    return df_train, df_val, df_test

In [3]:
df_train, df_val, df_test = load_and_split_gsk_dataset("../GSK_HepG2.csv", RANDOM_SEED)

In [4]:
def get_molecule_datapoint(row):
    feat_entry_names = [f for f in row.index if f.startswith('feat')]
    feat_array = pd.to_numeric(row[feat_entry_names], errors="coerce")
    return cp.data.MoleculeDatapoint(
        mol=row['mol'],
        y=np.array([row['per_inhibition'] > 50]),
        x_d=feat_array.to_numpy()
    )

featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_mol_dataset = data.MoleculeDataset(df_train.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)
val_mol_dataset = data.MoleculeDataset(df_val.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)
test_mol_dataset = data.MoleculeDataset(df_test.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)

x_d_scaler = train_mol_dataset.normalize_inputs("X_d")
val_mol_dataset.normalize_inputs("X_d", x_d_scaler)
test_mol_dataset.normalize_inputs("X_d", x_d_scaler)

train_mol_dataset.cache = True
val_mol_dataset.cache = True
test_mol_dataset.cache = True

In [5]:
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback

def tune_func(config, train_mol_dataset, val_mol_dataset):
    depth = config["depth"]
    ffn_hidden_dim = config["ffn_hidden_dim"]
    ffn_num_layers = config["ffn_num_layers"]
    message_hidden_dim = config["message_hidden_dim"]
    batch_norm = config['batch_norm']

    train_loader = cp.data.build_dataloader(train_mol_dataset, batch_size=32, num_workers=2, seed=RANDOM_SEED)
    val_loader = cp.data.build_dataloader(val_mol_dataset, batch_size=32, num_workers=2, shuffle=False)

    ###############################################################################################

    mp = cp.nn.BondMessagePassing(d_h=message_hidden_dim, depth=depth)
    agg = cp.nn.NormAggregation()
    ffn_dims = mp.output_dim + train_mol_dataset.X_d.shape[-1]
    ffn = cp.nn.BinaryClassificationFFN(n_tasks=1, input_dim=ffn_dims, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers)
    metric_list = [cp.nn.metrics.BinaryF1Score(), cp.nn.metrics.BinaryAUPRC(), cp.nn.metrics.BinaryAUROC()]
    X_d_transform = cp.nn.ScaleTransform.from_standard_scaler(x_d_scaler)
    mpnn = cp.models.MPNN(mp, agg, ffn, batch_norm, metric_list, X_d_transform=X_d_transform)

    ################################################################################################
    trainer = L.Trainer(
        logger=None,
        enable_checkpointing=True,
        enable_progress_bar=False,
        accelerator="auto",
        devices=1,
        max_epochs=20,
        callbacks=[
            EarlyStopping(monitor="val_loss", mode="min", verbose=False, patience=10),
            TuneReportCheckpointCallback()
        ],
    )

    trainer.fit(mpnn, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [6]:
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch
from ray.tune.search import ConcurrencyLimiter


search_space = {
    "depth": tune.qrandint(lower=2, upper=6, q=1),
    "ffn_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
    "ffn_num_layers": tune.qrandint(lower=1, upper=3, q=1),
    "message_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
    "batch_norm": tune.choice([True, False])
}

search_alg = ConcurrencyLimiter(OptunaSearch(seed=42), max_concurrent=8)
scheduler = ASHAScheduler(max_t=20, grace_period=1, reduction_factor=2)

tune_fn = tune.with_resources(
    tune.with_parameters(
        tune_func, 
        train_mol_dataset=train_mol_dataset, 
        val_mol_dataset=val_mol_dataset
    ),
    resources={"CPU": 4, "GPU": 0.25}
)

# Checkpoint config controls the checkpointing behavior of Ray
checkpoint_config = tune.CheckpointConfig(
    num_to_keep=1, # number of checkpoints to keep
    checkpoint_score_attribute="val_loss", # Save the checkpoint based on this metric
    checkpoint_score_order="min", # Save the checkpoint with the lowest metric value
)

tuner = tune.Tuner(
    tune_fn,
    param_space=search_space,
    tune_config=tune.TuneConfig(
        metric="val_loss",
        mode="min",
        num_samples=20,
        scheduler=scheduler,
        search_alg=search_alg,
    ),
    run_config=tune.RunConfig(
        checkpoint_config=tune.CheckpointConfig(
            num_to_keep=1,
            checkpoint_score_attribute="val_loss",
            checkpoint_score_order="min",
        ),
        failure_config = train.FailureConfig(max_failures=3)
    ),
)

results = tuner.fit()
_, best_result = results.get_best_result().best_checkpoints[0]
best_config = best_result['config']

0,1
Current time:,2025-10-25 03:24:50
Running for:,00:03:05.12
Memory:,30.8/503.7 GiB

Trial name,status,loc,batch_norm,depth,ffn_hidden_dim,ffn_num_layers,message_hidden_dim,iter,total time (s),train_loss,train_loss_step,val/f1
tune_func_d41379f8,TERMINATED,172.17.0.3:38201,True,3,2300,3,1600,18,176.255,0.211957,0.211957,0.543897
tune_func_ba5301c3,TERMINATED,172.17.0.3:38411,False,2,2200,2,1800,8,80.5885,0.333499,0.333499,0.563319
tune_func_cdc3619d,TERMINATED,172.17.0.3:39364,False,6,700,1,700,15,123.208,0.0496194,0.0496194,0.577406
tune_func_13857fcf,TERMINATED,172.17.0.3:39534,False,4,900,2,600,8,68.4146,0.344866,0.344866,0.570806
tune_func_b7152dc2,TERMINATED,172.17.0.3:39705,True,4,2000,1,1400,1,8.86239,0.418729,0.418729,0.177632
tune_func_ccbdfd4b,TERMINATED,172.17.0.3:39886,True,5,600,1,2300,1,15.0917,0.377505,0.377505,0.188925
tune_func_3730a31c,TERMINATED,172.17.0.3:40083,False,3,500,3,1200,8,54.7908,0.453117,0.453117,0.561247
tune_func_d13ecb2a,TERMINATED,172.17.0.3:40450,False,2,2300,1,1700,4,30.4582,0.206284,0.206284,0.579909
tune_func_84ebe927,TERMINATED,172.17.0.3:40725,True,4,700,3,2000,1,15.3689,0.416988,0.416988,0.0143885
tune_func_717630f7,TERMINATED,172.17.0.3:41252,False,4,2300,1,700,4,24.9933,0.202179,0.202179,0.581236


[36m(tune_func pid=38201)[0m 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
[36m(tune_func pid=38201)[0m GPU available: True (cuda), used: True
[36m(tune_func pid=38201)[0m TPU available: False, using: 0 TPU cores
[36m(tune_func pid=38201)[0m HPU available: False, using: 0 HPUs
[36m(tune_func pid=38201)[0m You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[36m(tune_func pid=38201)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[36m(tune_func pid=38201)[0m Loading `train_dataloader` to estimate number of steppin

In [7]:
best_result

{'train_loss': 0.155194491147995,
 'train_loss_step': 0.155194491147995,
 'val/f1': 0.5219638347625732,
 'val/prc': 0.6764230132102966,
 'val/roc': 0.826703667640686,
 'val_loss': 0.3367610573768616,
 'train_loss_epoch': 0.36423495411872864,
 'timestamp': 1761362671,
 'checkpoint_dir_name': 'checkpoint_000001',
 'should_checkpoint': True,
 'done': False,
 'training_iteration': 2,
 'trial_id': '32f06693',
 'date': '2025-10-25_03-24-31',
 'time_this_iter_s': 19.863657474517822,
 'time_total_s': 40.803967237472534,
 'pid': 42459,
 'hostname': '9a9bf6b9023e',
 'node_ip': '172.17.0.3',
 'config': {'depth': 6,
  'ffn_hidden_dim': 1500,
  'ffn_num_layers': 2,
  'message_hidden_dim': 1600,
  'batch_norm': False},
 'time_since_restore': 40.803967237472534,
 'iterations_since_restore': 2}

In [8]:
depth = best_config["depth"]
ffn_hidden_dim = best_config["ffn_hidden_dim"]
ffn_num_layers = best_config["ffn_num_layers"]
message_hidden_dim = best_config["message_hidden_dim"]
batch_norm = best_config['batch_norm']

train_loader = cp.data.build_dataloader(train_mol_dataset, batch_size=32, num_workers=1, seed=RANDOM_SEED)
val_loader = cp.data.build_dataloader(val_mol_dataset, batch_size=32, num_workers=1, shuffle=False)
test_loader = cp.data.build_dataloader(test_mol_dataset, batch_size=32, num_workers=1, shuffle=False)

###############################################################################################

mp = cp.nn.BondMessagePassing(d_h=message_hidden_dim, depth=depth)
agg = cp.nn.NormAggregation()
ffn_dims = mp.output_dim + train_mol_dataset.X_d.shape[-1]
ffn = cp.nn.BinaryClassificationFFN(n_tasks=1, input_dim=ffn_dims, hidden_dim=ffn_hidden_dim, n_layers=ffn_num_layers)
metric_list = [cp.nn.metrics.BinaryF1Score(), cp.nn.metrics.BinaryAUPRC(), cp.nn.metrics.BinaryAUROC()]
X_d_transform = cp.nn.ScaleTransform.from_standard_scaler(x_d_scaler)
mpnn = cp.models.MPNN(mp, agg, ffn, batch_norm, metric_list, X_d_transform=X_d_transform)

################################################################################################
trainer = L.Trainer(
    logger=None,
    enable_checkpointing=True,
    enable_progress_bar=False,
    accelerator="auto",
    devices=1,
    max_epochs=20,
    callbacks=[
        EarlyStopping(monitor="val_loss", mode="min", verbose=False, patience=10),
        ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1)
    ],
)

trainer.fit(mpnn, train_dataloaders=train_loader, val_dataloaders=val_loader)
mpnn = cp.models.MPNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

##################################################################################################

run = wandb.init(project="evaluation")
wandb.mark_preempting()


trainer = L.Trainer(
    enable_progress_bar=False,
    accelerator="auto",
    devices=1,
)

test_ds_preds = trainer.predict(model=mpnn, dataloaders=test_loader)
test_ds_preds = torch.cat(test_ds_preds)

pred_probs = test_ds_preds.squeeze().numpy()
preds = (pred_probs >= 0.5).astype(float)
labels = df_test['per_inhibition'] > 50.0

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.
/root/delta/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.

  | Name            | Type                    | Params | Mode 
-----------------------------------

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/root/delta/.venv/lib/python3.12/site-packages/lightning/pytorch/core/saving.py:363: Skipping 'metrics' parameter because it is not possible to safely dump to YAML.
/root/delta/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


In [11]:
from sklearn.metrics import (
    accuracy_score,
    average_precision_score,
    balanced_accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)


run.log({
    'final_metrics': wandb.Table(
        columns=['accuracy', 'balanced_accuracy', 'f1', 'precision', 'recall', 'AUCROC', 'PRAUC'],
        data=[[
            accuracy_score(labels, preds),
            balanced_accuracy_score(labels, preds),
            f1_score(labels, preds),
            precision_score(labels, preds),
            recall_score(labels, preds),
            roc_auc_score(labels, pred_probs),
            average_precision_score(labels, pred_probs)
        ]]
    )
})


wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.
