For details, refer to https://www.kaggle.com/competitions/leash-BELKA/discussion/498858

Data should be prepared by **gnn_save_graph_100M.ipynb**

The code requires 256 Gb of RAM to run training without OOM issues. It starts from checkpoint produced by **gnn_50M_100M_v2_01.ipynb** and run 1 epoch training on first 40M samples from the entire 100M training dataset.

In [None]:
import os
import gc

import pandas as pd
import numpy as np

import rdkit
from rdkit import Chem

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as PyGDataLoader

from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_scatter import scatter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

import pytorch_lightning as L
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger

from sklearn.metrics import average_precision_score

from multiprocessing import Pool

import _pickle as  cPickle
import bz2
import pgzip
from tqdm import tqdm

print('import ok!')

In [7]:
def save_compressed_pickle(file, data, type = "pgzip", n_treads = 1):
    if type == "pgzip":
        with pgzip.open(file , 'w', thread = n_treads) as f:
            cPickle.dump(data, f)
    elif type == "bz2":
        with bz2.BZ2File(file , 'w') as f:
            cPickle.dump(data, f)
    else:
        print("unsupported")

def load_decompress_pickle(file, type = "pgzip", n_treads = 8):
    if type == "pgzip":
        data = pgzip.open(file, thread = n_treads)
        data = cPickle.load(data)
    elif type == "bz2":
        data = bz2.BZ2File(file, 'rb')
        data = cPickle.load(data)
    else:
        print("unsupported")
    return data

def compute_ap(df):
    df = pd.concat(
        [df.loc[:, ["BRD4", "HSA", "sEH", "label_id", "split"]].melt(
            id_vars = ["label_id", "split"],
            var_name = "protein",
            value_name = "y_true"
            ),
        df.loc[:, ["BRD4_pred", "HSA_pred", "sEH_pred", "label_id", "split"]].melt(
            id_vars = ["label_id", "split"],
            var_name = "_",
            value_name = "y_pred"
            ).drop("split", axis = 1)],
        axis = 1
    )

    ap = df.groupby(["split", "protein"]).apply(
        lambda x: average_precision_score(x.y_true, x.y_pred),
        include_groups = True
    ).reset_index(name = "AP")

    return(ap)

In [8]:
TRAIN_FILE = "../data/train_no_test_wide_replace_dy.parquet"
VAL_FILE = "../data/test_ensemble_wide_replace_dy.parquet"

VAL_GRAPH_FILE = "../data/test_graph_2M.gz"

CONTINUE_TRAIN = True

# Set your checkpoint path
CHECKPOINT_PATH = "../GNN_50M_v2/fold_0_epoch_5_step_60000_0.0115.ckpt"

POS_WEIGHT = torch.tensor([1, 1, 1])

LR = 1e-4
BATCH_SIZE = 5000
ACCUMULATE_GRAD_BATCHES = 1
EPOCHS = 1
OUT_DIR = "../GNN_100M_v2_first_half"
SEED = 42

os.makedirs(OUT_DIR, exist_ok = True)

In [9]:
df = pd.read_parquet(
    TRAIN_FILE,
    columns = ["molecule_smiles", "BRD4", "HSA", "sEH"]
)
df["split"] = 0
df["id"] = df.reset_index().index
df

