In [None]:
import sys

sys.path.append('../')

import os
import random
import pandas as pd
import lightning as L
import numpy as np
import torch
from chemprop import data, featurizers, models
from chemprop import nn as chem_nn
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from pytorch_lightning.utilities import move_data_to_device
import pandas as pd
import rdkit.Chem as Chem
from rdkit.Chem.Descriptors import CalcMolDescriptors
from rdkit.rdBase import BlockLogs
from sklearn.model_selection import GroupShuffleSplit
from commons.utils import get_scaffold, standardize
from typing import NamedTuple
from itertools import chain

import wandb
# from commons.data import load_and_split_gsk_dataset

RANDOM_SEED = 42

def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seeds(RANDOM_SEED)

# load_dotenv('.env.secret')
# wandb.login(key='cf344975eb80edf6f0d52af80528cc6094234caf')

In [None]:
def mol_to_inchi(mol):
    with BlockLogs():
        return Chem.MolToInchi(mol)
    

def generate_features(df):
    with BlockLogs():
        feats = pd.DataFrame.from_records(df["mol"].map(CalcMolDescriptors).tolist())
        feats.columns = [f"feat_{f}" for f in feats.columns]
        df = pd.concat(
            [
                df.reset_index(drop=True),
                feats,
            ],
            axis=1,
        )

    return df


def load_and_split_gsk_dataset(path, RANDOM_SEED):
    df = pd.read_csv(path)
    df = df.iloc[:, 1:]
    df.columns = ["smiles", "per_inhibition"]

    # standardize and convert to inchi
    df["mol"] = df["smiles"].map(standardize)
    df = df.dropna(subset=["mol"])
    df["inchi"] = df["mol"].map(mol_to_inchi)
    df = df.groupby(["inchi"]).filter(lambda x: len(x) == 1).reset_index(drop=True)

    df["is_cytotoxic"] = df["per_inhibition"] > 50.0

    df = generate_features(df)

    clusters, _ = pd.factorize(
        df["mol"]
        .map(Chem.MolToSmiles)  # type: ignore
        .map(get_scaffold)
    )
    clusters = pd.Series(clusters)

    df = df.drop(["smiles", "inchi"], axis=1)

    splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED)
    train_idxs, val_test_idxs = next(splitter.split(df, groups=clusters))
    df_train = df.loc[train_idxs].reset_index(drop=True)
    df_val_test = df.loc[val_test_idxs].reset_index(drop=True)
    clusters_val_test = clusters.iloc[val_test_idxs].reset_index(drop=True)

    splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED, test_size=0.5)
    val_idxs, test_idxs = next(splitter.split(df_val_test, groups=clusters_val_test))
    df_val = df_val_test.loc[val_idxs].reset_index(drop=True)
    df_test = df_val_test.loc[test_idxs].reset_index(drop=True)

    return df_train, df_val, df_test

In [3]:
df_train, df_val, df_test = load_and_split_gsk_dataset("../GSK_HepG2.csv", RANDOM_SEED)

In [4]:
def get_molecule_datapoint(row):
    feat_entry_names = [f for f in row.index if f.startswith('feat')]
    feat_array = pd.to_numeric(row[feat_entry_names], errors="coerce")
    return data.MoleculeDatapoint(
        mol=row['mol'], 
        y=np.array([row['per_inhibition']]),
        x_d=feat_array.to_numpy()
    )

featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_mol_dataset = data.MoleculeDataset(df_train.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)
val_mol_dataset = data.MoleculeDataset(df_val.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)
test_mol_dataset = data.MoleculeDataset(df_test.apply(get_molecule_datapoint, axis=1), featurizer=featurizer)

x_d_scaler = train_mol_dataset.normalize_inputs("X_d")
val_mol_dataset.normalize_inputs("X_d", x_d_scaler)
test_mol_dataset.normalize_inputs("X_d", x_d_scaler)

train_mol_dataset.cache = True
val_mol_dataset.cache = True
test_mol_dataset.cache = True

In [5]:
class RandomPairDataPoint(NamedTuple):
    anchor: data.datasets.Datum
    exemplar: list[data.datasets.Datum]
    random: list[data.datasets.Datum]


