In [39]:
import sys,os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from google.colab import drive, files
import pickle as pickle


In [40]:
drive.mount('/content/drive')



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [41]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if __name__=='__main__':
    print('Using device:', device)

Using device: cuda


In [42]:
DATA_PATH = '/content/drive/My Drive/BiteNetProject/data_processing/'

data = pickle.load(open(os.path.join(DATA_PATH,'data.pkl'), 'rb'))


In [43]:
## list of patient visits, where each visit is a list of medical codes 
seqs= [i[2] for i in data]

## remove 0s from visits and patients - we will pad with mask later
for patient_i, patient in enumerate(seqs):
  seqs[patient_i] = [visit for visit in patient if sum(visit)>0]

for patient_i, patient in enumerate(seqs):
  for visit_j, visit in enumerate(patient):
    patient[visit_j] = [medcode for medcode in visit if medcode>0]

## number of unique medical codes
num_codes = max(set([code for visits in seqs for visit in visits for code in visit])) + 1

## target label of diagnosis
diagnosis = [i[3] for i in data]


assert len(seqs) == len(diagnosis)

In [44]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, seqs, diagnosis):
        self.x = seqs
        self.y = diagnosis
    
    def __len__(self):
        
        return len(self.x)
    
    def __getitem__(self, index):

        return self.x[index],self.y[index]
        
data = CustomDataset(seqs, diagnosis)


In [45]:
from torch.utils.data.dataset import random_split

train_test_split = int(len(data)*0.8)
lengths = [train_test_split, len(data) - train_test_split]
train_data, test_data = random_split(data, lengths)


train_val_split = int(len(train_data)*0.5)
lengths = [train_val_split, len(train_data) - train_val_split]
train_data, val_data = random_split(train_data, lengths)

print(train_data)
print("Length of train dataset:", len(train_data))
print("Length of val dataset:", len(val_data))
print("Length of test dataset:", len(test_data))



<torch.utils.data.dataset.Subset object at 0x7f775f630890>
Length of train dataset: 2998
Length of val dataset: 2998
Length of test dataset: 1500


In [46]:
def collate_fn(data):
    sequences, labels = zip(*data)
   
    y = torch.tensor(labels, dtype=torch.float)
    
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]

    max_num_visits = max(num_visits) ##10
    max_num_codes = max(num_codes) ##39
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)

    #Pad visits 
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            codes_needed = max_num_codes - len(sequences[i_patient][j_visit])
            codes_padding = torch.tensor(([0] * codes_needed),dtype=torch.long)
            original_visits = torch.tensor(sequences[i_patient][j_visit],dtype=torch.long)
            x[i_patient][j_visit] = torch.cat([original_visits,codes_padding],0)
            
    #Pad codes within visits
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            curr_codes = len(sequences[i_patient][j_visit])
            num_codes_needed = max_num_codes - curr_codes
            mask_real_portion = sequences[i_patient][j_visit]
            mask_padded_portion = [0] * num_codes_needed
            masks_total = mask_real_portion + mask_padded_portion
            masks[i_patient][j_visit] = torch.Tensor(masks_total)
    
    #Create mask     
    fake_visits_map = {}
    
    for i_patient in range(len(masks)):
        for j_visit in range(len(masks[i_patient])):
            if torch.sum(masks[i_patient][j_visit]) == torch.tensor(0):
                fake_visits_map[i_patient] = j_visit
                break

    #Create rev_x
    rev_x = torch.clone(x)  
    
    for i_patient in range(len(x)):
        for j_visit in range(len(x[i_patient])):
            if i_patient in fake_visits_map:
                first_fake = fake_visits_map[i_patient]
                rev_x[i_patient][:first_fake] = torch.flip(rev_x[i_patient][:first_fake],[0])
            else:
                rev_x[i_patient] = torch.flip(rev_x[i_patient],[0])
    
    #Create rev_mask for rev_x
    rev_masks = torch.clone(masks)

    for i_patient in range(len(masks)):
        for j_visit in range(len(masks[i_patient])):
            if i_patient in fake_visits_map:
                first_fake = fake_visits_map[i_patient]
                rev_masks[i_patient][:first_fake]= torch.flip(rev_masks[i_patient][:first_fake],[0])
            else:
                rev_masks[i_patient] = torch.flip(rev_masks[i_patient],[0])   
    
        
    #print("x.dtype",x.dtype,"rev_x.dtype",rev_x.dtype)
    return x, masks, rev_x, rev_masks, y

