<a href="https://colab.research.google.com/github/qinalan10/cs598-patient-phenotyping/blob/master/Code/DLH_Patient_Phenotyping_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Loading a Pre-Trained CNN Model for Patient Phenotyping Based on Patient Notes

This  notebook is to load a pre-trained model and test data and run the evaluation. 

*   Read the file containing the test data only.

*   Create CNN Classfier
*   Load the model
*   Test Model Perfromance using 











In [None]:
# importing libaries
from google.colab import files
from google.colab import drive
import pandas as pd 
from torchsummary import summary
import torch
from torch.utils.data import Dataset
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import re 
import numpy as np 
from sklearn.metrics import classification_report
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.metrics import roc_auc_score   
from torchsummary import summary
from nltk.corpus import stopwords
from tqdm import tqdm
from sklearn.metrics import roc_curve, auc, RocCurveDisplay
import pickle
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Setting Paths & Read the Final File

In [None]:
dataset_path = "/content/drive/MyDrive/Data/dataset_dict.pickle"
# Model Path
model_path = '/content/drive/MyDrive/Data/trained_model.pth'

In [None]:
def collate_fn(batch):
    texts = [item['text'] for item in batch]  # Extract the text data from the batch
    labels = [item['labels'] for item in batch]  # Extract the labels from the batch

    # Find the maximum length of text sequences in the batch
    max_len = max([len(text) for text in texts])

    # Pad the text sequences with a padding token (e.g., '<pad>') to the maximum length
    padded_texts = []
    for text in texts:
      padded_text = text + ['<pad>'] * (max_len - len(text))
      padded_texts.append(padded_text)

    # Convert the labels to a PyTorch tensor
    labels = torch.stack(labels)

    return {'text': padded_texts, 'labels': labels}

Load the test data.

In [None]:
with open(dataset_path, 'rb') as file:
    dataset_dict = pickle.load(file)
test_dataset = dataset_dict["test_dataset"]
batch_size = 64
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

Define the CNN Classifier - Based on which the model was trained.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random 

class CNNClassifier(nn.Module):
    def __init__(self, embedding_dim, num_filters, filter_sizes, embeddings, output_size, w2v_dictionary):
        super(CNNClassifier, self).__init__()
        self.dictionary = w2v_dictionary
        self.embedding = nn.Embedding.from_pretrained(embeddings)#(vocab_size, embedding_dim)
        self.conv1 = nn.Conv1d(embedding_dim, num_filters, filter_sizes[0])
        self.conv2 = nn.Conv1d(embedding_dim, num_filters, filter_sizes[1])
        self.conv3 = nn.Conv1d(embedding_dim, num_filters, filter_sizes[2])
        self.conv4 = nn.Conv1d(embedding_dim, num_filters, filter_sizes[3])
        self.dropout = nn.Dropout(p = 0.1)
        self.fc = nn.Linear(len(filter_sizes) * num_filters, output_size)
        
    def forward(self, x):
      # change to indexes, then to tensor
        z2 = []
        for ls in x:
          z1 = []
          for word in ls:
            if word not in self.dictionary.keys():
              random_key = random.sample(list(self.dictionary.keys()), 1)[0]
              z1.append(self.dictionary[random_key])
            else:
              z1.append(self.dictionary[word])
          z2.append(z1)  

        x = torch.tensor(z2, device = device)
        x = self.embedding(x)
        x = x.permute(0, 2, 1)  # Permute the dimensions for Conv1d (batch_size, embedding_dim, sequence_length)

        x1 = F.relu(self.conv1(x))
        x1 = F.max_pool1d(x1 , x1.size(2)).squeeze(2)

        x2 = F.relu(self.conv2(x))
        x2 = F.max_pool1d(x2 , x2.size(2)).squeeze(2)

        x3 = F.relu(self.conv3(x))
        x3 = F.max_pool1d(x3 , x3.size(2)).squeeze(2)

        x4 = F.relu(self.conv4(x))
        x4 = F.max_pool1d(x4 , x4.size(2)).squeeze(2)

        out = torch.cat((x1, x2, x3, x4), 1)
        out = self.dropout(out)
        out = self.fc(out)

        out = torch.sigmoid(out)

        return out

