In [18]:
import pickle
import re
import numpy as np
import pandas as pd
import chemprop as cp
import torch
from glob import glob
import lightning as L
from tempfile import TemporaryDirectory
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from sklearn.metrics import (
    accuracy_score,
    average_precision_score,
    balanced_accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from tqdm.auto import tqdm
import wandb
import random

RANDOM_SEED = 42
def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
def get_molecule_datapoint(row):
    feat_entry_names = [f for f in row.index if f.startswith('feat')]
    feat_array = pd.to_numeric(row[feat_entry_names], errors="coerce")
    return cp.data.MoleculeDatapoint(
        mol=row['mol'],
        y=np.array([row['per_inhibition'] > 50]),
        x_d=feat_array.to_numpy()
    )


def evaluate_on_split(df_train, df_val, df_test):
    df_train = df_train.copy()
    df_val = df_val.copy()
    df_test = df_test.copy()

    df_train['mol'] = df_train['mol_ser'].map(pickle.loads)
    df_val['mol'] = df_val['mol_ser'].map(pickle.loads)
    df_test['mol'] = df_test['mol_ser'].map(pickle.loads)

    featurizer = cp.featurizers.SimpleMoleculeMolGraphFeaturizer()
    train_mol_dataset = cp.data.MoleculeDataset(df_train.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)
    val_mol_dataset = cp.data.MoleculeDataset(df_val.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)
    test_mol_dataset = cp.data.MoleculeDataset(df_test.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)

    x_d_scaler = train_mol_dataset.normalize_inputs("X_d")
    val_mol_dataset.normalize_inputs("X_d", x_d_scaler)
    test_mol_dataset.normalize_inputs("X_d", x_d_scaler)

    train_mol_dataset.cache = True
    val_mol_dataset.cache = True
    test_mol_dataset.cache = True

    train_loader = cp.data.build_dataloader(train_mol_dataset, batch_size=32, num_workers=8, seed=RANDOM_SEED)
    val_loader = cp.data.build_dataloader(val_mol_dataset, batch_size=32, num_workers=8, shuffle=False)
    test_loader = cp.data.build_dataloader(test_mol_dataset, batch_size=32, num_workers=8, shuffle=False)

    ###############################################################################################

    fdims = cp.featurizers.SimpleMoleculeMolGraphFeaturizer().shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).
    mp = cp.nn.BondMessagePassing()
    agg = cp.nn.NormAggregation()
    ffn_dims = mp.output_dim + len([f for f in df_train.columns if f.startswith("feat")])
    ffn = cp.nn.BinaryClassificationFFN(n_tasks=1, input_dim=ffn_dims)
    batch_norm = True
    metric_list = [cp.nn.metrics.BinaryF1Score(), cp.nn.metrics.BinaryAUPRC(), cp.nn.metrics.BinaryAUROC()]
    X_d_transform = cp.nn.ScaleTransform.from_standard_scaler(x_d_scaler)
    mpnn = cp.models.MPNN(mp, agg, ffn, batch_norm, metric_list, X_d_transform=X_d_transform)

    ################################################################################################

    with TemporaryDirectory() as tmpdir:
        trainer = L.Trainer(
            logger=None,
            enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
            enable_progress_bar=False,
            accelerator="auto",
            devices=1,
            max_epochs=50,  # number of epochs to train for
            default_root_dir=tmpdir,
            callbacks=[
                EarlyStopping(monitor="val_loss", mode="min", verbose=True, patience=10),
                ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1)
            ]
        )

        trainer.fit(mpnn, train_loader, val_loader)

        mpnn = cp.models.MPNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)


    ##################################################################################################
    trainer = L.Trainer(
        enable_progress_bar=False,
        accelerator="auto",
        devices=1,
    )

    test_ds_preds = trainer.predict(model=mpnn, dataloaders=test_loader)
    test_ds_preds = torch.cat(test_ds_preds)

    pred_probs = test_ds_preds.squeeze().numpy()
    preds = (pred_probs >= 0.5).astype(float)
    labels = df_test['per_inhibition'] > 50.0

    return {
        "accuracy": accuracy_score(labels, preds),
        "balanced_accuracy": balanced_accuracy_score(labels, preds),
        "f1_score": f1_score(labels, preds),
        "precision": precision_score(labels, preds),
        "recall": recall_score(labels, preds),
        "roc_auc": roc_auc_score(labels, pred_probs),
        "average_precision": average_precision_score(labels, pred_probs)
    }