class RandomPairTrainBatch(NamedTuple):
    anchor: data.collate.TrainingBatch
    exemplar: data.collate.TrainingBatch
    random: data.collate.TrainingBatch
    B: int
    C: int
    

class RandomPairDataset(Dataset):
    def __init__(self, mol_dataset, n_candidates):
        super().__init__()
        self.mol_dataset: data.datasets.MoleculeDataset = mol_dataset
        self.n_candidates: int = n_candidates

    def __len__(self):
        return len(self.mol_dataset)
    
    def get_exemplar_candidates(self):
        targets = self.mol_dataset.Y.squeeze()
        mask = targets > 50
        weights = np.where(mask, 1.0, 0.0)
        probs = weights / weights.sum()
        exemplar_idxs = np.random.choice(
            targets.shape[0], 
            size=(self.n_candidates,), 
            p=probs, 
            replace=False
        )

        return [self.mol_dataset[idx] for idx in exemplar_idxs]

    
    def get_random_candidates(self):
        targets = self.mol_dataset.Y.squeeze()
        candidate_idxs = np.random.choice(
            targets.shape[0], 
            size=(self.n_candidates,), 
            replace=False
        )
        return [self.mol_dataset[idx] for idx in candidate_idxs]

    def __getitem__(self, idx) -> RandomPairDataPoint:
        return RandomPairDataPoint(
            self.mol_dataset[idx], 
            self.get_exemplar_candidates(),
            self.get_random_candidates()
        )
    
    @staticmethod
    def collate_function(batch):
        batch_anchors, batch_exemplars, batch_candidates = zip(*batch)
        B = len(batch)
        C = len(batch_candidates[0])
        batch_anchors = data.dataloader.collate_batch(batch_anchors)
        batch_exemplars = data.dataloader.collate_batch(chain.from_iterable(batch_exemplars))
        batch_candidates = data.dataloader.collate_batch(chain.from_iterable(batch_candidates))
        return RandomPairTrainBatch(batch_anchors, batch_exemplars, batch_candidates, B, C)

In [6]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

class RandomPairDataModule(L.LightningDataModule):
    def __init__(self, mol_ds_train, mol_ds_val) -> None:
        super().__init__()
        self.mol_ds_train: data.MoleculeDataset = mol_ds_train
        self.mol_ds_val: data.MoleculeDataset = mol_ds_val
        self.batch_size=32
        self.candidate_size=8

        self.ds_train = None
        self.ds_val = None

    def setup(self, stage=None):
        self.ds_train = RandomPairDataset(self.mol_ds_train, self.candidate_size)
        self.ds_val = RandomPairDataset(self.mol_ds_val, self.candidate_size)

    def train_dataloader(self):
        assert self.ds_train is not None
        return DataLoader(
            self.ds_train,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=RandomPairDataset.collate_function,
            worker_init_fn=seed_worker,
            num_workers=8,
        )

    def val_dataloader(self):
        assert self.ds_val is not None
        return DataLoader(
            self.ds_val,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=RandomPairDataset.collate_function,
            num_workers=8,
        )

In [None]:
from typing import Any, Iterable
from chemprop.nn import Aggregation, ChempropMetric, MessagePassing, Predictor
from chemprop.nn.transforms import ScaleTransform
import pytorch_lightning as pl


class RESCALInteraction(torch.nn.Module):
    def __init__(self, ndims) -> None:
        super().__init__()
        self.interaction_matrix = torch.nn.Linear(ndims, ndims, bias=False)
        self.head_dropout = torch.nn.Dropout(0.3)

    def forward(self, head_emb, tail_emb):
        R = self.interaction_matrix.weight.unsqueeze(0)
        z = self.head_dropout(head_emb @ R) @ tail_emb.transpose(-2, -1)
        # z = (self.head_dropout(head_emb @ R) - tail_emb.transpose(-2, -1)).sigmoid?
        return z.squeeze()


