In [1]:
import os
import argparse
import pickle
import math

import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import auc, roc_curve, precision_recall_curve

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms

from my_utils import Encoder , Decoder




In [2]:
def fill_mask(y_trn, m_trn):
  y_pos = y_trn.sum(axis=0)
  y_neg = ((1 - y_trn) * m_trn).sum(axis=0)

  y_add = np.array([[1 if (m_trn[idx,idy] == 0) and (y_pos[idy] > y_neg[idy]) else 0 for idy in range(y_trn.shape[1])] for idx in range(y_trn.shape[0])])

  y_trn = y_trn + y_add

  m_trn = np.ones(m_trn.shape)

  return y_trn, m_trn




def bin2idx(omic_bin):
  """ Transfer a binarized matrix into a index matrix (for input of embedding layer).

  omic_bin: (num_sample, num_feature), each value in {0,1}
  omic_idx: 0 is used for padding, and therefore meaningful index starts from 1.

  """

  num_max_omic = omic_bin.sum(axis=1).max() # max num of mutation in a single sample
  omic_idx = np.zeros((len(omic_bin), num_max_omic), dtype=int )
  for idx, line in enumerate(omic_bin):
    line = [idy+1 for idy, val in enumerate(line) if val == 1]
    omic_idx[idx][0:len(line)] = line

  return omic_idx

def get_ptw_ids(drug_info, tgt):

  id2pw = {id:pw for id,pw in zip(drug_info.index,drug_info['Target pathway'])}
  pws = [id2pw.get(int(c),'Unknown') for c in tgt.columns]
  pw2id = {pw:id for id,pw in enumerate(list(set(pws)))}
  ptw_ids = [pw2id[pw] for pw in pws]

  return ptw_ids

def load_dataset(input_dir="data/input", drug_id=-1, shuffle_feature=False):
    tgt = pd.read_csv(os.path.join(input_dir,'gdsc.csv'),index_col=0)
    drug_info = pd.read_csv(os.path.join(input_dir,'drug_info_gdsc.csv'),index_col=0)
    ptw_ids = get_ptw_ids(drug_info,tgt)
    
    omics_data = {'mut':None, 'cnv':None, 'exp':None, 'met':None}
    for omic in omics_data.keys():
        omics_data[omic] = pd.read_csv(
            os.path.join(input_dir,omic+'_'+'gdsc.csv'), index_col=0)
    
    common_samples = [v.index for v in omics_data.values()]
    common_samples = list( set(tgt.index).intersection(*common_samples))
    
    tgt = tgt.loc[common_samples]
    for omic in omics_data.keys():
        omics_data[omic] = omics_data[omic].loc[common_samples]

    tmr = list(tgt.index) # barcodes/names of tumors
    msk = tgt.notnull().astype(int).values # mask of target data: 1->data available, 0->nan
    tgt = tgt.fillna(0).astype(int).values # fill nan element of target with 0.

    num_sample = len(tmr)
    
    omics_data_keys = list(omics_data.keys())
    for omic in omics_data_keys:
        omic_val = omics_data.pop(omic)
        omic_val = omic_val.values
        omics_data[omic+'_bin'] = omic_val
        omics_data[omic+'_idx'] = bin2idx(omics_data[omic+'_bin'])
    
    
    omics_data['tgt'] = tgt
    omics_data['msk'] = msk
    omics_data['tmr'] = tmr
    
    return omics_data, ptw_ids

def split_dataset(dataset, ratio=0.8):

  num_sample = len(dataset["tmr"])
  num_train_sample = int(num_sample*ratio)

  train_set = {k:dataset[k][0:num_train_sample] for k in dataset.keys()}
  test_set = {k:dataset[k][num_train_sample:] for k in dataset.keys()}

  return train_set, test_set

