<h3> Training BiLSTM Model </h3>

In [1]:
import pandas as pd
import os
import numpy as np
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

# BiLSTM Model Training

This notebook implements the BiLSTM model which will take TF-IDF based symptom representation and Word2Vec based symptom representation and predicts diagnoses codes.

---

- Input : symptom_disease_dict_{RUN_TAG}.json - Contains HADM_ID to Symptom text and Diagnosis mapping as json object
- Input : icd9_dict_{RUN_TAG}.json - Contains ICD9 Codes of TOP N Diagnoses
- Input : weight_i_j_norm{tag}.csv - TF-IDF weights for symptom representation

In [2]:
cwd = os.getcwd()
print(f"Current working directory : {cwd}")
# Let's define some constants that will be used below in our processing
MAX_NUMBER_OF_DISEASE = 50
MAX_SYMPTOMS = 50
BATCH_SIZE=400
RUN_TAG = "_v2.0"
data_dir = cwd + "/../../data/"
SYMPTOM_DISEASE_DICT_FILE_PATH = data_dir + f"symptom_disease_dict_{RUN_TAG}.json"
ICD9_FILE_PATH = data_dir + f"icd9_dict_{RUN_TAG}.json"
SYMPTOM_DICT_FILE_PATH = data_dir + f"symptoms_dict_{RUN_TAG}.json"
TF_IDF_WEIGHTS_FILE_PATH = data_dir + f"weight_i_j_{RUN_TAG}.csv"
TF_IDF_NORM_WEIGHTS_FILE_PATH = data_dir + f"weight_i_j_norm_{RUN_TAG}.csv"

Current working directory : /Users/vijaymi/Studies/CS-598-DL4Health/Project/135-Disease-Inference-Method/disease_pred_using_bilstm/source


## 1. Loading all the required data

In [3]:
icd9_dict = None
with open(ICD9_FILE_PATH, 'r') as f:
            icd9_dict = json.load(f)
        
symptom_disease_dict = None
with open(SYMPTOM_DISEASE_DICT_FILE_PATH, 'r') as f:
            symptom_disease_dict = json.load(f)

symptoms_dict = None
with open(SYMPTOM_DICT_FILE_PATH, 'r') as f:
    symptoms_dict = json.load(f)
    
tfidf_weights = pd.read_csv(TF_IDF_NORM_WEIGHTS_FILE_PATH)
tfidf_weights.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,3980,3981,3982,3983,3984,3985,3986,3987,3988,3989
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.817708,...,0.090909,0.666667,0.125,0.833333,0.666667,0.714286,0.5,1.0,0.2,1.0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.802083,...,0.272727,0.0,0.625,0.5,0.111111,0.571429,0.333333,0.5,0.4,0.5
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.545455,0.666667,0.375,0.5,0.666667,0.714286,0.833333,0.875,0.6,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.739583,...,0.090909,0.166667,0.5,0.333333,0.222222,0.285714,0.166667,0.25,0.0,0.5
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.630208,...,0.0,0.166667,0.125,0.5,0.111111,0.428571,0.5,0.875,0.2,0.625


## 2. Symptom to Disease BiLSTM Model (for TF-IDF based Symptom representation)

### 2.1 Define Custom Loader for Symtom to Diagnoses based BiLSTM Model

In [4]:
from torch.utils.data import Dataset

