In [1]:
!pip install rdkit
!pip install torch_geometric
!pip install lightning -U
!pip install wandb -U

Collecting rdkit
  Downloading rdkit-2023.9.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB)
Downloading rdkit-2023.9.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.9/34.9 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2023.9.6
Collecting torch_geometric
  Downloading torch_geometric-2.5.3-py3-none-any.whl.metadata (64 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.2/64.2 kB[0m [31m526.9 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.5.3
Collecting lightning
  Downloading l

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
import lightning as L
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from pathlib import Path
from scipy.stats import pearsonr
import sklearn.metrics as skmetrics
from collections import defaultdict
from torch_geometric.loader import DataLoader
import wandb
import os
for dirname, _, filenames in os.walk('../input'):
    if dirname.find(".git") >= 0 or dirname.find("wandb") >= 0:
        continue
    filenames = [filename for filename in filenames if not filename.startswith("__")]
    for filename in filenames[:5]:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Common

In [None]:
try:
    from kaggle_secrets import UserSecretsClient
    wandb_api = UserSecretsClient().get_secret("wandb_key") 
except Exception as e:
    print(e)
    wandb_api = None
if wandb_api is not None:
    wandb.login(key=wandb_api)



In [3]:
DRAFT_MODE = False #True

BATCH_SIZE = 1024
target_names = ["pIC50", "pKi"]

In [4]:
list(Path("../input/cleaned-enamine").iterdir())

[PosixPath('../input/cleaned-enamine/cleaned_enamine.csv')]

In [5]:
# full_df = pd.read_csv("../input/cleaned-enamine/cleaned_enamine.csv")
# full_df.shape
                      

In [6]:
chunk_size = 5000000
# with pd.read_csv("../input/cleaned-enamine/cleaned_enamine.csv", chunksize=chunk_size) as reader:
#     for chunk_index, full_df in enumerate(reader):
#         full_df.to_csv(f"enamine_chunk_{chunk_index}_{chunk_size}.csv", index=None)

CHUNK = 0
full_df = pd.read_csv(f"../input/enamine-split-to-chunks/enamine_chunk_{CHUNK}_{chunk_size}.csv")
print(full_df.shape)
if DRAFT_MODE:
    full_df = full_df.head(2000)
print(full_df.shape)

full_df.head()

(5000000, 2)
(2000, 2)


Unnamed: 0,TID,SMILES
0,PV-009095522958,FCC(CF)NC(CF)CF
1,Z5067467525,FC(F)(F)C12CC(NC3CSCSC3)(CO1)C2
2,PV-009234181902,CCN(CC1C2CC3C(C2)C13)C(C)(C)C
3,Z7289498217,Br/C=C\CN1C2CC3CC1CC(C2)O3
4,PV-008660388768,CC(N[C@H]1[C@@H]2C[C@H]1N[C@@H]2C)C1C(C)(C)C1(C)C


In [7]:
import random
import torch
import os
import sys


def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")
    
RANDOM_STATE=2407
set_seed(RANDOM_STATE)

Random seed set as 2407


# Data

In [8]:
import pandas as pd
from pathlib import Path
import torch
import numpy as np
from torch_geometric.data import Data, InMemoryDataset, Dataset
from rdkit import Chem

from rdkit.Chem.rdchem import BondType as BT
from rdkit import Chem


ATOM_LIST = list(range(1,119))
CHIRALITY_LIST = [
    Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
    Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [BT.SINGLE, BT.DOUBLE, BT.TRIPLE, BT.AROMATIC]
BONDDIR_LIST = [
    Chem.rdchem.BondDir.NONE,
    Chem.rdchem.BondDir.ENDUPRIGHT,
    Chem.rdchem.BondDir.ENDDOWNRIGHT
]


def smiles2pyg_data(smiles, regression_target=None, classification_target=None, **kwargs):
    mol = Chem.MolFromSmiles(smiles)
    mol = Chem.AddHs(mol)

    num_atoms = mol.GetNumAtoms()
    num_bonds = mol.GetNumBonds()

    type_idx = []
    chirality_idx = []
    atomic_number = []
    for atom in mol.GetAtoms():
        type_idx.append(ATOM_LIST.index(atom.GetAtomicNum()))
        chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
        atomic_number.append(atom.GetAtomicNum())

    x1 = torch.tensor(type_idx, dtype=torch.long).view(-1,1)
    x2 = torch.tensor(chirality_idx, dtype=torch.long).view(-1,1)
    x = torch.cat([x1, x2], dim=-1)

    row, col, edge_feat = [], [], []
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        row += [start, end]
        col += [end, start]
        edge_feat.append([
            BOND_LIST.index(bond.GetBondType()),
            BONDDIR_LIST.index(bond.GetBondDir())
        ])
        edge_feat.append([
            BOND_LIST.index(bond.GetBondType()),
            BONDDIR_LIST.index(bond.GetBondDir())
        ])

    edge_index = torch.tensor([row, col], dtype=torch.long)
    edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.long)
    # if target is None:
    #     return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, **kwargs)
    #if task == 'classification':
    #    y = torch.tensor(target, dtype=torch.long).view(1,-1)
    #elif task == 'regression':
    #    y = torch.tensor(target, dtype=torch.float).view(1,-1)
    if regression_target is not None:
        regression_target = torch.tensor(regression_target, dtype=torch.float).view(1,-1)
        kwargs['regression_target'] = regression_target
    if classification_target is not None:
        classification_target = torch.tensor(classification_target, dtype=torch.long)  #.view(1,-1)
        kwargs['classification_target'] = classification_target

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, **kwargs)
    return data


class CacheDataset(Dataset):
    def __init__(self, df, smiles_column="smiles", inference_mode=False, additional_columns=[], root="../temp/temp_cache", transform=None, pre_transform=None, pre_filter=None):
        assert isinstance(df, pd.DataFrame)

        super().__init__(root, transform, pre_transform, pre_filter)
        self.df = df
        self.inference_mode = inference_mode
        self.additional_columns = [
            column for column in additional_columns if column in df.columns
        ]

        self.data_index = np.asarray(df.index)
        self.smiles_column = smiles_column #"smiles"
        self.regression_column = "acvalue_uM"
        self.classification_column = "class"
        self.regression_name_column = "acname"
        
    @property
    def raw_file_names(self):
        return []
        # return [self.filename]

    @property
    def processed_file_names(self):
        return []
    def len(self):
        return len(self.data_index)

    def get(self, idx):
        index = self.data_index[idx]
        smiles = self.df.loc[index, self.smiles_column]
        kwargs = {}
        if len(self.additional_columns) > 0:
            for column in self.additional_columns:
                kwargs[column] = self.df.loc[index, column]
        if not self.inference_mode:
            regression_name = self.df.loc[index, self.regression_name_column]
            regression_value = self.df.loc[index, self.regression_column]
            classification_value = self.df.loc[index, self.classification_column]
            kwargs['regression_target'] = regression_value
            kwargs['classification_target'] = classification_value
            kwargs['regression_name'] = regression_name

        return smiles2pyg_data(smiles, **kwargs)

    def process(self):
        pass


# Losses

In [11]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class FC(nn.Module):
    def __init__(self, d_graph_layer, fc_hidden_dim, dropout, n_tasks):
        super(FC, self).__init__()

        self.predict = nn.ModuleList()
        for index, dim in enumerate(fc_hidden_dim):
            self.predict.append(nn.Linear(d_graph_layer, dim))
            self.predict.append(nn.Dropout(dropout))
            self.predict.append(nn.LeakyReLU())
            self.predict.append(nn.BatchNorm1d(dim))
            d_graph_layer = dim
        self.predict.append(nn.Linear(d_graph_layer, n_tasks))

    def forward(self, h):
        for layer in self.predict:
            h = layer(h)
        # return torch.sigmoid(h)
        return h


class RankingLoss(nn.Module):
    def __init__(self, embedding_out_dim, dropout=0.3, ntasks=2):
        super(RankingLoss, self).__init__()
        #self.config = config
        self.loss_fn = nn.CrossEntropyLoss() # TODO: check this - was reduce=False #reduce=False)
        #if config.model.readout.startswith('multi_head') and config.model.attn_merge == 'concat':
        #    self.relation_mlp = FC(embedding_out_dim * (config.model.num_head + 1) * 2, 
        #    [embedding_out_dim * 2, embedding_out_dim], dropout, 2)
        #else:
        self.embedding_out_dim = embedding_out_dim
        self.relation_mlp = FC(
            embedding_out_dim * 2, 
            [embedding_out_dim, embedding_out_dim // 2], 
            dropout, 
            ntasks
        )
        self.m = nn.Softmax(dim=1)

    @torch.no_grad()
    def get_rank_relation(self, y_A, y_B):
        # y_A: [batch, 1]
        # target_relation: 0: <=, 1: >
        target_relation = torch.zeros(y_A.size(), dtype=torch.long, device=y_A.device)
        target_relation[(y_A - y_B) > 0.0] = 1

        return target_relation.squeeze()

    def forward(self, output_embedding, target):
        batch_repeat_num = len(output_embedding)
        # if batch_repeat_num // 2 < 1 and batch_repeat_num % 2 != 0:
        #     return None, None, None
        shift = max(batch_repeat_num // 2, 1)
        #batch_size = batch_repeat_num // 2
        #x_A, y_A, x_B, y_B = output_embedding[:batch_size], target[:batch_size],\
        #                     output_embedding[batch_size:], target[batch_size:]
        x_A, y_A = output_embedding, target
        x_B = torch.roll(output_embedding, shift, 0)
        y_B = torch.roll(target, shift, 0)

        relation = self.get_rank_relation(y_A, y_B)
        # print(x_A.shape, x_B.shape, y_A.shape, y_B.shape)
        relation_pred = self.relation_mlp(torch.cat([x_A, x_B], dim=1))

        ranking_loss = self.loss_fn(relation_pred, relation)

        _, y_pred = self.m(relation_pred).max(dim=1)

        return ranking_loss, relation.squeeze(), y_pred

# Model

## GINet

In [12]:
import torch
from torch import nn
import torch.nn.functional as F

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool

num_atom_type = 119 # including the extra mask tokens
num_chirality_tag = 3

num_bond_type = 5 # including aromatic and self-loop edge
num_bond_direction = 3 


class GINEConv(MessagePassing):
    def __init__(self, emb_dim):
        super(GINEConv, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, 2*emb_dim), 
            nn.ReLU(), 
            nn.Linear(2*emb_dim, emb_dim)
        )
        self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim)
        self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim)

        nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        nn.init.xavier_uniform_(self.edge_embedding2.weight.data)

    def forward(self, x, edge_index, edge_attr):
        # add self loops in the edge space
        edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0]

        # add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:,0] = 4 # bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr) #.to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0)

        edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + \
            self.edge_embedding2(edge_attr[:,1])

        return self.propagate(edge_index, x=x, edge_attr=edge_embeddings)

    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out):
        return self.mlp(aggr_out)


