In [1]:

from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from tqdm import tqdm

criterion = nn.BCEWithLogitsLoss()
def calculate_roc_auc_prc(model, data_loader):
    model.eval()
    all_probabilities = []
    all_labels = []
    total_loss = 0
    
    with torch.no_grad():
        for inputs in tqdm(data_loader, leave=False):
            outputs = model(inputs['encoder_input'], inputs['encoder_mask'])
            labels = inputs['label']
            logits = torch.sigmoid(outputs)
            loss = criterion(outputs, labels.float().view(-1,1))
            total_loss += loss.item()
            all_probabilities.append(logits.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    logits_all = np.concatenate(all_probabilities)
    labels_all = np.concatenate(all_labels)
    total_loss = total_loss/len(data_loader)
    
    roc_auc = roc_auc_score(labels_all, logits_all)
    auc_prc = average_precision_score(labels_all, logits_all)
    return roc_auc, auc_prc, total_loss

In [6]:
MAX_LEN = 256
batch_size = 32
d_model = 50
num_heads = 4
N = 2
num_variables = 18 
num_variables += 1 #for no variable embedding while doing padding
d_ff = 100
epochs = 75
learning_rate = 1e-5
drop_out = 0.2
sinusoidal = False
th_val_roc = 0.84
th_val_pr = 0.48
Uniform = True
import torch
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

import pandas as pd
import numpy as np
pd.set_option('future.no_silent_downcasting',True)

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn import functional as F

from tqdm import tqdm
from normalizer import Normalizer
from categorizer import Categorizer

train_data_path_inhospital = "/data/datasets/mimic3_18var/root/in-hospital-mortality/train_listfile.csv"
val_data_path_inhospital = "/data/datasets/mimic3_18var/root/in-hospital-mortality/val_listfile.csv"

train_data_path_phenotyping = "/data/datasets/mimic3_18var/root/phenotyping/train_listfile.csv"
val_data_path_phenotyping = "/data/datasets/mimic3_18var/root/phenotyping/val_listfile.csv"

train_data_path_decompensation = "/data/datasets/mimic3_18var/root/decompensation/train_listfile.csv"
val_data_path_decompensation = "/data/datasets/mimic3_18var/root/decompensation/val_listfile.csv"

data_dir_inhospital = "/data/datasets/mimic3_18var/root/in-hospital-mortality/train/"
data_dir_phenotyping = "/data/datasets/mimic3_18var/root/phenotyping/train/"
data_dir_decompensation = "/data/datasets/mimic3_18var/root/decompensation/train/"

import pickle

with open('normalizer.pkl', 'rb') as file:
    normalizer = pickle.load(file)

with open('categorizer.pkl', 'rb') as file:
    categorizer = pickle.load(file)
    

mean_variance = normalizer.mean_var_dict
cat_dict = categorizer.category_dict


train_ds_inhospital = MaskedMimicDataSetInHospitalMortality(data_dir_inhospital, train_data_path_inhospital, mean_variance, cat_dict, 'training', MAX_LEN)
val_ds_inhospital = MaskedMimicDataSetInHospitalMortality(data_dir_inhospital, val_data_path_inhospital, mean_variance, cat_dict, 'validation', MAX_LEN)

train_ds_phenotyping = MaskedMimicDataSetInHospitalMortality(data_dir_phenotyping, train_data_path_phenotyping, mean_variance, cat_dict, 'training', MAX_LEN)
val_ds_phenotyping = MaskedMimicDataSetInHospitalMortality(data_dir_phenotyping, val_data_path_phenotyping, mean_variance, cat_dict, 'validation', MAX_LEN)

train_ds_decompensation = MaskedMimicDataSetInHospitalMortality(data_dir_decompensation, train_data_path_decompensation, mean_variance, cat_dict, 'training', MAX_LEN)
val_ds_decompensation = MaskedMimicDataSetInHospitalMortality(data_dir_decompensation, val_data_path_decompensation, mean_variance, cat_dict, 'validation', MAX_LEN)


train_dataloader_inhospital = DataLoader(train_ds_inhospital, batch_size = batch_size, shuffle=True)
val_dataloader_inhospital = DataLoader(val_ds_inhospital, batch_size = 1)

train_dataloader_phenotyping = DataLoader(train_ds_phenotyping, batch_size = batch_size, shuffle=True)
val_dataloader_phenotyping = DataLoader(val_ds_phenotyping, batch_size = 1)

train_dataloader_decompensation = DataLoader(train_ds_decompensation, batch_size = batch_size, shuffle=True)
val_dataloader_decompensation = DataLoader(val_ds_decompensation, batch_size = 1)



In [7]:
import torch
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn import functional as F
import math

def t2v(tau, f, out_features, w, b, w0, b0, arg=None):
    if arg:
        v1 = f(torch.matmul(tau, w) + b, arg)
    else:
        v1 = f(torch.matmul(tau, w) + b)
    v2 = torch.matmul(tau, w0) + b0
    return torch.cat([v1, v2], -1)

class SineActivation(nn.Module):
    def __init__(self, in_features, out_features):
        super(SineActivation, self).__init__()
        self.out_features = out_features
        self.w0 = nn.parameter.Parameter(torch.randn(in_features, 1))
        self.b0 = nn.parameter.Parameter(torch.randn(1))
        self.w = nn.parameter.Parameter(torch.randn(in_features, out_features-1))
        self.b = nn.parameter.Parameter(torch.randn(out_features-1))
        self.f = torch.sin

    def forward(self, tau):
        return t2v(tau.unsqueeze(-1), self.f, self.out_features, self.w, self.b, self.w0, self.b0)

class CosineActivation(nn.Module):
    def __init__(self, in_features, out_features):
        super(CosineActivation, self).__init__()
        self.out_features = out_features
        self.w0 = nn.parameter.Parameter(torch.randn(in_features, 1))
        self.b0 = nn.parameter.Parameter(torch.randn(1))
        self.w = nn.parameter.Parameter(torch.randn(in_features, out_features-1))
        self.b = nn.parameter.Parameter(torch.randn(out_features-1))
        self.f = torch.cos

    def forward(self, tau):
        return t2v(tau.unsqueeze(-1), self.f, self.out_features, self.w, self.b, self.w0, self.b0)


class ContinuousValueEmbedding(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.W = nn.Linear(1, d_model*2)
        self.U = nn.Linear(d_model*2, d_model)
        self.tanh = nn.Tanh()
    def forward(self, x):
        out = self.W(x.unsqueeze(2))
        out = self.tanh(out)
        out = self.U(out)
        return out


class VariableEmbedding(nn.Module):
    def __init__(self, d_model, num_variables):
        super().__init__()
        self.embedding = nn.Embedding(num_variables+1, d_model)
        
    def forward(self, x):
        return self.embedding(x)
    
    

class Embedding(nn.Module):
    def __init__(self, d_model, num_variables, sinusoidal):
        super().__init__()
        self.sinusoidal = sinusoidal
        self.cvs_value = ContinuousValueEmbedding(d_model)
        if sinusoidal:
            self.cvs_time = SineActivation(1, d_model)
        if sinusoidal == "both":
            self.cvs_time = ContinuousValueEmbedding(d_model)
            self.sin_time = SineActivation(1, d_model)
        else:
            self.cvs_time = ContinuousValueEmbedding(d_model)
        self.var_embed = VariableEmbedding(d_model, num_variables)
    def forward(self, encoder_input):
        time = encoder_input[0]
        variable = encoder_input[1]
        value = encoder_input[2]
        if self.sinusoidal == "both":
            time_embed = self.cvs_time(time) + self.sin_time(time)
        else:
            time_embed = self.cvs_time(time)
        embed = time_embed + self.cvs_value(value) + self.var_embed(variable)
        return embed

class Attention(nn.Module):
    def __init__(self, d_model, d, dropout=0.2):
        super().__init__()
        self.d_model = d_model
        self.d = d
        self.Q = nn.Linear(d_model, d)
        self.K = nn.Linear(d_model, d)
        self.V = nn.Linear(d_model, d)
        self.dropout = nn.Dropout(dropout)
    def forward(self,x, mask): 
        q = self.Q(x) 
        k = self.K(x)
        v = self.V(x) 
        weights = q@k.transpose(-2,-1)*k.shape[-1]**(-0.5) 
        weights = weights.masked_fill(mask == 0, float('-inf'))
        weights = F.softmax(weights, dim = -1) 
        self.dropout(weights)
        out = weights @ v
        return out 

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout = 0.2):
        super().__init__()
        self.heads = nn.ModuleList([Attention(d_model, d_model//n_heads) for _ in range(n_heads)])
        self.proj = nn.Linear(n_heads*(d_model//n_heads), d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        out = torch.cat([h(x, mask) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedForwardBlock(nn.Module):
    def __init__(self, d_model, d_ff, dropout = 0.2):
        super().__init__()
        self.dropout = nn.Dropout(0.2)
        self.W1 = nn.Linear(d_model, d_ff)
        self.W2 = nn.Linear(d_ff, d_model)
    def forward(self, x):
        out = self.W1(x)
        out = F.relu(out)
        out = self.dropout(self.W2(out))
        return out

class EncoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.multi_attention = MultiHeadAttention(d_model, n_heads)
        self.ffb = FeedForwardBlock(d_model, d_ff)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
    
    def forward(self, x, mask):
        out = self.multi_attention(x, mask)
        out1 = x + self.ln2(out)
        out2 = self.ffb(out1)
        out = out1 + self.ln2(out2)
        return out

class Encoder(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, num_variables , N, sinusoidal):
        super().__init__()
        self.embedding = Embedding(d_model, num_variables, sinusoidal)
        self.encoder_blocks = nn.ModuleList([EncoderBlock(d_model, n_heads, d_ff) for _ in range(N)])
        self.N = N
    
    def forward(self, encoder_input, mask):
        time = encoder_input[0]
        variable = encoder_input[1]
        value = encoder_input[2]
        x = self.embedding((time, variable, value))
        for block in self.encoder_blocks:
            x = block(x, mask)
        return x

class FusionSelfAttention(nn.Module):
    def __init__(self, d_model, dropout = 0.2):
        super().__init__()
        self.Wa = nn.Linear(d_model, d_model)
        self.Ua = nn.Linear(d_model, d_model)
        self.Va = nn.Linear(d_model, 1)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, out, mask):
        q = out.unsqueeze(2) 
        k = out.unsqueeze(1) 
        v = out 
        a = F.tanh(self.Wa(q) + self.Ua(k)) 
        wei = self.Va(self.dropout(a)).squeeze()
        wei = wei.masked_fill(mask == 0, float('-inf'))
        wei = F.softmax(wei, dim = -1)
        wei = self.dropout(wei)
        out = wei@v
        return out
        
class Model(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, num_variables, N, sinusoidal = False):
        super().__init__()
        self.encoder = Encoder(d_model, n_heads, d_ff, num_variables, N, sinusoidal)
        self.fsa = FusionSelfAttention(d_model)
        self.proj = nn.Linear(d_model, 1)
    
    def forward(self, x, mask):
        out = self.encoder(x, mask)
        out = self.fsa(out, mask)
        # out = out.masked_fill(mask.transpose(-2,-1)==0, 0)
        # out = out.sum(dim = 1)
        out = self.proj(out)
        return out.squeeze(-1)

In [11]:
model = Model(d_model, num_heads, d_ff, num_variables, N, sinusoidal).to(DEVICE)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

total_params = sum(p.numel() for p in model.parameters())
print(f'Total number of parameters: {total_params}')

best_val_loss = float('inf')
early_stopping_counter = 0
patience = 10 

def calculate_loss(model, data_loader):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for inputs in data_loader:
            outputs = model(inputs['encoder_input'], inputs['encoder_mask'])
            outputs = torch.where(torch.logical_or(pretraining_mask==0,pretraining_mask==-1), torch.tensor(0.0), outputs)
            labels = torch.where(torch.logical_or(pretraining_mask==0,pretraining_mask==-1), torch.tensor(0.0), batch['labels'])
            loss = criterion(outputs, labels)
            total_loss += loss.item()
    return total_loss/len(data_loader)

for epoch in range(epochs):
    total_loss = 0
    model.train()
    n = 0
    for batch in tqdm(train_dataloader_inhospital, desc=f'Epoch {epoch + 1}/{epochs}', leave=False, mininterval=1):
        inp = batch['encoder_input']
        mask = batch['encoder_mask']
        pretraining_mask = batch['pretraining_mask']
        outputs = model(inp, mask)
        outputs = torch.where(torch.logical_or(pretraining_mask==0,pretraining_mask==-1), torch.tensor(0.0), outputs)
        labels = torch.where(torch.logical_or(pretraining_mask==0,pretraining_mask==-1), torch.tensor(0.0), batch['labels'])
        loss = criterion(outputs, labels)
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        n+=1
        if n%500 == 0:
            val_loss = calculate_loss(model, val_dataloader_inhospital)
            print(f'Epoch {epoch + 1}/{epochs} batches {n}, Validation Loss: {val_loss:.3f}', end='\r')
    val_loss = calculate_loss(model, val_dataloader_inhospital)
    print(f'Epoch {epoch + 1}/{epochs}, Training Loss: {total_loss/len(train_dataloader_inhospital):.3f}')
    print(f'Epoch {epoch + 1}/{epochs}, Validation Loss: {val_loss:.3f}')
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1

    if early_stopping_counter >= patience:
        print(f"Early stopping after {epoch + 1} epochs.")
        break

Total number of parameters: 56990


                                                                                

Epoch 1/75, Training Loss: 0.363
Epoch 1/75, Validation Loss: 0.371


                                                                                

Epoch 2/75, Training Loss: 0.286
Epoch 2/75, Validation Loss: 0.271


                                                                                

Epoch 3/75, Training Loss: 0.263
Epoch 3/75, Validation Loss: 0.312


                                                                                

Epoch 4/75, Training Loss: 0.139
Epoch 4/75, Validation Loss: 0.588


                                                                                

Epoch 5/75, Training Loss: 0.065
Epoch 5/75, Validation Loss: 0.588


                                                                                

Epoch 6/75, Training Loss: 0.058
Epoch 6/75, Validation Loss: 0.637


                                                                                

Epoch 7/75, Training Loss: 0.056
Epoch 7/75, Validation Loss: 0.635


                                                                                

Epoch 8/75, Training Loss: 0.060
Epoch 8/75, Validation Loss: 0.543


                                                                                

Epoch 9/75, Training Loss: 0.051
Epoch 9/75, Validation Loss: 0.543


                                                                                

Epoch 10/75, Training Loss: 0.051
Epoch 10/75, Validation Loss: 0.463


                                                                                

Epoch 11/75, Training Loss: 0.045
Epoch 11/75, Validation Loss: 0.596


                                                                                

Epoch 12/75, Training Loss: 0.043
Epoch 12/75, Validation Loss: 0.526
Early stopping after 12 epochs.


In [None]:
model = Model(d_model, num_heads, d_ff, num_variables, N).to(DEVICE)

In [None]:
out = model(batch['encoder_input'], batch['encoder_mask'])

In [None]:
out.shape

In [None]:
mask = batch['pretraining_mask']

In [None]:
batch['pretraining_mask'].shape

In [None]:
mask = mask.to(DEVICE)

In [None]:
batch['encoder_input'][2].shape

In [None]:
out = torch.where(torch.logical_or(mask==0,mask==-1), torch.tensor(0.0), out)
# output = torch.where((mask==0) or (mask==-1), torch.tensor(0.0), output)

In [None]:
mask

In [None]:
out.shape

In [None]:
torch.logical_or(mask==0, mask==1).shape

In [None]:
out.shape

In [None]:
torch.rand(100) < 0.1

In [None]:
import torch

def mask()