In [1]:
import sys
import os
import gc
import copy
import yaml
import pickle
import random
import joblib
import shutil
from time import time
import typing as tp
from pathlib import Path

import numpy as np
import pandas as pd
import scipy

from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedGroupKFold, StratifiedKFold
from sklearn.metrics import average_precision_score as APS
import duckdb


import torch
import torchvision
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from torch.cuda import amp
from torch.nn import BCELoss


import timm
from mamba_ssm import Mamba
from transformers import AutoModel, AutoTokenizer

import albumentations as A
from albumentations.pytorch import ToTensorV2


# use one device only
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
con = duckdb.connect()

In [2]:
class CFG:
    NUM = 20000
    TRAIN_BIND_PER = 0.1
    TEST_No = 1
    TEST_NUM = int(1674896/32 * TEST_No)
    TEST_OFFSET = int(TEST_NUM * (TEST_No-1))
    LR = 0.001
    WD = 1e-4
    NBR_FOLDS = 5
    SELECTED_FOLDS = [0, 1, 2, 3, 4]
    TRAIN_ENC_PATH = Path('../../data/external/train_enc.parquet')
    TEST_ENC_PATH = Path('../../data/external/test_enc.parquet')
    TRAIN_PATH = Path('../../data/raw/train.parquet')
    TEST_PATH = Path('../../data/raw/test.parquet')
    OUTPUT_PATH = Path(f'../../data/processed/{NUM}_50per_CLM.parquet')
    TEST_OUTPUT_PATH = Path(f'../../data/processed/test_CLM_{TEST_OFFSET}_to_{TEST_NUM}.parquet')
    BATCH_SIZE = 128
    EPOCHS = 5
    PATIENCE = 5
    REDUCE_LR_PATIENCE = 3
    REDUCE_LR_FACTOR = 0.5
    is_test = True

In [3]:
if not CFG.is_test:
    train = con.query(f"""(SELECT *
                            FROM parquet_scan('{CFG.TRAIN_PATH}')
                            WHERE binds = 0
                            ORDER BY random()
                            LIMIT {int(CFG.NUM/2)})
                            UNION ALL
                            (SELECT *
                            FROM parquet_scan('{CFG.TRAIN_PATH}')
                            WHERE binds = 1
                            ORDER BY random()
                            LIMIT {int(CFG.NUM/2)})""").df()
else:
    test = con.query(f"""(SELECT *
                        FROM parquet_scan('{CFG.TEST_PATH}')
                        LIMIT {CFG.TEST_NUM}
                        OFFSET {CFG.TEST_OFFSET}
                        )""").df()



In [4]:
if not CFG.is_test:
    display(train.head())
    display(train.tail())
else:
    display(test.head())
    display(test.tail())

Unnamed: 0,id,buildingblock1_smiles,buildingblock2_smiles,buildingblock3_smiles,molecule_smiles,protein_name
0,295246830,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,C=Cc1ccc(N)cc1,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C...,BRD4
1,295246831,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,C=Cc1ccc(N)cc1,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C...,HSA
2,295246832,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,C=Cc1ccc(N)cc1,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C...,sEH
3,295246833,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,CC(O)Cn1cnc2c(N)ncnc21,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ncnc3c2...,BRD4
4,295246834,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,CC(O)Cn1cnc2c(N)ncnc21,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ncnc3c2...,HSA


