# Tensor Fusion Network w/ Decomposition & Cox Survival Analysis

## Initial Imports

In [1]:
import torch
from torch.autograd import Variable
from torchvision import models
import sys
import numpy as np
import torchvision
import torch.nn as nn
import torch.optim as optim
import argparse
import time
import tensorly as tl
from tensorly.decomposition import partial_tucker
from decompositions import cp_decomposition_conv_layer, tucker_decomposition_conv_layer

from VBMF import VBMF

from torch.utils.data import Dataset, DataLoader, random_split

from torch.nn.utils.rnn import pad_sequence

import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.utils.rnn as rnn_utils
import torch.nn.functional as F
from torchvision.io import read_image
from PIL import Image
import glob
import os

import matplotlib.pyplot as plt
plt.ion()   # interactive mode

import pandas as pd
import numpy as np
from glob import glob
import os, os.path
import matplotlib.pyplot as plt
from numpy import asarray

import sklearn.metrics
from sklearn.metrics import accuracy_score, mean_squared_error, confusion_matrix, roc_auc_score
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.model_selection import train_test_split

from sksurv.datasets import load_gbsg2
# from sksurv.preprocessing import OneHotEncoder
from sksurv.ensemble import RandomSurvivalForest
from sksurv.metrics import concordance_index_censored, concordance_index_ipcw, integrated_brier_score

import itertools
from itertools import *

import datetime
import pickle

import torchtuples as tt
from pycox.models import CoxCC
from pycox.utils import kaplan_meier
from pycox.evaluation import EvalSurv

from ptflops import get_model_complexity_info
import torchprofile

#### Apply Configuration Changes

In [2]:
torch.autograd.detect_anomaly(True)

  torch.autograd.detect_anomaly(True)


<torch.autograd.anomaly_mode.detect_anomaly at 0x7fb564191790>

## Determine Compute Device

In [3]:
# Use GPU if possible
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# ^ Usually cuda:0, but at time of writing all avaliable memory on GPU 0 is in use.

# Use PyTorch as Tensorly Backend
tl.set_backend('pytorch')

# Force CPU Evaluation (Not Recommended)
# device = torch.device("cpu")


# Utility Functions

In [4]:
def estimate_ranks(input_tensor):
    """ Unfold the 3 modes of the 3D tensor and estimate ranks using VBMF for each mode """
    
    # Squeeze to remove the first singleton dimension
    input_tensor = input_tensor.squeeze(0)  
#     print(input_tensor)
    
    input_tensor[input_tensor == 0] = 1e-6  # This keeps zeros from causing numerical instability
    
    print("Input Tensor Shape:", input_tensor.size())
    print("Any NaNs in tensor:", torch.isnan(input_tensor).any())
    print("Any zeros in tensor:", (input_tensor == 0).any())
    
    # Unfold the tensor along each mode (0, 1, and 2)
    unfold_0 = tl.base.unfold(input_tensor, 0)  # Unfold along mode-0
    unfold_1 = tl.base.unfold(input_tensor, 1)  # Unfold along mode-1
    unfold_2 = tl.base.unfold(input_tensor, 2)  # Unfold along mode-2
    
    # Add a small epsilon to ensure numerical stability
    epsilon = 1e-8
    unfold_0 += epsilon
    unfold_1 += epsilon
    unfold_2 += epsilon
    
    # Apply VBMF to estimate the ranks for each unfolded matrix
    _, diag_0, _, _ = VBMF.EVBMF(unfold_0)
    _, diag_1, _, _ = VBMF.EVBMF(unfold_1)
    _, diag_2, _, _ = VBMF.EVBMF(unfold_2)
    
    # The ranks are the number of singular values (or latent dimensions) estimated for each mode
    ranks = [diag_0.shape[0], diag_1.shape[1], diag_2.shape[1]]
    return ranks

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    #plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    
def calculate_flop(net):
    macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, backend='aten', \
                                             print_per_layer_stat=False, verbose=False)
    flops = macs[:-5]
    flops = float(flops)
    flops = flops * 2
    
    flops = str(flops)
    flops = flops + " FLOP"
    
    print('{:<30}  {:<8}'.format('Computational Complexity: ', macs))
    print('{:<30}  {:<8}'.format('Approximate FLOP: ', flops))
    # print('{:<30}  {:<8}'.format('Number of parameters: ', params))
    
    return macs

# Declare TFN Constants

In [5]:
AUTOENCODER_INPUTS = 4

## Import Data

### Deserialize AD Dataset

In [6]:
pd.options.display.max_columns = None

ad_patient_df = pd.read_csv("AD_Patient_Manifest.csv")

ad_patient_df.reset_index(drop=True)

ad_patient_df

Unnamed: 0,PTID,path
0,053_S_1044,patients_csv/053_S_1044.pkl
1,035_S_0204,patients_csv/035_S_0204.pkl
2,027_S_0256,patients_csv/027_S_0256.pkl
3,128_S_0230,patients_csv/128_S_0230.pkl
4,114_S_0173,patients_csv/114_S_0173.pkl
...,...,...
810,029_S_0843,patients_csv/029_S_0843.pkl
811,129_S_1204,patients_csv/129_S_1204.pkl
812,032_S_1169,patients_csv/032_S_1169.pkl
813,099_S_0060,patients_csv/099_S_0060.pkl


### Define Patient Dataset and Dataloader

#### Patient Dataset

In [7]:
# CustomDataset gets ADNI cohorts for a NN.
class Patient_Dataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.manifest = dataframe
        self.transform = transform # Apply any given transformations.

    def __getitem__(self, index):
        row_entry = self.manifest.iloc[index]['path'] # Get the row (patient) we want to read.
        cohort = pd.read_pickle(row_entry)

        image = Image.open(cohort.iloc[0]['image_path'])
        if self.transform :
            image = self.transform(image)
        
        # For Cox Model, label should be in the tuple: duration, event
        mmse = torch.tensor(cohort.iloc[0]['MMSE'], dtype=torch.float32)
        dx = torch.tensor(cohort.iloc[0]['DX_encoded'], dtype=torch.float32)
        label = (mmse, dx)
        
        demographics = torch.tensor(cohort.iloc[0]['one_hot_vector'], dtype=torch.float32)
    
        time_series = cohort[['Years_bl', 'ADAS11', 'ADAS13', 'ADASQ4']]
        
        # Convert to Tensor
        time_series_tensor = torch.tensor(time_series.values, dtype=torch.float32)
        
        patient = (image, demographics, time_series_tensor)
        
        return patient, label

    def __len__(self):
        return len(self.manifest)

#### Split MIMIC-IV Dataset

In [8]:
# Split MIMIC-IV Dataset into 80-20% for training and testing.
train = ad_patient_df.sample(frac=0.8,random_state=200)
test = ad_patient_df.drop(train.index)

