In [1]:
import os
import sys
import torch
import pandas as pd
from tqdm.notebook import tqdm
import pickle as pk
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
from torch.utils.data import DataLoader

device = torch.cuda.set_device(0)

BASE_DIR = ".."
MODEL_BASE_DIR = f"{BASE_DIR}/best_models"
DATA_DIR = f"{BASE_DIR}/nbdata"
os.makedirs(MODEL_BASE_DIR,exist_ok=True)
os.makedirs(DATA_DIR,exist_ok=True)
sys.path.append(BASE_DIR)

from plm_dti import DTIDataset, molecule_protein_collate_fn

In [2]:
from tdc import utils
from tdc.benchmark_group import dti_dg_group

names = utils.retrieve_benchmark_names('DTI_DG_Group')
group = dti_dg_group(path = DATA_DIR)
benchmark = group.get('bindingdb_patent')
name = benchmark['name']
train_val, test = benchmark['train_val'], benchmark['test'] ## Natural log transformed (kd/ki/ic50??)

Found local copy...


In [3]:
all_drugs = pd.concat([train_val,test]).Drug.values
all_proteins = pd.concat([train_val,test]).Target.values

In [4]:
from mol_feats import Morgan_f, Morgan_DC_f
from prot_feats import Prose_f

mol_featurizer = Morgan_DC_f()
prot_featurizer = Prose_f()

In [5]:
to_disk_path = f"{DATA_DIR}/tdc_bindingdb_patent_train"

mol_featurizer.precompute(all_drugs,to_disk_path=to_disk_path,from_disk=True)
prot_featurizer.precompute(all_proteins,to_disk_path=to_disk_path,from_disk=True)

--- precomputing morgan_DC molecule featurizer ---
--- loading from disk ---
--- precomputing Prose protein featurizer ---
--- loading from disk ---


In [6]:
## --- train your model --- ##

In [7]:
class SimplePLMModel(nn.Module):
    def __init__(self,
                 mol_emb_size = 2048,
                 prot_emb_size = 6165,
                 hidden_dim = 512,
                 activation = nn.ReLU
                ):
        super().__init__()
        self.mol_emb_size = mol_emb_size
        self.prot_emb_size = prot_emb_size

        self.mol_projector = nn.Sequential(
            nn.Linear(self.mol_emb_size, hidden_dim),
            activation()
        )

        self.prot_projector = nn.Sequential(
            nn.Linear(self.prot_emb_size, hidden_dim),
            activation()
        )
        
        self.fc = nn.Linear(2*hidden_dim, 1)

    def forward(self, mol_emb, prot_emb):
        mol_proj = self.mol_projector(mol_emb)
        prot_proj = self.prot_projector(prot_emb)
        # print(mol_proj.shape, prot_proj.shape)
        cat_emb = torch.cat([mol_proj, prot_proj],axis=1)
        # print(cat_emb.shape)
        return self.fc(cat_emb).squeeze()

In [8]:
import wandb
import copy
from torch.autograd import Variable
from time import time
from scipy.stats import pearsonr

test_dataset = DTIDataset(
        test.Drug,
        test.Target,
        test.Y,
        mol_featurizer,
        prot_featurizer,
    )

test_dataloader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=lambda x: molecule_protein_collate_fn(x, pad=False))

best_models = {}

