In [None]:
import sys
sys.path.append('../')

import random
import pandas as pd
import lightning as L
import numpy as np
import torch
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
import rdkit.Chem as Chem
from rdkit.Chem.Descriptors import CalcMolDescriptors
from rdkit.rdBase import BlockLogs
from sklearn.model_selection import GroupShuffleSplit
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
import chemprop as cp
from delta_model import DeltaProp, Encoder, Interaction
from delta_data import RandomPairDataModule
from ray import tune, train
import ray
import pickle
import wandb

RANDOM_SEED = 42

def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seeds(RANDOM_SEED)

# load_dotenv('.env.secret')
wandb.login(key='cf344975eb80edf6f0d52af80528cc6094234caf')

In [None]:
def standardize(smiles):
    with BlockLogs():
        params = Chem.SmilesParserParams()
        params.removeHs = False
        mol = Chem.MolFromSmiles(smiles, params)  # type: ignore
        if mol is None:
            return None
        
        clean_mol = rdMolStandardize.Cleanup(mol)
        parent_clean_mol = rdMolStandardize.FragmentParent(clean_mol)
        uncharger = rdMolStandardize.Uncharger()
        uncharged_parent_clean_mol = uncharger.uncharge(parent_clean_mol)
        return pickle.dumps(uncharged_parent_clean_mol)

def mol_to_inchi(mol):
    with BlockLogs():
        return Chem.MolToInchi(mol)

def generate_features(mol):
    with BlockLogs():
        return {f"feat_{k}": v for k, v in CalcMolDescriptors(mol).items()}
    
def get_scaffold(mol) -> str:
    smi = Chem.MolToSmiles(mol)
    scaffold = MurckoScaffoldSmiles(smi)
    if len(scaffold) == 0:
        scaffold = smi
    return scaffold
    

df = pd.read_csv("../GSK_HepG2.csv").sample(7000)
df = df.iloc[:, 1:]
df.columns = ["smiles", "per_inhibition"]

df = (
    (
        ray.data.from_pandas(df.reset_index(), override_num_blocks=len(df) // 64)
        .map(lambda row: row | {"mol_ser": standardize(row["smiles"])})
        .filter(lambda row: row["mol_ser"] is not None)
        .map(lambda row: row | {"inchi": mol_to_inchi(pickle.loads(row["mol_ser"]))})
        .map(lambda row: row | generate_features(pickle.loads(row["mol_ser"])))
        .map(lambda row: row | {"scaffold": get_scaffold(pickle.loads(row["mol_ser"]))})
    )
    .materialize()
    .to_pandas()
)

df = df.groupby(["inchi"]).filter(lambda x: len(x) == 1).reset_index(drop=True)

clusters, _ = pd.factorize(df["scaffold"])
clusters = pd.Series(clusters)

df['mol'] = df['mol_ser'].map(pickle.loads)
df = df.drop(["smiles", "inchi", "scaffold", "mol_ser"], axis=1)

splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED)
train_idxs, val_test_idxs = next(splitter.split(df, groups=clusters))
df_train = df.loc[train_idxs].reset_index(drop=True)
df_val_test = df.loc[val_test_idxs].reset_index(drop=True)
clusters_val_test = clusters.iloc[val_test_idxs].reset_index(drop=True)

splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED, test_size=0.5)
val_idxs, test_idxs = next(splitter.split(df_val_test, groups=clusters_val_test))
df_val = df_val_test.loc[val_idxs].reset_index(drop=True)
df_test = df_val_test.loc[test_idxs].reset_index(drop=True)

In [None]:
def get_molecule_datapoint(row):
    feat_entry_names = [f for f in row.index if f.startswith('feat')]
    feat_array = pd.to_numeric(row[feat_entry_names], errors="coerce")
    return cp.data.MoleculeDatapoint(
        mol=row['mol'], 
        y=np.array([row['per_inhibition']]),
        x_d=feat_array.to_numpy()
    )

featurizer = cp.featurizers.SimpleMoleculeMolGraphFeaturizer()
train_mol_dataset = cp.data.MoleculeDataset(df_train.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)
val_mol_dataset = cp.data.MoleculeDataset(df_val.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)
test_mol_dataset = cp.data.MoleculeDataset(df_test.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)

x_d_scaler = train_mol_dataset.normalize_inputs("X_d")
val_mol_dataset.normalize_inputs("X_d", x_d_scaler)
test_mol_dataset.normalize_inputs("X_d", x_d_scaler)

train_mol_dataset.cache = True
val_mol_dataset.cache = True
test_mol_dataset.cache = True

In [None]:
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback

def tune_func(config, train_mol_dataset, val_mol_dataset):
    depth = config["depth"]
    ffn_hidden_dim = config["ffn_hidden_dim"]
    ffn_num_layers = config["ffn_num_layers"]
    message_hidden_dim = config["message_hidden_dim"]
    batch_norm = config['batch_norm']
    encoder_dropout = config["encoder_dropout"]
    interaction_dropout = config["interaction_dropout"]

    ###############################################################################################

    mp = cp.nn.BondMessagePassing(d_h=message_hidden_dim, depth=depth)
    agg = cp.nn.NormAggregation()
    ffn_dims = mp.output_dim + train_mol_dataset.X_d.shape[-1]
    encoder = Encoder(
        input_dim=ffn_dims, 
        hidden_dim=ffn_hidden_dim, 
        n_layers=ffn_num_layers, 
        activation=torch.nn.ELU(), 
        dropout=encoder_dropout
    )
    interaction = Interaction(encoder.output_dim, dropout=interaction_dropout)
    model = DeltaProp(mp, agg, encoder, interaction, batch_norm=batch_norm)

    ################################################################################################
    trainer = L.Trainer(
        logger=None,
        enable_checkpointing=True,
        enable_progress_bar=False,
        accelerator="auto",
        devices=1,
        max_epochs=20,
        callbacks=[
            EarlyStopping(monitor="val_loss", mode="min", verbose=False, patience=8),
            TuneReportCheckpointCallback()
        ],
    )

    trainer.fit(model, datamodule=RandomPairDataModule(train_mol_dataset, val_mol_dataset))

