In [1]:
import pickle
import re
import numpy as np
import pandas as pd
import chemprop as cp
import torch
from glob import glob
import lightning as L
from tempfile import TemporaryDirectory
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from sklearn.metrics import (
    accuracy_score,
    average_precision_score,
    balanced_accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from tqdm.auto import tqdm
import wandb
import random
from typing import NamedTuple, Iterable
from itertools import chain
from pytorch_lightning.utilities import move_data_to_device

RANDOM_SEED = 42
def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [2]:
class RandomPairDataPoint(NamedTuple):
    anchor: cp.data.datasets.Datum
    exemplar: list[cp.data.datasets.Datum]
    random: list[cp.data.datasets.Datum]


class RandomPairTrainBatch(NamedTuple):
    anchor: cp.data.collate.TrainingBatch
    exemplar: cp.data.collate.TrainingBatch
    random: cp.data.collate.TrainingBatch
    B: int
    C: int
    

class RandomPairDataset(torch.utils.data.Dataset):
    def __init__(self, mol_dataset, n_candidates):
        super().__init__()
        self.mol_dataset: cp.data.datasets.MoleculeDataset = mol_dataset
        self.n_candidates: int = n_candidates

    def __len__(self):
        return len(self.mol_dataset)
    
    def get_exemplar_candidates(self):
        targets = self.mol_dataset.Y.squeeze()
        mask = targets > 50
        weights = np.where(mask, 1.0, 0.0)
        probs = weights / weights.sum()
        exemplar_idxs = np.random.choice(
            targets.shape[0], 
            size=(self.n_candidates,), 
            p=probs, 
            replace=False
        )

        return [self.mol_dataset[idx] for idx in exemplar_idxs]

    
    def get_random_candidates(self):
        targets = self.mol_dataset.Y.squeeze()
        candidate_idxs = np.random.choice(
            targets.shape[0], 
            size=(self.n_candidates,), 
            replace=False
        )
        return [self.mol_dataset[idx] for idx in candidate_idxs]

    def __getitem__(self, idx) -> RandomPairDataPoint:
        return RandomPairDataPoint(
            self.mol_dataset[idx], 
            self.get_exemplar_candidates(),
            self.get_random_candidates()
        )
    
    @staticmethod
    def collate_function(batch):
        batch_anchors, batch_exemplars, batch_candidates = zip(*batch)
        B = len(batch)
        C = len(batch_candidates[0])
        batch_anchors = cp.data.dataloader.collate_batch(batch_anchors)
        batch_exemplars = cp.data.dataloader.collate_batch(chain.from_iterable(batch_exemplars))
        batch_candidates = cp.data.dataloader.collate_batch(chain.from_iterable(batch_candidates))
        return RandomPairTrainBatch(batch_anchors, batch_exemplars, batch_candidates, B, C)
    

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

class RandomPairDataModule(L.LightningDataModule):
    def __init__(self, mol_ds_train, mol_ds_val) -> None:
        super().__init__()
        self.mol_ds_train: cp.data.MoleculeDataset = mol_ds_train
        self.mol_ds_val: cp.data.MoleculeDataset = mol_ds_val
        self.batch_size=32
        self.candidate_size=8

        self.ds_train = None
        self.ds_val = None

    def setup(self, stage=None):
        self.ds_train = RandomPairDataset(self.mol_ds_train, self.candidate_size)
        self.ds_val = RandomPairDataset(self.mol_ds_val, self.candidate_size)

    def train_dataloader(self):
        assert self.ds_train is not None
        return torch.utils.data.DataLoader(
            self.ds_train,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=RandomPairDataset.collate_function,
            worker_init_fn=seed_worker,
            num_workers=8,
        )

    def val_dataloader(self):
        assert self.ds_val is not None
        return torch.utils.data.DataLoader(
            self.ds_val,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=RandomPairDataset.collate_function,
            num_workers=8,
        )

In [3]:
class RESCALInteraction(torch.nn.Module):
    def __init__(self, ndims) -> None:
        super().__init__()
        self.interaction_matrix = torch.nn.Linear(ndims, ndims, bias=False)
        self.head_dropout = torch.nn.Dropout(0.3)

    def forward(self, head_emb, tail_emb):
        R = self.interaction_matrix.weight.unsqueeze(0)
        z = self.head_dropout(head_emb @ R) @ tail_emb.transpose(-2, -1)
        return z.squeeze()
    

class ContrastiveMPNN(cp.models.MPNN):
    def __init__(
        self,
        message_passing: cp.nn.MessagePassing,
        agg: cp.nn.Aggregation,
        predictor: cp.nn.Predictor,
        batch_norm: bool = False,
        metrics: Iterable[cp.nn.ChempropMetric] | None = None,
        warmup_epochs: int = 2,
        init_lr: float = 0.0001,
        max_lr: float = 0.001,
        final_lr: float = 0.0001,
        X_d_transform: cp.nn.transforms.ScaleTransform | None = None,
    ):
        super().__init__(
            message_passing,
            agg,
            predictor,
            batch_norm,
            metrics,
            warmup_epochs,
            init_lr,
            max_lr,
            final_lr,
            X_d_transform,
        )

        self.interaction = RESCALInteraction(300)
        self.loss_fn = torch.nn.BCEWithLogitsLoss()


    def embed_simple_batch(self, batch: cp.data.collate.TrainingBatch):
        bmg, V_d, X_d, target, _, _, _ = batch
        Z = self.encoding(bmg, V_d, X_d)
        return dict(embeds=Z, targets=target)
    

    def get_losses(self, batch: RandomPairTrainBatch):
        B, C = batch.B, batch.C

        bmg, V_d, X_d, target_anchor, _, _, _ = batch.anchor
        Z_anchor = self.encoding(bmg, V_d, X_d)

        bmg, V_d, X_d, target_exemplar, _, _, _ = batch.exemplar
        Z_exemplar = self.encoding(bmg, V_d, X_d)

        bmg, V_d, X_d, target_random, _, _, _ = batch.random
        Z_random = self.encoding(bmg, V_d, X_d)

        Z_anchor = Z_anchor.view(B, 1, -1)              # (B, d) -> (B, 1, d)
        Z_exemplar = Z_exemplar.view(B, C, -1)          # (B*C, d) -> (B, C, d)
        Z_random = Z_random.view(B, C, -1)              # (B*C, d) -> (B, C, d)
        
        target_anchor = target_anchor.view(-1, 1)
        target_exemplar = target_exemplar.view(B, C)
        target_random = target_random.view(B, C)


        # left to right loss
        preds = self.interaction(Z_anchor, Z_exemplar).squeeze()
        labels = (target_anchor > target_exemplar).float() # type: ignore
        lr_exemplar_loss = self.loss_fn(preds, labels)

        preds = self.interaction(Z_anchor, Z_random).squeeze()
        labels = (target_anchor > target_random).float() # type: ignore
        lr_random_loss = self.loss_fn(preds, labels)
       

        # right to left loss
        preds = self.interaction(Z_exemplar, Z_anchor).squeeze()
        labels = (target_exemplar > target_anchor).float() # type: ignore
        rl_exemplar_loss = self.loss_fn(preds, labels)

        preds = self.interaction(Z_random, Z_anchor).squeeze()
        labels = (target_random > target_anchor).float() # type: ignore
        rl_random_loss = self.loss_fn(preds, labels)

        loss = (lr_exemplar_loss + lr_random_loss + rl_exemplar_loss + rl_random_loss) / 4
        return loss, (lr_exemplar_loss, lr_random_loss, rl_exemplar_loss, rl_random_loss)


    def training_step(self, batch: RandomPairTrainBatch, batch_idx):  # type: ignore
        loss, (lr_exemplar_loss, lr_random_loss, rl_exemplar_loss, rl_random_loss) = self.get_losses(batch)
        self.log("train_lr_exemplar_loss", lr_exemplar_loss, batch_size=batch.B, on_epoch=True, enable_graph=True)
        self.log("train_lr_random_loss", lr_random_loss, batch_size=batch.B, on_epoch=True, enable_graph=True)
        self.log("train_rl_exemplar_loss", rl_exemplar_loss, batch_size=batch.B, on_epoch=True, enable_graph=True)
        self.log("train_rl_random_loss", rl_random_loss, batch_size=batch.B, on_epoch=True, enable_graph=True)
        self.log("train_loss", loss, batch_size=batch.B, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch: RandomPairTrainBatch, batch_idx):  # type: ignore
        loss, (lr_exemplar_loss, lr_random_loss, rl_exemplar_loss, rl_random_loss) = self.get_losses(batch)
        self.log("val_lr_exemplar_loss", lr_exemplar_loss, batch_size=batch.B, on_epoch=True, enable_graph=True)
        self.log("val_lr_random_loss", lr_random_loss, batch_size=batch.B, on_epoch=True, enable_graph=True)
        self.log("val_rl_exemplar_loss", rl_exemplar_loss, batch_size=batch.B, on_epoch=True, enable_graph=True)
        self.log("val_rl_random_loss", rl_random_loss, batch_size=batch.B, on_epoch=True, enable_graph=True)
        self.log("val_loss", loss, batch_size=batch.B)
        return loss

In [20]:
def get_molecule_datapoint(row):
    feat_entry_names = [f for f in row.index if f.startswith('feat')]
    feat_array = pd.to_numeric(row[feat_entry_names], errors="coerce")
    return cp.data.MoleculeDatapoint(
        mol=row['mol'], 
        y=np.array([row['per_inhibition']]),
        x_d=feat_array.to_numpy()
    )


@torch.no_grad()
def embed_all(mol_dataset: cp.data.datasets.MoleculeDataset, contrastive_mpnn):
    dl = torch.utils.data.DataLoader(
        mol_dataset, 
        batch_size=64, 
        shuffle=False, 
        collate_fn=cp.data.dataloader.collate_batch
    )
    all_embeds = []
    for batch in dl:
        batch = move_data_to_device(batch, contrastive_mpnn.device)
        res = contrastive_mpnn.embed_simple_batch(batch)
        all_embeds.append(res['embeds'])

    all_embeds = torch.cat(all_embeds)
    return all_embeds


def evaluate_on_split(df_train, df_val, df_test):
    df_train = df_train.copy().sample(100).reset_index(drop=True)
    df_val = df_val.copy().sample(100).reset_index(drop=True)
    df_test = df_test.copy().sample(100).reset_index(drop=True)

    df_train['mol'] = df_train['mol_ser'].map(pickle.loads)
    df_val['mol'] = df_val['mol_ser'].map(pickle.loads)
    df_test['mol'] = df_test['mol_ser'].map(pickle.loads)

    featurizer = cp.featurizers.SimpleMoleculeMolGraphFeaturizer()
    train_mol_dataset = cp.data.MoleculeDataset(df_train.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)
    val_mol_dataset = cp.data.MoleculeDataset(df_val.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)
    test_mol_dataset = cp.data.MoleculeDataset(df_test.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)

    x_d_scaler = train_mol_dataset.normalize_inputs("X_d")
    val_mol_dataset.normalize_inputs("X_d", x_d_scaler)
    test_mol_dataset.normalize_inputs("X_d", x_d_scaler)

    train_mol_dataset.cache = True
    val_mol_dataset.cache = True
    test_mol_dataset.cache = True


    #########################################################################
    fdims = cp.featurizers.SimpleMoleculeMolGraphFeaturizer().shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).
    mp = cp.nn.BondMessagePassing()
    agg = cp.nn.NormAggregation()
    ffn_dims = mp.output_dim + len([f for f in df_train.columns if f.startswith("feat")])
    ffn = cp.nn.BinaryClassificationFFN(n_tasks=1, input_dim=ffn_dims, activation=torch.nn.ELU())
    batch_norm = True
    metric_list = [cp.nn.metrics.BinaryF1Score(), cp.nn.metrics.BinaryAUPRC(), cp.nn.metrics.BinaryAUROC()]
    X_d_transform = cp.nn.ScaleTransform.from_standard_scaler(x_d_scaler)
    contrastive_mpnn = ContrastiveMPNN(mp, agg, ffn, batch_norm, metric_list, X_d_transform=X_d_transform)



    # with TemporaryDirectory() as tmpdir:

    trainer = L.Trainer(
        logger=None,
        enable_checkpointing=True,
        enable_progress_bar=True,
        accelerator="auto",
        devices=1,
        max_epochs=1,
        # default_root_dir=tmpdir,
        callbacks=[
            EarlyStopping(monitor="val_loss", mode="min", verbose=True, patience=10),
            ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1)
        ]
    )

    trainer.fit(contrastive_mpnn, datamodule=RandomPairDataModule(train_mol_dataset, val_mol_dataset))

    ckpt = torch.load(trainer.checkpoint_callback.best_model_path, map_location='cuda', weights_only=False)
    contrastive_mpnn.load_state_dict(ckpt['state_dict'])

    ################################################################################

    train_embeds = embed_all(train_mol_dataset, contrastive_mpnn)
    test_embeds = embed_all(test_mol_dataset, contrastive_mpnn)


    exemplar_idxs = np.argwhere(train_mol_dataset.Y.squeeze() > 50)
    exemplar_embeds = train_embeds[exemplar_idxs].squeeze()
    # exemplar_targets = train_mol_dataset.Y[exemplar_idxs].squeeze()

    with torch.no_grad():
        pred_probs = contrastive_mpnn.interaction(test_embeds, exemplar_embeds).sigmoid().mean(axis=-1)
        preds = (pred_probs >= 0.5).float()

        pred_probs = pred_probs.detach().numpy().squeeze()
        preds = preds.detach().numpy().squeeze()
        labels = df_test['per_inhibition'] > 50.0


    return {
        "accuracy": accuracy_score(labels, preds),
        "balanced_accuracy": balanced_accuracy_score(labels, preds),
        "f1_score": f1_score(labels, preds),
        "precision": precision_score(labels, preds),
        "recall": recall_score(labels, preds),
        "roc_auc": roc_auc_score(labels, pred_probs),
        "average_precision": average_precision_score(labels, pred_probs)
    }