for seed in range(5):
    train, valid = group.get_train_valid_split(benchmark = name, split_type = 'default', seed = seed)

    train_dataset = DTIDataset(
        train.Drug,
        train.Target,
        train.Y,
        mol_featurizer,
        prot_featurizer,
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        collate_fn=lambda x: molecule_protein_collate_fn(x, pad=False))

    valid_dataset = DTIDataset(
        valid.Drug,
        valid.Target,
        valid.Y,
        mol_featurizer,
        prot_featurizer,
    )

    valid_dataloader = DataLoader(
        valid_dataset,
        batch_size=32,
        shuffle=True,
        collate_fn=lambda x: molecule_protein_collate_fn(x, pad=False))
    
    # wandb.init(
    #         project=args.wandb_proj,
    #         name=config.experiment_id,
    #         config=flatten(config),
    #     )
    # wandb.watch(model, log_freq=100)

    # early stopping
    max_pcc = 0

    model = SimplePLMModel().cuda()
    torch.backends.cudnn.benchmark = True
    n_epo = 10
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    every_n_val = 1
    loss_history = []

    tg_len = len(train_dataloader)
    start_time = time()
    for epo in tqdm(range(n_epo)):
        model.train()
        epoch_time_start = time()
        for i, (d, p, label) in enumerate(train_dataloader):

            score = model(d.cuda(), p.cuda())
            label = Variable(torch.from_numpy(np.array(label)).float()).cuda()

            loss_fct = torch.nn.MSELoss()

            loss = loss_fct(score, label)
            loss_history.append((epo, i, float(loss.cpu().detach().numpy())))
            # wandb.log({"train/loss": loss, "epoch": epo,
            #                "step": epo*tg_len*args.batch_size + i*args.batch_size
            #           })

            opt.zero_grad()
            loss.backward()
            opt.step()

            if (i % 1000 == 0):
                print(f'[{seed}] Training at Epoch {epo+1} iteration {i} with loss {loss.cpu().detach().numpy()}')

        epoch_time_end = time()
        if epo % 5 == 0:
            with torch.set_grad_enabled(False):
                pred_list = []
                lab_list = []
                model.eval()
                for i, (d, p, label) in enumerate(valid_dataloader):
                    score = model(d.cuda(), p.cuda())
                    score = score.detach().cpu().numpy()
                    label = label.detach().cpu().numpy()
                    pred_list.extend(score)
                    lab_list.extend(label)

                pred_list = torch.tensor(pred_list)
                lab_list = torch.tensor(lab_list)
                val_pcc = pearsonr(pred_list, lab_list)[0]
                # wandb.log({"val/loss": val_loss, "epoch": epo,
                #            "val/pcc": float(val_pcc),
                #            "Charts/epoch_time": (epoch_time_end - epoch_time_start)/config.training.every_n_val
                #   })
                if val_pcc > max_pcc:
                    model_max = copy.deepcopy(model)
                    max_pcc = val_pcc
                print(f'[seed{seed}] Validation at Epoch {epo+1}: PCC={val_pcc}')
        end_time = time()
        
    best_models[seed] = (model_max, max_pcc)
    torch.save(model_max, f"best_models/TDC_DTI_DG_seed{seed}_best_model.sav")

generating training, validation splits...


  0%|          | 0/10 [00:00<?, ?it/s]

[0] Training at Epoch 1 iteration 0 with loss 24.035785675048828
[0] Training at Epoch 1 iteration 1000 with loss 5.604084491729736
[0] Training at Epoch 1 iteration 2000 with loss 3.7022926807403564
[0] Training at Epoch 1 iteration 3000 with loss 1.740458607673645
[0] Training at Epoch 1 iteration 4000 with loss 4.7416229248046875
[0] Validation at Epoch 1: PCC=0.8190084527118858
[0] Training at Epoch 2 iteration 0 with loss 2.08184814453125
[0] Training at Epoch 2 iteration 1000 with loss 2.543416976928711
[0] Training at Epoch 2 iteration 2000 with loss 1.7525396347045898
[0] Training at Epoch 2 iteration 3000 with loss 1.8970952033996582
[0] Training at Epoch 2 iteration 4000 with loss 1.8240851163864136
[0] Training at Epoch 3 iteration 0 with loss 1.7656350135803223
[0] Training at Epoch 3 iteration 1000 with loss 3.5992610454559326
[0] Training at Epoch 3 iteration 2000 with loss 2.9975829124450684
[0] Training at Epoch 3 iteration 3000 with loss 2.7799134254455566
[0] Training

generating training, validation splits...


  0%|          | 0/10 [00:00<?, ?it/s]

[1] Training at Epoch 1 iteration 0 with loss 36.12716293334961
[1] Training at Epoch 1 iteration 1000 with loss 4.674874305725098
[1] Training at Epoch 1 iteration 2000 with loss 3.9652135372161865
[1] Training at Epoch 1 iteration 3000 with loss 4.16023063659668
[1] Training at Epoch 1 iteration 4000 with loss 2.8760318756103516
[1] Validation at Epoch 1: PCC=0.8236950453030591
[1] Training at Epoch 2 iteration 0 with loss 2.695178985595703
[1] Training at Epoch 2 iteration 1000 with loss 2.9995810985565186
[1] Training at Epoch 2 iteration 2000 with loss 2.0711474418640137
[1] Training at Epoch 2 iteration 3000 with loss 2.877331256866455
[1] Training at Epoch 2 iteration 4000 with loss 3.4949283599853516
[1] Training at Epoch 3 iteration 0 with loss 3.7171449661254883
[1] Training at Epoch 3 iteration 1000 with loss 2.073960304260254
[1] Training at Epoch 3 iteration 2000 with loss 1.2705814838409424
[1] Training at Epoch 3 iteration 3000 with loss 3.230710506439209
[1] Training at

generating training, validation splits...


  0%|          | 0/10 [00:00<?, ?it/s]

