In [6]:
import os
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

DATA_PATH = "./resource"


In [3]:
import pickle
import numpy as np

def save_pkl(path, obj):
  with open(path, 'wb') as f:
    pickle.dump(obj, f)
    print(" [*] save %s" % path)

def load_pkl(path):
  with open(path,'rb') as f:
    obj = pickle.load(f)
    print(" [*] load %s" % path)
    return obj

def save_npy(path, obj):
  np.save(path, obj)
  print(" [*] save %s" % path)

def load_npy(path):
  obj = np.load(path)
  print(" [*] load %s" % path)
  return obj

In [15]:
vocab = load_pkl(DATA_PATH + '/vocab.pkl')
TOTAL_NUM_CODES = len(vocab)
TOTAL_NUM_CODES

 [*] load ./resource/vocab.pkl


490

In [4]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    
    def __init__(self):
        self._data = load_pkl(DATA_PATH + '/data.pkl')
        self._label = load_pkl(DATA_PATH + '/label.pkl')


    
    def __len__(self):
        """ return the number of samples (i.e. patients). """
        return len(self._data)
    
    def __getitem__(self, index):
        data = self._data[index]
        label = self._label[index]
        return data, label

In [7]:
dataset = CustomDataset()
print('Size of dataset:', len(dataset))

 [*] load ./resource/data.pkl
 [*] load ./resource/label.pkl
Size of dataset: 3000


In [8]:
def collate_fn(data):

    sequences, labels = zip(*data)

    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)

    y = torch.zeros((num_patients, max_num_visits), dtype=torch.float)

    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    l = torch.zeros((num_patients), dtype=torch.long)
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            x[i_patient,j_visit,0:len(visit)] = torch.Tensor(visit)
            masks[i_patient,j_visit,0:len(visit)] = torch.ones(len(visit))
            rev_j = len(patient) - j_visit - 1
            rev_x[i_patient,rev_j,0:len(visit)] = torch.Tensor(visit)
            rev_masks[i_patient,rev_j,0:len(visit)] = torch.ones(len(visit))
            y[i_patient,j_visit] = labels[i_patient][j_visit]
        l[i_patient] = len(patient)

    
    return x, masks, rev_x, rev_masks, y, l

In [9]:
from torch.utils.data.dataset import random_split

split = int(len(dataset)*0.8)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 2400
Length of val dataset: 600


In [136]:
from torch.utils.data import DataLoader
def load_data(train_dataset, val_dataset, collate_fn):
    
    batch_size = 32
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)
    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

In [42]:
class AlphaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        self.a_att = nn.Linear(hidden_dim, 1)

    def forward(self, g):
        alpha = torch.softmax(self.a_att(g),dim=1)
        return alpha
    
class BetaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        self.b_att = nn.Linear(hidden_dim, hidden_dim)


    def forward(self, h):
        beta = torch.tanh(self.b_att(h))
        return beta
def apply_attention(alpha, beta, rev_v, rev_masks):
    rev_masks = torch.sum(rev_masks,dim=2)
    rev_masks = rev_masks > 0
    rev_masks = rev_masks.unsqueeze(-1)
    rev_v = rev_v * rev_masks
    #c = torch.sum(beta * alpha * rev_v, dim =1)
    c = beta * alpha * rev_v
    return c
def sum_embeddings_with_mask(x, masks):
    x = x * masks.unsqueeze(-1)
    x = torch.sum(x, dim = -2)
    return x
    

In [135]:
EMBEDDING_DIM=256


class RETAIN(nn.Module):
    
    def __init__(self, num_codes, embedding_dim=EMBEDDING_DIM):
        super().__init__()
        # Define the embedding layer using `nn.Embedding`. Set `embDimSize` to 128.
        self.embedding = nn.Embedding(num_codes, embedding_dim)
        # Define the RNN-alpha using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_a = nn.GRU(input_size=embedding_dim,hidden_size=embedding_dim, batch_first=True)
        # Define the RNN-beta using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_b = nn.GRU(input_size=embedding_dim,hidden_size=embedding_dim, batch_first=True)
        # Define the alpha-attention using `AlphaAttention()`;
        self.att_a = AlphaAttention(embedding_dim)
        # Define the beta-attention using `BetaAttention()`;
        self.att_b = BetaAttention(embedding_dim)
        # Define the linear layers using `nn.Linear()`;
        self.fc = nn.Linear(embedding_dim, 1)
        # Define the final activation layer using `nn.Sigmoid().
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, masks, rev_x, rev_masks):
        """
        Arguments:
            rev_x: the diagnosis sequence in reversed time of shape (# visits, batch_size, # diagnosis codes)
            rev_masks: the padding masks in reversed time of shape (# visits, batch_size, # diagnosis codes)

        Outputs:
            probs: probabilities of shape (batch_size)
        """
        # 1. Pass the reversed sequence through the embedding layer;
        #print(rev_x.shape)
        rev_x = self.embedding(rev_x)
        # 2. Sum the reversed embeddings for each diagnosis code up for a visit of a patient.
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
        #print(rev_x.shape)
        # 3. Pass the reversed embegginds through the RNN-alpha and RNN-beta layer separately;
        g, _ = self.rnn_a(rev_x)
        h, _ = self.rnn_b(rev_x)
        #print(g.shape)
        #print(h.shape)
        # 4. Obtain the alpha and beta attentions using `AlphaAttention()` and `BetaAttention()`;
        alpha = self.att_a(g)
        beta = self.att_b(h)
        #print(alpha.shape)
        #print(beta.shape)
        # 5. Sum the attention up using `attention_sum()`;
        c = apply_attention(alpha, beta, rev_x, rev_masks)
        #print(c.shape)
        # 6. Pass the context vector through the linear and activation layers.
        logits = self.fc(c)
        #print(logits.shape)
        probs = self.sigmoid(logits)
        #probs = torch.softmax(logits,dim=1)
        #print(probs.shape)
        return probs.squeeze()
    