In [None]:
run = wandb.init(project="evaluation")
wandb.mark_preempting()


cross_val_results = []
for split_fpath in tqdm(glob("./generated_splits/*.parquet")):
    matches = re.match(".*split_(?P<outer>\\d)x(?P<inner>\\d)", split_fpath)
    assert matches is not None, split_fpath
    matches = matches.groupdict()
    outer_idx, inner_idx = int(matches["outer"]), int(matches["inner"]) 

    total_split_df = pd.read_parquet(split_fpath)
    total_split_df = total_split_df.drop("index", axis=1)

    df_train = total_split_df[total_split_df['split'] == "train"]
    df_val = total_split_df[total_split_df['split'] == "val"]
    df_test = total_split_df[total_split_df['split'] == "test"]

    df_train = df_train.drop("split", axis=1)
    df_val = df_val.drop("split", axis=1)
    df_test = df_test.drop("split", axis=1)

    scores = evaluate_on_split(df_train, df_val, df_test)
    split_result_entry = scores | {"outer": outer_idx, "inner": inner_idx}
    cross_val_results.append(split_result_entry)

    print(f"completed_{outer_idx}x{inner_idx}")
    print('---------------------------------------------------------------')

[34m[1mwandb[0m: Currently logged in as: [33mrahul-e-dev[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/3 [00:00<?, ?it/s]

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/rahul/delta/.venv/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (4) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.

  | Name            | Type                    | Params | Mode 
--------------------------------------------------------------------
0 | message_passing | BondMessagePassing      | 227 K  | train
1 | agg             | NormAggregation         | 0      | train
2 | bn              | BatchNorm1d             | 600    | train
3 | predictor       | BinaryClassificationFFN | 155 K  | train
4 | X_d_transform   | ScaleTransform          | 0      | train
5 | metrics         | ModuleList          

completed_0x4
---------------------------------------------------------------


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/rahul/delta/.venv/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (4) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.

  | Name            | Type                    | Params | Mode 
--------------------------------------------------------------------
0 | message_passing | BondMessagePassing      | 227 K  | train
1 | agg             | NormAggregation         | 0      | train
2 | bn              | BatchNorm1d             | 600    | train
3 | predictor       | BinaryClassificationFFN | 155 K  | train
4 | X_d_transform   | ScaleTransform          | 0      | train
5 | metrics         | ModuleList          

completed_2x3
---------------------------------------------------------------


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/rahul/delta/.venv/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (4) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.

  | Name            | Type                    | Params | Mode 
--------------------------------------------------------------------
0 | message_passing | BondMessagePassing      | 227 K  | train
1 | agg             | NormAggregation         | 0      | train
2 | bn              | BatchNorm1d             | 600    | train
3 | predictor       | BinaryClassificationFFN | 155 K  | train
4 | X_d_transform   | ScaleTransform          | 0      | train
5 | metrics         | ModuleList          

completed_0x3
---------------------------------------------------------------


In [None]:
cross_val_results = pd.DataFrame.from_records(cross_val_results)
cross_val_results["model"] = "baseline"
run.log({"Cross Val Results": wandb.Table(dataframe=cross_val_results)})

mean_scores = cross_val_results.drop(["outer", "inner"], axis=1).groupby("model").agg("mean").reset_index()
mean_scores.columns = [f"mean_{c}" for c in mean_scores.columns]
run.log({"Mean Results": wandb.Table(dataframe=cross_val_results)})

std_scores = cross_val_results.drop(["outer", "inner"], axis=1).groupby("model").agg("std").reset_index()
std_scores.columns = [f"std_{c}" for c in std_scores.columns]
run.log({"Std Results": wandb.Table(dataframe=cross_val_results)})

In [None]:
wandb.finish()