In [3]:
dataset, ptw_ids = load_dataset(input_dir='data/input', drug_id=-1)
train_set, test_set = split_dataset(dataset, ratio=0.8)
train_set['tgt'],train_set['msk'] = fill_mask(train_set['tgt'],train_set['msk'])

In [4]:
omc_size =dataset['exp_bin'].shape[1]
drg_size =dataset['tgt'].shape[1]
emb_dim = 200
train_size = len(train_set['tmr'])
test_size = len(test_set['tmr'])

In [5]:
class CadreDataset(data.Dataset):
    def __init__(self,dataset_split, phase='train'):
        
        self.dataset = dataset_split
        self.phase = phase
        
        
    
    def __len__(self):
        
        return len(self.dataset['tmr'])
    
    def __getitem__(self,index):
        
        exp_idx = self.dataset['exp_idx'][index]
        
        labels = [self.dataset['tgt'][index], self.dataset['msk'][index]]

        return exp_idx,labels
        
            
        

#dataset --> trainset testset
#dataloader --> trainloader, testloader
# dataloaders_dict = {'Train': -----, 'test':------}

In [6]:
train_dataset = CadreDataset(train_set,phase='train')
test_dataset = CadreDataset(test_set,phase='test')

In [7]:
batch_size=8
train_dataloader = data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_dataloader = data.DataLoader(test_dataset, batch_size=batch_size,shuffle=False)

dataloaders_dict = {'train':train_dataloader, 'test':test_dataloader}

In [14]:
# network define 
class CadreNet(nn.Module):
    def __init__(self,ptw_ids,drg_size,omc_size,emb_dim,device):
        super(CadreNet, self).__init__()
        
        self.ptw_ids = ptw_ids
        self.drg_size = drg_size
        
        self.drg_ids = np.array([list(range(self.drg_size))])
        self.drg_ids = torch.LongTensor(self.drg_ids)
        self.drg_ids = self.drg_ids.to(device)
        self.ptw_ids = torch.LongTensor([self.ptw_ids])
        self.ptw_ids = self.ptw_ids.to(device)
        
        
        self.encoder = Encoder(omc_size,ptw_ids)
        self.decoder = Decoder(emb_dim,drg_size)
        
        
        
    def forward(self,inputs):
        hid_omc= self.encoder(inputs,self.ptw_ids)
        logit_drg=self.decoder(hid_omc,self.drg_ids)
        
        return logit_drg
        
    
    

In [15]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device:',device)
net = CadreNet(ptw_ids, drg_size, omc_size, emb_dim,device)
net.to(device)

device: cuda:0


CadreNet(
  (encoder): Encoder(
    (layer_emb): Embedding(3001, 200, padding_idx=0)
    (layer_dropout_0): Dropout(p=0.5, inplace=False)
    (layer_w_0): Linear(in_features=200, out_features=128, bias=True)
    (layer_beta): Linear(in_features=128, out_features=8, bias=True)
    (layer_emb_ptw): Embedding(25, 128)
  )
  (decoder): Decoder(
    (layer_emb_drg): Embedding(260, 200)
  )
)

In [16]:
criterion = nn.BCEWithLogitsLoss(reduction="none")
epsilon = 1e-5
optimizer = optim.SGD(params = net.parameters(), lr = 0.001, momentum =0.9)

In [17]:
def loss_cross_entropy(lgt_drg,tgts,msks):
    loss = torch.sum(
        torch.mul(criterion(lgt_drg,tgts),msks)
        )/ (torch.sum(msks)+epsilon)
    
    return loss