class ProjEInteraction(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.embs = torch.nn.Embedding(4, 300)
        self.dh_emb_idx = torch.nn.Parameter(torch.LongTensor([0]), requires_grad=False)
        self.dr_emb_idx = torch.nn.Parameter(torch.LongTensor([1]), requires_grad=False)
        self.b_emb_idx = torch.nn.Parameter(torch.LongTensor([2]), requires_grad=False)
        self.r_emb_idx = torch.nn.Parameter(torch.LongTensor([3]), requires_grad=False)
        self.b_p = torch.nn.Parameter(torch.rand(1))

    def forward(self, head_emb, tail_emb):
        dh_emb = self.embs(self.dh_emb_idx)
        dr_emb = self.embs(self.dr_emb_idx)
        r_emb = self.embs(self.r_emb_idx)
        b_emb = self.embs(self.b_emb_idx)
        b_p = self.b_p

        x = torch.nn.functional.relu(dh_emb * head_emb + dr_emb * r_emb + b_emb)
        y = x @ tail_emb.transpose(-2, -1) + b_p
        return y.squeeze()


class NNInteraction(torch.nn.Module):
    def __init__(self, ndims) -> None:
        super().__init__()
        self.nn = torch.nn.Sequential(
            torch.nn.Linear(2 * ndims, 2 * ndims),
            torch.nn.ReLU(),
            torch.nn.Linear(2 * ndims, 1),
        )

    def forward(self, head_emb, tail_emb):
        if head_emb.shape[-2] == 1:
            head_emb = head_emb.expand_as(tail_emb)
        else:
            tail_emb = tail_emb.expand_as(head_emb)

        Z_combined = torch.cat([head_emb, tail_emb], dim=-1)  # (B, C, 2*d)
        return self.nn(Z_combined).squeeze()


class ContrastiveMPNN(models.MPNN):
    def __init__(
        self,
        message_passing: MessagePassing,
        agg: Aggregation,
        predictor: Predictor,
        batch_norm: bool = False,
        metrics: Iterable[ChempropMetric] | None = None,
        warmup_epochs: int = 2,
        init_lr: float = 0.0001,
        max_lr: float = 0.001,
        final_lr: float = 0.0001,
        X_d_transform: ScaleTransform | None = None,
    ):
        super().__init__(
            message_passing,
            agg,
            predictor,
            batch_norm,
            metrics,
            warmup_epochs,
            init_lr,
            max_lr,
            final_lr,
            X_d_transform,
        )

        self.interaction = RESCALInteraction(300)
        self.loss_fn = torch.nn.BCEWithLogitsLoss()

    def embed_simple_batch(self, batch: data.collate.TrainingBatch):
        bmg, V_d, X_d, target, _, _, _ = batch
        Z = self.encoding(bmg, V_d, X_d)
        return dict(embeds=Z, targets=target)
    

    def bidirectional_interaction_loss(
        self, Z_left, Z_right, target_left, target_right
    ):
        # left to right loss
        lr_interaction = self.interaction(Z_left, Z_right).squeeze()
        lr_labels = (target_left > target_right).float()  # type: ignore
        lr_loss = self.loss_fn(lr_interaction, lr_labels)

        # right to left loss
        rl_interaction = self.interaction(Z_right, Z_left).squeeze()
        rl_labels = (target_left <= target_right).float()  # type: ignore
        rl_loss = self.loss_fn(rl_interaction, rl_labels)

        delta = (lr_interaction + rl_interaction) ** 2
        symm_loss = delta.sum(dim=-1).mean()

        return symm_loss, lr_loss, rl_loss

    def get_losses(self, batch: RandomPairTrainBatch):
        B, C = batch.B, batch.C

        bmg, V_d, X_d, target_anchor, _, _, _ = batch.anchor
        Z_anchor = self.encoding(bmg, V_d, X_d)

        bmg, V_d, X_d, target_exemplar, _, _, _ = batch.exemplar
        Z_exemplar = self.encoding(bmg, V_d, X_d)

        bmg, V_d, X_d, target_random, _, _, _ = batch.random
        Z_random = self.encoding(bmg, V_d, X_d)

        Z_anchor = Z_anchor.view(B, 1, -1)  # (B, d) -> (B, 1, d)
        Z_exemplar = Z_exemplar.view(B, C, -1)  # (B*C, d) -> (B, C, d)
        Z_random = Z_random.view(B, C, -1)  # (B*C, d) -> (B, C, d)

        target_anchor = target_anchor.view(-1, 1)
        target_exemplar = target_exemplar.view(B, C)
        target_random = target_random.view(B, C)

        (exemplar_sym_loss, lr_exemplar_loss, rl_exemplar_loss) = (
            self.bidirectional_interaction_loss(
                Z_anchor, Z_exemplar, target_anchor, target_exemplar
            )
        )

        (random_sym_loss, lr_random_loss, rl_random_loss) = (
            self.bidirectional_interaction_loss(
                Z_anchor, Z_random, target_anchor, target_random
            )
        )

        loss = (
            exemplar_sym_loss
            + random_sym_loss
            + lr_exemplar_loss
            + lr_random_loss
            + rl_exemplar_loss
            + rl_random_loss
        ) / 6
        
        return loss, (
            exemplar_sym_loss,
            random_sym_loss,
            lr_exemplar_loss,
            lr_random_loss,
            rl_exemplar_loss,
            rl_random_loss,
        )

    def training_step(self, batch: RandomPairTrainBatch, batch_idx):  # type: ignore
        loss, (
            exemplar_sym_loss,
            random_sym_loss,
            lr_exemplar_loss,
            lr_random_loss,
            rl_exemplar_loss,
            rl_random_loss,
        ) = self.get_losses(batch)

        self.log(
            "train_exemplar_sym_loss",
            exemplar_sym_loss,
            batch_size=batch.B,
            on_epoch=True,
            enable_graph=True,
        )

        self.log(
            "train_random_sym_loss",
            random_sym_loss,
            batch_size=batch.B,
            on_epoch=True,
            enable_graph=True,
        )

        self.log(
            "train_lr_exemplar_loss",
            lr_exemplar_loss,
            batch_size=batch.B,
            on_epoch=True,
            enable_graph=True,
        )
        self.log(
            "train_lr_random_loss",
            lr_random_loss,
            batch_size=batch.B,
            on_epoch=True,
            enable_graph=True,
        )
        self.log(
            "train_rl_exemplar_loss",
            rl_exemplar_loss,
            batch_size=batch.B,
            on_epoch=True,
            enable_graph=True,
        )
        self.log(
            "train_rl_random_loss",
            rl_random_loss,
            batch_size=batch.B,
            on_epoch=True,
            enable_graph=True,
        )
        self.log("train_loss", loss, batch_size=batch.B, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch: RandomPairTrainBatch, batch_idx):  # type: ignore
        loss, (
            exemplar_sym_loss,
            random_sym_loss,
            lr_exemplar_loss,
            lr_random_loss,
            rl_exemplar_loss,
            rl_random_loss,
        ) = self.get_losses(batch)

        self.log(
            "val_exemplar_sym_loss",
            exemplar_sym_loss,
            batch_size=batch.B,
            on_epoch=True,
            enable_graph=True,
        )

        self.log(
            "val_random_sym_loss",
            random_sym_loss,
            batch_size=batch.B,
            on_epoch=True,
            enable_graph=True,
        )

        self.log(
            "val_lr_exemplar_loss",
            lr_exemplar_loss,
            batch_size=batch.B,
            on_epoch=True,
            enable_graph=True,
        )
        self.log(
            "val_lr_random_loss",
            lr_random_loss,
            batch_size=batch.B,
            on_epoch=True,
            enable_graph=True,
        )
        self.log(
            "val_rl_exemplar_loss",
            rl_exemplar_loss,
            batch_size=batch.B,
            on_epoch=True,
            enable_graph=True,
        )
        self.log(
            "val_rl_random_loss",
            rl_random_loss,
            batch_size=batch.B,
            on_epoch=True,
            enable_graph=True,
        )
        self.log("val_loss", loss, batch_size=batch.B, prog_bar=True, on_epoch=True)
        return loss

In [8]:
fdims = featurizers.SimpleMoleculeMolGraphFeaturizer().shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).
mp = chem_nn.BondMessagePassing()
agg = chem_nn.NormAggregation()
ffn_dims = mp.output_dim + len([f for f in df_train.columns if f.startswith("feat")])
ffn = chem_nn.BinaryClassificationFFN(n_tasks=1, input_dim=ffn_dims, activation=torch.nn.ELU(), dropout=0.3)
batch_norm = True
metric_list = [chem_nn.metrics.BinaryF1Score(), chem_nn.metrics.BinaryAUPRC(), chem_nn.metrics.BinaryAUROC()]
X_d_transform = chem_nn.ScaleTransform.from_standard_scaler(x_d_scaler)
contrastive_mpnn = ContrastiveMPNN(mp, agg, ffn, batch_norm, metric_list, X_d_transform=X_d_transform)
# contrastive_mpnn.max_lr = 0.01