class SymptomToDiagnosesBiLSTMDataset(Dataset):
    
    def __init__(self, filename):        
        self.hadm_id_map = {}
        
        # TF-IDF Weights
        self.tfidf_weights = pd.read_csv(TF_IDF_NORM_WEIGHTS_FILE_PATH)
        
        # Symptom dictionary
        with open(SYMPTOM_DICT_FILE_PATH, 'r') as f:
            self.symptom_dict = json.load(f)
        
        with open(ICD9_FILE_PATH, 'r') as f:
            self.icd9_dict = json.load(f)
        
        # read in the data files
        self.hadm_list = self.process_raw_data(filename)
        
        
    def process_raw_data(self, filename):
        symptom_disease_dict = None
        with open(filename, 'r') as f:
            symptom_disease_dict = json.load(f)
        hadm_list = []

        # Collecting all records for one admission as a tuple of symptoms list and diagnoses list
        for hadm_id in symptom_disease_dict.keys():            
            symp_list, icd9_list = symptom_disease_dict[hadm_id]
            symp_vec = self.create_symptom_vector(symp_list)
            diag_vec = self.create_diagnosis_vector(icd9_list)
            
            hadm_list.append((symp_vec, diag_vec))
        
        return hadm_list
        
    def __len__(self):
        return len(self.hadm_list)
    
    def __getitem__(self, index):
        """
            Output:
            symptpm_vector : max_number_of_symptoms (50) x number_of_diagnoses (50)
            diagnoses_vector = number_of_diagnoses
            symptom_count = number of symtoms for current record
        """
        symptom_list, diagnosis_list = self.hadm_list[index]
        
        # Create symptom vector for this admission record
        symptom_vector = np.zeros((MAX_SYMPTOMS, MAX_NUMBER_OF_DISEASE))
        
        # Create Diagnosis vector to keep true labels
        diag_vector = np.zeros((MAX_NUMBER_OF_DISEASE))

        # Populate Symptom Vector by getting corresponding embeddings from TF-IDF vector
        for index, symptom_idx in enumerate(symptom_list):
            # print(f"Symptom vector index: {index}, symptom index : {symptom_idx} \n {self.tfidf_weights.iloc[:,symptom_idx]}")
            symptom_vector[index] = self.tfidf_weights.iloc[:,symptom_idx]
            
        # Populate disease vector   
        for diagnosis_index in diagnosis_list:
            diag_vector[diagnosis_index] = 1
                # print(f"icd code : {icd_code}, diagnosis_index : {diagnosis_index}")
        return torch.tensor(symptom_vector.T, dtype=torch.float), torch.tensor(diag_vector, dtype=torch.float), len(symptom_list)
    
    def create_symptom_vector(self, symp_list):
        symp_index_list = []  
        # only consider notes with symptoms count more than 1
        if len(symp_list) > 1:
            for symptom in symp_list:
                if symptom in self.symptom_dict:
                    symp_index_list.append(self.symptom_dict[symptom])
        # print(f"symp_index_list[:MAX_SYMPTOMS] -- {symp_index_list[:MAX_SYMPTOMS]}")
        return symp_index_list[:MAX_SYMPTOMS]
    
    def create_diagnosis_vector(self, diagnoses_list):
        diag_index_list = []  
        # only consider notes with symptoms count more than 1
        for diagnoses in diagnoses_list:
            if diagnoses in self.icd9_dict:
                diag_index_list.append(self.icd9_dict[diagnoses])
        return diag_index_list   


