## !!! Deep Patient Similarity using CNN Softmax using a balanced dataset !!!

In [609]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

import random
import os

In [610]:
seed = 98
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

### Loading data (balanced)

* Re-using the saved pre-processed data of patients grouped by visits. 
* In the pre-processing each patient is clipped to have 40 visits and 42 features. 
* Picking only 2000 patients for target label N180 to make it balanced with the other targets

In [611]:
X = torch.load("p_x.pt")
Y = torch.load("p_y.pt")

In [612]:
Xb = []
Yb = []
n180_c = 0
for x, y in zip(X, Y):
    if y == "N180":
        if n180_c < 2000:
            n180_c = n180_c + 1
            Xb.append(x)
            Yb.append(y)
    else: 
        Xb.append(x)
        Yb.append(y)

In [613]:
X = Xb
Y = Yb

In [614]:
print(len(Y))
print(len(X))
print(set(Y))

7514
7514
{'E142', 'I10', 'N088', 'N083', 'N180', 'I120', 'N188', 'E102', 'N039', 'N189'}


### Number of patients in each target class

In [615]:
for l in set(Y):
    print(f"{l}:", len(np.where(np.array(Y) == l)[0]))

E142: 164
I10: 367
N088: 130
N083: 310
N180: 2000
I120: 1204
N188: 299
E102: 102
N039: 296
N189: 2642


### Converting target lables to one-hot encoding 

In [616]:
y_labels = list(set(Y))
Y_oh = np.zeros((len(Y), len(y_labels)))
for idx, y in enumerate(Y):
    Y_oh[idx][y_labels.index(y)] = 1

In [617]:
print(Y_oh.shape)
Y_oh

(7514, 10)


array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 1.],
       [0., 0., 0., ..., 0., 0., 1.]])

### Converting patient data (X) and target one-hot encoded data (Y_oh) as tensors to build the model

In [618]:
Xt = torch.tensor(X)
Yt = torch.tensor(Y_oh)
Xt = Xt.type(torch.FloatTensor)
Yt = Yt.type(torch.LongTensor)
print("Xt shape:", Xt.shape)
print("Yt shape:", Yt.shape)

Xt shape: torch.Size([7514, 40, 42])
Yt shape: torch.Size([7514, 10])


### A custom dataset to load pairwise data. We are just picking adjacent patient data as pair to learn patient similiarity. Data is shuffled by dataloader for training 

In [619]:
from torch.utils.data import Dataset
from torch.utils.data import random_split

class PairwiseDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.n = len(X)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x1 = self.X[idx]
        y1 = self.y[idx]
        if idx+1 == self.n:
            x2 = self.X[idx]
            y2 = self.y[idx]
        else:
            x2 = self.X[idx+1]
            y2 = self.y[idx+1]
        
        y = int(np.array_equal(y1, y2))
        
        return [x1, x2, np.asarray([y]), y1, y2]

    def get_splits(self, n_test=0.2):
        test_size = round(n_test * len(self.X))
        train_size = len(self.X) - test_size
        
        return random_split(self, [train_size, test_size])

In [620]:
from torch.utils.data import DataLoader

BATCH_SIZE = 32 

dataset = PairwiseDataset(Xt, Yt)
train, test = dataset.get_splits()
train_dl = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(test, batch_size=BATCH_SIZE, shuffle=False)
print("# of train batches:", len(train_dl))
print("# of test batches:", len(test_dl))

# of train batches: 188
# of test batches: 47


In [621]:
train_iter = iter(train_dl)
x1, x2, y, y1, y2 = next(train_iter)

print('Shape of a batch x1:', x1.shape)
print('Shape of a batch y1:', y1.shape)
print('Shape of a batch x2:', x2.shape)
print('Shape of a batch y2:', y2.shape)
print('Shape of a batch y:', y.shape)

Shape of a batch x1: torch.Size([32, 40, 42])
Shape of a batch y1: torch.Size([32, 10])
Shape of a batch x2: torch.Size([32, 40, 42])
Shape of a batch y2: torch.Size([32, 10])
Shape of a batch y: torch.Size([32, 1])


### CNN_softmax model

