In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

import os
import time
import tqdm
import pandas as pd
from copy import deepcopy
from typing import Dict

from sklearn.metrics import confusion_matrix

import math
import pickle

In [18]:
all_pat = pickle.load(open('all_pat.pkl', 'rb'))
all_visits = pickle.load(open('all_visits.pkl', 'rb'))
all_diags = pickle.load(open('all_diags.pkl', 'rb'))
mortality = pickle.load(open('mortality.pkl', 'rb'))
icd_dic = pickle.load(open('icd_dic.pkl', 'rb'))

In [19]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, seqs, hfs):
        
        """
        TODO: Store `seqs`. to `self.x` and `hfs` to `self.y`.
        
        Note that you DO NOT need to covert them to tensor as we will do this later.
        Do NOT permute the data.
        """
        
        # your code here
        self.x = seqs
        self.y = hfs
#         raise NotImplementedError
    
    def __len__(self):
        
        """
        TODO: Return the number of samples (i.e. patients).
        """
        
        # your code here
        return len(self.y)
#         raise NotImplementedError
    
    def __getitem__(self, index):
        
        """
        TODO: Generates one sample of data.
        
        Note that you DO NOT need to covert them to tensor as we will do this later.
        """
        
        # your code here
        seq = self.x[index]
        hf = self.y[index]

        return seq, hf
#         raise NotImplementedError
        

dataset = CustomDataset(all_diags, mortality)

In [20]:
def collate_fn(data):


    sequences, labels = zip(*data)
    
    y = torch.tensor(labels, dtype=torch.float)
#     print(y)
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
#     print("MAX", max_num_visits, max_num_codes)
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):
#         print("AT First", patient)
        for j_visit, visit in enumerate(patient):
            """
            TODO: update `x`, `rev_x`, `masks`, and `rev_masks`
            """
            # your code here
            
            for k_diag, diag in enumerate(visit):
                if diag != diag:
                    continue
                else:   
                    x[i_patient, j_visit,k_diag ] = icd_dic[diag]
                    
                    masks[i_patient, j_visit,k_diag ] = 1
        
        
        for j_visit, visit in enumerate(reversed(patient)):
#             print(visit)
            for k_diag, diag in enumerate(visit):
                if diag != diag:
                    continue
                else: 
                    rev_x[i_patient, j_visit,k_diag ] = icd_dic[diag]
                    rev_masks[i_patient, j_visit,k_diag ] = 1
        for i in range(num_visits[i_patient], max_num_visits):
            rev_masks[i_patient, i, :] = torch.zeros(max_num_codes)
    
    return x, masks, rev_x, rev_masks, y

In [21]:
from torch.utils.data.dataset import random_split

split = int(len(dataset)*0.7)

lengths = [split, len(dataset) - split]
train_dataset, val_test_dataset = random_split(dataset, lengths)


split = int(len(val_test_dataset)*0.7)
lengths = [split, len(val_test_dataset) - split]
val_dataset, test_dataset = random_split(val_test_dataset, lengths)


print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))
print("Length of test dataset:", len(test_dataset))

Length of train dataset: 32563
Length of val dataset: 9769
Length of test dataset: 4188


In [22]:
# from torch.utils.data.dataset import random_split

# split = int(len(dataset)*0.8)

# lengths = [split, len(dataset) - split]
# train_dataset, val_dataset = random_split(dataset, lengths)

# print("Length of train dataset:", len(train_dataset))
# print("Length of val dataset:", len(val_dataset))

In [23]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset,test_dataset, collate_fn):
    
    '''
    TODO: Implement this function to return the data loader for  train and validation dataset. 
    Set batchsize to 32. Set `shuffle=True` only for train dataloader.
    
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, val_loader: train and validation dataloaders
    
    Note that you need to pass the collate function to the data loader `collate_fn()`.
    '''
    
    batch_size = 128
    # your code here
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn)
#     raise NotImplementedError
    
    return train_loader, val_loader, test_loader


train_loader, val_loader, test_loader = load_data(train_dataset, val_dataset, test_dataset, collate_fn)