In [5]:
dataset = SymptomToDiagnosesBiLSTMDataset(SYMPTOM_DISEASE_DICT_FILE_PATH)
# train_size = int(len(dataset)*0.8)
# test_size = int(len(dataset)*0.2)
train_size = 20000
test_size = 2000
validation_size = len(dataset)  - (train_size + test_size)
train_dataset, test_dataset, validation_dataset = torch.utils.data.random_split(dataset, [train_size, test_size, validation_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

**Validate dataset created is with right size**

In [6]:
print(f"Dataset size : {len(dataset)}")
symptom_item, diag_item, symptom_len = dataset[8]

print(f"symptom_item.shape : {symptom_item.shape}")
print(f"diag_item.shape : {diag_item}")
assert symptom_item.shape == (MAX_SYMPTOMS, MAX_NUMBER_OF_DISEASE), "Incorrect Symptom representation shape."
assert diag_item.shape == (MAX_NUMBER_OF_DISEASE, ), "Incorrect diagnoses labels"
symptom_item

Dataset size : 49352
symptom_item.shape : torch.Size([50, 50])
diag_item.shape : tensor([0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.])


tensor([[0.1053, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.3684, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0789, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0526, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.1579, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

### 2.2 Define Symptom to Disease to BiLSTM Model (which uses TF-IDF based represenation)

In [7]:
class SymptomToDiseaseBiLstm(nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, num_classes, batch_size):
        super(SymptomToDiseaseBiLstm, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.batch_size = batch_size
        
        self.bilstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True, dropout=0.8)
        self.fc = nn.Linear(hidden_size*2, num_classes)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, symp_length):
        h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size)
        c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size)

        # x = torch.nn.utils.rnn.pack_padded_sequence(x, symp_length, batch_first=True, enforce_sorted=False)
        
        out, (ht, ct) = self.bilstm(x, (h0, c0))
        
        # out, _ = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        out = self.fc(out[:,-1,:])
        out = self.sigmoid(out)
        
        return out

sym_disease_model = SymptomToDiseaseBiLstm(MAX_SYMPTOMS, 100, 2, MAX_NUMBER_OF_DISEASE, BATCH_SIZE)

### 2.3 Training and inferencing

In [9]:
import torch.optim as optim

optimizer = optim.Adam(sym_disease_model.parameters(), lr=0.001)
criterion = nn.BCELoss()

In [None]:
def get_non_padded_pred_and_true_labels(y_pred, y_true, symptom_length_vector):
    
    print(f"y_pred.shape: {y_pred.shape}, y_true.shape: {y_true.shape}, symptom_length_vector: {symptom_length_vector.shape}")
    # Create a mask which will have all padded field to be zero
    
    mask_vector = np.ones(y_pred.shape)
    idx = 0
    for symptom_length in symptom_length_vector:
        mask_vector[idx,symptom_length:] = 0
        idx += 1
        
    mask = torch.tensor(mask_vector)
    mask_1 = mask.view(-1)
    mask_1 = mask_1.ge(1)
    
    y_pred_1 = y_pred.view(-1)
    y_true_1 = y_true.view(-1)
    
    y_pred_final = torch.masked_select(y_pred_1, mask_1)
    y_true_final = torch.masked_select(y_true_1, mask_1)
    
    print(f"y_pred_final.shape: {y_pred_final.shape}, y_true_final.shape: {y_true_final.shape}")
    return y_pred_final, y_true_final

In [12]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
DISEASE_THRESHOLD = 0.20

def eval(model, test_loader):
    
    """    
    INPUT:
        model: model
        test_loader: dataloader
        
    OUTPUT:
        precision: overall micro precision score
        recall: overall micro recall score
        f1: overall micro f1 score
        
    REFERENCE: checkout https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
    """

    model.eval()
    y_pred = torch.LongTensor()
    y_true = torch.LongTensor()
    for sequences, labels, symp_len in test_loader:
        y_prob = model(sequences, symp_len)
        y_hat = (y_prob > DISEASE_THRESHOLD).int()
#         print(f"y_prob: {y_hat}")
#         print(f"labels: {labels}")
        #y_hat, labels = get_non_padded_pred_and_true_labels(y_hat.detach().to('cpu'), labels.detach().to('cpu'), symp_len)
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, labels.detach().to('cpu')), dim=0)
        
    
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='micro')
    auc = roc_auc_score(y_true, y_pred, average='micro')
    return p, r, f, auc


In [13]:
def train(model, train_loader, test_loader, n_epochs):
    """    
    INPUT:
        model: the model
        train_loader: dataloder
        val_loader: dataloader
        n_epochs: total number of epochs
    """
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for sequences, y_true, symp_len in train_loader:
            optimizer.zero_grad()
            y_hat = model(sequences, symp_len)
            
            loss = criterion(y_hat, y_true)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f, auc = eval(model, test_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}, auc: {:.2f}'.format(epoch+1, p, r, f, auc))

    
# number of epochs to train the model
n_epochs = 5

train(sym_disease_model, train_loader, test_loader, n_epochs)