Unnamed: 0,id,buildingblock1_smiles,buildingblock2_smiles,buildingblock3_smiles,molecule_smiles,protein_name
52335,295299165,CC(C)(C)OC(=O)CC(NC(=O)OCC1c2ccccc2-c2ccccc21)...,NCC1(CO)CCOC1,COC(=O)c1cscc1N.Cl,COC(=O)c1cscc1Nc1nc(NCC2(CO)CCOC2)nc(NC(CC(=O)...,BRD4
52336,295299166,CC(C)(C)OC(=O)CC(NC(=O)OCC1c2ccccc2-c2ccccc21)...,NCC1(CO)CCOC1,COC(=O)c1cscc1N.Cl,COC(=O)c1cscc1Nc1nc(NCC2(CO)CCOC2)nc(NC(CC(=O)...,HSA
52337,295299167,CC(C)(C)OC(=O)CC(NC(=O)OCC1c2ccccc2-c2ccccc21)...,NCC1(CO)CCOC1,COC(=O)c1cscc1N.Cl,COC(=O)c1cscc1Nc1nc(NCC2(CO)CCOC2)nc(NC(CC(=O)...,sEH
52338,295299168,CC(C)(C)OC(=O)CC(NC(=O)OCC1c2ccccc2-c2ccccc21)...,NCC1(CO)CCOC1,Cc1csc(N)n1,Cc1csc(Nc2nc(NCC3(CO)CCOC3)nc(NC(CC(=O)OC(C)(C...,BRD4
52339,295299169,CC(C)(C)OC(=O)CC(NC(=O)OCC1c2ccccc2-c2ccccc21)...,NCC1(CO)CCOC1,Cc1csc(N)n1,Cc1csc(Nc2nc(NCC3(CO)CCOC3)nc(NC(CC(=O)OC(C)(C...,HSA


In [5]:
if not CFG.is_test:
    smiles = train['molecule_smiles']#.unique()
    print(len(smiles))
else:
    smiles = test['molecule_smiles']#.unique()
    print(len(smiles))

52340


In [6]:
# 104681 rows take about 10 minutes.
# load pre-trained ChemBERTa model checkpoint and tokenizer
cb_tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-10M-MLM')
cb_model = AutoModel.from_pretrained('DeepChem/ChemBERTa-10M-MLM')
cb_model.eval()

# tokenize SMILES
cb_encoded_inputs = cb_tokenizer(list(smiles), padding=True, truncation=True, return_tensors="pt")

# calculate embeddings
with torch.no_grad():
    outputs = cb_model(**cb_encoded_inputs)

# extract pooled output
cb_embeddings = outputs.pooler_output

cb_embeddings_df = pd.DataFrame(cb_embeddings.numpy())
cb_embeddings_df.head()

  return self.fget.__get__(instance, owner)()
Some weights of RobertaModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-10M-MLM and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,374,375,376,377,378,379,380,381,382,383
0,0.059212,0.268544,0.459275,-0.478332,0.064293,0.066137,-0.147887,0.115527,0.04611,-0.391337,...,0.155696,0.065302,0.083106,-0.26467,0.19297,-0.045991,0.291919,0.124646,-0.182428,0.36823
1,0.059212,0.268544,0.459275,-0.478332,0.064293,0.066137,-0.147887,0.115527,0.04611,-0.391337,...,0.155696,0.065302,0.083106,-0.26467,0.19297,-0.045991,0.291919,0.124646,-0.182428,0.36823
2,0.059212,0.268544,0.459275,-0.478332,0.064293,0.066137,-0.147887,0.115527,0.04611,-0.391337,...,0.155696,0.065302,0.083106,-0.26467,0.19297,-0.045991,0.291919,0.124646,-0.182428,0.36823
3,-0.048795,0.222011,0.395821,-0.379432,-0.070935,-0.064048,-0.125544,0.057238,-0.03522,-0.281463,...,0.194638,0.161062,0.121532,-0.174147,0.034281,0.092628,0.27025,0.135999,-0.095416,0.323474
4,-0.048795,0.222011,0.395821,-0.379432,-0.070935,-0.064048,-0.125544,0.057238,-0.03522,-0.281463,...,0.194638,0.161062,0.121532,-0.174147,0.034281,0.092628,0.27025,0.135999,-0.095416,0.323474


In [7]:
# df_repeated = cb_embeddings_df.loc[cb_embeddings_df.index.repeat(3)].reset_index(drop=True)

In [8]:
if not CFG.is_test:
    cb_embeddings_df = pd.concat([train['id'], cb_embeddings_df], axis=1)
    binds = train[['binds', 'protein_name']]
    binds['bind1'] = train.apply(lambda row: row['binds'] if row['protein_name'] == 'BRD4' else 0, axis=1)
    binds['bind2'] = train.apply(lambda row: row['binds'] if row['protein_name'] == 'HSA' else 0, axis=1)
    binds['bind3'] = train.apply(lambda row: row['binds'] if row['protein_name'] == 'sEH' else 0, axis=1)
    cb_embeddings_df = pd.concat([cb_embeddings_df, binds], axis=1)
else:
    cb_embeddings_df = pd.concat([test['id'], cb_embeddings_df], axis=1)
    cb_embeddings_df = pd.concat([cb_embeddings_df, test['protein_name']], axis=1)


In [9]:
display(cb_embeddings_df.head())
display(cb_embeddings_df.tail())

Unnamed: 0,id,0,1,2,3,4,5,6,7,8,...,375,376,377,378,379,380,381,382,383,protein_name
0,295246830,0.059212,0.268544,0.459275,-0.478332,0.064293,0.066137,-0.147887,0.115527,0.04611,...,0.065302,0.083106,-0.26467,0.19297,-0.045991,0.291919,0.124646,-0.182428,0.36823,BRD4
1,295246831,0.059212,0.268544,0.459275,-0.478332,0.064293,0.066137,-0.147887,0.115527,0.04611,...,0.065302,0.083106,-0.26467,0.19297,-0.045991,0.291919,0.124646,-0.182428,0.36823,HSA
2,295246832,0.059212,0.268544,0.459275,-0.478332,0.064293,0.066137,-0.147887,0.115527,0.04611,...,0.065302,0.083106,-0.26467,0.19297,-0.045991,0.291919,0.124646,-0.182428,0.36823,sEH
3,295246833,-0.048795,0.222011,0.395821,-0.379432,-0.070935,-0.064048,-0.125544,0.057238,-0.03522,...,0.161062,0.121532,-0.174147,0.034281,0.092628,0.27025,0.135999,-0.095416,0.323474,BRD4
4,295246834,-0.048795,0.222011,0.395821,-0.379432,-0.070935,-0.064048,-0.125544,0.057238,-0.03522,...,0.161062,0.121532,-0.174147,0.034281,0.092628,0.27025,0.135999,-0.095416,0.323474,HSA


Unnamed: 0,id,0,1,2,3,4,5,6,7,8,...,375,376,377,378,379,380,381,382,383,protein_name
52335,295299165,-0.372445,0.089123,0.418904,-0.146498,-0.177251,-0.06379,0.174884,-0.163991,-0.124751,...,0.079181,0.057339,-0.033592,0.156972,0.118414,0.300975,0.437202,-0.245715,0.274537,BRD4
52336,295299166,-0.372445,0.089123,0.418904,-0.146498,-0.177251,-0.06379,0.174884,-0.163991,-0.124751,...,0.079181,0.057339,-0.033592,0.156972,0.118414,0.300975,0.437202,-0.245715,0.274537,HSA
52337,295299167,-0.372445,0.089123,0.418904,-0.146498,-0.177251,-0.06379,0.174884,-0.163991,-0.124751,...,0.079181,0.057339,-0.033592,0.156972,0.118414,0.300975,0.437202,-0.245715,0.274537,sEH
52338,295299168,-0.20151,-0.133856,0.205848,0.05759,0.017345,-0.040076,-0.064621,-0.257788,-0.11316,...,0.015858,-0.154248,-0.173892,0.304074,0.00425,0.301082,0.348183,-0.131817,0.121814,BRD4
52339,295299169,-0.20151,-0.133856,0.205848,0.05759,0.017345,-0.040076,-0.064621,-0.257788,-0.11316,...,0.015858,-0.154248,-0.173892,0.304074,0.00425,0.301082,0.348183,-0.131817,0.121814,HSA


In [10]:
cb_embeddings_df.columns = cb_embeddings_df.columns.astype(str)

if not CFG.is_test:
    cb_embeddings_df.to_parquet(CFG.OUTPUT_PATH)
else:
    cb_embeddings_df.to_parquet(CFG.TEST_OUTPUT_PATH)