In [622]:
class CNNSoftMax(nn.Module):
    def __init__(self):
        super(CNNSoftMax, self).__init__()
        self.conv1 = nn.Conv1d(42, 84, kernel_size=2, stride=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv1d(84, 128, kernel_size=2, stride=1)
        
        self.fc1 = nn.Linear(1152, 32)
    
        self.fc2 = nn.Linear(65, 10)
        
        self.softmax = nn.Softmax(dim=1)
        
    def _forward(self, x):
        x = np.transpose(x, axes=[0, 2, 1])
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        return x
        
    def forward(self, x1, x2):
        x1 = self._forward(x1)
        x2 = self._forward(x2)
        
        f_out1 = x1.flatten(start_dim=1)
        f_out2 = x2.flatten(start_dim=1)
        
        f_out1 = self.fc1(f_out1)
        f_out2 = self.fc1(f_out2)
        
        distance = torch.pairwise_distance(f_out1, f_out2, p=2)
        distance = distance.reshape(distance.shape[0], 1)

        x_f = torch.cat((distance, f_out1, f_out2), 1)
        
        y_f = self.fc2(x_f)
        y_hat = self.softmax(y_f)
 
        return distance, y_hat

In [623]:
model = CNNSoftMax()
model

CNNSoftMax(
  (conv1): Conv1d(42, 84, kernel_size=(2,), stride=(1,))
  (relu): ReLU()
  (maxpool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(84, 128, kernel_size=(2,), stride=(1,))
  (fc1): Linear(in_features=1152, out_features=32, bias=True)
  (fc2): Linear(in_features=65, out_features=10, bias=True)
  (softmax): Softmax(dim=1)
)

In [624]:
from sklearn.metrics import accuracy_score, \
precision_recall_fscore_support, recall_score, f1_score
import warnings
warnings.filterwarnings("ignore")

def model_metrics(y_test, y_pred):
    acc = accuracy_score(y_test, y_pred)
    p, r, f1, _ = precision_recall_fscore_support(y_test, y_pred, average='macro')
    return acc, p, r, f1

### Contrastive loss to distinguish the pairs

https://ieeexplore.ieee.org/document/1640964

In [625]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1):
        super().__init__()
        self.margin = margin

    def forward(self, distance, y):
        loss = (1-y) * distance**2 + y * (torch.max(torch.zeros_like(distance), self.margin - distance)**2)
        return torch.mean(loss, dtype=torch.float)

### Using two loss functions to minimize the error to distinguish the pairs and to predict the target class. 

In [670]:
from torch.optim import SGD, Adam
from torch.nn import CrossEntropyLoss, BCELoss
from pytorch_metric_learning import losses

contrastive_loss = ContrastiveLoss()
cross_entropy_loss = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.01)

### Padding the batches so that all the batches have same number of records

In [627]:
import torch.nn.functional as F
def pad_batch(x1, x2, y0, y1, y2):
    b_size = x1.shape[0]
    pad_len = BATCH_SIZE - b_size
    if pad_len > 0:
        x1 = F.pad(x1, (0,0,0,0,0,pad_len))
        x2 = F.pad(x2, (0,0,0,0,0,pad_len))
        y0 = F.pad(y0, (0,0,0,pad_len))
        y1 = F.pad(y1, (0,0,0,pad_len))
        y2 = F.pad(y2, (0,0,0,pad_len))
    return x1, x2, y0, y1, y2

In [628]:
from numpy import vstack
from numpy import argmax

def evaluate(model, dl):
    model.eval()
    all_y_pred, all_y_true = list(), list()
    
    for x1, x2, y0, y1, y2 in dl:
        x1, x2, y0, y1, y2 = pad_batch(x1, x2, y0, y1, y2)
        distance, y_hat = model(x1, x2)
        
        y = torch.bitwise_and(y1.type(torch.IntTensor), y2.type(torch.IntTensor))
        
        y_true = y.type(torch.FloatTensor)
        y_pred = y_hat
        
        y_pred = (y_pred > 0.5).type(torch.FloatTensor)

        all_y_pred.append(y_pred)
        all_y_true.append(y_true)
        
    all_y_pred, all_y_true = vstack(all_y_pred), vstack(all_y_true)
    acc, p, r, f1 = model_metrics(all_y_true.flatten(), all_y_pred.flatten())
    print(f"acc: {acc:.4f}, precision: {p:.4f}, recall: {r:.4f}, f1: {f1:.4f}")  

In [629]:
def train():
    n_epochs = 10
    model.train()
    train_loss_arr = []
    for epoch in range(n_epochs):
        train_loss = 0
        for x1, x2, y0, y1, y2 in train_dl:
            optimizer.zero_grad()
            
            x1, x2, y0, y1, y2 = pad_batch(x1, x2, y0, y1, y2)
            distance, y_hat = model(x1, x2)
            
            y = torch.bitwise_and(y1.type(torch.IntTensor), y2.type(torch.IntTensor))
            y_true = y.type(torch.FloatTensor)
            y_pred = y_hat
            
            loss1 = cross_entropy_loss(y_pred, y_true)
            loss2 = contrastive_loss(distance, y0)
            loss = loss1 + loss2

            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_dl)
        print('Epoch: {} \tTraining Loss: {:.4f}'.format(epoch, train_loss))
    evaluate(model, test_dl)

### Training and Evaluating the model

In [630]:
%%time
train()

Epoch: 0 	Training Loss: 1.4470
acc: 0.9353, precision: 0.7315, recall: 0.8692, f1: 0.7803
Epoch: 1 	Training Loss: 1.3312
Epoch: 2 	Training Loss: 1.3037
Epoch: 3 	Training Loss: 1.2976
Epoch: 4 	Training Loss: 1.2820
Epoch: 5 	Training Loss: 1.2705
Epoch: 6 	Training Loss: 1.2696
Epoch: 7 	Training Loss: 1.2676
Epoch: 8 	Training Loss: 1.2618
Epoch: 9 	Training Loss: 1.2594
Epoch: 10 	Training Loss: 1.2617
acc: 0.9348, precision: 0.7300, recall: 0.8668, f1: 0.7785
Epoch: 11 	Training Loss: 1.2579
Epoch: 12 	Training Loss: 1.2602
Epoch: 13 	Training Loss: 1.2549
Epoch: 14 	Training Loss: 1.2604
Epoch: 15 	Training Loss: 1.2582
Epoch: 16 	Training Loss: 1.2509
Epoch: 17 	Training Loss: 1.2480
Epoch: 18 	Training Loss: 1.2436
Epoch: 19 	Training Loss: 1.2505
acc: 0.9361, precision: 0.7337, recall: 0.8727, f1: 0.7830
CPU times: user 52.4 s, sys: 5.2 s, total: 57.6 s
Wall time: 44.5 s