# load the model here
retain = RETAIN(num_codes = TOTAL_NUM_CODES)
retain

RETAIN(
  (embedding): Embedding(490, 256)
  (rnn_a): GRU(256, 256, batch_first=True)
  (rnn_b): GRU(256, 256, batch_first=True)
  (att_a): AlphaAttention(
    (a_att): Linear(in_features=256, out_features=1, bias=True)
  )
  (att_b): BetaAttention(
    (b_att): Linear(in_features=256, out_features=256, bias=True)
  )
  (fc): Linear(in_features=256, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [130]:
from sklearn.metrics import *


#input: Y_score,Y_pred,Y_true
#output: accuracy, auc, precision, recall, f1-score
def classification_metrics(Y_score, Y_pred, Y_true):
    acc, auc, precision, recall, f1score = accuracy_score(Y_true, Y_pred), \
                                           roc_auc_score(Y_true, Y_score), \
                                           precision_score(Y_true, Y_pred), \
                                           recall_score(Y_true, Y_pred), \
                                           f1_score(Y_true, Y_pred)
    return acc, auc, precision, recall, f1score



def eval(model, val_loader):
    
    """
    Evaluate the model.
    
    Arguments:
        model: the RNN model
        val_loader: validation dataloader
        
    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
        
    REFERENCE: checkout https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
    """
    
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    for x, masks, rev_x, rev_masks, y, l in val_loader:
        y_logit = model(x, masks, rev_x, rev_masks)
#        print(y_logit)
        y_hat = (y_logit > 0.4).int()
#        print(sum(y_hat))
        for i in range(y.shape[0]):
            y_score = torch.cat((y_score, y_logit[i,:l[i]].detach().to('cpu').flatten()), dim=0)
            y_pred = torch.cat((y_pred,  y_hat[i,:l[i]].detach().to('cpu').flatten()), dim=0)
            y_true = torch.cat((y_true,  y[i,:l[i]].detach().to('cpu').flatten()), dim=0)
    
    #p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    acc, roc_auc, p, r, f = classification_metrics(y_score.detach().numpy(), 
                                                             y_pred.detach().numpy(), 
                                                             y_true.detach().numpy())

    #roc_auc = roc_auc_score(y_true, y_score)
    return p, r, f, roc_auc, acc

In [100]:
def train(model, train_loader, val_loader, n_epochs):
    """
    Train the model.
    
    Arguments:
        model: the RNN model
        train_loader: training dataloder
        val_loader: validation dataloader
        n_epochs: total number of epochs
    """
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x, masks, rev_x, rev_masks, y,l in train_loader:
            optimizer.zero_grad()
            y_hat = model(x, masks, rev_x, rev_masks)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        #evaluate(model, val_loader)
        p, r, f, roc_auc, acc = eval(model, val_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}, roc_auc: {:.2f}, acc: {:.2f}'.format(epoch+1, p, r, f, roc_auc, acc))

    return round(roc_auc, 2)

In [137]:
# load the model
retain = RETAIN(num_codes = TOTAL_NUM_CODES)

# load the loss function
criterion = nn.BCELoss()
# load the optimizer
optimizer = torch.optim.Adam(retain.parameters(), lr=1e-3)

n_epochs = 6
train(retain, train_loader, val_loader, n_epochs)

Epoch: 1 	 Training Loss: 0.684651
Epoch: 1 	 Validation p: 0.23, r:0.68, f: 0.34, roc_auc: 0.49, acc: 0.40
Epoch: 2 	 Training Loss: 0.640114
Epoch: 2 	 Validation p: 0.23, r:0.42, f: 0.30, roc_auc: 0.51, acc: 0.56
Epoch: 3 	 Training Loss: 0.601998
Epoch: 3 	 Validation p: 0.25, r:0.22, f: 0.23, roc_auc: 0.56, acc: 0.67
Epoch: 4 	 Training Loss: 0.571388
Epoch: 4 	 Validation p: 0.34, r:0.21, f: 0.26, roc_auc: 0.60, acc: 0.73
Epoch: 5 	 Training Loss: 0.546833
Epoch: 5 	 Validation p: 0.38, r:0.26, f: 0.31, roc_auc: 0.63, acc: 0.74
Epoch: 6 	 Training Loss: 0.525029
Epoch: 6 	 Validation p: 0.40, r:0.17, f: 0.24, roc_auc: 0.62, acc: 0.75


0.62