In [40]:
import os
import sys
import torch
import pandas as pd
from tqdm.notebook import tqdm
import pickle as pk
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
from torch.utils.data import DataLoader

from plm_dti import DTIDataset, molecule_protein_collate_fn

device = torch.cuda.set_device(0)

BASE_DIR = ".."
MODEL_BASE_DIR = f"{BASE_DIR}/best_models"
DATA_DIR = f"{BASE_DIR}/nbdata"
os.makedirs(DATA_DIR,exist_ok=True)
sys.path.append(BASE_DIR)

In [3]:
from tdc import utils
from tdc.benchmark_group import dti_dg_group

names = utils.retrieve_benchmark_names('DTI_DG_Group')
group = dti_dg_group(path = DATA_DIR)
benchmark = group.get('bindingdb_patent')
name = benchmark['name']
train_val, test = benchmark['train_val'], benchmark['test'] ## Natural log transformed (kd/ki/ic50??)

In [7]:
all_drugs = pd.concat([train_val,test]).Drug.values
all_proteins = pd.concat([train_val,test]).Target.values

In [10]:
from mol_feats import Morgan_f, Morgan_DC_f
from prot_feats import Prose_f

mol_featurizer = Morgan_DC_f()
prot_featurizer = Prose_f()

In [12]:
to_disk_path = f"{DATA_DIR}/tdc_bindingdb_patent_train"

mol_featurizer.precompute(all_drugs,to_disk_path=to_disk_path,from_disk=True)
prot_featurizer.precompute(all_proteins,to_disk_path=to_disk_path,from_disk=True)

  0%|                                                                                        | 30/232458 [00:00<13:04, 296.35it/s]

--- precomputing morgan_DC molecule featurizer ---


  3%|██▌                                                                                   | 6862/232458 [00:18<10:49, 347.48it/s]RDKit ERROR: [09:18:59] Explicit valence for atom # 20 N, 4, is greater than permitted
Failed to featurize datapoint 0, None. Appending empty array
[09:18:59] Explicit valence for atom # 20 N, 4, is greater than permitted
Exception message: Python argument types in
    rdkit.Chem.rdmolfiles.CanonicalRankAtoms(NoneType)
did not match C++ signature:
    CanonicalRankAtoms(RDKit::ROMol mol, bool breakTies=True, bool includeChirality=True, bool includeIsotopes=True)
RDKit ERROR: [09:18:59] Explicit valence for atom # 20 N, 4, is greater than permitted
Failed to featurize datapoint 0, None. Appending empty array
[09:18:59] Explicit valence for atom # 20 N, 4, is greater than permitted
Exception message: Python argument types in
    rdkit.Chem.rdmolfiles.CanonicalRankAtoms(NoneType)
did not match C++ signature:
    CanonicalRankAtoms(RDKit::ROMol mol, bool breakTi

--- saving morgans to ../nbdata/tdc_bindingdb_patent_train_Morgan_DC_MOLECULES.pk ---


In [15]:
## --- train your model --- ##

In [16]:
class SimplePLMModel(nn.Module):
    def __init__(self,
                 mol_emb_size = 2048,
                 prot_emb_size = 6165,
                 hidden_dim = 512,
                 activation = nn.ReLU
                ):
        super().__init__()
        self.mol_emb_size = mol_emb_size
        self.prot_emb_size = prot_emb_size

        self.mol_projector = nn.Sequential(
            nn.Linear(self.mol_emb_size, hidden_dim),
            activation()
        )

        self.prot_projector = nn.Sequential(
            nn.Linear(self.prot_emb_size, hidden_dim),
            activation()
        )
        
        self.fc = nn.Linear(2*hidden_dim, 1)

    def forward(self, mol_emb, prot_emb):
        mol_proj = self.mol_projector(mol_emb)
        prot_proj = self.prot_projector(prot_emb)
        # print(mol_proj.shape, prot_proj.shape)
        cat_emb = torch.cat([mol_proj, prot_proj],axis=1)
        # print(cat_emb.shape)
        return self.fc(cat_emb).squeeze()

In [47]:
import wandb
import copy
from torch.autograd import Variable
from time import time
from scipy.stats import pearsonr

test_dataset = DTIDataset(
        test.Drug,
        test.Target,
        test.Y,
        mol_featurizer,
        prot_featurizer,
    )

test_dataloader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=lambda x: molecule_protein_collate_fn(x, pad=False))

best_models = {}