In [None]:
run = wandb.init(project="evaluation")
wandb.mark_preempting()


cross_val_results = []
for split_fpath in tqdm(glob("./generated_splits/*.parquet")):
    matches = re.match(".*split_(?P<outer>\\d)x(?P<inner>\\d)", split_fpath)
    assert matches is not None, split_fpath
    matches = matches.groupdict()
    outer_idx, inner_idx = int(matches["outer"]), int(matches["inner"]) 

    total_split_df = pd.read_parquet(split_fpath)
    total_split_df = total_split_df.drop("index", axis=1)

    df_train = total_split_df[total_split_df['split'] == "train"]
    df_val = total_split_df[total_split_df['split'] == "val"]
    df_test = total_split_df[total_split_df['split'] == "test"]

    df_train = df_train.drop("split", axis=1)
    df_val = df_val.drop("split", axis=1)
    df_test = df_test.drop("split", axis=1)

    scores = evaluate_on_split(df_train, df_val, df_test)
    split_result_entry = scores | {"outer": outer_idx, "inner": inner_idx}
    cross_val_results.append(split_result_entry)

    print(f"completed_{outer_idx}x{inner_idx}")
    print('---------------------------------------------------------------')

In [None]:
cross_val_results = pd.DataFrame.from_records(cross_val_results)
cross_val_results["model"] = "baseline"
run.log({"Cross Val Results": wandb.Table(dataframe=cross_val_results)})

mean_scores = cross_val_results.drop(["outer", "inner"], axis=1).groupby("model").agg("mean").reset_index()
mean_scores.columns = [f"mean_{c}" for c in mean_scores.columns]
run.log({"Mean Results": wandb.Table(dataframe=cross_val_results)})

std_scores = cross_val_results.drop(["outer", "inner"], axis=1).groupby("model").agg("std").reset_index()
std_scores.columns = [f"std_{c}" for c in std_scores.columns]
run.log({"Std Results": wandb.Table(dataframe=cross_val_results)})

In [None]:
wandb.finish()