Epoch: 1 	 Training Loss: 0.358892
Epoch: 1 	 Validation p: 0.28, r:0.44, f: 0.34, auc: 0.64
Epoch: 2 	 Training Loss: 0.358633
Epoch: 2 	 Validation p: 0.28, r:0.44, f: 0.34, auc: 0.64
Epoch: 3 	 Training Loss: 0.358447
Epoch: 3 	 Validation p: 0.29, r:0.43, f: 0.35, auc: 0.64
Epoch: 4 	 Training Loss: 0.357412
Epoch: 4 	 Validation p: 0.28, r:0.44, f: 0.34, auc: 0.64
Epoch: 5 	 Training Loss: 0.354182
Epoch: 5 	 Validation p: 0.30, r:0.43, f: 0.36, auc: 0.64




_100_v3 - 10 Epochs - Using non-normalized weights. lr=0.001
<code>
Epoch: 1 	 Training Loss: 0.236686
Epoch: 1 	 Validation p: 0.34, r:0.39, f: 0.36, auc: 0.66
Epoch: 2 	 Training Loss: 0.235472
Epoch: 2 	 Validation p: 0.34, r:0.40, f: 0.37, auc: 0.66
Epoch: 3 	 Training Loss: 0.233751
Epoch: 3 	 Validation p: 0.35, r:0.39, f: 0.37, auc: 0.66
Epoch: 4 	 Training Loss: 0.232628
Epoch: 4 	 Validation p: 0.34, r:0.39, f: 0.37, auc: 0.66
Epoch: 5 	 Training Loss: 0.232200
Epoch: 5 	 Validation p: 0.35, r:0.39, f: 0.37, auc: 0.66
</code>

----------------

n_epochs = 2, learning rate = 0.001 <br>
<code>
Epoch: 1 	 Training Loss: 0.332956
Epoch: 1 	 Validation p: 0.35, r:0.51, f: 0.42, auc: 0.69
Epoch: 2 	 Training Loss: 0.303654
Epoch: 2 	 Validation p: 0.40, r:0.51, f: 0.45, auc: 0.70
</code>
         <code>
Epoch: 1 	 Training Loss: 0.298307
Epoch: 1 	 Validation p: 0.40, r:0.53, f: 0.45, auc: 0.71
Epoch: 2 	 Training Loss: 0.294352
Epoch: 2 	 Validation p: 0.38, r:0.57, f: 0.46, auc: 0.72
Epoch: 3 	 Training Loss: 0.290162
Epoch: 3 	 Validation p: 0.40, r:0.55, f: 0.46, auc: 0.72
Epoch: 4 	 Training Loss: 0.286585
Epoch: 4 	 Validation p: 0.39, r:0.57, f: 0.47, auc: 0.72
Epoch: 5 	 Training Loss: 0.283719
Epoch: 5 	 Validation p: 0.40, r:0.58, f: 0.47, auc: 0.73      
    </code>

<h5>n_epochs=5, learning_rate = 0.005</h5>
<code>
Epoch: 1 	 Training Loss: 0.292470
Epoch: 1 	 Validation p: 0.36, r:0.59, f: 0.45, auc: 0.72
Epoch: 2 	 Training Loss: 0.288039
Epoch: 2 	 Validation p: 0.37, r:0.62, f: 0.46, auc: 0.73
Epoch: 3 	 Training Loss: 0.284519
Epoch: 3 	 Validation p: 0.38, r:0.60, f: 0.47, auc: 0.73
Epoch: 4 	 Training Loss: 0.282809
Epoch: 4 	 Validation p: 0.37, r:0.63, f: 0.46, auc: 0.74
Epoch: 5 	 Training Loss: 0.281407
Epoch: 5 	 Validation p: 0.38, r:0.61, f: 0.47, auc: 0.74
</code>

With Batch size of 400, starting fresh!
<code>
Epoch: 1 	 Training Loss: 0.400608
Epoch: 1 	 Validation p: 0.28, r:0.40, f: 0.33, auc: 0.63
Epoch: 2 	 Training Loss: 0.347110
Epoch: 2 	 Validation p: 0.35, r:0.36, f: 0.35, auc: 0.63
Epoch: 3 	 Training Loss: 0.326768
Epoch: 3 	 Validation p: 0.37, r:0.44, f: 0.40, auc: 0.67
Epoch: 4 	 Training Loss: 0.319100
Epoch: 4 	 Validation p: 0.37, r:0.46, f: 0.41, auc: 0.67
Epoch: 5 	 Training Loss: 0.312817
Epoch: 5 	 Validation p: 0.36, r:0.54, f: 0.43, auc: 0.70
    