for seed in range(5):
    train, valid = group.get_train_valid_split(benchmark = name, split_type = 'default', seed = seed)

    train_dataset = DTIDataset(
        train.Drug,
        train.Target,
        train.Y,
        mol_featurizer,
        prot_featurizer,
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        collate_fn=lambda x: molecule_protein_collate_fn(x, pad=False))

    valid_dataset = DTIDataset(
        valid.Drug,
        valid.Target,
        valid.Y,
        mol_featurizer,
        prot_featurizer,
    )

    valid_dataloader = DataLoader(
        valid_dataset,
        batch_size=32,
        shuffle=True,
        collate_fn=lambda x: molecule_protein_collate_fn(x, pad=False))
    
    # wandb.init(
    #         project=args.wandb_proj,
    #         name=config.experiment_id,
    #         config=flatten(config),
    #     )
    # wandb.watch(model, log_freq=100)

    # early stopping
    max_pcc = 0

    model = SimplePLMModel().cuda()
    torch.backends.cudnn.benchmark = True
    n_epo = 10
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    every_n_val = 1
    loss_history = []

    tg_len = len(train_dataloader)
    start_time = time()
    for epo in tqdm(range(n_epo)):
        model.train()
        epoch_time_start = time()
        for i, (d, p, label) in enumerate(train_dataloader):

            score = model(d.cuda(), p.cuda())
            label = Variable(torch.from_numpy(np.array(label)).float()).cuda()

            loss_fct = torch.nn.MSELoss()

            loss = loss_fct(score, label)
            loss_history.append((epo, i, float(loss.cpu().detach().numpy())))
            # wandb.log({"train/loss": loss, "epoch": epo,
            #                "step": epo*tg_len*args.batch_size + i*args.batch_size
            #           })

            opt.zero_grad()
            loss.backward()
            opt.step()

            if (i % 1000 == 0):
                print(f'[{seed}] Training at Epoch {epo+1} iteration {i} with loss {loss.cpu().detach().numpy()}')

        epoch_time_end = time()
        if epo % 5 == 0:
            with torch.set_grad_enabled(False):
                pred_list = []
                lab_list = []
                model.eval()
                for i, (d, p, label) in enumerate(valid_dataloader):
                    score = model(d.cuda(), p.cuda())
                    score = score.detach().cpu().numpy()
                    label = label.detach().cpu().numpy()
                    pred_list.extend(score)
                    lab_list.extend(label)

                pred_list = torch.tensor(pred_list)
                lab_list = torch.tensor(lab_list)
                val_pcc = pearsonr(pred_list, lab_list)[0]
                # wandb.log({"val/loss": val_loss, "epoch": epo,
                #            "val/pcc": float(val_pcc),
                #            "Charts/epoch_time": (epoch_time_end - epoch_time_start)/config.training.every_n_val
                #   })
                if val_pcc > max_pcc:
                    model_max = copy.deepcopy(model)
                    max_pcc = val_pcc
                print(f'[{seed}] Validation at Epoch {epo+1}: PCC={val_pcc}')
        end_time = time()
        
    best_models[seed] = (model_max, max_pcc)

generating training, validation splits...


  0%|          | 0/10 [00:00<?, ?it/s]

[0] Training at Epoch 1 iteration 0 with loss 24.837581634521484
[0] Training at Epoch 1 iteration 1000 with loss 4.105448246002197
[0] Training at Epoch 1 iteration 2000 with loss 1.938743233680725
[0] Training at Epoch 1 iteration 3000 with loss 3.6975269317626953
[0] Training at Epoch 1 iteration 4000 with loss 3.6933109760284424
[0] Validation at Epoch 1: PCC=0.8213020147755926
[0] Training at Epoch 2 iteration 0 with loss 3.1155881881713867
[0] Training at Epoch 2 iteration 1000 with loss 1.8293713331222534
[0] Training at Epoch 2 iteration 2000 with loss 3.300053119659424
[0] Training at Epoch 2 iteration 3000 with loss 2.6379446983337402
[0] Training at Epoch 2 iteration 4000 with loss 2.459156036376953
[0] Training at Epoch 3 iteration 0 with loss 1.514702558517456
[0] Training at Epoch 3 iteration 1000 with loss 1.8828380107879639
[0] Training at Epoch 3 iteration 2000 with loss 3.55964994430542
[0] Training at Epoch 3 iteration 3000 with loss 1.8052459955215454
[0] Training a

generating training, validation splits...


  0%|          | 0/10 [00:00<?, ?it/s]

[1] Training at Epoch 1 iteration 0 with loss 26.527992248535156
[1] Training at Epoch 1 iteration 1000 with loss 2.4547810554504395
[1] Training at Epoch 1 iteration 2000 with loss 4.502969741821289
[1] Training at Epoch 1 iteration 3000 with loss 2.5985445976257324
[1] Training at Epoch 1 iteration 4000 with loss 2.6350390911102295
[1] Validation at Epoch 1: PCC=0.8205621955981985
[1] Training at Epoch 2 iteration 0 with loss 2.5370070934295654
[1] Training at Epoch 2 iteration 1000 with loss 2.0154314041137695
[1] Training at Epoch 2 iteration 2000 with loss 3.0500972270965576
[1] Training at Epoch 2 iteration 3000 with loss 6.313941478729248
[1] Training at Epoch 2 iteration 4000 with loss 2.433480739593506
[1] Training at Epoch 3 iteration 0 with loss 2.987537384033203
[1] Training at Epoch 3 iteration 1000 with loss 2.1784253120422363
[1] Training at Epoch 3 iteration 2000 with loss 2.2773399353027344
[1] Training at Epoch 3 iteration 3000 with loss 3.6700279712677
[1] Training a