Load the Pre-Trained Model

In [None]:
# Load the model
trained_model = torch.load(model_path)
print(trained_model)

CNNClassifier(
  (embedding): Embedding(10316, 300)
  (conv1): Conv1d(300, 256, kernel_size=(1,), stride=(1,))
  (conv2): Conv1d(300, 256, kernel_size=(2,), stride=(1,))
  (conv3): Conv1d(300, 256, kernel_size=(3,), stride=(1,))
  (conv4): Conv1d(300, 256, kernel_size=(5,), stride=(1,))
  (dropout): Dropout(p=0.1, inplace=False)
  (fc): Linear(in_features=1024, out_features=14, bias=True)
)


Evaluate Model Function -

In [None]:
def eval_model(model, dataloader):
    model.eval()
    Y_pred = []
    Y_true = []
    with torch.no_grad():
        for batch in dataloader:
            text = batch['text']
            labels = batch['labels']
            pred = model(text).cpu()
            pred=np.round(pred)
            Y_pred.append(pred.numpy())
            Y_true.append(labels.cpu().numpy())
        Y_pred = np.concatenate(Y_pred, axis = 0)
        Y_true = np.concatenate(Y_true, axis = 0)
    return Y_pred, Y_true

Make Predictions on the Test Data -

In [None]:
y_pred, y_true = eval_model(trained_model, test_loader)

Evaluating the Model Performance

In [None]:
target_classes =['Obesity', 'Non.Adherence', 'Developmental.Delay.Retardation',
               'Advanced.Heart.Disease', 'Advanced.Lung.Disease',
               'Schizophrenia.and.other.Psychiatric.Disorders', 'Alcohol.Abuse',
               'Other.Substance.Abuse', 'Chronic.Pain.Fibromyalgia',
               'Chronic.Neurological.Dystrophies', 'Advanced.Cancer', 'Depression',
               'Dementia', 'Unsure']

# Collecting the Predcitions and Truth
print("Test Accuracy : {}".format(accuracy_score(y_true, y_pred), zero_division=1))
print("\nClassification Report : ")
print(classification_report(y_true, y_pred, target_names=target_classes, zero_division=1))

Test Accuracy : 0.20710059171597633

Classification Report : 
                                               precision    recall  f1-score   support

                                      Obesity       1.00      0.00      0.00        11
                                Non.Adherence       0.33      0.14      0.20        14
              Developmental.Delay.Retardation       1.00      0.00      0.00         3
                       Advanced.Heart.Disease       0.61      0.39      0.47        36
                        Advanced.Lung.Disease       0.75      0.18      0.29        17
Schizophrenia.and.other.Psychiatric.Disorders       0.88      0.18      0.30        38
                                Alcohol.Abuse       0.82      0.67      0.74        21
                        Other.Substance.Abuse       0.67      0.36      0.47        11
                    Chronic.Pain.Fibromyalgia       0.57      0.46      0.51        28
             Chronic.Neurological.Dystrophies       1.00      0.11 

AUC Score

In [None]:
y_t = y_true.transpose()
y_p = y_pred.transpose()
for i in range(len(y_t)):
  auc = roc_auc_score(y_t[i], y_p[i], average=None)
  print(f'{target_classes[i]} auc score: {auc:.2f}')

Obesity auc score: 0.50
Non.Adherence auc score: 0.56
Developmental.Delay.Retardation auc score: 0.50
Advanced.Heart.Disease auc score: 0.66
Advanced.Lung.Disease auc score: 0.58
Schizophrenia.and.other.Psychiatric.Disorders auc score: 0.59
Alcohol.Abuse auc score: 0.82
Other.Substance.Abuse auc score: 0.68
Chronic.Pain.Fibromyalgia auc score: 0.70
Chronic.Neurological.Dystrophies auc score: 0.56
Advanced.Cancer auc score: 0.67
Depression auc score: 0.61
Dementia auc score: 0.50
Unsure auc score: 0.49