class GINet(nn.Module):
    """
    Args:
        num_layer (int): the number of GNN layers
        emb_dim (int): dimensionality of embeddings
        drop_ratio (float): dropout rate
        gnn_type: gin, gcn, graphsage, gat
    Output:
        node representations
    """
    def __init__(self, 
        task='classification', num_layer=5, emb_dim=300, feat_dim=512, 
        drop_ratio=0, pool='mean', pred_n_layer=2, pred_act='softplus'
    ):
        super(GINet, self).__init__()
        self.num_layer = num_layer
        self.emb_dim = emb_dim
        self.feat_dim = feat_dim
        self.drop_ratio = drop_ratio
        self.task = task

        self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim)
        self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim)
        nn.init.xavier_uniform_(self.x_embedding1.weight.data)
        nn.init.xavier_uniform_(self.x_embedding2.weight.data)

        # List of MLPs
        self.gnns = nn.ModuleList()
        for layer in range(num_layer):
            self.gnns.append(GINEConv(emb_dim))

        # List of batchnorms
        self.batch_norms = nn.ModuleList()
        for layer in range(num_layer):
            self.batch_norms.append(nn.BatchNorm1d(emb_dim))

        if pool == 'mean':
            self.pool = global_mean_pool
        elif pool == 'max':
            self.pool = global_max_pool
        elif pool == 'add':
            self.pool = global_add_pool
        self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim)

        if self.task == 'classification':
            out_dim = 2
        elif self.task == 'regression':
            out_dim = 1
        if self.task != "head":
            self.pred_n_layer = max(1, pred_n_layer)

            if pred_act == 'relu':
                pred_head = [
                    nn.Linear(self.feat_dim, self.feat_dim//2), 
                    nn.ReLU(inplace=True)
                ]
                for _ in range(self.pred_n_layer - 1):
                    pred_head.extend([
                        nn.Linear(self.feat_dim//2, self.feat_dim//2), 
                        nn.ReLU(inplace=True),
                    ])
                pred_head.append(nn.Linear(self.feat_dim//2, out_dim))
            elif pred_act == 'softplus':
                pred_head = [
                    nn.Linear(self.feat_dim, self.feat_dim//2), 
                    nn.Softplus()
                ]
                for _ in range(self.pred_n_layer - 1):
                    pred_head.extend([
                        nn.Linear(self.feat_dim//2, self.feat_dim//2), 
                        nn.Softplus()
                    ])
            else:
                raise ValueError('Undefined activation function')

            pred_head.append(nn.Linear(self.feat_dim//2, out_dim))
            self.pred_head = nn.Sequential(*pred_head)

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        
        h = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])

        for layer in range(self.num_layer):
            h = self.gnns[layer](h, edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layer - 1:
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)

        h = self.pool(h, data.batch)
        h = self.feat_lin(h)
        if self.task == "head":
            return h
        return h, self.pred_head(h)

    def load_my_state_dict(self, state_dict):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            if isinstance(param, nn.parameter.Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)


## Common model-related

In [13]:
def _load_pre_trained_weights(model, ckpt_path="molclr_ckpt", device='cpu'):
    try:
        fine_tune_from = 'pretrained_gin'
        checkpoints_folder = os.path.join(ckpt_path, fine_tune_from, 'checkpoints')
        state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location=device)
        # model.load_state_dict(state_dict)
        model.load_my_state_dict(state_dict)
        print("Loaded pre-trained model with success.")
    except FileNotFoundError:
        print("Pre-trained weights not found. Training from scratch.")

    return model

In [14]:
model_params = dict(
    num_layer=5,  # number of graph conv layers
    emb_dim=300,  # embedding dimension in graph conv layers
    feat_dim=512,  # output feature dimention
    drop_ratio= 0.3,  # dropout ratio
    pool= 'mean'  # readout pooling (i.e., mean/max/add)
)

# Model wrapper

In [15]:
class MolCLRWrapper(L.LightningModule):
    def __init__(self, target_names=["pIC50", "pKi"], ckpt_path=None, device='cpu', dropout=0.3):
        super().__init__()
        # self.device = device
        self.target_names = target_names
        self.feature_dim = model_params['feat_dim']
        self.dropout = model_params['drop_ratio']
        self.head = GINet('head', **model_params)

        if not ckpt_path is None:
            self.head = _load_pre_trained_weights(self.head, ckpt_path=ckpt_path)

        self.ranking_losses_func = nn.ModuleDict({
            name: RankingLoss(model_params['feat_dim'])
            for name in target_names
        })

        self.regression_heads = nn.ModuleDict({
            name: FC(
                self.feature_dim, [self.feature_dim // 2, self.feature_dim // 4], 
                self.dropout, 1
            )
            for name in target_names
        })
        self.classification_head = FC(
            self.feature_dim, [self.feature_dim // 2, self.feature_dim // 4], 
            self.dropout, 2
        )
        # 
        self.validation_loss_outputs = defaultdict(list)

        self.validation_regression_outputs = defaultdict(list)
        self.validation_regression_gt = defaultdict(list)
        self.validation_classification_outputs = defaultdict(list)
        self.validation_classification_gt = defaultdict(list)
        
        self.test_regression_outputs = defaultdict(list)
        self.test_regression_gt = defaultdict(list)
        self.test_classification_outputs = []
        self.test_classification_gt = []
        
        self.ranking_lambda = 1.
        self.classification_lambda = 1.0
        self.regression_lambda = 0.5
    
    def training_step(self, batch, batch_idx, dataloader_idx=0):
        raise "training shouldn't be run during inference"

    def forward(self, data_batch):
        batch_embeddings, regression_predictions, classification_logits, class_probabilities = self._get_embeddings_and_predictions(data_batch)
        return regression_predictions, class_probabilities
    
    def _get_embeddings_and_predictions(self, data_batch):
        regression_predictions = dict()
        batch_embeddings = self.head(data_batch)
        regression_predictions = {
            k: self.regression_heads[k](batch_embeddings)
            for k in self.regression_heads
        }

        classification_logits = self.classification_head(batch_embeddings)
        class_probabilities = F.softmax(classification_logits, dim=1)
        return batch_embeddings, regression_predictions, classification_logits, class_probabilities
    
    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        batch_embeddings, regression_predictions, classification_logits, class_probabilities = self._get_embeddings_and_predictions(batch)
        classification_loss = F.cross_entropy(classification_logits, batch.classification_target)
        
        target = target_names[dataloader_idx]
        ranking_loss, relation, y_pred = self.ranking_losses_func[target](batch_embeddings, batch.regression_target)
        
        regression_loss = F.mse_loss(regression_predictions[target], batch.regression_target)
        val_loss = self.ranking_lambda * ranking_loss + self.regression_lambda * regression_loss + self.classification_lambda * classification_loss
        scores = {
            'val_loss': val_loss,
            'classification_loss': classification_loss,
            'regression_loss': regression_loss,
        }

        self.validation_loss_outputs[dataloader_idx].append(val_loss)
        # save regression info
        self.validation_regression_outputs[dataloader_idx].append(regression_predictions[target].detach().cpu())
        self.validation_regression_gt[dataloader_idx].append(batch.regression_target.detach().cpu())
        # save classification info
        self.validation_classification_outputs[dataloader_idx].append(class_probabilities.detach().cpu())
        self.validation_classification_gt[dataloader_idx].append(batch.classification_target.detach().cpu())
        # log scores and return
        self.log_dict(scores)
        return val_loss
    
    def on_validation_epoch_end(self):
        val_losses = []
        for dataloader_idx in self.validation_loss_outputs:
            val_losses.extend(self.validation_loss_outputs[dataloader_idx])
            self.validation_loss_outputs[dataloader_idx].clear()  # free memory
        val_losses = torch.stack(val_losses)
        scores = {
            "val_loss": val_losses.sum()
        }
        # aggregate regression predictions and gt values and compute validation metrics
        for dataloader_idx in self.validation_regression_outputs:
            outputs = torch.concatenate(self.validation_regression_outputs[dataloader_idx], axis=0).squeeze()
            gts = torch.concatenate(self.validation_regression_gt[dataloader_idx], axis=0).squeeze()
            scores[f"reg_MAE_{dataloader_idx}"] = skmetrics.mean_absolute_error(gts, outputs)
            scores[f"reg_MSE_{dataloader_idx}"] = skmetrics.mean_squared_error(gts, outputs)
            scores[f"reg_R2_{dataloader_idx}"] = skmetrics.r2_score(gts, outputs)
            
            self.validation_regression_outputs[dataloader_idx].clear()
            self.validation_regression_gt[dataloader_idx].clear()
        # aggregate class predictions and gt values and compute validation metrics
        for dataloader_idx in self.validation_classification_outputs:
            outputs = torch.concatenate(self.validation_classification_outputs[dataloader_idx], axis=0)[:, 1]
            gts = torch.concatenate(self.validation_classification_gt[dataloader_idx], axis=0)
            try:
                scores[f'classif_roc_auc_{dataloader_idx}'] = skmetrics.roc_auc_score(gts, outputs)
                scores[f'classif_accuracy_{dataloader_idx}'] = skmetrics.accuracy_score(gts, outputs > 0.5)
                scores[f'classif_f1_{dataloader_idx}'] = skmetrics.f1_score(gts, outputs > 0.5)
            except:
                continue
            self.validation_classification_outputs[dataloader_idx].clear()
            self.validation_classification_gt[dataloader_idx].clear()
        self.log_dict(scores)
    
    def test_step(self, batch, batch_idx, dataloader_idx=0):
        regression_predictions, class_probabilities = self(batch)
        regression_predictions = {
            k: v.detach().cpu().numpy() 
            for k, v in regression_predictions.items()
        }
        class_probabilities = class_probabilities.detach().cpu().numpy()
        regression_gt = batch.regression_target.detach().cpu().numpy()
        regression_name = batch.regression_name
        target_name = target_names[dataloader_idx]
        for name in self.target_names:
            preds = regression_predictions[name]
            self.test_regression_outputs[name].append(preds)
            self.test_regression_gt[name].append(regression_gt)
        
        classification_gt = batch.classification_target.detach().cpu().numpy()
        self.test_classification_outputs.append(class_probabilities)
        self.test_classification_gt.append(classification_gt)
        
    def on_test_epoch_end(self):
        data = []
        scores = {}
        for target_name in target_names:    
            pred_values = np.concatenate(self.test_regression_outputs[target_name]).flatten()
            gt_values = np.concatenate(self.test_regression_gt[target_name]).flatten()
            p = pearsonr(pred_values, gt_values)
            scores[f'pearson_{target_name}_pvalue'] = p.pvalue
            scores[f'pearson_{target_name}_stat'] = p.statistic

            data = [[x, y] for (x, y) in zip(pred_values, gt_values)]
            table = wandb.Table(data=data, columns = ["predictions", "true_values"])
            wandb.log({
                f"scatter_{target_name}": wandb.plot.scatter(
                    table, "predictions", "true_values", title="Predicted vs. GT values")
            })
            scores[f"test_reg_MAE_{target_name}"] = skmetrics.mean_absolute_error(gt_values, pred_values)
            scores[f"test_reg_MSE_{target_name}"] = skmetrics.mean_squared_error(gt_values, pred_values)
            scores[f"test_reg_R2_{target_name}"] = skmetrics.r2_score(gt_values, pred_values)
            self.test_regression_outputs[target_name].clear()
            self.test_regression_gt[target_name].clear()

        pred_values = np.concatenate(self.test_classification_outputs)
        gt_values = np.concatenate(self.test_classification_gt)
        labels_hat = np.argmax(pred_values, axis=1)
        
        scores[f'test_accuracy'] = skmetrics.accuracy_score(gt_values, labels_hat)
        scores[f'test_f1'] = skmetrics.f1_score(gt_values, labels_hat)
        test_auc = skmetrics.roc_auc_score(gt_values, pred_values[:, 1])
        scores[f'test_auc'] = test_auc

        self.test_classification_outputs.clear()
        self.test_classification_gt.clear()
        self.log_dict(scores)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch), batch.TID

    def configure_optimizers(self):
        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.parameters()), 
            lr=1e-3
        )
        return optimizer


# Actual training

In [16]:
import lightning as L
import torch.nn as nn
import torch.optim as optim
from lightning import seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, BaseFinetuning, EarlyStopping, LambdaCallback
from lightning.pytorch.loggers import WandbLogger
import wandb

In [17]:
class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
    def __init__(self, unfreeze_at_epoch=5):
        super().__init__()
        self._unfreeze_at_epoch = unfreeze_at_epoch

    def freeze_before_training(self, pl_module):
        self.freeze(pl_module.head.x_embedding1)
        self.freeze(pl_module.head.x_embedding2)
        self.freeze(pl_module.head.gnns)
    
    def finetune_function(self, pl_module, current_epoch, optimizer):
        if current_epoch == self._unfreeze_at_epoch:
            self.unfreeze_and_add_param_group(
                modules=[
                    # pl_module.head
                    pl_module.head.x_embedding1,
                    pl_module.head.x_embedding2,
                    pl_module.head.gnns
                ],
                optimizer=optimizer,
                train_bn=True,
            )

In [31]:
trainer_kwargs = {
    "max_epochs": 10,
    "log_every_n_steps": 2,
    "limit_train_batches": 500,
    "limit_val_batches": 100,
    "limit_test_batches": 100,
    "check_val_every_n_epoch": 1, 
    "deterministic": True,
    #"logger": wandb_logger,
}
if DRAFT_MODE:
    trainer_kwargs['log_every_n_steps'] = 10
    trainer_kwargs['limit_train_batches'] = 20
    trainer_kwargs['limit_val_batches'] = 10
    trainer_kwargs['limit_test_batches'] = 10
    trainer_kwargs['max_epochs'] = 2
    
if torch.cuda.is_available():
    trainer_kwargs["accelerator"] = "gpu"
    trainer_kwargs['devices'] = [0]

# callbacks:
checkpoint_callback = ModelCheckpoint(
    dirpath="saves", 
    save_top_k=5, 
    monitor="val_loss",
    filename='{epoch}-{step}-{val_loss:.4f}-{regression_loss/dataloader_idx_0:.4f}-{classif_roc_auc_0:.4f}'
)
early_stopping = EarlyStopping(
    'val_loss/dataloader_idx_0',
    #'val_loss/dataloader_idx_0', 
    patience=50
)
finetuning_callback = FeatureExtractorFreezeUnfreeze(10)
def change_loss_on_train_start(trainer, pl_module):
    if trainer.current_epoch == 2:
        print('current_epoch is 2')
    if trainer.current_epoch == 200:
        pl_module.ranking_lambda = 0
        pl_module.regression_lambda = 0
        pl_module.classification_lambda = 1.

change_loss_parameters_callback = LambdaCallback(on_train_epoch_start=change_loss_on_train_start)
# trainer_kwargs['callbacks'] = [
#     #finetuning_callback,
#     checkpoint_callback,
#     early_stopping,
#     change_loss_parameters_callback
# ]
trainer_kwargs['callbacks'] = [] 




In [32]:
# wandb.finish()

In [20]:
wandb_project_name = f"CACHE5"
print(wandb_project_name)
#if DRAFT_MODE:
#    wandb_logger = WandbLogger(project=wandb_project_name, offline=True)  # , log_model='all')
#else:
mode = "DRAFT_MODE" if DRAFT_MODE else "FULL_MODE"
wandb_logger = WandbLogger(
    project=wandb_project_name,
    tags=["GIN", "MLP", mode]
)  # , log_model='all')

# trainer_kwargs['logger'] = wandb_logger

CACHE5


In [21]:
SAVEDIR = Path("../input/cache5-molclr-baseline-notebook/saves/")
if SAVEDIR.exists():
    paths = list(SAVEDIR.glob("*ckpt"))
    paths = sorted(paths, key=lambda x: float(x.stem.split("=")[-1]), reverse=True)
    for p in paths:
        print(p)

    best_checkpoint = paths[0]
else:
    best_checkpoint = "./epoch=64-step=195-val_loss=4.5169-classif_roc_auc_0=0.6545.ckpt"
print("\nbest:", best_checkpoint)
# "/kaggle/working/saves/epoch=64-step=195-val_loss=4.5169-classif_roc_auc_0=0.6545.ckpt"

../input/cache5-molclr-baseline-notebook/saves/epoch=64-step=195-val_loss=4.5169-classif_roc_auc_0=0.6545.ckpt
../input/cache5-molclr-baseline-notebook/saves/epoch=65-step=198-val_loss=4.6764-classif_roc_auc_0=0.6167.ckpt
../input/cache5-molclr-baseline-notebook/saves/epoch=63-step=192-val_loss=4.5763-classif_roc_auc_0=0.5891.ckpt
../input/cache5-molclr-baseline-notebook/saves/epoch=66-step=201-val_loss=4.8053-classif_roc_auc_0=0.4975.ckpt
../input/cache5-molclr-baseline-notebook/saves/epoch=67-step=204-val_loss=4.8630-classif_roc_auc_0=0.4766.ckpt

best: ../input/cache5-molclr-baseline-notebook/saves/epoch=64-step=195-val_loss=4.5169-classif_roc_auc_0=0.6545.ckpt


In [33]:
# model_wrapper = MolCLRWrapper()
# best_checkpoint = "../input/cache5-molclr-baseline-notebook/saves/epoch=5-step=18-val_loss=33.6971-regression_loss/dataloader_idx_0=25.5480-classif_roc_auc_0=0.2299.ckpt"
model_wrapper = MolCLRWrapper.load_from_checkpoint(best_checkpoint, map_location=torch.device('cpu'))
trainer = L.Trainer(**trainer_kwargs)
# trainer.fit(
#     model=model_wrapper, 
#     train_dataloaders=train_dataloaders,
#     val_dataloaders=test_dataloaders
# )


INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs


# Compute metrics on test

In [34]:
# trainer.test(
#      model=model_wrapper, 
# #     ckpt_path="best", 
#     dataloaders=test_dataloaders
# );

- ID (e.g. ID00015)
- cls_prb the classification probability
- pKi the predicted pKi value
- pIC50 the predicted pIC50 value


In [24]:
# test_datasets = [
#     CacheDataset(test_df[test_df.acname == acname]) 
#     for acname in target_names
# ]

In [35]:
# test_df
full_df.head()

Unnamed: 0,TID,SMILES
0,PV-009095522958,FCC(CF)NC(CF)CF
1,Z5067467525,FC(F)(F)C12CC(NC3CSCSC3)(CO1)C2
2,PV-009234181902,CCN(CC1C2CC3C(C2)C13)C(C)(C)C
3,Z7289498217,Br/C=C\CN1C2CC3CC1CC(C2)O3
4,PV-008660388768,CC(N[C@H]1[C@@H]2C[C@H]1N[C@@H]2C)C1C(C)(C)C1(C)C


In [26]:

# selection_ids = data_df.DataSAIL_10f.isin([f"Fold_{i}" for i in range(10)])
# inference_df = data_df.loc[selection_ids, ["ID", "smiles"]]

# dataset = CacheDataset(
#     full_df, 
#     smiles_column="SMILES", 
#     inference_mode=True, 
#     additional_columns=["TID"]
# )
# inference_loader = DataLoader(
#     dataset, 
#     batch_size=BATCH_SIZE, # num_workers=2
# )
# len(dataset)

In [36]:
from tqdm import tqdm
import csv

def predict_for_data_file(
        model_wrapper, trainer,
        data_file,
        save_file, 
        batch_size=BATCH_SIZE, 
        num_workers=3, 
        draft_mode=DRAFT_MODE
        ):
    if isinstance(data_file, str):
        data_file = Path(data_file)
    if not data_file.exists():
        print("Enamine file is not available")
        return None
    full_df = pd.read_csv(data_file)
    if draft_mode:
        full_df = full_df.head(2000)
    dataset = CacheDataset(
        full_df, 
        smiles_column="SMILES", 
        inference_mode=True, 
        additional_columns=["TID"]
    )
    print("mode:", draft_mode, len(dataset))
    inference_loader = DataLoader(
        dataset, 
        batch_size=BATCH_SIZE, 
        num_workers=num_workers,
        shuffle=False
    )
    len(dataset)
    
    columns_order = ["TID", "cls_prb", "pKi", "pIC50"]
    # predictions = trainer.predict(model_wrapper, inference_loader)
    model_wrapper.eval()

    # chunk_filename = f'enamine_predictions_chunk_{chunk_number}.csv'
    with open(save_file, 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=columns_order)

        writer.writeheader()

        prediction_batches = trainer.predict(model_wrapper, inference_loader)
        for (regression_predictions, classification_predictions), tids in tqdm(prediction_batches):
            # regression_predictions, classification_predictions = model_wrapper(batch)
            classification_predictions = classification_predictions[:, 1]
            data = {k: v.numpy().flatten() for k, v in regression_predictions.items()}
            data['cls_prb'] = classification_predictions
            data['TID'] = tids
            data = pd.DataFrame.from_dict(data).to_dict('records')
            writer.writerows(data)

    return save_file

In [28]:
# run = wandb.init(
#     project=wandb_project_name,
#     id="enamine_data5678",
#     resume="allow"
# )

[34m[1mwandb[0m: Currently logged in as: [33mlacemaker[0m ([33msquirrel-writes-her-phd[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [37]:
# wandb.finish()

In [None]:
print("start processing chunk 5")

datadir = Path("../input/enamine-split-to-chunks")
if datadir.exists():
    chunk_number = 5
    chunk_size = 5_000_000
    data_file = datadir / f"enamine_chunk_{chunk_number}_{chunk_size}.csv"
    save_file = f"enamine_predictions_{chunk_number}.csv"
else:
    data_file = "cleaned_enamine.csv"
    save_file = "cleaned_enamine_predictions.csv"
chunk_filename = predict_for_data_file(
    model_wrapper,
    trainer,
    data_file,
    save_file,
    batch_size=BATCH_SIZE,
    draft_mode=DRAFT_MODE
)
df = pd.read_csv(chunk_filename)
df.to_csv(chunk_filename + ".gz", compression='gzip', index=None)
print(df['TID'].values[:4])

# artifact = wandb.Artifact(name=chunk_filename + ".gz", type="data")
# artifact.add_file(chunk_filename + ".gz")
# run.log_artifact(artifact)
print("end processing chunk 5")

start processing chunk 5


Processing...
Done!


mode: False 5000000


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |          | 0/? [00:00<?, ?it/s]