In [113]:
import sys,os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from google.colab import drive, files
import pickle as pickle


In [114]:
drive.mount('/content/drive')



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [115]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if __name__=='__main__':
    print('Using device:', device)

Using device: cuda


In [116]:
DATA_PATH = '/content/drive/My Drive/BiteNetProject/data_processing/'

data = pickle.load(open(os.path.join(DATA_PATH,'data.pkl'), 'rb'))


In [117]:
## list of patient visits, where each visit is a list of medical codes 
seqs= [i[2] for i in data]


## target label of readmission 
diag = [i[3] for i in data]


## number of unique medical codes
num_codes = max(set([code for visits in seqs for visit in visits for code in visit])) + 1


print(num_codes)



assert len(seqs) == len(diag)

3874


In [118]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    def __init__(self, seqs, diag):
        self.x = seqs
        self.y = diag
    
    def __len__(self):
        
        return len(self.x)
    
    def __getitem__(self, index):

        return self.x[index],self.y[index]
        
data = CustomDataset(seqs, diag)


In [119]:
from torch.utils.data.dataset import random_split

train_test_split = int(len(data)*0.8)
lengths = [train_test_split, len(data) - train_test_split]
train_data, test_data = random_split(data, lengths)


train_val_split = int(len(train_data)*0.5)
lengths = [train_val_split, len(train_data) - train_val_split]
train_data, val_data = random_split(train_data, lengths)

print(train_data)
print("Length of train dataset:", len(train_data))
print("Length of val dataset:", len(val_data))
print("Length of test dataset:", len(test_data))



<torch.utils.data.dataset.Subset object at 0x7f57bf0f9050>
Length of train dataset: 2998
Length of val dataset: 2998
Length of test dataset: 1500


In [120]:
def collate_fn(data):
    """
    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis
        sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation
        is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
        
    Note that you can obtains the list of diagnosis codes and the list of hf labels
        using: `sequences, labels = zip(*data)`
    """

    sequences, labels = zip(*data)
   

    y = torch.tensor(labels, dtype=torch.float)
    
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)

    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)

    
            
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            """
            TODO: update `x`, `rev_x`, `masks`, and `rev_masks`
            """
            codes_needed = max_num_codes - len(sequences[i_patient][j_visit])
            
            codes_padding = torch.tensor(([0] * codes_needed),dtype=torch.long)
            
            original_visits = torch.tensor(sequences[i_patient][j_visit],dtype=torch.long)
        
            x[i_patient][j_visit] = torch.cat([original_visits,codes_padding],0)
            

    
    ##Construct Masks
    
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            #print(sequences[i_patient])
            
            curr_codes = len(sequences[i_patient][j_visit])
            
            num_codes_needed = max_num_codes - curr_codes
            
            #print(sequences[i_patient][j_visit])
            #print("max_codes",max_num_codes,"curr_codes",curr_codes,"num_codes_needed",num_codes_needed,)
            
            mask_real_portion = sequences[i_patient][j_visit]
            mask_padded_portion = [0] * num_codes_needed
            
            masks_total = mask_real_portion + mask_padded_portion
            
            masks[i_patient][j_visit] = torch.Tensor(masks_total)
            
            

    
    fake_visits_map = {}
    
    for i_patient in range(len(masks)):
        for j_visit in range(len(masks[i_patient])):
            if torch.sum(masks[i_patient][j_visit]) == torch.tensor(0):
                
                fake_visits_map[i_patient] = j_visit
                break

    
    rev_x = torch.clone(x)  
    
    for i_patient in range(len(x)):
        for j_visit in range(len(x[i_patient])):
            if i_patient in fake_visits_map:
                first_fake = fake_visits_map[i_patient]
                rev_x[i_patient][:first_fake] = torch.flip(rev_x[i_patient][:first_fake],[0])
            else:
                rev_x[i_patient] = torch.flip(rev_x[i_patient],[0])
    

    
    rev_masks = torch.clone(masks)

    
    for i_patient in range(len(masks)):
        for j_visit in range(len(masks[i_patient])):
            if i_patient in fake_visits_map:
                first_fake = fake_visits_map[i_patient]
                rev_masks[i_patient][:first_fake]= torch.flip(rev_masks[i_patient][:first_fake],[0])
            else:
                rev_masks[i_patient] = torch.flip(rev_masks[i_patient],[0])   
    
        
    #print("x.dtype",x.dtype,"rev_x.dtype",rev_x.dtype)
    return x, rev_x, y