def evaluate(labels,msks,preds,epsilon=1e-5):
    
    flat_labels = np.reshape(labels,-1)
    flat_preds_nr = np.reshape(preds,-1)
    flat_preds = np.reshape(np.around(preds),-1)
    flat_msks = np.reshape(msks,-1)

    flat_labels_msk = np.array([flat_labels[idx] for idx, val in enumerate(flat_msks) if val == 1])
    flat_preds_msk = np.array([flat_preds[idx] for idx, val in enumerate(flat_msks) if val == 1])
    flat_preds_nr_msk = np.array([flat_preds_nr[idx] for idx, val in enumerate(flat_msks) if val == 1])

    accuracy = np.mean(flat_labels_msk == flat_preds_msk)
    true_pos = np.dot(flat_labels_msk, flat_preds_msk)
    precision = 1.0*true_pos/(flat_preds_msk.sum()+epsilon)
    recall = 1.0*true_pos/(flat_labels_msk.sum()+epsilon)

    f1score = 2*precision*recall/(precision+recall+epsilon)

    # a bug fixed
    fpr, tpr, _ = roc_curve(flat_labels_msk, flat_preds_nr_msk)
    auc_val = auc(fpr, tpr)

    return precision, recall, f1score,accuracy, auc_val

In [24]:
def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):
    
    
    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch+1, num_epochs))
        print('-------------')
        
        for phase in ['train','test']:
            
                
            epoch_loss = 0.0
            epoch_corrects = 0
            
            if (epoch == 0) and (phase == 'train'):
                continue
                
            for inputs,labels in tqdm(dataloaders_dict[phase]):
                inputs=inputs.to(device)
                labels[0],labels[1] = labels[0].float().to(device) , labels[1].float().to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase=='train'):
                    outputs = net(inputs)
                    loss_ent = loss_cross_entropy(outputs,labels[0],labels[1])
                    #labels[0] = tgts , labels[1] = msks
                    loss = loss_ent
                    preds = torch.sigmoid(outputs)
                    
    
                
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                    
                epoch_loss += loss.item() * inputs.size(0)
                #epoch_corrects += correct_num(labels[0],labels[1],preds)
                preds = preds.detach().cpu().numpy()
                labels[0],labels[1] = labels[0].detach().cpu().numpy(),labels[1].detach().cpu().numpy()
                precision, recall, f1score, accuracy, auc = evaluate(labels[0], labels[1], preds)
            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            #epoch_acc = epoch_corrects/ len(dataloaders_dict[phase].dataset)
           
            print('{} Loss: {:.4f} Acc: {:.4f} auc: {:.4f} f1: {:.4f}'.format(phase,epoch_loss,accuracy,auc*100.0,f1score*100.0))

In [None]:
num_epochs = 5
train_model(net,dataloaders_dict, criterion, optimizer, num_epochs)

  5%|▍         | 1/22 [00:00<00:03,  6.06it/s]

Epoch 1/5
-------------


100%|██████████| 22/22 [00:02<00:00,  7.85it/s]
  0%|          | 0/85 [00:00<?, ?it/s]

test Loss: 5.6675 Acc: 0.5070 auc: 50.6716 f1: 51.1516
Epoch 2/5
-------------


100%|██████████| 85/85 [00:23<00:00,  3.63it/s]
  5%|▍         | 1/22 [00:00<00:02,  7.61it/s]

train Loss: 5.7456 Acc: 0.5038 auc: 50.5887 f1: 40.9606


100%|██████████| 22/22 [00:02<00:00,  7.85it/s]
  0%|          | 0/85 [00:00<?, ?it/s]

test Loss: 5.5944 Acc: 0.4814 auc: 51.0515 f1: 46.7775
Epoch 3/5
-------------


100%|██████████| 85/85 [00:23<00:00,  3.60it/s]
  5%|▍         | 1/22 [00:00<00:02,  7.46it/s]

train Loss: 5.6153 Acc: 0.5019 auc: 52.2893 f1: 45.3582


100%|██████████| 22/22 [00:02<00:00,  7.82it/s]
  0%|          | 0/85 [00:00<?, ?it/s]

test Loss: 5.4867 Acc: 0.4953 auc: 52.0017 f1: 46.9433
Epoch 4/5
-------------


 95%|█████████▌| 81/85 [00:22<00:01,  3.53it/s]