In [24]:
def sum_embeddings_with_mask(x, masks):
    """
    TODO: mask select the embeddings for true visits (not padding visits) and then
        sum the embeddings for each visit up.

    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
        
    NOTE: Do NOT use for loop.

    """
    
    # your code here
    
    masks = torch.unsqueeze(masks, 3)
    masks[:,:,:,0] = 1
    masks = masks.expand(-1,-1,-1, x.shape[3])
#     print(x.shape, masks.shape)
    return torch.sum (x * masks, 2)
#     raise NotImplementedError

In [25]:
def get_last_visit(hidden_states, masks):
    """
    TODO: obtain the hidden state for the last true visit (not padding visits)

    Arguments:
        hidden_states: the hidden states of each visit of shape (batch_size, # visits, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        last_hidden_state: the hidden state for the last true visit of shape (batch_size, embedding_dim)
        
    NOTE: DO NOT use for loop.
    
    HINT: First convert the mask to a vector of shape (batch_size,) containing the true visit length; 
          and then use this length vector as index to select the last visit.
    """
    
    # your code here
    b, v, e = hidden_states.shape
    count = torch.count_nonzero(masks, 2)
    pad = torch.zeros(b, v+1)
    pad[:, :pad.shape[1]-1] = count + pad[:, :pad.shape[1]-1]
    num_visits = torch.argmin(pad,1)
    return hidden_states[range(b), num_visits-1,:]
#     raise NotImplementedError

In [26]:
class NaiveRNN(nn.Module):
    
    """
    TODO: implement the naive RNN model above.
    """
    
    def __init__(self, num_codes):
        super().__init__()
        """
        TODO: 
            1. Define the embedding layer using `nn.Embedding`. Set `embDimSize` to 128.
            2. Define the RNN using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
            2. Define the RNN for the reverse direction using `nn.GRU()`;
               Set `hidden_size` to 128. Set `batch_first` to True.
            3. Define the linear layers using `nn.Linear()`; Set `in_features` to 256, and `out_features` to 1.
            4. Define the final activation layer using `nn.Sigmoid().

        Arguments:
            num_codes: total number of diagnosis codes
        """

        # your code here
        self.embedding = nn.Embedding(embedding_dim = 128, num_embeddings=num_codes)
        self.rnn = nn.GRU(128, hidden_size = 128, batch_first=True)
        self.rev_rnn = nn.GRU(128, hidden_size=128, batch_first = True)
        self.fc = nn.Linear(256, 1)
        self.sigmoid = nn.Sigmoid()
#         raise NotImplementedError
    
    def forward(self, x, masks, rev_x, rev_masks):
        """
        Arguments:
            x: the diagnosis sequence of shape (batch_size, # visits, # diagnosis codes)
            masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

        Outputs:
            probs: probabilities of shape (batch_size)
        """
        
        batch_size = x.shape[0]
        
        # 1. Pass the sequence through the embedding layer;
        x = self.embedding(x)
        # 2. Sum the embeddings for each diagnosis code up for a visit of a patient.
        x = sum_embeddings_with_mask(x, masks)
        
        # 3. Pass the embegginds through the RNN layer;
        output, _ = self.rnn(x)
        # 4. Obtain the hidden state at the last visit.
        true_h_n = get_last_visit(output, masks)
        
        """
        TODO:
            5. Do the step 1-4 again for the reverse order (rev_x), and concatenate the hidden
               states for both directions;
        """
        true_h_n_rev = None
        # your code here
        rev_x = self.embedding(rev_x)
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
        output, _ = self.rnn(rev_x)
        true_h_n_rev = get_last_visit(output, rev_masks)
#         raise NotImplementedError
        
        # 6. Pass the hidden state through the linear and activation layers.
        logits = self.fc(torch.cat([true_h_n, true_h_n_rev], 1))        
        probs = self.sigmoid(logits)
        return probs.view(batch_size)
    
# load the model here
naive_rnn = NaiveRNN(num_codes = len(icd_dic))
naive_rnn

