For details, refer to https://www.kaggle.com/competitions/leash-BELKA/discussion/498858

In [None]:
!pip install pgzip
!pip install rdkit
!pip install torch_geometric
!pip install torch-scatter

In [None]:
import os

import pandas as pd
import numpy as np

import rdkit
from rdkit import Chem

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as PyGDataLoader

from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_scatter import scatter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from sklearn.metrics import average_precision_score

from multiprocessing import Pool

import _pickle as  cPickle
import bz2
import pgzip
from tqdm import tqdm

print('import ok!')

In [3]:
def save_compressed_pickle(file, data, type = "pgzip", n_treads = 1):
    if type == "pgzip":
        with pgzip.open(file , 'w', thread = n_treads) as f:
            cPickle.dump(data, f)
    elif type == "bz2":
        with bz2.BZ2File(file , 'w') as f:
            cPickle.dump(data, f)
    else:
        print("unsupported")

def load_decompress_pickle(file, type = "pgzip", n_treads = 4):
    if type == "pgzip":
        data = pgzip.open(file, thread = n_treads)
        data = cPickle.load(data)
    elif type == "bz2":
        data = bz2.BZ2File(file, 'rb')
        data = cPickle.load(data)
    else:
        print("unsupported")
    return data

def compute_ap(df):
    df = pd.concat(
        [df.loc[:, ["BRD4", "HSA", "sEH", "label_id", "split"]].melt(
            id_vars = ["label_id", "split"],
            var_name = "protein",
            value_name = "y_true"
            ),
        df.loc[:, ["BRD4_pred", "HSA_pred", "sEH_pred", "label_id", "split"]].melt(
            id_vars = ["label_id", "split"],
            var_name = "_",
            value_name = "y_pred"
            ).drop("split", axis = 1)],
        axis = 1
    )

    ap = df.groupby(["split", "protein"]).apply(
        lambda x: average_precision_score(x.y_true, x.y_pred)
        #include_groups = True
    ).reset_index(name = "AP")

    return(ap)

In [17]:
N_THREADS = 96

TRAIN_FILE = "../data/train_no_test_wide_replace_dy.parquet"
TEST_FILE = "../data/test_ensemble_wide_replace_dy.parquet"
SUBMIT_FILE = "../data/submit_replace_dy.parquet" 

TRAIN_GRAPH_FILE = "../data/train_graph_100M"
TEST_GRAPH_FILE = "../data/test_graph_2M.gz"
SUBMIT_GRAPH_FILE = "../data/submit_graph.gz"

In [None]:
df = pd.read_parquet(TRAIN_FILE)
df

In [None]:
# helper
# torch version of np unpackbits
#https://gist.github.com/vadimkantorov/30ea6d278bc492abf6ad328c6965613a

def tensor_dim_slice(tensor, dim, dim_slice):
    return tensor[(dim if dim >= 0 else dim + tensor.dim()) * (slice(None),) + (dim_slice,)]

# @torch.jit.script
def packshape(shape, dim: int = -1, mask: int = 0b00000001, dtype=torch.uint8, pack=True):
    dim = dim if dim >= 0 else dim + len(shape)
    bits, nibble = (
        8 if dtype is torch.uint8 else 16 if dtype is torch.int16 else 32 if dtype is torch.int32 else 64 if dtype is torch.int64 else 0), (
        1 if mask == 0b00000001 else 2 if mask == 0b00000011 else 4 if mask == 0b00001111 else 8 if mask == 0b11111111 else 0)
    # bits = torch.iinfo(dtype).bits # does not JIT compile
    assert nibble <= bits and bits % nibble == 0
    nibbles = bits // nibble
    shape = (shape[:dim] + (int(math.ceil(shape[dim] / nibbles)),) + shape[1 + dim:]) if pack else (
                shape[:dim] + (shape[dim] * nibbles,) + shape[1 + dim:])
    return shape, nibbles, nibble

# @torch.jit.script
def F_unpackbits(tensor, dim: int = -1, mask: int = 0b00000001, shape=None, out=None, dtype=torch.uint8):
    dim = dim if dim >= 0 else dim + tensor.dim()
    shape_, nibbles, nibble = packshape(tensor.shape, dim=dim, mask=mask, dtype=tensor.dtype, pack=False)
    shape = shape if shape is not None else shape_
    out = out if out is not None else torch.empty(shape, device=tensor.device, dtype=dtype)
    assert out.shape == shape

    if shape[dim] % nibbles == 0:
        shift = torch.arange((nibbles - 1) * nibble, -1, -nibble, dtype=torch.uint8, device=tensor.device)
        shift = shift.view(nibbles, *((1,) * (tensor.dim() - dim - 1)))
        return torch.bitwise_and((tensor.unsqueeze(1 + dim) >> shift).view_as(out), mask, out=out)

    else:
        for i in range(nibbles):
            shift = nibble * i
            sliced_output = tensor_dim_slice(out, dim, slice(i, None, nibbles))
            sliced_input = tensor.narrow(dim, 0, sliced_output.shape[dim])
            torch.bitwise_and(sliced_input >> shift, mask, out=sliced_output)
    return out

class dotdict(dict):
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

    def __getattr__(self, name):
        try:
            return self[name]
        except KeyError:
            raise AttributeError(name)


print('helper ok!')

In [8]:
# mol to graph adopted from
# from https://github.com/LiZhang30/GPCNDTA/blob/main/utils/DrugGraph.py