In [121]:
from torch.utils.data import DataLoader

def load_data(train_data, val_data, test_data, collate_fn):
    
    batch_size = 32
    ## iter will get a batch of size 32 [10 visits x 39 codes ] 

    train_loader = DataLoader(dataset = train_data, batch_size = 32, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(dataset = val_data, batch_size = 32, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(dataset = test_data, batch_size = 32, shuffle=False, collate_fn=collate_fn)

    
    return train_loader, val_loader, test_loader


train_loader, val_loader, test_loader = load_data(train_data, val_data, test_data, collate_fn)






In [122]:
def get_last_visit_state(x,hidden_states):
  ## for each patient in batch, get index of last visit
  ## input x 32,10,39  hidden states 32 10 128
  ## output 32 1

  x = torch.sum(x,2)


  index = torch.zeros(x.shape,dtype=torch.int64)
  index[x!=0]=1
  index = torch.sum(index,1)
  index = torch.add(index,-1)

  last_visit_state = hidden_states[range(len(hidden_states)),index,:]
  
  return last_visit_state





In [123]:
def mask_sum(new_x,original_x):
  ## originalx 32 10 39 
  ## newx 32 10 39 128
  ## output 32 10 128

  mask = torch.ones(new_x.shape,dtype=torch.int64)

  for i in range(new_x.shape[0]):
    for j in range(new_x.shape[1]):
      if torch.sum(original_x[i,j,:]) == 0:
        mask[i,j] = torch.zeros(new_x.shape[2],new_x.shape[3])

  new_x = torch.mul(new_x,mask)

  new_x = torch.sum(new_x,2)

  return new_x


In [124]:
class BRNN(nn.Module):
    
    def __init__(self, num_codes, dropout = 0.1):
        super().__init__()

        self.embedding_medcode = nn.Embedding(num_embeddings = num_codes, embedding_dim = 128)
        self.rnn_medcode = nn.GRU(128, hidden_size = 128, dropout = dropout, bidirectional = False, batch_first = True)
        self.rev_rnn_medcode = nn.GRU(128, hidden_size = 128, dropout = dropout, bidirectional = False, batch_first = True)
        self.fc = nn.Linear(256,170)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, rev_x): 

        batch_size = x.shape[0]
        
        original_x = x         ##print("start x",x.shape) ## 32,10,39 ##Each of the 39s is a medical code
        x = self.embedding_medcode(x) ##print("post embedding",x.shape) ## 32, 10, 39, 128
        x = mask_sum(x,original_x) ##print("after sum",x.shape) ## 32 10 128
        rnn_medcode_output, last_h_n = self.rnn_medcode(x) #print("after rnn output",rnn_medcode_output.shape )  ## 32 10 128   
        rnn_medcode_last_hs = get_last_visit_state(original_x,rnn_medcode_output) ##True last hidden state 32 128
        #print(rnn_medcode_last_hs.shape)

        original_rev_x = rev_x
        rev_x = self.embedding_medcode(rev_x)
        rev_x = mask_sum(rev_x,original_rev_x)
        rev_rnn_medcode_output, rev_last_h_n = self.rev_rnn_medcode(rev_x) #print("after rnn output",rnn_medcode_output.shape )  ## 32 10 128  
        rev_rnn_medcode_last_hs = rev_rnn_medcode_output[:,0,:]
        

        
        logits = self.fc(torch.cat([rnn_medcode_last_hs,rev_rnn_medcode_last_hs],1))     #print("after linear layer shape",logits)    ## 32 1 
        
        probs = self.sigmoid(logits)


        return probs


model = BRNN(num_codes = num_codes)
model


  "num_layers={}".format(dropout, num_layers))