NaiveRNN(
  (embedding): Embedding(6985, 128)
  (rnn): GRU(128, 128, batch_first=True)
  (rev_rnn): GRU(128, 128, batch_first=True)
  (fc): Linear(in_features=256, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [27]:

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(naive_rnn.parameters(), lr=0.001 )

In [28]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score


def eval_model(model, val_loader):
    
    """
    TODO: evaluate the model.
    
    Arguments:
        model: the RNN model
        val_loader: validation dataloader
        
    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
        
    Note that please pass all four arguments to the model so that we can use this function for both 
    models. (Use `model(x, masks, rev_x, rev_masks)`.)
        
    HINT: checkout https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
    """
    
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    for x, masks, rev_x, rev_masks, y in val_loader:
        y_hat = model(x, masks, rev_x, rev_masks)
        y_score = torch.cat((y_score,  y_hat.detach().to('cpu')), dim=0)
        y_hat = (y_hat > 0.5).int()
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
        
    # your code here
    p, r, f,_ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    roc_auc = roc_auc_score(y_true, y_score)
#     raise NotImplementedError
    return p, r, f, roc_auc

In [29]:
def train(model, train_loader, val_loader, n_epochs):
    """
    TODO: train the model.
    
    Arguments:
        model: the RNN model
        train_loader: training dataloder
        val_loader: validation dataloader
        n_epochs: total number of epochs
        
    You need to call `eval_model()` at the end of each training epoch to see how well the model performs 
    on validation data.
        
    Note that please pass all four arguments to the model so that we can use this function for both 
    models. (Use `model(x, masks, rev_x, rev_masks)`.)
    """
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x, masks, rev_x, rev_masks, y in train_loader:
            """
            TODO:
                1. zero grad
                2. model forward
                3. calculate loss
                4. loss backward
                5. optimizer step
            """
            loss = None
            # your code here
            optimizer.zero_grad()
            outputs = model.forward(x, masks, rev_x, rev_masks)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
#             raise NotImplementedError
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f, roc_auc = eval_model(model, val_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}, roc_auc: {:.2f}'
              .format(epoch+1, p, r, f, roc_auc))
        experiment.log_metric("precision", p)
        experiment.log_metric("ROC", roc_auc)

In [30]:
from comet_ml import Experiment

# Initialize Comet.ml with your API key
experiment = Experiment(api_key="dGXMo2OT0jVdMZPp2KVui30tD", project_name="general")

n_epochs = 5
train(naive_rnn, train_loader, val_loader, n_epochs)

COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.com/sayan3/general/c825e7c9251e49f9b5a75ae5ae2c1a8b
COMET INFO:   Uploads:
COMET INFO:     environment details : 1
COMET INFO:     filename            : 1
COMET INFO:     installed packages  : 1
COMET INFO:     notebook            : 1
COMET INFO:     source_code         : 1
COMET INFO: ---------------------------
COMET INFO: Couldn't find a Git repository in '/Users/AIUDD75/Library/CloudStorage/OneDrive-AmwayCorp/MCS/DL598/Final Project' nor in any parent directory. You can override where Comet is looking for a Git Patch by setting the configuration `COMET_GIT_DIRECTORY`
COMET INFO: Experiment is live on comet.com https://www.comet.com/sayan3/general/81632d2605cd4c6db74a234ecd91843f



Epoch: 1 	 Training Loss: 0.358707
Epoch: 1 	 Validation p: 0.58, r:0.01, f: 0.02, roc_auc: 0.81
Epoch: 2 	 Training Loss: 0.284397
Epoch: 2 	 Validation p: 0.64, r:0.26, f: 0.37, roc_auc: 0.88
Epoch: 3 	 Training Loss: 0.241885
Epoch: 3 	 Validation p: 0.62, r:0.51, f: 0.56, roc_auc: 0.90
Epoch: 4 	 Training Loss: 0.219077
Epoch: 4 	 Validation p: 0.61, r:0.52, f: 0.56, roc_auc: 0.90
Epoch: 5 	 Training Loss: 0.204792
Epoch: 5 	 Validation p: 0.62, r:0.55, f: 0.58, roc_auc: 0.91


In [32]:
p, r, f, roc_auc = eval_model(naive_rnn, test_loader)
print(roc_auc)

0.8959743268134557