[2] Training at Epoch 1 iteration 0 with loss 29.845439910888672
[2] Training at Epoch 1 iteration 1000 with loss 3.7798047065734863
[2] Training at Epoch 1 iteration 2000 with loss 6.608565807342529
[2] Training at Epoch 1 iteration 3000 with loss 5.181746482849121
[2] Training at Epoch 1 iteration 4000 with loss 3.8914988040924072
[2] Validation at Epoch 1: PCC=0.8195777564916156
[2] Training at Epoch 2 iteration 0 with loss 2.6653618812561035
[2] Training at Epoch 2 iteration 1000 with loss 3.062657356262207
[2] Training at Epoch 2 iteration 2000 with loss 1.652127981185913
[2] Training at Epoch 2 iteration 3000 with loss 1.8174138069152832
[2] Training at Epoch 2 iteration 4000 with loss 1.8002684116363525
[2] Training at Epoch 3 iteration 0 with loss 1.6508482694625854
[2] Training at Epoch 3 iteration 1000 with loss 3.6090269088745117
[2] Training at Epoch 3 iteration 2000 with loss 1.437587022781372
[2] Training at Epoch 3 iteration 3000 with loss 1.9503026008605957
[2] Training

generating training, validation splits...


  0%|          | 0/10 [00:00<?, ?it/s]

[3] Training at Epoch 1 iteration 0 with loss 36.807273864746094
[3] Training at Epoch 1 iteration 1000 with loss 3.958169937133789
[3] Training at Epoch 1 iteration 2000 with loss 4.124624252319336
[3] Training at Epoch 1 iteration 3000 with loss 2.252260446548462
[3] Training at Epoch 1 iteration 4000 with loss 2.250272512435913
[3] Validation at Epoch 1: PCC=0.8158872440309023
[3] Training at Epoch 2 iteration 0 with loss 3.3046488761901855
[3] Training at Epoch 2 iteration 1000 with loss 3.400376558303833
[3] Training at Epoch 2 iteration 2000 with loss 1.839687466621399
[3] Training at Epoch 2 iteration 3000 with loss 4.518401145935059
[3] Training at Epoch 2 iteration 4000 with loss 1.9515361785888672
[3] Training at Epoch 3 iteration 0 with loss 2.83988094329834
[3] Training at Epoch 3 iteration 1000 with loss 1.7946913242340088
[3] Training at Epoch 3 iteration 2000 with loss 1.9969806671142578
[3] Training at Epoch 3 iteration 3000 with loss 4.6395158767700195
[3] Training at 

generating training, validation splits...


  0%|          | 0/10 [00:00<?, ?it/s]

[4] Training at Epoch 1 iteration 0 with loss 25.415008544921875
[4] Training at Epoch 1 iteration 1000 with loss 2.9821510314941406
[4] Training at Epoch 1 iteration 2000 with loss 4.140312194824219
[4] Training at Epoch 1 iteration 3000 with loss 2.425964832305908
[4] Training at Epoch 1 iteration 4000 with loss 1.9992005825042725
[4] Validation at Epoch 1: PCC=0.817678846344615
[4] Training at Epoch 2 iteration 0 with loss 2.670764923095703
[4] Training at Epoch 2 iteration 1000 with loss 3.197417736053467
[4] Training at Epoch 2 iteration 2000 with loss 2.474884033203125
[4] Training at Epoch 2 iteration 3000 with loss 4.401054382324219
[4] Training at Epoch 2 iteration 4000 with loss 3.0352072715759277
[4] Training at Epoch 3 iteration 0 with loss 2.166949510574341
[4] Training at Epoch 3 iteration 1000 with loss 1.7536730766296387
[4] Training at Epoch 3 iteration 2000 with loss 1.158900260925293
[4] Training at Epoch 3 iteration 3000 with loss 2.3164823055267334
[4] Training at 

In [12]:
pcc_seed = {}

for seed in range(5):
    pred_list = []

    best_mod_ev = best_models[seed][0]
    best_mod_ev.eval()
    with torch.no_grad():
        for i, (d, p, label) in enumerate(test_dataloader):
            score = best_mod_ev(d.cuda(), p.cuda())
            score = score.detach().cpu().numpy()
            pred_list.extend(score)

    pred_list = np.array(pred_list)
    predictions = {name: pred_list}
    
    out = group.evaluate(predictions)
    pcc_seed[seed] = out
    print(f'{seed}: PCC={out[name]["pcc"]}')
print(f'Average PCC: {sum([pcc_seed[s][name]["pcc"] for s in range(5)])/5}')

0: PCC=0.504
1: PCC=0.518
2: PCC=0.505
3: PCC=0.506
4: PCC=0.529
Average PCC: 0.5124000000000001