PACK_NODE_DIM = 9
PACK_EDGE_DIM = 1
NODE_DIM = 40 #PACK_NODE_DIM * 8
EDGE_DIM = PACK_EDGE_DIM * 8

def one_of_k_encoding(x, allowable_set, allow_unk=False):
    if x not in allowable_set:
        if allow_unk:
            x = allowable_set[-1]
        else:
            raise Exception(f'input {x} not in allowable set{allowable_set}!!!')
    return list(map(lambda s: x == s, allowable_set))


#Get features of an atom (one-hot encoding:)
'''
    1.atom element: 44+1 dimensions
    2.the atom's hybridization: 5 dimensions
    3.degree of atom: 6 dimensions
    4.total number of H bound to atom: 6 dimensions
    5.number of implicit H bound to atom: 6 dimensions
    6.whether the atom is on ring: 1 dimension
    7.whether the atom is aromatic: 1 dimension
    Total: 70 dimensions
'''

ATOM_SYMBOL = ['B', 'Br', 'C', 'Cl', 'F', 'I', 'N', 'O', 'S', 'Si']
#print('ATOM_SYMBOL', len(ATOM_SYMBOL))44

HYBRIDIZATION_TYPE = [
    Chem.rdchem.HybridizationType.S,
    Chem.rdchem.HybridizationType.SP,
    Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3,
    Chem.rdchem.HybridizationType.SP3D
]

def get_atom_feature(atom):
    feature = (
         one_of_k_encoding(atom.GetSymbol(), ATOM_SYMBOL)
       + one_of_k_encoding(atom.GetHybridization(), HYBRIDIZATION_TYPE)
       + one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
       + one_of_k_encoding(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5])
       + one_of_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5])
       + [atom.IsInRing()]
       + [atom.GetIsAromatic()]
    )
    #feature = np.array(feature, dtype=np.uint8)
    feature = np.packbits(feature)
    return feature


#Get features of an edge (one-hot encoding)
'''
    1.single/double/triple/aromatic: 4 dimensions
    2.the atom's hybridization: 1 dimensions
    3.whether the bond is on ring: 1 dimension
    Total: 6 dimensions
'''

def get_bond_feature(bond):
    bond_type = bond.GetBondType()
    feature = [
        bond_type == Chem.rdchem.BondType.SINGLE,
        bond_type == Chem.rdchem.BondType.DOUBLE,
        bond_type == Chem.rdchem.BondType.TRIPLE,
        bond_type == Chem.rdchem.BondType.AROMATIC,
        bond.GetIsConjugated(),
        bond.IsInRing()
    ]
    #feature = np.array(feature, dtype=np.uint8)
    feature = np.packbits(feature)
    return feature


def smile_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    N = mol.GetNumAtoms()
    node_feature = []
    edge_feature = []
    edge = []
    for i in range(mol.GetNumAtoms()):
        atom_i = mol.GetAtomWithIdx(i)
        atom_i_features = get_atom_feature(atom_i)
        node_feature.append(atom_i_features)

        for j in range(mol.GetNumAtoms()):
            bond_ij = mol.GetBondBetweenAtoms(i, j)
            if bond_ij is not None:
                edge.append([i, j])
                bond_features_ij = get_bond_feature(bond_ij)
                edge_feature.append(bond_features_ij)
    node_feature=np.stack(node_feature)
    edge_feature=np.stack(edge_feature)
    edge = np.array(edge,dtype=np.uint8)
    return N,edge,node_feature,edge_feature

def to_pyg_format(N,edge,node_feature,edge_feature):
    graph = Data(
        idx=-1,
        edge_index = torch.from_numpy(edge.T).int(),
        x          = torch.from_numpy(node_feature).byte(),
        edge_attr  = torch.from_numpy(edge_feature).byte(),
    )
    return graph


In [None]:
#debug one example
g = to_pyg_format(*smile_to_graph(smiles="C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](CC#C)CC(=O)NC)n2)cc1"))
print(g)
print('[Dy] is replaced by C !!')
print('smile_to_graph() ok!')

In [None]:
for i in range(10):
    with Pool(processes = N_THREADS) as pool:
        start = 10_000_000 * i
        end = start + 10_000_000
        end = np.min([end, df.shape[0]])
        train_graph = list(
            tqdm(pool.imap(smile_to_graph, df.iloc[start:end].molecule_smiles), 
                 total = len(df.iloc[start:end]))
        )
    save_compressed_pickle(f"{TRAIN_GRAPH_FILE}_{i+1}.gz", train_graph, 
                           type = "pgzip", n_treads = 16)

In [None]:
df_test = pd.read_parquet(TEST_FILE)
with Pool(processes = 96) as pool:
    test_graph = list(
        tqdm(pool.imap(smile_to_graph, df_test.molecule_smiles), 
             total = len(df_test))
    )
save_compressed_pickle(TEST_GRAPH_FILE, test_graph, 
                       type = "pgzip", n_treads = 16)

In [None]:
df_submit = pd.read_parquet(SUBMIT_FILE)
df_submit = df_submit[["molecule_smiles"]].drop_duplicates()
with Pool(processes = 96) as pool:
    submit_graph = list(
        tqdm(pool.imap(smile_to_graph, df_submit.molecule_smiles), 
             total = len(df_submit))
    )
save_compressed_pickle(SUBMIT_GRAPH_FILE, submit_graph, 
                       type = "pgzip", n_treads = 16)