In [None]:
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

In [None]:
from google.colab import drive
drive.mount('/content/gdrive/')
import sys
sys.path.append('/content/gdrive/My Drive/bitnet')


Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if __name__=='__main__':
    print('Using device:', device)

Using device: cpu


In [None]:
# !cat '/content/gdrive/My Drive/bitnet/processed/patients_mimic3_full.json'

In [None]:
import os
import data_proc
from data_proc import Dataset

In [None]:
dataset = Dataset()
data = dataset.load_data()
data = data
len(data)

Number of second level category:  170
Length of reverse dictionary  3875


7496

In [None]:
dataset.max_len_visit, dataset.vocabulary_size, dataset.digit3_size

(39, 3875, 170)

In [None]:
print(data[0][1])
print(data[0][4])

[1, 4, 31, 56, 61, 90, 105, 109, 114, 146]
1


In [None]:
pids = [i[0] for i in data]
intervals = [i[1] for i in data]
seqs = [i[2] for i in data]


readmission = [i[4] for i in data]
diag = [i[3] for i in data]

num_codes = set([code for visits in seqs for visit in visits for code in visit])
num_codes = len(set(num_codes)) 

print("num_codes",num_codes)


assert len(pids) == len(seqs) == len(intervals) == len(readmission)

num_codes 3434


In [None]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    def __init__(self, seqs, intervals, readmission, diag):
        self.seqs = seqs
        self.intervals = intervals
        self.y1 = readmission
        self.y2 = diag
    
    def __len__(self):
        
        return len(self.y1)
    
    def __getitem__(self, index):

        return self.seqs[index], self.intervals[index], self.y1[index], self.y2[index]
data = CustomDataset(seqs, intervals, readmission, diag)
print(len(data))

7496


In [None]:
from torch.utils.data.dataset import random_split

train_test_split = int(len(data)*0.9)
lengths = [train_test_split, len(data) - train_test_split]
train_data, test_data = random_split(data, lengths)


train_val_split = int(len(train_data)*0.89)
lengths = [train_val_split, len(train_data) - train_val_split]
train_data, val_data = random_split(train_data, lengths)


print("Length of train dataset:", len(train_data))
print("Length of val dataset:", len(val_data))
print("Length of test dataset:", len(test_data))


Length of train dataset: 6003
Length of val dataset: 743
Length of test dataset: 750


In [None]:
def collate_fn(data):
  sequences, intervals, labels1, labels2 = zip(*data)

  num_patients = len(sequences)
  num_visits = len(sequences[0])
  num_codes = len(sequences[0][0])

  y1 = torch.tensor(labels1, dtype=torch.float)
  y2 = torch.tensor(labels2, dtype=torch.float)

  return sequences, intervals, y1, y2

In [None]:
from torch.utils.data import DataLoader



def load_data(train_data, val_data, test_data, collate_fn):
    
    batch_size = 32
    
    train_loader = DataLoader(dataset = train_data, batch_size = 32, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(dataset = val_data, batch_size = 32, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(dataset = test_data, batch_size = 32, shuffle=True, collate_fn=collate_fn)

    
    return train_loader, val_loader, test_loader


train_loader, val_loader, test_loader = load_data(train_data, val_data, test_data, collate_fn)

print(num_codes)

3434


In [None]:
import math
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, seq_len) -> None:
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model

        pe = torch.zeros(seq_len, d_model)

        for pos in range(seq_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model)))
                pe[pos, i+1] = math.cos(pos / (10000 ** ((2 * (i+1)) / d_model)))

        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x) -> torch.Tensor:
        seq_len = x.shape[1]
        x = math.sqrt(self.d_model) * x
        x = x + self.pe[:, :seq_len].requires_grad_(False)
        return x