generating training, validation splits...


  0%|          | 0/10 [00:00<?, ?it/s]

[2] Training at Epoch 1 iteration 0 with loss 32.689414978027344
[2] Training at Epoch 1 iteration 1000 with loss 4.101603031158447
[2] Training at Epoch 1 iteration 2000 with loss 4.630967140197754
[2] Training at Epoch 1 iteration 3000 with loss 2.3897604942321777
[2] Training at Epoch 1 iteration 4000 with loss 4.140097618103027
[2] Validation at Epoch 1: PCC=0.8188091082201606
[2] Training at Epoch 2 iteration 0 with loss 2.570394515991211
[2] Training at Epoch 2 iteration 1000 with loss 2.503743886947632
[2] Training at Epoch 2 iteration 2000 with loss 3.3873023986816406
[2] Training at Epoch 2 iteration 3000 with loss 2.65244460105896
[2] Training at Epoch 2 iteration 4000 with loss 2.5011792182922363
[2] Training at Epoch 3 iteration 0 with loss 1.3105204105377197
[2] Training at Epoch 3 iteration 1000 with loss 2.874603748321533
[2] Training at Epoch 3 iteration 2000 with loss 2.587315082550049
[2] Training at Epoch 3 iteration 3000 with loss 3.262030601501465
[2] Training at E

generating training, validation splits...


  0%|          | 0/10 [00:00<?, ?it/s]

[3] Training at Epoch 1 iteration 0 with loss 33.45869445800781
[3] Training at Epoch 1 iteration 1000 with loss 2.508051872253418
[3] Training at Epoch 1 iteration 2000 with loss 4.68975305557251
[3] Training at Epoch 1 iteration 3000 with loss 1.769756555557251
[3] Training at Epoch 1 iteration 4000 with loss 2.248765468597412
[3] Validation at Epoch 1: PCC=0.8163294222879756
[3] Training at Epoch 2 iteration 0 with loss 2.4014763832092285
[3] Training at Epoch 2 iteration 1000 with loss 4.003084182739258
[3] Training at Epoch 2 iteration 2000 with loss 2.613588809967041
[3] Training at Epoch 2 iteration 3000 with loss 1.7472437620162964
[3] Training at Epoch 2 iteration 4000 with loss 2.11207914352417
[3] Training at Epoch 3 iteration 0 with loss 1.2026317119598389
[3] Training at Epoch 3 iteration 1000 with loss 3.0194787979125977
[3] Training at Epoch 3 iteration 2000 with loss 3.788757085800171
[3] Training at Epoch 3 iteration 3000 with loss 2.5604248046875
[3] Training at Epoch

generating training, validation splits...


  0%|          | 0/10 [00:00<?, ?it/s]

[4] Training at Epoch 1 iteration 0 with loss 18.42188262939453
[4] Training at Epoch 1 iteration 1000 with loss 4.198205947875977
[4] Training at Epoch 1 iteration 2000 with loss 4.158092975616455
[4] Training at Epoch 1 iteration 3000 with loss 3.1352334022521973
[4] Training at Epoch 1 iteration 4000 with loss 4.455635070800781
[4] Validation at Epoch 1: PCC=0.8173426161171405
[4] Training at Epoch 2 iteration 0 with loss 2.2666702270507812
[4] Training at Epoch 2 iteration 1000 with loss 2.360943078994751
[4] Training at Epoch 2 iteration 2000 with loss 1.7159645557403564
[4] Training at Epoch 2 iteration 3000 with loss 2.3912618160247803
[4] Training at Epoch 2 iteration 4000 with loss 4.862433433532715
[4] Training at Epoch 3 iteration 0 with loss 1.2773594856262207
[4] Training at Epoch 3 iteration 1000 with loss 3.0193135738372803
[4] Training at Epoch 3 iteration 2000 with loss 2.3110647201538086
[4] Training at Epoch 3 iteration 3000 with loss 1.8578425645828247
[4] Training 

In [49]:
pcc_seed = {}

for seed in range(5):
    pred_list = []

    best_mod_ev = best_models[seed][0]
    best_mod_ev.eval()
    with torch.no_grad():
        for i, (d, p, label) in enumerate(test_dataloader):
            score = best_mod_ev(d.cuda(), p.cuda())
            score = score.detach().cpu().numpy()
            pred_list.extend(score)

    pred_list = np.array(pred_list)
    predictions = {name: pred_list}
    
    out = group.evaluate(predictions)
    pcc_seed[seed] = out
    print(out)

{'bindingdb_patent': {'pcc': 0.518}}
{'bindingdb_patent': {'pcc': 0.519}}
{'bindingdb_patent': {'pcc': 0.516}}
{'bindingdb_patent': {'pcc': 0.519}}
{'bindingdb_patent': {'pcc': 0.513}}