Epoch: 1 	 Training Loss: 0.304074
Epoch: 1 	 Validation p: 0.38, r:0.52, f: 0.44, auc: 0.70
Epoch: 2 	 Training Loss: 0.300410
Epoch: 2 	 Validation p: 0.39, r:0.51, f: 0.44, auc: 0.70
Epoch: 3 	 Training Loss: 0.299053
Epoch: 3 	 Validation p: 0.38, r:0.54, f: 0.45, auc: 0.71
Epoch: 4 	 Training Loss: 0.297172
Epoch: 4 	 Validation p: 0.40, r:0.52, f: 0.45, auc: 0.70
Epoch: 5 	 Training Loss: 0.294421
Epoch: 5 	 Validation p: 0.38, r:0.56, f: 0.45, auc: 0.71
</code>

Number of layers in BiLSRT as 4. Very slow as well!
<code>
Epoch: 1 	 Training Loss: 0.691968
Epoch: 1 	 Validation p: 0.13, r:1.00, f: 0.23, auc: 0.50
Epoch: 2 	 Training Loss: 0.691965
Epoch: 2 	 Validation p: 0.13, r:1.00, f: 0.23, auc: 0.50
</code>

Number of layers in BiLSRT as 3. Very slow as well!
<code>
Epoch: 1 	 Training Loss: 0.329327
Epoch: 1 	 Validation p: 0.30, r:0.56, f: 0.39, auc: 0.69
Epoch: 2 	 Training Loss: 0.320431
Epoch: 2 	 Validation p: 0.33, r:0.53, f: 0.41, auc: 0.69
Epoch: 3 	 Training Loss: 0.317443
Epoch: 3 	 Validation p: 0.35, r:0.49, f: 0.41, auc: 0.68
Epoch: 4 	 Training Loss: 0.315925
Epoch: 4 	 Validation p: 0.34, r:0.52, f: 0.41, auc: 0.69
Epoch: 5 	 Training Loss: 0.311251
Epoch: 5 	 Validation p: 0.38, r:0.49, f: 0.43, auc: 0.69
</code>

10 Epoch run with 50 disease
<code>
Epoch: 1 	 Training Loss: 0.403777
Epoch: 1 	 Validation p: 0.28, r:0.44, f: 0.35, auc: 0.64
Epoch: 2 	 Training Loss: 0.354655
Epoch: 2 	 Validation p: 0.33, r:0.42, f: 0.37, auc: 0.65
Epoch: 3 	 Training Loss: 0.343463
Epoch: 3 	 Validation p: 0.34, r:0.46, f: 0.39, auc: 0.66
Epoch: 4 	 Training Loss: 0.337650
Epoch: 4 	 Validation p: 0.33, r:0.49, f: 0.39, auc: 0.67
Epoch: 5 	 Training Loss: 0.331578
Epoch: 5 	 Validation p: 0.36, r:0.48, f: 0.41, auc: 0.68
Epoch: 6 	 Training Loss: 0.323138
Epoch: 6 	 Validation p: 0.36, r:0.50, f: 0.42, auc: 0.69
Epoch: 7 	 Training Loss: 0.319433
Epoch: 7 	 Validation p: 0.36, r:0.51, f: 0.43, auc: 0.69
Epoch: 8 	 Training Loss: 0.317392
Epoch: 8 	 Validation p: 0.39, r:0.48, f: 0.43, auc: 0.68
Epoch: 9 	 Training Loss: 0.315443
Epoch: 9 	 Validation p: 0.37, r:0.52, f: 0.43, auc: 0.69
Epoch: 10 	 Training Loss: 0.313061
Epoch: 10 	 Validation p: 0.38, r:0.52, f: 0.44, auc: 0.70
<code>

## 3.0 Symptom to Symptom BiLSTM Model (for Word2Vec based symptom representation)