In [47]:
from torch.utils.data import DataLoader

def load_data(train_data, val_data, test_data, collate_fn):
    
    batch_size = 32
    ## iter will get a batch of size 32 [10 visits x 39 codes ] 

    train_loader = DataLoader(dataset = train_data, batch_size = 32, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(dataset = val_data, batch_size = 32, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(dataset = test_data, batch_size = 32, shuffle=False, collate_fn=collate_fn)

    
    return train_loader, val_loader, test_loader


train_loader, val_loader, test_loader = load_data(train_data, val_data, test_data, collate_fn)






In [48]:
def sum_embeddings_with_mask(x, masks):

    x = x * masks.unsqueeze(-1)
    x = torch.sum(x, dim = -2)
    return x

In [49]:
class AlphaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        self.a_att = nn.Linear(hidden_dim, 1)

    def forward(self, g):
        linear_a_att = self.a_att(g)
        return torch.softmax(linear_a_att, dim = 1)

In [50]:
class BetaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        self.b_att = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, h):
        linear_b_att = self.b_att(h)
        return torch.tanh(linear_b_att)
    

In [51]:
def attention_sum(alpha, beta, rev_v, rev_masks):

    rev_v_embed_dim = rev_v.shape[-1]
    beta_hidden_dim = beta.shape[-1]
    
    rev_masks = torch.sum(rev_masks,dim=2)
    rev_masks = rev_masks.unsqueeze(2)

    rev_masks = rev_masks.expand(-1,-1,rev_v_embed_dim)

    rev_masks = rev_masks.apply_(lambda x: min(x,1))
    
    true_visits = torch.mul(rev_masks,rev_v)


    alpha = alpha.expand(-1,-1,beta_hidden_dim)
    attention = torch.mul(beta,alpha)

    
    attention = torch.mul(attention,true_visits) 
    
    attention_sum = torch.sum(attention,1)
    


    return attention_sum

    

In [52]:
class RETAIN(nn.Module):
    
    def __init__(self, num_codes, embedding_dim=128):
        super().__init__()

        self.embedding = nn.Embedding(num_codes, embedding_dim)
   
        self.rnn_a = nn.GRU(embedding_dim, embedding_dim, batch_first=True)

        self.rnn_b = nn.GRU(embedding_dim, embedding_dim, batch_first=True)

        self.att_a = AlphaAttention(embedding_dim)

        self.att_b = BetaAttention(embedding_dim)

        self.fc = nn.Linear(embedding_dim, 170)

        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, masks, rev_x, rev_masks):


  
        rev_x = self.embedding(rev_x)

        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)

        g, _ = self.rnn_a(rev_x)
        h, _ = self.rnn_b(rev_x)

        alpha = self.att_a(g)
        beta = self.att_b(h)

        c = attention_sum(alpha, beta, rev_x, rev_masks)

        logits = self.fc(c)
        probs = self.sigmoid(logits)
        return probs.squeeze()
    

# load the model here
model = RETAIN(num_codes = num_codes)
model

RETAIN(
  (embedding): Embedding(3874, 128)
  (rnn_a): GRU(128, 128, batch_first=True)
  (rnn_b): GRU(128, 128, batch_first=True)
  (att_a): AlphaAttention(
    (a_att): Linear(in_features=128, out_features=1, bias=True)
  )
  (att_b): BetaAttention(
    (b_att): Linear(in_features=128, out_features=128, bias=True)
  )
  (fc): Linear(in_features=128, out_features=170, bias=True)
  (sigmoid): Sigmoid()
)