print(len(train))
print(len(test))
# train
# train.iloc[0]

652
163


#### Define Patient Dataloader

In [9]:
# Data augmentation and normalization for training
# Just normalization for validation
# Will need to be applied by passing in to Dataset constructor!
data_transforms = {
    'train': transforms.Compose([
        torchvision.transforms.Grayscale(num_output_channels=3),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
#         transforms.Normalize((0.5), (0.5))
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        torchvision.transforms.Grayscale(num_output_channels=3),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
#         transforms.Normalize((0.5), (0.5))
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

dl_args = dict(batch_size=16, num_workers=4)

# Dataloaders used to iterate through the patients. Patients split 80-20% into train-test loaders.
train_dataset = Patient_Dataset(train, data_transforms['train'])
train_dataset.transform = data_transforms['train']
train_dataloader = DataLoader(train_dataset)

test_dataset = Patient_Dataset(test, data_transforms['test'])
test_dataset.transform = data_transforms['test']
test_dataloader = DataLoader(test_dataset)

dataloaders = {'train': train_dataloader,
              'test': test_dataloader,
              }

i = 0
for patient in dataloaders['train']:
    i = i + 1
    if i > 5:
        break
    print(patient)

[[tensor([[[[-1.8953, -1.8782, -1.8439,  ..., -2.1179, -2.1179, -2.1179],
          [-1.3815, -1.3302, -1.2788,  ..., -2.1179, -2.1179, -2.1179],
          [-0.8507, -0.7822, -0.7137,  ..., -2.1179, -2.1179, -2.1179],
          ...,
          [-0.5253, -0.5424, -0.5938,  ..., -2.1179, -2.1179, -2.1179],
          [-0.6623, -0.6623, -0.7137,  ..., -2.1179, -2.1179, -2.1179],
          [-0.7650, -0.7822, -0.8507,  ..., -2.1179, -2.1179, -2.1179]],

         [[-1.8081, -1.7906, -1.7556,  ..., -2.0357, -2.0357, -2.0357],
          [-1.2829, -1.2304, -1.1779,  ..., -2.0357, -2.0357, -2.0357],
          [-0.7402, -0.6702, -0.6001,  ..., -2.0357, -2.0357, -2.0357],
          ...,
          [-0.4076, -0.4251, -0.4776,  ..., -2.0357, -2.0357, -2.0357],
          [-0.5476, -0.5476, -0.6001,  ..., -2.0357, -2.0357, -2.0357],
          [-0.6527, -0.6702, -0.7402,  ..., -2.0357, -2.0357, -2.0357]],

         [[-1.5779, -1.5604, -1.5256,  ..., -1.8044, -1.8044, -1.8044],
          [-1.0550, -1.0027,

### Time Series Dataset & Dataloader
!!! Used for training autoencoder only !!!

#### Custom Collate Function

In [10]:
# collate_fn handles padding on inputs.
def collate_fn(batch):
    # Separate the data and labels
    sequences, labels = zip(*batch)
    
    # Pad sequences
    padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=0.0)
    
    # Stack labels into a tensor
    labels = torch.stack(labels)
    
    return padded_sequences, labels

#### Time Series Dataset

In [11]:
# Initialize encoders
onehot_encoder = OneHotEncoder(sparse_output=False)
label_encoder = LabelEncoder()

class InHospitalMortalityDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.manifest = dataframe
        self.transform = transform # Apply any given transformations.
        
    def __len__(self):
        return len(self.manifest)    
    
    def __getitem__(self, index):
        row_entry = self.manifest.iloc[index]['path'] # Get the row (patient) we want to read.
        cohort = pd.read_pickle(row_entry)
        
        time_series_data = cohort[['Years_bl', 'ADAS11', 'ADAS13', 'ADASQ4']]
        
        time_series_tensor = torch.tensor(time_series_data.values, dtype=torch.float32)
        
        return time_series_tensor, time_series_tensor  # Input and target are the same

#### Time Series Dataloader

In [12]:
train_listfile_path = 'time_series_list.csv'
train_files_path = '/home/mason/TFN/patient_time_series'
batch_size = 1

time_train_dataset = InHospitalMortalityDataset(ad_patient_df, None)

time_train_loader = DataLoader(dataset=time_train_dataset, batch_size=batch_size, shuffle=True, \
                               collate_fn=collate_fn)

i = 0
for time_series, label in time_train_loader:
    i = i + 1
    if i > 5:
        break
    print(time_series)

tensor([[[ 0.0000, 13.3300, 21.3300,  5.0000],
         [ 0.4956, 10.6700, 16.6700,  5.0000],
         [ 1.0157, 12.6700, 21.6700,  9.0000],
         [ 1.6290, 19.6700, 24.6700,  5.0000],
         [ 2.0233, 18.3300, 24.3300,  6.0000],
         [ 3.3949, 18.0000, 24.0000,  6.0000],
         [ 4.5832, 31.0000, 45.0000,  9.0000],
         [ 5.9795, 33.0000, 44.0000, 10.0000],
         [ 7.1294, 40.0000, 55.0000, 10.0000],
         [ 8.0274, 52.0000, 66.0000, 10.0000]]])
tensor([[[ 0.0000, 10.6700, 21.6700,  9.0000],
         [ 0.4600, 14.3300, 25.3300, 10.0000],
         [ 0.9172, 17.0000, 27.0000,  9.0000],
         [ 1.5825, 15.3300, 26.3300,  9.0000],
         [ 1.9932, 21.0000, 33.0000, 10.0000]]])
tensor([[[ 0.0000,  7.6700, 12.6700,  4.0000],
         [ 0.5503,  5.3300,  9.3300,  3.0000],
         [ 1.0075,  5.6700,  9.6700,  4.0000],
         [ 2.0233,  4.6700,  6.6700,  2.0000],
         [ 2.9815,  5.3300,  7.3300,  2.0000],
         [ 4.0192,  5.6700,  7.6700,  2.0000],
         

## Define Models

The Fusion Network will fuse the embeddings of three input models---X-Ray, Demographics, and Time Series Data.

### X-Ray Embedder

(VGG16)

In [13]:
class ModifiedVGG16Model(torch.nn.Module):
    def __init__(self, model=None):
        super(ModifiedVGG16Model, self).__init__()

        model = models.vgg16(weights='IMAGENET1K_V1')
        self.features = model.features
        
        # When embedding, we only want the output of the first FC layer.
        self.embedder = nn.Sequential(
            nn.Dropout(),
            nn.Linear(25088, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True)) # This last ReLU layer may be unnecessary
        
        # Contains the Tail of VGG16 (all 3 FC layers and ReLU, when combined with embedder)
        self.classifier = nn.Sequential(
            nn.Linear(4096, 2),
            nn.Sigmoid())
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.embedder(x)
        x = self.classifier(x)
        return x
    
    def embed(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.embedder(x)
        # May need nn.flatten()
        return x

### Demographics Embedder

Ingests:
 - Age
 - Gender
 - Martial Status
 - Ethnicity

All words will need to be reshaped to a common length (set by the longest word/string) and will be concatonated from there so the final dimension of the input tensor (representing concat words) is understood.

In [14]:
class DemographicsEmbedder(torch.nn.Module):
    def __init__(self, model=None):
        super(DemographicsEmbedder, self).__init__()
        
#         wordvec_model = gensim.models.KeyedVectors.load_word2vec_format('path/to/file')
#         wordvec_weights = torch.FloatTensor(wordvecmodel.vectors) # formerly syn0, which is soon deprecated
        
        self.features = nn.Sequential(
            nn.Conv1d(1, 10, 3), # Fix input size!
            nn.MaxPool1d(2))
        
        self.embedder = nn.Sequential(
            nn.Dropout(),
            nn.Flatten(),
            nn.Linear(140, 70),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(70, 10),
            nn.ReLU(inplace=True))

        # FIXME: Update classifier to handle embedding and Conv1D output.
        self.regressor = nn.Sequential(
            nn.Linear(10, 1))
        
        # Use one-hot encoding before embedding.
        # After can use embedder for NN.
#         self.embedder = nn.Sequential(
#             nn.Embedding.from_pretrained(wordvec_weights))
        
    def embed(self, demographics):
        x = demographics
        x = x.view(1, 1, 30) # Reshape tensor to [N, C, L] format expected by Conv1D
        x = self.features(x)
        x = self.embedder(x)
        return x;
        
    def forward(self, demographics):
        x = demographics
        x = x.view(1, 1, 30) # Reshape tensor to [N, C, L] format expected by Conv1D
        x = self.features(x)
        x = self.embedder(x)
        x = self.regressor(x)
        return x;

In [15]:
# Test Demographics Embedder
def test_demographics():
    test_demo_embedder = DemographicsEmbedder()

    i = 0
    for entry in dataloaders['train']:
        patient = entry[0]
        label = entry[1]
        
        demographics = patient[1]
        
        i = i + 1
        if i > 5:
            break
    #     print(patient)
    #     print(patient[2])
    #     print(patient[2][1][0])
    #     print(type(patient[2][1][0]))
        print(test_demo_embedder.forward(demographics))
        
test_demographics()

tensor([[-0.0095]], grad_fn=<AddmmBackward0>)
tensor([[0.0092]], grad_fn=<AddmmBackward0>)
tensor([[-0.1246]], grad_fn=<AddmmBackward0>)
tensor([[-0.1113]], grad_fn=<AddmmBackward0>)
tensor([[-0.0929]], grad_fn=<AddmmBackward0>)


### Time Series Data Autoencoder

#### Define Autoencoder

In [16]:
class Time_Series_Autoencoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super(Time_Series_Autoencoder, self).__init__()
        self.encoder = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.hidden_to_latent = nn.Linear(hidden_size, latent_size)
        self.latent_to_hidden = nn.Linear(latent_size, hidden_size)
        self.decoder = nn.LSTM(input_size, hidden_size, batch_first=True)
#         self.decoder = nn.LSTM(hidden_size, input_size, batch_first=True)
        self.output_layer = nn.Linear(hidden_size, input_size)  # Additional final linear layer

    def forward(self, x, lengths):
        # Pack the padded sequence
        packed_x = rnn_utils.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        
        # Encoder
        packed_h, (h, c) = self.encoder(packed_x)
        h = h[-1]  # Get the hidden state from the last layer of the LSTM
        latent = self.hidden_to_latent(h)
        
        # Prepare for decoder
        hidden = self.latent_to_hidden(latent).unsqueeze(0)
        cell = torch.zeros_like(hidden)
        
        # Decoder
        packed_output, _ = self.decoder(packed_x, (hidden, cell))
        
        # Pad the packed sequence
        decoded, _ = rnn_utils.pad_packed_sequence(packed_output, batch_first=True)
        
        # Apply the final linear layer to map hidden state to input size
        decoded = self.output_layer(decoded)
        
        return latent, decoded
    
    def encode(self, x, lengths):
        with torch.no_grad():
            # Pack the padded sequence
            packed_x = rnn_utils.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
            
            # Encoder
            packed_h, (h, c) = self.encoder(packed_x)
            h = h[-1]  # Get the hidden state from the last layer of the LSTM
            latent = self.hidden_to_latent(h)
        
        return latent

#### Autoencoder without Training

In [17]:
# Define Autoencoder Parameters
input_size = AUTOENCODER_INPUTS  # Number of features in time series data
hidden_size = 32
latent_size = 8 # Attempt to get good bottleneck, given large length of time-series data.

test_autoencoder = Time_Series_Autoencoder(input_size, hidden_size, latent_size)

# test_autoencoder.to(device)

# Forward pass example
for time_series, label in time_train_loader:
    lengths = [len(seq) for seq in time_series]
    latent, reconstructed = test_autoencoder(time_series, lengths)
    print("Original Shape", time_series.shape)
    print("Latent representation shape:", latent.shape)
    print("Reconstructed shape:", reconstructed.shape)
#     print(time_series)
#     print(latent)
#     print(reconstructed)
    break

Original Shape torch.Size([1, 1, 4])
Latent representation shape: torch.Size([1, 8])
Reconstructed shape: torch.Size([1, 1, 4])


#### Training Autoencoder

In [18]:
def train_autoencoder(model, dataloader, num_epochs, learning_rate):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    print("Beginning Training")
    print("Total Entries:", len(dataloader))
    
    model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        currIdx = 0
        for time_series, label in dataloader:
            # Calculate lengths of sequences (non-zero length for padded sequences)
#             print(time_series)
            lengths = [len(seq) for seq in time_series]
#             print("Lengths Calculated")
            
            optimizer.zero_grad()
            latent, reconstructed = model(time_series, lengths)

            loss = criterion(reconstructed, time_series)

            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
#             print("Original Sequence:")
#             print(sequences)
#             print("Autoencoder Reconstruction:")
#             print(reconstructed)
#             print("Loss:")
#             print(loss)
        
            currIdx = currIdx + 1
            if (currIdx % 100 == 0):
                print("Current Entry:", currIdx)
                print("Current Loss:", loss.item())
        
        avg_loss = epoch_loss / len(dataloader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
        
    model.eval()

In [19]:
num_epochs = 6
learning_rate = 0.001

# train_autoencoder(test_autoencoder, time_train_loader, num_epochs, learning_rate)

# torch.save(test_autoencoder, "TFN_AEC_R1")

#### Autoencoder after training

In [20]:
completed_autoencoder = torch.load('TFN_AEC_R1')#.to(device)

# completed_autoencoder.to(device)

for time_series, label in time_train_loader:
    lengths = [len(seq) for seq in time_series]
    latent, reconstructed = completed_autoencoder(time_series, lengths)
    print("Original Shape", time_series.shape)
    print("Latent representation shape:", latent.shape)
    print("Reconstructed shape:", reconstructed.shape)
#     print(time_series)
#     print(latent)
#     print(reconstructed)
    break

Original Shape torch.Size([1, 3, 4])
Latent representation shape: torch.Size([1, 8])
Reconstructed shape: torch.Size([1, 3, 4])


# Medical Fusion Network


VGG-16 Feature Vector Autoencoder

In [21]:
class VGG16Autoencoder(nn.Module):
    def __init__(self):
        super(VGG16Autoencoder, self).__init__()
        # Encoder: Reduce the 4096 feature vector into a smaller latent space
        self.encoder = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 16),  # Reduce to 128-dim latent space
            nn.ReLU(True))
        
        # Decoder: Reconstruct back to 4096 dimensions
        self.decoder = nn.Sequential(
            nn.Linear(16, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 4096))
        
    def forward(self, x):
        encoded = self.encoder(x)  # Compress
        decoded = self.decoder(encoded)  # Reconstruct
        return encoded, decoded

Train VGG Autoencoder

In [22]:
vgg_autoencoder = VGG16Autoencoder()
vgg_train_model = ModifiedVGG16Model()

vgg_autoencoder.to(device)
vgg_train_model.to(device)

def test_vgg_autoencoder():
    i = 0
    for entry in dataloaders['train']:
        i = i + 1
        if i > 1:
            break
        patient = entry[0]
        x_ray = patient[0].to(device)
        visual_features = vgg_train_model.embed(x_ray)
        encoded, decoded = vgg_autoencoder(visual_features)
        print(encoded)
        
test_vgg_autoencoder()

tensor([[0.0114, 0.0000, 0.0000, 0.0186, 0.0221, 0.0872, 0.0617, 0.0901, 0.0004,
         0.0394, 0.1041, 0.0320, 0.0596, 0.0733, 0.0000, 0.0116]],
       device='cuda:1', grad_fn=<ReluBackward0>)


In [23]:
def train_vgg_autoencoder(model, vgg_embedder, dataloader, num_epochs, learning_rate):
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    print("Beginning Training")
    print("Total Entries:", len(dataloader))
    
    model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        currIdx = 0
        for entry in dataloader:
            patient = entry[0]
            x_ray = patient[0].to(device)
            visual_features = vgg_embedder.embed(x_ray)
            
            optimizer.zero_grad()
                       
            latent, reconstructed = model(visual_features)

            loss = criterion(reconstructed, visual_features)

            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
#             print("Original Sequence:")
#             print(sequences)
#             print("Autoencoder Reconstruction:")
#             print(reconstructed)
#             print("Loss:")
#             print(loss)
        
            currIdx = currIdx + 1
#             if (currIdx % 100 == 0):
#                 print("Current Entry:", currIdx)
#                 print("Current Loss:", loss.item())
        
        avg_loss = epoch_loss / len(dataloader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
        
    model.eval()

In [24]:
num_epochs = 5
learning_rate = 0.0001

# train_vgg_autoencoder(vgg_autoencoder, vgg_train_model, dataloaders['train'], num_epochs, learning_rate)

# torch.save(vgg_autoencoder, "VGG_AEC_R1")

### Complete Model - Fusion

(Contains X-Ray Embedder and Demographics Embedder)

In [25]:
class MedicalFusionNetwork(torch.nn.Module):
    def __init__(self, model=None):
        super(MedicalFusionNetwork, self).__init__()
        
        self.visual_embedder = ModifiedVGG16Model()
        self.visual_autoencoder = torch.load("VGG_AEC_R1")
        self.demographics_embedder = DemographicsEmbedder()
        self.autoencoder = torch.load('TFN_AEC_R1')
        
        self.regression = nn.Sequential(
            nn.Dropout(),
            nn.Linear(576, 288),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(288, 144),
            nn.ReLU(inplace=True),
            nn.Linear(144, 1))
        
        self.classification = nn.Sequential(
            nn.Dropout(),
            nn.Linear(576, 1152),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(1152, 576),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(576, 288),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(288, 144),
            nn.ReLU(inplace=True),
            nn.Linear(144, 3))
        
        self.softmax = nn.Softmax(dim=1)
        
    def tensor_fusion(self, imagery, demographics, time_series):
        # Get feature outputs from model subnets.
        visual_features = self.visual_embedder.embed(imagery)
        visual_latent, visual_reconstructed = self.visual_autoencoder(visual_features)
        visual_features = visual_latent
        demographic_features = self.demographics_embedder.embed(demographics)
        lengths = [len(seq) for seq in time_series]
        latent, reconstructed = self.autoencoder(time_series, lengths)

        # Concat 1's onto feature vectors to prepare for Tensor Fusion
        visual_h = torch.concat((torch.ones(1, 1).to(device), visual_features), dim=1)
        demographics_h = torch.concat((torch.ones(1, 1).to(device), demographic_features), dim=1)
        time_series_h = torch.concat((torch.ones(1, 1).to(device), latent), dim=1)
    
        # Perform Tensor Fusion (Kronecker Product)
        outer_xy = torch.einsum('bi,bj->bij', demographics_h, time_series_h)  # Shape: (batch_size, x_dim, y_dim)
        outer_xyz = torch.einsum('bij,bk->bijk', outer_xy, visual_h)  # Shape: (batch_size, x_dim, y_dim, z_dim)

        # Prevent tensor values from becoming extreme.
        outer_xyz = torch.clamp(outer_xyz, min=-1e6, max=1e6)
        
        return outer_xyz
        
    def tucker_feature_extraction(self, patient):
        imagery = patient[0]
        demographics = patient[1]
        time_series = patient[2]
        
        # Tensor Fusion
        outer_xyz = self.tensor_fusion(imagery, demographics, time_series).to('cpu')
        
        # Tucker Decomposition      
        with torch.no_grad():
            (core, factors), rec_error = partial_tucker(outer_xyz, modes=[0, 1, 2, 3], rank=[1, 6, 6, 16])
            # README: Make table comparing rank output performance
        
        # Flatten the resulting tensor for use in FC layer
        outer_xyz_flattened = core.reshape(core.size(0), -1)
        
#         print("Decomposed Size:", outer_xyz_flattened.size())
        
        return outer_xyz_flattened
        
    def regression_classification(self, patient):
        fusion = self.tucker_feature_extraction(patient).to(device)
        
        los = self.regression(fusion)
        mortality = self.classification(fusion)
#         mortality = self.softmax(mortality)
        out = (los, mortality)
        return out
    
    def forward(self, patient):
        fusion = self.tucker_feature_extraction(patient)
    
        return fusion

## Testing Complete Model

Untrained Fusion

In [26]:
MFN = MedicalFusionNetwork()
MFN.to(device)

def test_medical_net():
    i = 0
    for entry in dataloaders['train']:
        i = i + 1
        if i > 5:
            break
        patient = entry[0]
        patient = (patient[0].to(device), patient[1].to(device), patient[2].to(device))
        label = entry[1]
        label = (label[0].to(device), label[1].to(device))
        pred = MFN.regression_classification(patient)
        print("Label:")
        print(label)
        print("Prediction:")
        print(pred)
        
test_medical_net()

Label:
(tensor([30.], device='cuda:1'), tensor([0.], device='cuda:1'))
Prediction:
(tensor([[-0.0576]], device='cuda:1', grad_fn=<AddmmBackward0>), tensor([[ 0.0146,  0.0054, -0.0234]], device='cuda:1',
       grad_fn=<AddmmBackward0>))
Label:
(tensor([25.], device='cuda:1'), tensor([2.], device='cuda:1'))
Prediction:
(tensor([[-0.1732]], device='cuda:1', grad_fn=<AddmmBackward0>), tensor([[-0.0002,  0.0100, -0.0032]], device='cuda:1',
       grad_fn=<AddmmBackward0>))
Label:
(tensor([29.], device='cuda:1'), tensor([0.], device='cuda:1'))
Prediction:
(tensor([[-0.1079]], device='cuda:1', grad_fn=<AddmmBackward0>), tensor([[ 0.0192,  0.0074, -0.0177]], device='cuda:1',
       grad_fn=<AddmmBackward0>))
Label:
(tensor([29.], device='cuda:1'), tensor([2.], device='cuda:1'))
Prediction:
(tensor([[-0.0766]], device='cuda:1', grad_fn=<AddmmBackward0>), tensor([[ 0.0218,  0.0071, -0.0121]], device='cuda:1',
       grad_fn=<AddmmBackward0>))
Label:
(tensor([30.], device='cuda:1'), tensor([0.],

## Train Model

### Define Trainer

In [27]:
class Trainer:
    def __init__(self, dataloaders, model, optimizer):
        self.optimizer = optimizer

        self.model = model
        self.model.to(device)
        self.classification_criterion = nn.CrossEntropyLoss()
        self.regression_criterion = nn.MSELoss()
        self.model.train()
        self.train_regression_loss = []
        self.train_classification_loss = []

    def test(self):
        self.model.eval()
        total = 0
        total_time = 0
        running_regression_loss = 0.0
        running_classification_loss = 0.0
        classification_predictions = []
        classification_labels = []
        
        for i, (entry) in enumerate(dataloaders['test']):
            # Get image, demographics, and label from dataloader and send to device.
            patient = entry[0]
            patient = (patient[0].to(device), patient[1].to(device), patient[2].to(device))
            label = entry[1]
            label = (label[0].to(device), label[1].to(device))
            regression_label = label[0].to(device).unsqueeze(1)
            classification_label = label[1].to(device)
            classification_label = classification_label.long()
            
            # Start keeping time, and run model for output.
            t0 = time.time()
            
            output = self.model.regression_classification(patient)
            
            t1 = time.time()
            total_time = total_time + (t1 - t0)
            
            # Calculate item loss
            model_regression = output[0]
            model_classification = output[1].float()
            
#             print(classification_label)
#             print(F.softmax(model_classification, dim=1))
            
            # Keep track of current classification target
            classification_labels.extend(classification_label.float().to("cpu"))
            # Apply softmax to get probabilities
            probabilities = F.softmax(model_classification, dim=1)
            # Use argmax to get the index of the class with the highest probability
            predicted_class = torch.argmax(probabilities, dim=1)
            # extend output tracker 
            classification_predictions.extend(predicted_class.to("cpu").detach().numpy().tolist())
            
            regression_loss = self.regression_criterion(model_regression, regression_label)
            classification_loss = self.classification_criterion(model_classification, classification_label)
            
            # Add to running total
            running_regression_loss += regression_loss.item()
            running_classification_loss += classification_loss.item()
        
        # Get average loss
        total_regression_loss = running_regression_loss / len(dataloaders['test'])
        total_classification_loss = running_classification_loss / len(dataloaders['test'])
        
        # Print model training time and statistics.
        print("=== Regression Accuracy ===")
        print("Mean Squared Error:", total_regression_loss)
        print("=== Classification Accuracy ===")
        classification_labels = [ int(x.item()) for x in classification_labels ]
        classification_predictions = [ round(elem) for elem in classification_predictions ]
        print("Accuracy Score:", accuracy_score(classification_labels, classification_predictions))
        print("Cross Entropy Loss:", total_classification_loss)
        calc_time = float(total_time) / (i + 1)
        print("Total Prediction Time:", total_time)
        print('Average Prediction Time: {min}m {sec}s'.format(min=calc_time // 60.0, sec=calc_time % 60.0))
        print("Total Entries Compared: ", i + 1)
#         print(outputs)
        
        return (total_time, classification_labels, classification_predictions)

    def train(self, epoches=10):
        since = time.time()
        self.train_regression_loss = []
        self.train_classification_loss = []
        
        for i in range(epoches):
            print("Epoch: ", i)
            self.train_epoch()
            self.test()
            self.model.eval()
            
        print("Finished fine tuning.")
        time_elapsed = time.time() - since
        print(f'Training complete in  {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        

    def train_batch(self, patient, regression_label, classification_label):
        self.model.train()
        # Not sure why double works fine for forward, but not for backwards. .float() is here to fix this.

        output = self.model.regression_classification(patient)
        model_regression = output[0]
        model_classification = output[1].float()
#         model_classification = model_classification.squeeze()
        classification_label = classification_label.squeeze(dim=0).long()
#         print(model_classification.shape)
#         print(model_classification)
#         print(classification_label.shape)
#         print(classification_label)
        
        regression_loss = self.regression_criterion(model_regression, regression_label)
        
        classification_loss = self.classification_criterion(model_classification, classification_label)
        
        main_weight = 0.0
        aux_weight = 2.0
#         main_weight = 0.5
#         aux_weight = 0.5
        
        total_loss = (main_weight * regression_loss) + (aux_weight * classification_loss)
        
        self.model.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(MFN.parameters(), max_norm=1.0)
        self.optimizer.step()

    def train_epoch(self):
        for i, (entry) in enumerate(dataloaders['train']):
            # NOTE: Disable model.to(device) for better traceback.
#             self.model.to(device)
            patient = entry[0]
            patient = (patient[0].to(device), patient[1].to(device), patient[2].to(device))
            label = entry[1]
            label = (label[0].to(device), label[1].to(device))
            regression_label = label[0].to(device)
            classification_label = label[1].to(device)
            # Convert input from [1] to [1, 1] size to match input.
            regression_label = regression_label.unsqueeze(1)
            classification_label = classification_label.unsqueeze(1)
            self.train_batch(patient, regression_label, classification_label)
            if(i % 1000 == 0):
                print(i)

In [28]:
sys.exit(130)

SystemExit: 130

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


### Train Model

In [None]:
# MFN = torch.load("AD_TFN_R3")

# optimizer = torch.optim.Adam(MFN.parameters(), lr=0.00007)
optimizer = torch.optim.Adam(MFN.parameters(), lr=0.0001)
trainer = Trainer(datasets, MFN, optimizer)

# _ = trainer.test()

trainer.train(15)

torch.save(MFN, "AD_TFN_R4")

Epoch:  0
0
=== Regression Accuracy ===
Mean Squared Error: 731.12844474184
=== Classification Accuracy ===
Accuracy Score: 0.4171779141104294
Cross Entropy Loss: 1.3394819200953092
Total Prediction Time: 3.7917418479919434
Average Prediction Time: 0.0m 0.02326221992633094s
Total Entries Compared:  163
Epoch:  1
0
=== Regression Accuracy ===
Mean Squared Error: 731.12844474184
=== Classification Accuracy ===
Accuracy Score: 0.4171779141104294
Cross Entropy Loss: 1.6651038537430434
Total Prediction Time: 3.8397159576416016
Average Prediction Time: 0.0m 0.023556539617433137s
Total Entries Compared:  163
Epoch:  2
0
=== Regression Accuracy ===
Mean Squared Error: 731.12844474184
=== Classification Accuracy ===
Accuracy Score: 0.4171779141104294
Cross Entropy Loss: 1.6530235319704023
Total Prediction Time: 3.8338310718536377
Average Prediction Time: 0.0m 0.023520436023641948s
Total Entries Compared:  163
Epoch:  3
0
=== Regression Accuracy ===
Mean Squared Error: 731.12844474184
=== Classi

=== Regression Accuracy ===
Mean Squared Error: 7.7321083134624455
=== Classification Accuracy ===
Cross Binary Entropy Loss: 1.0958607646128151
Total Prediction Time: 4.07979154586792
Average Prediction Time: 0.0m 0.025029395986919754s
Total Entries Compared:  163

In [None]:
_, labels, preds = trainer.test()

In [None]:
print(labels)
print(preds)

#### Model after Training:

In [None]:
trained_concat = torch.load('6aec_concat_mfn_R2')

trained_concat_trainer = Trainer(datasets, trained_concat, optimizer)

_, _, _, _, _ = trained_concat_trainer.test()

In [None]:
trained_fusion = torch.load("6aec_fusion_mfn")

i = 0
for entry in dataloaders['test']:
    i = i + 1
    if i > 5:
        break
    patient = entry[0]
    patient = (patient[0].to(device), patient[1].to(device), patient[2].to(device))
    label = entry[1]
    label = (label[0].to(device), label[1].to(device))
    pred = trained_fusion.regression_classification(patient)
    print("Label:")
    print(label)
    print("Prediction:")
    print(pred)

# Perform Decomposition on completed TFN Model.

In [None]:
# Determine whether to use CPD/PARAFAC. Uses Tucker if False.
cp = False
decompose = False
origin_model = '6aec_concat_mfn_R2'
save_name = '6aec_concat_mfn_R2_decomposed'

if decompose:
    model = torch.load(origin_model).to(device)
    model.eval()
    model.to("cpu") # FIXME: Original code moves model to GPU, the CPU. Unnecessary?
    N = len(model.visual_embedder.features._modules.keys())
    for i, key in enumerate(model.visual_embedder.features._modules.keys()):

        if i >= N - 2:
            break
        if isinstance(model.visual_embedder.features._modules[key], torch.nn.modules.conv.Conv2d):
            conv_layer = model.visual_embedder.features._modules[key]
            if cp:
                rank = max(conv_layer.weight.data.numpy().shape)//3
                decomposed = cp_decomposition_conv_layer(conv_layer, rank)
            else:
                print("Tucker Performed!")
                decomposed = tucker_decomposition_conv_layer(conv_layer)

            model.visual_embedder.features._modules[key] = decomposed
            
    M = len(model.demographics_embedder.features._modules.keys())
    for i, key in enumerate(model.demographics_embedder.features._modules.keys()):

        if i >= M - 2:
            break
        if isinstance(model.demographics_embedder.features._modules[key], torch.nn.modules.conv.Conv2d):
            conv_layer = model.demographics_embedder.features._modules[key]
            if cp:
                rank = max(conv_layer.weight.data.numpy().shape)//3
                decomposed = cp_decomposition_conv_layer(conv_layer, rank)
            else:
                print("Tucker Performed!")
                decomposed = tucker_decomposition_conv_layer(conv_layer)

            model.demographics_embedder.features._modules[key] = decomposed

    torch.save(model, save_name)
    print("Decomposed Model Saved!")
else:
    print("Skipping Decomposition, Loading Decomposed Model from File.")

## Test Decomposed Model

In [None]:
def TestDecomposedModel():
    concat_decomposed_model = torch.load('6aec_concat_mfn_R2_decomposed').to(device)
    concat_decomposed_optimizer = optim.SGD(concat_medical_network.parameters(), lr=0.00007, momentum=0.70)
    concat_decomposed_trainer = Trainer(datasets, concat_medical_network, optimizer)

    _, _, _, d_classification_labels, d_classification_out = concat_decomposed_trainer.test()
    
#     print(d_classification_labels)
#     print(d_classification_out)

#     i = 0
#     for patient in dataloaders['train']:
#         i = i + 1
#         if i > 5:
#             break
#         print(patient)

    example_input, classes = next(iter(dataloaders['train'])) 
    example_input = (example_input[0].to(device), example_input[1].to(device), example_input[2].to(device))
    
    print("")
    print("----- Model Computational Complexity -----")
    # Calculate FLOPs
    flops = torchprofile.profile_macs(concat_decomposed_model, example_input)
    print(f'FLOPs: {flops}')

#     decomposed_macs = calculate_flop(concat_decomposed_model)
    
TestDecomposedModel()

# Preprocessing Features for Cox Survivability Analysis

#### New Patient Dataset & Dataloder containing all Patients
Subject_ID is now also returned in the label for constructing feature dataframe.

In [None]:
class ID_Patient_Dataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe # Reference the MIMIC-IV Dataframe.
        self.transform = transform # Apply any given transformations.

    def __getitem__(self, index):
        row = self.dataframe.iloc[index] # Get the row (patient) we want to read.

        image = Image.open(row['image_path'])
        if self.transform :
            image = self.transform(image)
        
        los = torch.tensor(row['los'], dtype=torch.float32)
        expire_flag = torch.tensor(row['hospital_expire_flag'], dtype=torch.float32)
        subject_id = torch.tensor(row['subject_id'], dtype=torch.float32)
        stay_id = torch.tensor(row['stay_id'], dtype=torch.float32)
        label = (los, expire_flag, subject_id, stay_id)
        
        demographics = torch.tensor(row['one_hot'], dtype=torch.float32)
    
        time_series = pd.read_csv(row['time_series_path'])
    
        # Drop charttime, since it isn't a feature
        time_series = time_series.drop('charttime', axis=1)

        # perform imputation
        time_series = time_series.ffill()
        time_series = time_series.bfill()
        
        # Convert to Tensor
        time_series_tensor = torch.tensor(time_series.values, dtype=torch.float32)
        
        patient = (image, demographics, time_series_tensor)
        
        return patient, label

    def __len__(self):
        return len(self.dataframe)

In [None]:
batch_size = 1

Total_ID_Dataset = ID_Patient_Dataset(mimic_df, data_transforms['test'])
Total_ID_Dataset.transform = data_transforms['test']
Total_ID_Dataloader = DataLoader(Total_ID_Dataset, batch_size, shuffle=False)

# i = 0
# for patient in Total_ID_Dataloader:
#     i = i + 1
#     if i > 5:
#         break
#     print(patient)

#### Apply TFN Network for Feature Extraction

In [None]:
def extract_patient_features():
    processing_tfn_model = torch.load('6aec_concat_mfn_R2').to(device)

    feature_list = []

    for i, (entry) in enumerate(Total_ID_Dataloader):
        patient = entry[0]
        patient = (patient[0].to(device), patient[1].to(device), patient[2].to(device))
        label = entry[1]
        label = [label[0].item(), label[1].item(), label[2].item(), label[3].item()]
        label[1] = int(label[1])
        label[2] = int(label[2])
        label[3] = int(label[3])

        # get features, put in panda
        output = processing_tfn_model.forward(patient)

        features = output[0]
        features = features.tolist()
        features = features + label

        feature_list.append(features)

    feature_dataframe = pd.DataFrame(feature_list)
    feature_dataframe.rename(columns={4122: 'duration', 4123: 'event', 4124: 'subject_id',\
                                      4125: 'stay_id'}, inplace=True)

    feature_dataframe.drop_duplicates(subset='stay_id', keep='first', inplace=True)
    feature_dataframe.to_csv('fusionmfn_patient_features.csv', index=False)

    feature_dataframe
    
extract_patient_features()

# Survivability Analysis utilizing Cox Survivability Model

#### Load Data into Memory

In [None]:
cox_train = pd.read_csv('patient_features.csv')
cox_train.drop(['subject_id', 'stay_id'], axis=1, inplace=True)
cox_train.columns = cox_train.columns.astype(str)
# 0 represents survival/censured data, and 1 represents mortality, so flipping is not necessary.
cox_train['event'] = cox_train['event'].replace({0:1, 1:0})
cox_train

In [None]:
cox_train['event'].sum()

In [None]:
cox_test = cox_train.sample(frac=0.2)
cox_train = cox_train.drop(cox_test.index)
cox_val = cox_train.sample(frac=0.2)
cox_train = cox_train.drop(cox_val.index)

In [None]:
cols_standardize = []
cols_leave = list(cox_test.columns)
del cols_leave[-2:]
# cols_leave = list(map(int, cols_leave)) DON'T convert to ints... we don't want this!

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]

x_mapper = DataFrameMapper(standardize + leave)

In [None]:
x_train = x_mapper.fit_transform(cox_train).astype('float32')
x_val = x_mapper.transform(cox_val).astype('float32')
x_test = x_mapper.transform(cox_test).astype('float32')

In [None]:
get_target = lambda df: (df['duration'].values, df['event'].values)
y_train = get_target(cox_train)
y_val = get_target(cox_val)
durations_test, events_test = get_target(cox_test)
val = tt.tuplefy(x_val, y_val)

In [None]:
val.shapes()

In [None]:
val.repeat(2).cat().shapes()

#### Cox-CC Input Network

In [None]:
in_features = x_train.shape[1]
num_nodes = [32, 32]
out_features = 1
batch_norm = True
dropout = 0.1
output_bias = False

cox_net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm,
                              dropout, output_bias=output_bias)

#### Cox-CC Network

In [None]:
# Invoke Cox model by wrapping it around MLP.
cox_model = CoxCC(cox_net, optimizer=tt.optim.Adam, device=device)

In [None]:
batch_size = 1
lrfinder = cox_model.lr_finder(x_train, y_train, batch_size, tolerance=2)
_ = lrfinder.plot()

In [None]:
lrfinder.get_best_lr()

In [None]:
cox_model.optimizer.set_lr(0.00001)

In [None]:
# epochs = 512
epochs = 120
# callbacks = [tt.callbacks.EarlyStopping()]
callbacks = []
verbose = True

In [None]:
fit_model = False

if fit_model:
    %%time
    log = cox_model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose,
                    val_data=val.repeat(10).cat())

In [None]:
if fit_model:
    _ = log.plot()

In [None]:
cox_model.partial_log_likelihood(*val).mean()

## Serialize/Deserialize Cox Model

In [None]:
serialize = False

if serialize:
    cox_model.save_net('6en_cox_invert_flag')
else:
    cox_model.load_net('6en_cox_invert_flag')

## Cox-CC Predictions

In [None]:
_ = cox_model.compute_baseline_hazards()

In [None]:
surv = cox_model.predict_surv_df(x_test)

In [None]:
surv.iloc[:, :10].plot()
plt.ylabel('S(t | x)')
_ = plt.xlabel('Time')

These are the individual predictions of patients within cox_test, printed below.

In [None]:
cox_test

# Cox-CC Evaluation

In [None]:
ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

#### Concordance Score

In [None]:
ev.concordance_td()

#### Brier Score

In [None]:
time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
ev.brier_score(time_grid).plot()
plt.ylabel('Brier score')
_ = plt.xlabel('Time')
plt.savefig('brier_time.png',format='png',dpi=1200,bbox_inches='tight')

#### Negative binomial log-likelihood

In [None]:
ev.nbll(time_grid).plot()
plt.ylabel('Negative Binomial Log-Likelihood')
_ = plt.xlabel('Time')
plt.savefig('nbll_time.png',format='png',dpi=1200,bbox_inches='tight')

#### Integrated Scores

In [None]:
ev.integrated_brier_score(time_grid)

In [None]:
ev.integrated_nbll(time_grid)

In [None]:
ev[2:3].plot_surv()
plt.ylabel('Survival')
_ = plt.xlabel('Time')

plt.savefig('example_survival_curve.png',format='png',dpi=1200,bbox_inches='tight')

# Compare Cox Model against Random Survival Forest

In [None]:
# Assuming df is your DataFrame with shape (1676, 4124)
patient_df = pd.read_csv('patient_features.csv')

# Separate features, duration, and event columns
X = patient_df.iloc[:, :-2].values  # First 4121 columns as features
durations = patient_df.iloc[:, -2].values  # Second to last column as duration
events = patient_df.iloc[:, -1].values  # Last column as event marker

# Create structured array for survival data
y = np.array([(e, t) for e, t in zip(events, durations)], dtype=[('event', 'bool'), ('time', 'float')])

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
# Train Random Survival Forest
rsf = RandomSurvivalForest(n_estimators=100, min_samples_split=10, min_samples_leaf=15, random_state=42)
rsf.fit(X_train, y_train)

In [None]:
# Get concordance score.
rsf.score(X_test, y_test)

In [None]:
# Predictions for evaluation
y_train_pred = rsf.predict_survival_function(X_train)
y_test_pred = rsf.predict_survival_function(X_test)

# Adjust the time grid to be within the observed durations of the test set
min_time_test = y_test['time'].min()
max_time_test = y_test['time'].max()
time_grid = np.linspace(min_time_test, max_time_test, 100)

# Convert the step functions to arrays
y_train_pred = np.asarray([[fn(t) for t in time_grid] for fn in y_train_pred])
y_test_pred = np.asarray([[fn(t) for t in time_grid] for fn in y_test_pred])

# Evaluate Performance
c_index_train = concordance_index_censored(y_train['event'], y_train['time'], rsf.predict(X_train))[0]
c_index_test = concordance_index_censored(y_test['event'], y_test['time'], rsf.predict(X_test))[0]

print(f'C-Index (Train): {c_index_train}')
print(f'C-Index (Test): {c_index_test}')

# FIXME: IBS Calculation is broken!!!
# Calculate Integrated Brier Score (IBS)
# ibs_train = integrated_brier_score(y_train, y_train, y_train_pred, time_grid)
# ibs_test = integrated_brier_score(y_train, y_test, y_test_pred, time_grid)

# print(f'Integrated Brier Score (Train): {ibs_train}')
# print(f'Integrated Brier Score (Test): {ibs_test}')

# Misc Testbenching

In [None]:
eval_tfn_model = torch.load('6aec_concat_mfn_decomposed').to(device)

In [None]:
def test_model(model, dataloader):
        model.eval()
        total = 0
        total_time = 0
        preds = []
        regression_labels = []
        classification_labels = []
        regression_outputs = []
        classification_outputs = []
        
        for i, (entry) in enumerate(dataloader):
            # Get image, demographics, and label from dataloader and send to device.
            patient = entry[0]
            patient = (patient[0].to(device), patient[1].to(device), patient[2].to(device))
            label = entry[1]
            label = (label[0].to(device), label[1].to(device))
            current_regression_label = label[0].to(device)
            current_classification_label = label[1].to(device)

            # Add each label to list keeping track of all entries.
            regression_labels.extend(current_regression_label.to("cpu"))
            classification_labels.extend(current_classification_label.to("cpu"))

            # Start keeping time, and run model for output.
            t0 = time.time()
            
            output = model.regression_classification(patient)
            
            t1 = time.time()
            total_time = total_time + (t1 - t0)
            
            # Add model outputs to output lists. (For later comparison)
            for output_tensor in output[0]:
#                 print("Regression:", output_tensor)
                regression_outputs.extend(output_tensor.to("cpu").detach().numpy().tolist())
                
            for output_tensor in output[1]:
#                 print("Classification:", output_tensor)
                classification_outputs.extend(output_tensor.to("cpu").detach().numpy().tolist())
        
        # Print model training time and statistics.
        print("=== Regression Accuracy ===")
        print("Mean Squared Error:", mean_squared_error(regression_labels, regression_outputs))
        print("=== Classification Accuracy ===")
        classification_outputs_rounded = [ round(elem) for elem in classification_outputs ]
#         print(classification_outputs_rounded)
        print("Accuracy Score:", accuracy_score(classification_labels, classification_outputs_rounded))
        calc_time = float(total_time) / (i + 1)
        print("Total Prediction Time:", total_time)
        print('Average Prediction Time: {min}m {sec}s'.format(min=calc_time // 60.0, sec=calc_time % 60.0))
        print("Total Entries Compared: ", i + 1)
#         print(outputs)
        
        return (total_time, regression_labels, regression_outputs, \
                classification_labels, classification_outputs_rounded)

In [None]:
# eval_optimizer = optim.SGD(eval_tfn_model.parameters(), lr=0.00007, momentum=0.70)
# eval_trainer = Trainer(datasets, eval_tfn_model, eval_optimizer)

concat_time, concat_regression_labels, concat_regression_outputs, \
concat_classification_labels, concat_classification_outputs = test_model(eval_tfn_model, Total_ID_Dataloader)

In [None]:
concat_classification_labels = [int(x.item()) for x in concat_classification_labels]

In [None]:
concat_classification_labels

In [None]:
# I think there may be not enough mortalities in the dataset, 
# so the model swings too far and marks every patient as a survival.
# I'll have to look at the classification again because it seems it might not be learning the patterns correctly.

In [None]:
def plot_original():
    label_names = ["Censored survival", "Mortality"]
    plot_confusion_matrix(confusion_matrix(concat_classification_labels, \
                                           concat_classification_outputs), label_names)

    #plt.savefig('original_cm.svg',format='svg',dpi=1200,bbox_inches='tight')
    plt.savefig('classification_cm.png',format='png',dpi=1200,bbox_inches='tight')
    
plot_original()

In [None]:
fig = plt.figure()

# print(original_trainer.get_belief())

fpr, tpr, thresholds = sklearn.metrics.roc_curve(concat_classification_labels, \
                                                 concat_classification_outputs)
roc_auc = sklearn.metrics.auc(fpr, tpr)
plt.plot(fpr, tpr, color='blue', label='Original Model (AUC = {:.2f})'.format(roc_auc))

# fpr, tpr, thresholds = sklearn.metrics.roc_curve(decomposed_labels, decomposed_outputs[1::2])
# roc_auc = sklearn.metrics.auc(fpr, tpr)
# plt.plot(fpr, tpr, color='green', label='Decomposed Model (AUC = {:.2f})'.format(roc_auc))


plt.legend()
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves')
plt.show()