BRNN(
  (embedding_medcode): Embedding(3874, 128)
  (rnn_medcode): GRU(128, 128, batch_first=True, dropout=0.1)
  (rev_rnn_medcode): GRU(128, 128, batch_first=True, dropout=0.1)
  (fc): Linear(in_features=256, out_features=170, bias=True)
  (sigmoid): Sigmoid()
)

In [125]:
import torch.optim as optim

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)



In [126]:
from sklearn.utils.validation import indexable
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import precision_recall_curve, auc
from sklearn.metrics import precision_score

def eval_model(model, loader):
    model.eval()
    y_pred = torch.LongTensor()
    y_true = torch.LongTensor()
    model.eval()
    for x, rev_x, y in loader:
        y_hat = model(x,rev_x)
        y_hat = y_hat > 0.5
        y_hat = y_hat.int()
        y = y.int()
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    
    precision = precision_score(y_true,y_pred,average="micro")


    return precision

In [127]:
def train(model, train_loader, val_loader, n_epochs, print_train_results=True):
    for epoch in range(n_epochs):
      model.train()
      train_loss = 0
      for x, rev_x, y in train_loader:
        loss = None
        optimizer.zero_grad()
        y_hat = model(x,rev_x)
        
        
        loss = criterion(y_hat,y)
        loss.backward()
        optimizer.step()
        # your code here
        
        train_loss += loss.item()
      train_loss = train_loss / len(train_loader)
      if print_train_results==True:
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
      precision = eval_model(model, val_loader)
      if print_train_results==True:
        print('Epoch: {} \t Validation precision:{:.3f}'
              .format(epoch+1,precision))
      


In [128]:
n_epochs = 10
train(model, train_loader, val_loader, n_epochs)



Epoch: 1 	 Training Loss: 0.200485
Epoch: 1 	 Validation precision:0.610
Epoch: 2 	 Training Loss: 0.157627
Epoch: 2 	 Validation precision:0.583
Epoch: 3 	 Training Loss: 0.156384
Epoch: 3 	 Validation precision:0.659
Epoch: 4 	 Training Loss: 0.155538
Epoch: 4 	 Validation precision:0.608
Epoch: 5 	 Training Loss: 0.154000
Epoch: 5 	 Validation precision:0.675
Epoch: 6 	 Training Loss: 0.152877
Epoch: 6 	 Validation precision:0.652
Epoch: 7 	 Training Loss: 0.151930
Epoch: 7 	 Validation precision:0.655
Epoch: 8 	 Training Loss: 0.151269
Epoch: 8 	 Validation precision:0.666
Epoch: 9 	 Training Loss: 0.150357
Epoch: 9 	 Validation precision:0.643
Epoch: 10 	 Training Loss: 0.149620
Epoch: 10 	 Validation precision:0.666


In [129]:
def test(model, data, test_number):
      precision = eval_model(model, test_loader)
      print('Test number: {} \t test pr_auc:{:.3f}'
            .format(test_number+1,precision))
      


In [130]:
test_number = 3
for i in range(test_number):
  train_test_split = int(len(data)*0.8)
  lengths = [train_test_split, len(data) - train_test_split]
  train_data, test_data = random_split(data, lengths)


  train_val_split = int(len(train_data)*0.5)
  lengths = [train_val_split, len(train_data) - train_val_split]
  train_data, val_data = random_split(train_data, lengths)

  train_loader, val_loader, test_loader = load_data(train_data, val_data, test_data, collate_fn)

  newmodel = BRNN(num_codes = num_codes)
  criterion = nn.BCELoss()
  optimizer = optim.Adam(newmodel.parameters(), lr=0.001)

  n_epochs = 10
  train(newmodel, train_loader, val_loader, n_epochs,print_train_results=False)
  test(newmodel, test_loader, i)

  "num_layers={}".format(dropout, num_layers))


Test number: 1 	 test pr_auc:0.660


  "num_layers={}".format(dropout, num_layers))


Test number: 2 	 test pr_auc:0.655


  "num_layers={}".format(dropout, num_layers))


Test number: 3 	 test pr_auc:0.674