In [53]:
import torch.optim as optim

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)



In [54]:
from sklearn.utils.validation import indexable
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import precision_recall_curve, auc
from sklearn.metrics import precision_score

def eval_model(model, val_loader):
    
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    for x, masks, rev_x, rev_masks, y in val_loader:
        y_hat = model(x,masks,rev_x,rev_masks)
        y_hat = y_hat > 0.5
        y_hat = y_hat.int()
        y = y.int()
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    
    precision = precision_score(y_true,y_pred,average="micro")


    return precision


In [55]:
def train(model, train_loader, val_loader, n_epochs, print_train_results=True):
    for epoch in range(n_epochs):
      model.train()
      train_loss = 0
      for x, masks, rev_x, rev_masks, y in train_loader:
        loss = None
        optimizer.zero_grad()
        y_hat = model(x,masks,rev_x,rev_masks)

        loss = criterion(y_hat,y)
        loss.backward()
        optimizer.step()
        # your code here
        
        train_loss += loss.item()
      train_loss = train_loss / len(train_loader)
      if print_train_results==True:
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
      precision = eval_model(model, val_loader)
      if print_train_results==True:
        print('Epoch: {} \t Validation overall precision p:{:.3f}'.format(epoch+1,precision))
      

In [56]:
n_epochs = 10
train(model, train_loader, val_loader, n_epochs)

Epoch: 1 	 Training Loss: 0.348740
Epoch: 1 	 Validation overall precision p:0.493
Epoch: 2 	 Training Loss: 0.171251
Epoch: 2 	 Validation overall precision p:0.572
Epoch: 3 	 Training Loss: 0.151813
Epoch: 3 	 Validation overall precision p:0.615
Epoch: 4 	 Training Loss: 0.140846
Epoch: 4 	 Validation overall precision p:0.615
Epoch: 5 	 Training Loss: 0.132778
Epoch: 5 	 Validation overall precision p:0.613
Epoch: 6 	 Training Loss: 0.125954
Epoch: 6 	 Validation overall precision p:0.601
Epoch: 7 	 Training Loss: 0.120030
Epoch: 7 	 Validation overall precision p:0.586
Epoch: 8 	 Training Loss: 0.114157
Epoch: 8 	 Validation overall precision p:0.568
Epoch: 9 	 Training Loss: 0.108900
Epoch: 9 	 Validation overall precision p:0.557
Epoch: 10 	 Training Loss: 0.103407
Epoch: 10 	 Validation overall precision p:0.550


In [57]:
def test(model, test_loader, test_number):
      precision = eval_model(model, test_loader)
      
      
      print('Test: test_number{} \t Test precision :{:.3f}'
              .format(test_number+1,precision))

In [58]:
test_number = 3

for i in range(test_number):
  train_test_split = int(len(data)*0.8)
  lengths = [train_test_split, len(data) - train_test_split]
  train_data, test_data = random_split(data, lengths)


  train_val_split = int(len(train_data)*0.5)
  lengths = [train_val_split, len(train_data) - train_val_split]
  train_data, val_data = random_split(train_data, lengths)

  train_loader, val_loader, test_loader = load_data(train_data, val_data, test_data, collate_fn)

  newmodel = RETAIN(num_codes = num_codes)
  criterion = nn.BCELoss()
  optimizer = optim.Adam(newmodel.parameters(), lr=0.001)

  n_epochs = 10
  train(newmodel, train_loader, val_loader, n_epochs,print_train_results=False)
  test(newmodel, test_loader, i)

Test: test_number1 	 Test precision :0.545
Test: test_number2 	 Test precision :0.516
Test: test_number3 	 Test precision :0.547
