In [None]:
import pandas as pd
import rdkit.Chem as Chem
from rdkit.rdBase import BlockLogs
from sklearn.model_selection import GroupShuffleSplit
import numpy as np
from utils import standardize, get_scaffold

from torch.utils.data import Dataset, DataLoader
import random

import lightning as L
from chemprop.data.collate import collate_batch
from chemprop.data.dataloader import build_dataloader

from chemprop import data, featurizers, models, nn

import wandb
import os
from dotenv import load_dotenv
import torch

RANDOM_SEED = 42

def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

set_seeds(RANDOM_SEED)

load_dotenv('.env.secret')
wandb.login(key='cf344975eb80edf6f0d52af80528cc6094234caf')

In [None]:
def mol_to_inchi(mol):
    with BlockLogs():
        return Chem.MolToInchi(mol)

df = pd.read_csv("./GSK_HepG2.csv")
df = df.iloc[:, 1:]
df.columns = ['smiles', 'per_inhibition']
df['per_inhibition'] = -df['per_inhibition']


# standardize and convert to inchi
df['mol'] = df['smiles'].map(standardize)
df = df.dropna(subset=['mol'])
df['inchi'] = df['mol'].map(mol_to_inchi)
df = df.groupby(["inchi"]).filter(lambda x: len(x) == 1).reset_index(drop=True)

clusters, _ = pd.factorize(
    df['mol']
        .map(Chem.MolToSmiles) # type: ignore
        .map(get_scaffold)
)
clusters = pd.Series(clusters)


df = df.drop(['smiles', 'inchi'], axis=1)

In [None]:
splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED)
train_idxs, val_test_idxs = next(splitter.split(df, groups=clusters))
df_train = df.loc[train_idxs].reset_index(drop=True)
df_val_test = df.loc[val_test_idxs].reset_index(drop=True)
clusters_val_test = clusters.iloc[val_test_idxs].reset_index(drop=True)


splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED, test_size=0.5)
val_idxs, test_idxs = next(splitter.split(df_val_test, groups=clusters_val_test))
df_val = df_val_test.loc[val_idxs].reset_index(drop=True)
df_test = df_val_test.loc[test_idxs].reset_index(drop=True)

In [None]:
df_train['true'] = (df_train['per_inhibition'] >= -15).astype(float)
df_val['true'] = (df_val['per_inhibition'] >= -15).astype(float)
df_test['true'] = (df_test['per_inhibition'] >= -15).astype(float)

In [None]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_ds = data.MoleculeDataset([
    data.MoleculeDatapoint(
        df_train['mol'][idx],
        y=np.array([df_train['true'][idx]])
    )
    for idx in range(len(df_train))
], featurizer=featurizer)

val_ds = data.MoleculeDataset([
    data.MoleculeDatapoint(
        df_val['mol'][idx],
        y=np.array([df_val['true'][idx]])
    )
    for idx in range(len(df_val))
], featurizer=featurizer)


test_ds = data.MoleculeDataset([
    data.MoleculeDatapoint(
        df_test['mol'][idx], 
        y=np.array([df_test['true'][idx]])
    )
    for idx in range(len(df_test))
], featurizer=featurizer)

In [None]:
train_loader = data.build_dataloader(train_ds, num_workers=12)
val_loader = data.build_dataloader(val_ds, num_workers=12, shuffle=False)
test_loader = data.build_dataloader(test_ds, num_workers=12, shuffle=False)

In [None]:
fdims = featurizers.SimpleMoleculeMolGraphFeaturizer().shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).
mp = nn.BondMessagePassing()
agg = nn.NormAggregation()
ffn = nn.BinaryClassificationFFN(n_tasks=1)
batch_norm = True
metric_list = [nn.metrics.BinaryF1Score(), nn.metrics.BinaryAUPRC(), nn.metrics.BinaryAUROC()]
mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)
mpnn.max_lr = 0.01

In [None]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger

wandb.finish()
wandb_logger = WandbLogger(project="chemprop_baseline", log_model="all")
wandb_logger.experiment.mark_preempting()

trainer = L.Trainer(
    logger=wandb_logger,
    enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=50,  # number of epochs to train for
    reload_dataloaders_every_n_epochs=1,
    log_every_n_steps=50,
    callbacks=[
        EarlyStopping(monitor="val/prc", mode="max", verbose=True, patience=10),
        ModelCheckpoint(monitor="val/prc", mode="max", save_top_k=2)
    ]
)


trainer.fit(mpnn, train_loader, val_loader)

In [None]:
from pathlib import Path

run_id = wandb_logger.experiment.id
checkpoint_reference = f"rahul-e-dev/chemprop_baseline/model-{run_id}:best"
artifact_dir = wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")


ckpt = torch.load(Path(artifact_dir) / "model.ckpt", map_location='cpu', weights_only=False)
hparams = ckpt.get('hyper_parameters', ckpt.get('hparams', {}))
mpnn.load_state_dict(ckpt['state_dict'])

trainer = L.Trainer(
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
)

test_ds_preds = trainer.predict(model=mpnn, dataloaders=test_loader)
test_ds_preds = torch.cat(test_ds_preds)

In [None]:
df_test['preds'] = (test_ds_preds.squeeze().numpy() >= 0.5).astype(float)

In [None]:
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

wandb_logger.log_table(
    'final_metrics', 
    ['f1', 'precision', 'recall', 'accuracy'],
    [[
        f1_score(df_test['true'], df_test['preds']),
        precision_score(df_test['true'], df_test['preds']),
        recall_score(df_test['true'], df_test['preds']),
        accuracy_score(df_test['true'], df_test['preds'])
    ]]
)

In [None]:
wandb.finish()