class ResidualBlock(nn.Module):
    def __init__(self, layer: nn.Module, embed_dim: int, p=0.1) -> None:
        super(ResidualBlock, self).__init__()
        self.layer = layer
        self.dropout = nn.Dropout(p=p)
        self.norm = nn.LayerNorm(embed_dim)
        self.attn_weights = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        :param x: [N, seq_len, features]
        :return: [N, seq_len, features]
        """
        if isinstance(self.layer, nn.MultiheadAttention):
            src = x.transpose(0, 1)     # [seq_len, N, features]
            output, self.attn_weights = self.layer(src, src, src)
            output = output.transpose(0, 1)     # [N, seq_len, features]

        else:
            output = self.layer(x)

        output = self.dropout(output)
        output = self.norm(x + output)
        return output


class PositionWiseFeedForward(nn.Module):
    def __init__(self, hidden_size: int) -> None:
        super(PositionWiseFeedForward, self).__init__()
        self.hidden_size = hidden_size

        self.conv = nn.Sequential(
            nn.Conv1d(hidden_size, hidden_size * 2, 1),
            nn.ReLU(),
            nn.Conv1d(hidden_size * 2, hidden_size, 1)
        )

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        tensor = tensor.transpose(1, 2)
        tensor = self.conv(tensor)
        tensor = tensor.transpose(1, 2)

        return tensor


class EncoderBlock(nn.Module):
    def __init__(self, embed_dim: int, num_head: int, dropout_rate=0.1) -> None:
        super(EncoderBlock, self).__init__()
        self.attention = ResidualBlock(
            nn.MultiheadAttention(embed_dim, num_head), embed_dim, p=dropout_rate
        )
        self.ffn = ResidualBlock(PositionWiseFeedForward(embed_dim), embed_dim, p=dropout_rate)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.attention(x)
        x = self.ffn(x)
        return x


class DenseInterpolation(nn.Module):
    def __init__(self, seq_len: int, factor: int) -> None:
        """
        :param seq_len: sequence length
        :param factor: factor M
        """
        super(DenseInterpolation, self).__init__()

        W = np.zeros((factor, seq_len), dtype=np.float32)

        for t in range(seq_len):
            s = np.array((factor * (t + 1)) / seq_len, dtype=np.float32)
            for m in range(factor):
                tmp = np.array(1 - (np.abs(s - (1+m)) / factor), dtype=np.float32)
                w = np.power(tmp, 2, dtype=np.float32)
                W[m, t] = w

        W = torch.tensor(W).float().unsqueeze(0)
        self.register_buffer("W", W)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w = self.W.repeat(x.shape[0], 1, 1).requires_grad_(False)
        u = torch.bmm(w, x)
        return u.transpose_(1, 2)


class ClassificationModule(nn.Module):
    def __init__(self, d_model: int, factor: int, num_class: int) -> None:
        super(ClassificationModule, self).__init__()
        self.d_model = d_model
        self.factor = factor
        self.num_class = num_class

        self.fc = nn.Linear(int(d_model * factor), num_class)

        nn.init.normal_(self.fc.weight, std=0.02)
        nn.init.normal_(self.fc.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.contiguous().view(-1, int(self.factor * self.d_model))
        x = self.fc(x)
        return x


class RegressionModule(nn.Module):
    def __init__(self, d_model: int, factor: int, output_size: int) -> None:
        super(RegressionModule, self).__init__()
        self.d_model = d_model
        self.factor = factor
        self.output_size = output_size
        self.fc = nn.Linear(int(d_model * factor), output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.contiguous().view(-1, int(self.factor * self.d_model))
        x = self.fc(x)
        return x


In [None]:
class EncoderLayerForSAnD(nn.Module):
    def __init__(self, input_features, seq_len, n_heads, n_layers, d_model: int = 128, dropout_rate=0.2) -> None:
        super(EncoderLayerForSAnD, self).__init__()
        self.d_model = d_model

        self.input_embedding = nn.Conv1d(input_features, d_model, 1)
        self.positional_encoding = PositionalEncoding(d_model, seq_len)
        self.blocks = nn.ModuleList([
            EncoderBlock(d_model, n_heads, dropout_rate) for _ in range(n_layers)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.transpose(1, 2)

        x = x.type(torch.float)
        x = self.input_embedding(x)

        x = x.transpose(1, 2)
        
        x = self.positional_encoding(x)

        for l in self.blocks:
            x = l(x)\
        
        return x


class SAnD(nn.Module):
    """
    Simply Attend and Diagnose model

    The Thirty-Second AAAI Conference on Artificial Intelligence (AAAI-18)

    `Attend and Diagnose: Clinical Time Series Analysis Using Attention Models <https://arxiv.org/abs/1711.03905>`_
    Huan Song, Deepta Rajan, Jayaraman J. Thiagarajan, Andreas Spanias
    """
    def __init__(
            self, input_features: int, seq_len: int, n_heads: int, factor: int,
            n_class: int, n_layers: int, d_model: int = 128, dropout_rate: float = 0.2
    ) -> None:
        super(SAnD, self).__init__()
        self.encoder = EncoderLayerForSAnD(input_features, seq_len, n_heads, n_layers, d_model, dropout_rate)
        self.dense_interpolation = DenseInterpolation(seq_len, factor)
        self.clf = ClassificationModule(d_model, factor, n_class)
       
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        x = torch.LongTensor(x)
        batch_size = x.shape[0]
        # print("a", x.shape)
        x = self.encoder(x)
        # print("b", x.shape)
        x = self.dense_interpolation(x)
        # print("c", x.shape)
        x = self.clf(x)
        # print("d", x.shape)

        probs = self.sigmoid(x)
        # print("after sigmoid", x.shape)

        return probs.view((batch_size, num_class))


in_feature = 39
seq_len = 10
n_heads = 32
factor = 32
num_class = 170
num_layers = 6
s_model = SAnD(in_feature, seq_len, n_heads, factor, num_class, num_layers)

In [None]:
import torch.optim as optim

criterion = nn.BCELoss()
optimizer = optim.Adam(s_model.parameters(), lr=0.001)


In [None]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, precision_recall_curve, auc
from sklearn.metrics import top_k_accuracy_score
def eval_model(model, val_loader):
    
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    p_at_k = None
    for x0, x1, y0, y1 in val_loader:
        y_hat = model(x0)
        y_score = torch.cat((y_score,  y_hat.detach().to(device)), dim=0)
        y_hat = (y_hat > 0.5).int()
        y_pred = torch.cat((y_pred, y_hat.detach().to(device)), dim=0)
        y_true = torch.cat((y_true, y1.detach().to(device)), dim=0)
        # for i in range(len(y_true)):
        #   p_at_k = top_k_accuracy_score(y_true[i], y_score[i], k=5)
      
        #   print(p_at_k)
    # print(y_true)
        
    torch.set_printoptions(profile="full")
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average="micro")
    return p, r, f

In [None]:
def train(model, train_loader, val_loader, n_epochs):
    for epoch in range(n_epochs):
      model.train()
      train_loss = 0
      for x0, x1, y0, y1 in train_loader:
        optimizer.zero_grad()
        y_pred = model(x0)
        loss = criterion(y_pred.squeeze(), y1)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
      train_loss = train_loss / len(train_loader)
      print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
      p, r, f = eval_model(model, val_loader)
      print('Epoch: {} \t Validation p: {:.4f}, r:{:.4f}, f: {:.4f}'
              .format(epoch+1, p, r, f))

In [None]:
n_epochs = 30
train(s_model, train_loader, val_loader, n_epochs)

  if __name__ == '__main__':


Epoch: 1 	 Training Loss: 0.235545
Epoch: 1 	 Validation p: 0.5615, r:0.2045, f: 0.2797
Epoch: 2 	 Training Loss: 0.170191
Epoch: 2 	 Validation p: 0.5550, r:0.2457, f: 0.3080
Epoch: 3 	 Training Loss: 0.167281
Epoch: 3 	 Validation p: 0.6416, r:0.2216, f: 0.3300
Epoch: 4 	 Training Loss: 0.165589
Epoch: 4 	 Validation p: 0.5279, r:0.2335, f: 0.3546
Epoch: 5 	 Training Loss: 0.164378
Epoch: 5 	 Validation p: 0.6244, r:0.1671, f: 0.2340
Epoch: 6 	 Training Loss: 0.163228
Epoch: 6 	 Validation p: 0.5682, r:0.2907, f: 0.3489
Epoch: 7 	 Training Loss: 0.162273
Epoch: 7 	 Validation p: 0.6050, r:0.1558, f: 0.2807
Epoch: 8 	 Training Loss: 0.161062
Epoch: 8 	 Validation p: 0.6006, r:0.1722, f: 0.2578
Epoch: 9 	 Training Loss: 0.160461
Epoch: 9 	 Validation p: 0.6039, r:0.2143, f: 0.3324
Epoch: 10 	 Training Loss: 0.159784
Epoch: 10 	 Validation p: 0.5402, r:0.2898, f: 0.3799
Epoch: 11 	 Training Loss: 0.158937
Epoch: 11 	 Validation p: 0.6575, r:0.2149, f: 0.3064
Epoch: 12 	 Training Loss: 0