Unnamed: 0,molecule_smiles,BRD4,HSA,sEH,split,id
0,C#CC[C@@H](CC(=O)NC)Nc1nc(NCCCN2C(=O)NC(C)(C)C...,0,0,0,0,0
1,C#CC[C@@H](CC(=O)NC)Nc1nc(NCc2ncc(C)o2)nc(NCc2...,0,0,0,0,1
2,C#CC[C@@H](CC(=O)NC)Nc1nc(NCc2cnc(Cl)s2)nc(Nc2...,0,0,0,0,2
3,C#CC[C@@H](CC(=O)NC)Nc1nc(NCCOc2cccnc2)nc(NCc2...,0,0,0,0,3
4,C#CC[C@@H](CC(=O)NC)Nc1nc(NCc2cnc(Cl)s2)nc(NCc...,0,0,0,0,4
...,...,...,...,...,...,...
96415605,CNC(=O)[C@H]1Cc2ccccc2CN1c1nc(Nc2cc(S(C)(=O)=O...,0,0,0,0,96415605
96415606,CNC(=O)[C@H]1Cc2ccccc2CN1c1nc(Nc2nc(Cl)c3[nH]c...,0,0,0,0,96415606
96415607,CNC(=O)[C@H]1Cc2ccccc2CN1c1nc(Nc2nc(OCc3ccccc3...,0,0,0,0,96415607
96415608,CNC(=O)[C@H]1Cc2ccccc2CN1c1nc(Nc2nc3c(Br)cccc3...,0,0,0,0,96415608


In [10]:
df = df.iloc[:40_000_000]

In [11]:
df_val = pd.read_parquet(
    VAL_FILE,
    columns = ["molecule_smiles", "BRD4", "HSA", "sEH"]
).sample(n = 500_000, random_state = 42)

In [12]:
df_val["split"] = 1

In [14]:
train_graph = load_decompress_pickle("../data/train_graph_100M_1.gz") \
+ load_decompress_pickle("../data/train_graph_100M_2.gz") \
+ load_decompress_pickle("../data/train_graph_100M_3.gz") \
+ load_decompress_pickle("../data/train_graph_100M_4.gz") 

In [15]:
val_graph = load_decompress_pickle(VAL_GRAPH_FILE, type = "pgzip")

In [None]:
val_graph = [val_graph[i] for i in df_val.index]
df_val["id"] = df_val.reset_index().index
len(val_graph)

In [None]:
gc.collect()

In [None]:
# helper
# torch version of np unpackbits
#https://gist.github.com/vadimkantorov/30ea6d278bc492abf6ad328c6965613a

def tensor_dim_slice(tensor, dim, dim_slice):
    return tensor[(dim if dim >= 0 else dim + tensor.dim()) * (slice(None),) + (dim_slice,)]

# @torch.jit.script
def packshape(shape, dim: int = -1, mask: int = 0b00000001, dtype=torch.uint8, pack=True):
    dim = dim if dim >= 0 else dim + len(shape)
    bits, nibble = (
        8 if dtype is torch.uint8 else 16 if dtype is torch.int16 else 32 if dtype is torch.int32 else 64 if dtype is torch.int64 else 0), (
        1 if mask == 0b00000001 else 2 if mask == 0b00000011 else 4 if mask == 0b00001111 else 8 if mask == 0b11111111 else 0)
    # bits = torch.iinfo(dtype).bits # does not JIT compile
    assert nibble <= bits and bits % nibble == 0
    nibbles = bits // nibble
    shape = (shape[:dim] + (int(math.ceil(shape[dim] / nibbles)),) + shape[1 + dim:]) if pack else (
                shape[:dim] + (shape[dim] * nibbles,) + shape[1 + dim:])
    return shape, nibbles, nibble

# @torch.jit.script
def F_unpackbits(tensor, dim: int = -1, mask: int = 0b00000001, shape=None, out=None, dtype=torch.uint8):
    dim = dim if dim >= 0 else dim + tensor.dim()
    shape_, nibbles, nibble = packshape(tensor.shape, dim=dim, mask=mask, dtype=tensor.dtype, pack=False)
    shape = shape if shape is not None else shape_
    out = out if out is not None else torch.empty(shape, device=tensor.device, dtype=dtype)
    assert out.shape == shape

    if shape[dim] % nibbles == 0:
        shift = torch.arange((nibbles - 1) * nibble, -1, -nibble, dtype=torch.uint8, device=tensor.device)
        shift = shift.view(nibbles, *((1,) * (tensor.dim() - dim - 1)))
        return torch.bitwise_and((tensor.unsqueeze(1 + dim) >> shift).view_as(out), mask, out=out)

    else:
        for i in range(nibbles):
            shift = nibble * i
            sliced_output = tensor_dim_slice(out, dim, slice(i, None, nibbles))
            sliced_input = tensor.narrow(dim, 0, sliced_output.shape[dim])
            torch.bitwise_and(sliced_input >> shift, mask, out=sliced_output)
    return out

class dotdict(dict):
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

    def __getattr__(self, name):
        try:
            return self[name]
        except KeyError:
            raise AttributeError(name)


print('helper ok!')

In [19]:
# mol to graph adopted from
# from https://github.com/LiZhang30/GPCNDTA/blob/main/utils/DrugGraph.py

PACK_NODE_DIM = 9
PACK_EDGE_DIM = 1
NODE_DIM = 40 #PACK_NODE_DIM * 8
EDGE_DIM = PACK_EDGE_DIM * 8

def one_of_k_encoding(x, allowable_set, allow_unk=False):
    if x not in allowable_set:
        if allow_unk:
            x = allowable_set[-1]
        else:
            raise Exception(f'input {x} not in allowable set{allowable_set}!!!')
    return list(map(lambda s: x == s, allowable_set))


#Get features of an atom (one-hot encoding:)
'''
    1.atom element: 44+1 dimensions
    2.the atom's hybridization: 5 dimensions
    3.degree of atom: 6 dimensions
    4.total number of H bound to atom: 6 dimensions
    5.number of implicit H bound to atom: 6 dimensions
    6.whether the atom is on ring: 1 dimension
    7.whether the atom is aromatic: 1 dimension
    Total: 70 dimensions
'''

ATOM_SYMBOL = ['B', 'Br', 'C', 'Cl', 'F', 'I', 'N', 'O', 'S', 'Si']
#print('ATOM_SYMBOL', len(ATOM_SYMBOL))44

HYBRIDIZATION_TYPE = [
    Chem.rdchem.HybridizationType.S,
    Chem.rdchem.HybridizationType.SP,
    Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3,
    Chem.rdchem.HybridizationType.SP3D
]

def get_atom_feature(atom):
    feature = (
         one_of_k_encoding(atom.GetSymbol(), ATOM_SYMBOL)
       + one_of_k_encoding(atom.GetHybridization(), HYBRIDIZATION_TYPE)
       + one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
       + one_of_k_encoding(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5])
       + one_of_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5])
       + [atom.IsInRing()]
       + [atom.GetIsAromatic()]
    )
    #feature = np.array(feature, dtype=np.uint8)
    feature = np.packbits(feature)
    return feature


#Get features of an edge (one-hot encoding)
'''
    1.single/double/triple/aromatic: 4 dimensions
    2.the atom's hybridization: 1 dimensions
    3.whether the bond is on ring: 1 dimension
    Total: 6 dimensions
'''

def get_bond_feature(bond):
    bond_type = bond.GetBondType()
    feature = [
        bond_type == Chem.rdchem.BondType.SINGLE,
        bond_type == Chem.rdchem.BondType.DOUBLE,
        bond_type == Chem.rdchem.BondType.TRIPLE,
        bond_type == Chem.rdchem.BondType.AROMATIC,
        bond.GetIsConjugated(),
        bond.IsInRing()
    ]
    #feature = np.array(feature, dtype=np.uint8)
    feature = np.packbits(feature)
    return feature


def smile_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    N = mol.GetNumAtoms()
    node_feature = []
    edge_feature = []
    edge = []
    for i in range(mol.GetNumAtoms()):
        atom_i = mol.GetAtomWithIdx(i)
        atom_i_features = get_atom_feature(atom_i)
        node_feature.append(atom_i_features)

        for j in range(mol.GetNumAtoms()):
            bond_ij = mol.GetBondBetweenAtoms(i, j)
            if bond_ij is not None:
                edge.append([i, j])
                bond_features_ij = get_bond_feature(bond_ij)
                edge_feature.append(bond_features_ij)
    node_feature=np.stack(node_feature)
    edge_feature=np.stack(edge_feature)
    edge = np.array(edge,dtype=np.uint8)
    return N,edge,node_feature,edge_feature

def to_pyg_format(N,edge,node_feature,edge_feature):
    graph = Data(
        idx=-1,
        edge_index = torch.from_numpy(edge.T).int(),
        x          = torch.from_numpy(node_feature).byte(),
        edge_attr  = torch.from_numpy(edge_feature).byte(),
    )
    return graph


In [None]:
#debug one example
g = to_pyg_format(*smile_to_graph(smiles="C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](CC#C)CC(=O)NC)n2)cc1"))
print(g)
print('[Dy] is replaced by C !!')
print('smile_to_graph() ok!')

In [21]:
class SmilesDataset(Dataset):
    def __init__(
        self,
        df,
        target_cols = ["BRD4", "HSA", "sEH"],
        mode = "train"
    ):
        self.df = df
        self.target_cols = target_cols
        self.mode = mode
    def __len__(self):
        """
        Length of dataset.
        """
        return len(self.df)

    def __getitem__(self, index):
        """
        Get one item.
        """
        if self.mode != "test":
            X, y, label_ids, split = self.__data_generation(index)
            return X, y, label_ids, split
        else:
            X = self.__data_generation(index)
            return X

    def __data_generation(self, index):

        row = self.df.iloc[index]
        id = torch.tensor(row.id, dtype = torch.float32)
        split = torch.tensor(row.split, dtype = torch.float32)

        if self.mode != "test":
            y = row[self.target_cols].to_numpy().astype(np.float32)
            y = torch.tensor(y, dtype = torch.float32)

        X = row.molecule_smiles
        X = to_pyg_format(*smile_to_graph(X))

        if self.mode != "test":
            return X, y, id, split
        else:
            return X
        
class SmilesDataset(Dataset):
    def __init__(
        self,
        df,
        graph,
        target_cols = ["BRD4", "HSA", "sEH"],
        mode = "train"
    ):
        self.df = df
        self.graph = graph
        self.target_cols = target_cols
        self.mode = mode
    def __len__(self):
        """
        Length of dataset.
        """
        return len(self.df)

    def __getitem__(self, index):
        """
        Get one item.
        """
        if self.mode != "test":
            X, y, label_ids, split = self.__data_generation(index)
            return X, y, label_ids, split
        else:
            X = self.__data_generation(index)
            return X

    def __data_generation(self, index):

        row = self.df.iloc[index]
        id = torch.tensor(row.id, dtype = torch.float32)
        split = torch.tensor(row.split, dtype = torch.float32)

        if self.mode != "test":
            y = row[self.target_cols].to_numpy().astype(np.float32)
            y = torch.tensor(y, dtype = torch.float32)

        X = to_pyg_format(*self.graph[index])

        if self.mode != "test":
            return X, y, id, split
        else:
            return X

In [22]:
#MODEL: simple MPNNModel
#from https://github.com/chaitjo/geometric-gnn-dojo/blob/main/geometric_gnn_101.ipynb

# i have removed all comments here to jepp it clean. refer to orginal link for code comments
# of MPNNModel
class MPNNLayer(MessagePassing):
    def __init__(self, emb_dim=128, edge_dim=4, aggr='add'):
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.edge_dim = edge_dim
        self.mlp_msg = nn.Sequential(
            nn.Linear(2 * emb_dim + edge_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(),
            nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU()
        )
        self.mlp_upd = nn.Sequential(
            nn.Linear(2 * emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(),
            nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU()
        )

    def forward(self, h, edge_index, edge_attr):
        out = self.propagate(edge_index, h=h, edge_attr=edge_attr)
        return out

    def message(self, h_i, h_j, edge_attr):
        msg = torch.cat([h_i, h_j, edge_attr], dim=-1)
        return self.mlp_msg(msg)

    def aggregate(self, inputs, index):
        return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)

    def update(self, aggr_out, h):
        upd_out = torch.cat([h, aggr_out], dim=-1)
        return self.mlp_upd(upd_out)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')


class MPNNModel(nn.Module):
    def __init__(self, num_layers=4, emb_dim=128, in_dim=11, edge_dim=4, out_dim=1):
        super().__init__()

        self.lin_in = nn.Linear(in_dim, emb_dim)

        # Stack of MPNN layers
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(MPNNLayer(emb_dim, edge_dim, aggr='add'))

        self.pool = global_mean_pool

    def forward(self, data): #PyG.Data - batch of PyG graphs

        h = self.lin_in(F_unpackbits(data.x,-1).float())

        for conv in self.convs:
            h = h + conv(h, data.edge_index.long(), F_unpackbits(data.edge_attr,-1).float())  # (n, d) -> (n, d)

        h_graph = self.pool(h, data.batch)
        return h_graph

In [23]:
class EMATracker:
    def __init__(self, alpha: float = 0.05):
        super().__init__()
        self.alpha = alpha
        self._value = None

    def update(self, new_value):
        if self._value is None:
            self._value = new_value
        else:
            self._value = (
                new_value * self.alpha +
                self._value * (1-self.alpha)
            )

    @property
    def value(self):
        return self._value

class Net(L.LightningModule):
    def __init__(self, lr, epochs, steps_per_epoch, lr_scheduler = False, fold = 0):
        super().__init__()
        self.lr = lr
        self.epochs = epochs
        self.steps_per_epoch = steps_per_epoch
        self.lr_scheduler = lr_scheduler
        self.fold = fold

        self.output_type = ['infer', 'loss']

        self.graph_dim = 96

        self.smile_encoder = MPNNModel(
            in_dim=NODE_DIM,
            edge_dim=EDGE_DIM,
            emb_dim=self.graph_dim,
            num_layers=4
        )

        self.bind = nn.Sequential(
            nn.Linear(self.graph_dim, 1024),
            #nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(1024, 1024),
            #nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(1024, 512),
            #nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(512, 3)
        )

        self.loss = nn.BCEWithLogitsLoss(pos_weight = POS_WEIGHT)
        self.validation_step_outputs = []

    def forward(self, x):
        x = self.smile_encoder(x)
        out = self.bind(x)
        return out

    def training_step(self, batch, batch_idx):
        x, y, label_ids, split = batch
        out = self.forward(x)
        loss = self.loss(out, y)
        self.log("train_loss", loss, on_step = True, on_epoch = True,
                 prog_bar = True, batch_size = BATCH_SIZE)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y, label_ids, split = batch
        out = self.forward(x)
        val_loss = self.loss(out, y)
        self.log("val_loss", val_loss, on_step = False, on_epoch = True,
                 logger = True, prog_bar = True, batch_size = BATCH_SIZE)
        self.validation_step_outputs.append(
            {"val_loss": val_loss, "predictions": out, "targets": y, "label_ids": label_ids, "split": split}
        )
        return {"val_loss": val_loss}

    def on_validation_end(self):
        outputs = self.validation_step_outputs
        output_val = torch.cat([x['predictions'] for x in outputs], dim = 0)
        output_val = F.sigmoid(output_val).to(torch.float32).cpu().detach().numpy()
        target_val = torch.cat([x['targets'] for x in outputs], dim = 0).to(torch.float32).cpu().detach().numpy()
        label_ids = torch.cat([x['label_ids'] for x in outputs], dim = 0).to(torch.float32).cpu().detach().numpy()
        split = torch.cat([x['split'] for x in outputs], dim = 0).to(torch.float32).cpu().detach().numpy()
        self.validation_step_outputs = []

        TARGETS = ["BRD4", "HSA", "sEH"]
        PREDS = [f"{t}_pred" for t in TARGETS]
        val_df = pd.DataFrame(target_val, columns = list(TARGETS))
        pred_df = pd.DataFrame(output_val, columns = list(PREDS))
        val_df = pd.concat([val_df, pred_df], axis = 1)
        val_df["label_id"] = label_ids
        val_df["split"] = split

        ap = compute_ap(val_df)
        ap["epoch"] = self.current_epoch
        ap["step"] = self.global_step
        ap.loc[ap.split == 1, "split"] = "val"
        #ap.loc[ap.split == 2, "split"] = "val_nonshare"
        ap.to_csv(f"{OUT_DIR}/metrics.csv", index = False, mode = "a",
                  header = not os.path.exists(f"{OUT_DIR}/metrics.csv"))

        map = ap.AP.mean()

        print(f"validation MAP: {map}")

        val_df.to_csv(f"{OUT_DIR}/val_df_fold_{self.fold}_epoch_{self.current_epoch}_step_{self.global_step}_{map:.4f}.csv",
                      index = False)


    def predict_step(self, batch, batch_idx, dataloader_idx = 0):
        return F.sigmoid(self(batch))

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr = self.lr,
            betas = (0.9, 0.999),
            eps = 1e-08,
            weight_decay = 0.01
        )
        if self.lr_scheduler:
            scheduler_dict = {
            "scheduler": OneCycleLR(
                optimizer,
                self.lr,
                total_steps =  math.floor(self.steps_per_epoch * self.epochs / ACCUMULATE_GRAD_BATCHES)
            ),
            "interval": "step",
        }
            return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
        else:
            return optimizer

In [None]:
train_ds = SmilesDataset(df, train_graph)
train_loader = PyGDataLoader(
    train_ds,
    shuffle = True,
    batch_size = BATCH_SIZE,
    num_workers = 3
)

valid_ds = SmilesDataset(df_val, val_graph)
valid_loader = PyGDataLoader(
    valid_ds,
    shuffle = False,
    batch_size = BATCH_SIZE
)

logger = CSVLogger(save_dir = OUT_DIR, prefix = f"fold_{0}")

checkpoint_callback = ModelCheckpoint(
    dirpath = OUT_DIR,
    monitor = "val_loss",
    save_top_k = -1,
    save_last = False,
    save_weights_only = True,
    filename = f"fold_{0}" + "_epoch_{epoch}_step_{step}_{val_loss:.4f}",
    save_on_train_epoch_end = False,
    verbose = True,
    auto_insert_metric_name = False,
    mode = "min"
)

early_stop_callback = EarlyStopping(
    monitor = "val_loss",
    min_delta = 0.00,
    patience = 10,
    verbose = False,
    mode = "min"
)

model = Net(
    lr = LR,
    epochs = EPOCHS,
    steps_per_epoch = len(train_loader),
    lr_scheduler = False
)

if CONTINUE_TRAIN:    
    model = Net.load_from_checkpoint(
        CHECKPOINT_PATH,
        lr = LR,
        epochs = EPOCHS,
        steps_per_epoch = len(train_loader),
        lr_scheduler = False
    )

trainer = L.Trainer(
    max_epochs = EPOCHS,
    deterministic = True,
    accumulate_grad_batches = ACCUMULATE_GRAD_BATCHES,
    logger = logger,
    val_check_interval = 1/5,
    log_every_n_steps = 100,
    callbacks = [checkpoint_callback],
    devices = 1,
    precision = "bf16-mixed"
)

trainer.fit(
    model = model,
    train_dataloaders = train_loader,
    val_dataloaders = valid_loader
)