In [9]:
wandb.finish()
wandb_logger = WandbLogger(project="chemprop_delta_clf", log_model="all", save_code=True)
wandb_logger.watch(contrastive_mpnn, log="gradients", log_freq=50) 
wandb_logger.experiment.mark_preempting()

trainer = L.Trainer(
    logger=wandb_logger,
    enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20,
    # reload_dataloaders_every_n_epochs=1,
    log_every_n_steps=50,
    callbacks=[
        EarlyStopping(monitor="val_loss", mode="min", verbose=True, patience=10),
        ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1)
    ]
)

trainer.fit(contrastive_mpnn, datamodule=RandomPairDataModule(train_mol_dataset, val_mol_dataset))

[34m[1mwandb[0m: Currently logged in as: [33mrahul-e-dev[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                    | Params | Mode 
--------------------------------------------------------------------
0 | message_passing | BondMessagePassing      | 227 K  | train
1 | agg             | NormAggregation         | 0      | train
2 | bn              | BatchNorm1d             | 600    | train
3 | predictor       | BinaryClassificationFFN | 155 K  | train
4 | X_d_transform   | ScaleTransform          | 0      | train
5 | metrics         | ModuleList              | 0      | train
6 | interaction     | RESCALInteraction       | 90.0 K | train
7 | loss_fn         | BCEWithLogitsLoss       | 0      | train
---------------------------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 0.417


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.025 >= min_delta = 0.0. New best score: 0.391


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.019 >= min_delta = 0.0. New best score: 0.372


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.364


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.363


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.014 >= min_delta = 0.0. New best score: 0.349


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.014 >= min_delta = 0.0. New best score: 0.335


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


In [10]:
from pathlib import Path

run_id = wandb_logger.experiment.id
checkpoint_reference = f"rahul-e-dev/chemprop_delta_clf/model-{run_id}:best"
artifact_dir = wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")


ckpt = torch.load(Path(artifact_dir) / "model.ckpt", map_location='cpu', weights_only=False)
hparams = ckpt.get('hyper_parameters', ckpt.get('hparams', {}))
contrastive_mpnn.load_state_dict(ckpt['state_dict'])

[34m[1mwandb[0m:   1 of 1 files downloaded.  


<All keys matched successfully>

In [11]:
contrastive_mpnn = contrastive_mpnn.eval()

In [12]:
from tqdm.auto import tqdm

@torch.no_grad()
def embed_all(mol_dataset: data.datasets.MoleculeDataset, contrastive_mpnn):
    dl = DataLoader(mol_dataset, batch_size=64, shuffle=False, collate_fn=data.dataloader.collate_batch)
    all_embeds = []
    for batch in tqdm(dl, total=len(dl)):
        batch = move_data_to_device(batch, contrastive_mpnn.device)
        res = contrastive_mpnn.embed_simple_batch(batch)
        all_embeds.append(res['embeds'])

    all_embeds = torch.cat(all_embeds)
    return all_embeds

In [19]:
train_embeds = embed_all(train_mol_dataset, contrastive_mpnn)
test_embeds = embed_all(test_mol_dataset, contrastive_mpnn)

  0%|          | 0/162 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

In [20]:
exemplar_idxs = np.argwhere(train_mol_dataset.Y.squeeze() > 50)
exemplar_embeds = train_embeds[exemplar_idxs].squeeze()
exemplar_targets = train_mol_dataset.Y[exemplar_idxs].squeeze()

In [None]:
with torch.no_grad():
    all_preds = contrastive_mpnn.interaction(exemplar_embeds, test_embeds).sigmoid()
    all_preds = all_preds.mean(axis=-1)
    all_preds = all_preds.detach().numpy().squeeze()
    all_true = test_mol_dataset.Y > 50

In [26]:
all_true.shape, all_preds.shape

((1282, 1), (1282,))

In [27]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, average_precision_score, roc_auc_score, balanced_accuracy_score, recall_score

wandb_logger.log_table(
    'final_metrics', 
    ['accuracy', 'balanced_accuracy', 'f1', 'precision', 'recall', 'AUCROC', 'PRAUC'],
    [[
        accuracy_score(all_true, all_preds > 0.5),
        balanced_accuracy_score(all_true, all_preds > 0.5),
        f1_score(all_true, all_preds > 0.5),
        precision_score(all_true, all_preds > 0.5),
        recall_score(all_true, all_preds > 0.5),
        roc_auc_score(all_true, all_preds),
        average_precision_score(all_true, all_preds)
    ]]
)

In [28]:
wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▆▆▆▇▇▇▇▇▇▇███
train_exemplar_sym_loss_epoch,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_exemplar_sym_loss_step,█▃▂▃▂▃▁▂▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_epoch,█▄▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁
train_loss_step,█▆▅▄▄▄▆▄▄▆▃▄▃▄▃▄▂▂▂▂▄▃▂▃▂▂▁▁▁▁▂▁▁▂▂▃▄▂▂▁
train_lr_exemplar_loss_epoch,█▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
train_lr_exemplar_loss_step,█▃▃▆▆▄▇▃▅▄▁▄▅▃▃▃▃▃▃▃▃▄▃▃▃▂▁▂▃▅▁▃▂▃▂▃▂▁▃▃
train_lr_random_loss_epoch,█▇▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▁▁▁
train_lr_random_loss_step,▇▆▄▆▅▆▅▄█▃▄▆▄▄▆▄▃▃▃▄▄▃▆▃▄▃▃▅▃▃▃▄▂▂▅▁▂▂▂▃
train_random_sym_loss_epoch,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,19.0
train_exemplar_sym_loss_epoch,0.00447
train_exemplar_sym_loss_step,0.00392
train_loss_epoch,0.21961
train_loss_step,0.20238
train_lr_exemplar_loss_epoch,0.17656
train_lr_exemplar_loss_step,0.27012
train_lr_random_loss_epoch,0.47426
train_lr_random_loss_step,0.32941
train_random_sym_loss_epoch,0.01077


In [None]:
from tqdm.auto import tqdm

@torch.no_grad()
def embed_all(mol_dataset: data.datasets.MoleculeDataset, contrastive_mpnn):
    dl = DataLoader(mol_dataset, batch_size=64, shuffle=False, collate_fn=data.dataloader.collate_batch)
    all_embeds = []
    for batch in tqdm(dl, total=len(dl)):
        batch = move_data_to_device(batch, contrastive_mpnn.device)
        res = contrastive_mpnn.embed_simple_batch(batch)
        all_embeds.append(res['embeds'])

    all_embeds = torch.cat(all_embeds)
    return all_embeds

@torch.no_grad()
def mine_hard_and_rand_negatives_for_anchor(anchor_idx, all_embeds: torch.Tensor, all_targets: torch.Tensor, clf: torch.nn.Linear):
    anchor = all_embeds[anchor_idx]
    anchor = anchor.view(1, -1).expand_as(all_embeds)
    logits = clf(torch.cat([anchor, all_embeds], dim=-1)).squeeze()
    anchor_target = all_targets[anchor_idx]

    negative_mask = torch.logical_xor(logits > 0.5, anchor_target > all_targets)
    logits[~negative_mask] = float("-inf")

    _, hard_neg_idxs = torch.topk(logits, 3, largest=True)

    negative_mask[hard_neg_idxs] = False

    rand_neg_slection_prob = torch.where(negative_mask, 1.0, float("-inf")).softmax(dim=-1)
    rand_neg_idxs = torch.multinomial(rand_neg_slection_prob, 4)