In [None]:
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch
from ray.tune.search import ConcurrencyLimiter


search_space = {
    "depth": tune.qrandint(lower=2, upper=6, q=1),
    "ffn_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
    "ffn_num_layers": tune.qrandint(lower=1, upper=3, q=1),
    "message_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
    "encoder_dropout": tune.uniform(lower=0.0, upper=0.3),
    "interaction_dropout": tune.uniform(lower=0.0, upper=0.3),
    "batch_norm": tune.choice([True, False])
}

search_alg = ConcurrencyLimiter(OptunaSearch(seed=42), max_concurrent=3)
scheduler = ASHAScheduler(max_t=20, grace_period=1, reduction_factor=2)

tune_fn = tune.with_resources(
    tune.with_parameters(
        tune_func, 
        train_mol_dataset=train_mol_dataset, 
        val_mol_dataset=val_mol_dataset
    ),
    resources={"GPU": 0.5}
)

# Checkpoint config controls the checkpointing behavior of Ray
checkpoint_config = tune.CheckpointConfig(
    num_to_keep=1, # number of checkpoints to keep
    checkpoint_score_attribute="val_loss", # Save the checkpoint based on this metric
    checkpoint_score_order="min", # Save the checkpoint with the lowest metric value
)

tuner = tune.Tuner(
    tune_fn,
    param_space=search_space,
    tune_config=tune.TuneConfig(
        metric="val_loss",
        mode="min",
        num_samples=20,
        scheduler=scheduler,
        search_alg=search_alg,
    ),
    run_config=tune.RunConfig(
        checkpoint_config=tune.CheckpointConfig(
            num_to_keep=1,
            checkpoint_score_attribute="val_loss",
            checkpoint_score_order="min",
        ),
        failure_config = train.FailureConfig(max_failures=3)
    ),
)

results = tuner.fit()
_, best_result = results.get_best_result().best_checkpoints[0]
best_config = best_result['config']

In [None]:
best_result

In [None]:
depth = best_config["depth"]
ffn_hidden_dim = best_config["ffn_hidden_dim"]
ffn_num_layers = best_config["ffn_num_layers"]
message_hidden_dim = best_config["message_hidden_dim"]
batch_norm = best_config['batch_norm']
encoder_dropout = best_config["encoder_dropout"]
interaction_dropout = best_config["interaction_dropout"]

train_loader = cp.data.build_dataloader(train_mol_dataset, batch_size=32, num_workers=8, seed=RANDOM_SEED)
val_loader = cp.data.build_dataloader(val_mol_dataset, batch_size=32, num_workers=8, shuffle=False)
test_loader = cp.data.build_dataloader(test_mol_dataset, batch_size=32, num_workers=8, shuffle=False)

###############################################################################################

mp = cp.nn.BondMessagePassing(d_h=message_hidden_dim, depth=depth)
agg = cp.nn.NormAggregation()
ffn_dims = mp.output_dim + train_mol_dataset.X_d.shape[-1]
encoder = Encoder(
    input_dim=ffn_dims, 
    hidden_dim=ffn_hidden_dim, 
    n_layers=ffn_num_layers, 
    activation=torch.nn.ELU(), 
    dropout=encoder_dropout
)
interaction = Interaction(encoder.output_dim, dropout=interaction_dropout)
model = DeltaProp(mp, agg, encoder, interaction, batch_norm=batch_norm)


################################################################################################
wandb.finish()
wandb_logger = WandbLogger(project="chemprop_delta_clf", log_model="all", save_code=True)
wandb_logger.watch(model, log="gradients", log_freq=50) 
wandb_logger.experiment.mark_preempting()

trainer = L.Trainer(
    logger=wandb_logger,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=3,
    callbacks=[
        EarlyStopping(monitor="val_loss", mode="min", verbose=True, patience=10),
        ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1)
    ],
)

trainer.fit(model, datamodule=RandomPairDataModule(train_mol_dataset, val_mol_dataset))
model = DeltaProp.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

##################################################################################################

run = wandb.init(project="evaluation")
wandb.mark_preempting()


trainer = L.Trainer(
    enable_progress_bar=False,
    accelerator="auto",
    devices=1,
)

test_ds_preds = trainer.predict(model, dataloaders=test_loader)
test_ds_preds = torch.cat(test_ds_preds)

pred_probs = test_ds_preds.squeeze().numpy()
preds = (pred_probs >= 0.5).astype(float)
labels = df_test['per_inhibition'] > 50.0

In [None]:
from sklearn.metrics import (
    accuracy_score,
    average_precision_score,
    balanced_accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)


run.log({
    'final_metrics': wandb.Table(
        columns=['accuracy', 'balanced_accuracy', 'f1', 'precision', 'recall', 'AUCROC', 'PRAUC'],
        data=[[
            accuracy_score(labels, preds),
            balanced_accuracy_score(labels, preds),
            f1_score(labels, preds),
            precision_score(labels, preds),
            recall_score(labels, preds),
            roc_auc_score(labels, pred_probs),
            average_precision_score(labels, pred_probs)
        ]]
    )
})


wandb.finish()