In [1]:
# MLP Model for Node-Level classification - for left and right temporal lobe nodes
import os
import numpy as np
import pandas as pd
import random
from tqdm import tqdm
from scipy.io import loadmat
from sklearn.neural_network import MLPClassifier as MLP
from sklearn.ensemble import RandomForestClassifier as RF
from sklearn.svm import SVC as SVM

from sklearn.metrics import confusion_matrix, balanced_accuracy_score
from matplotlib import pyplot as plt

import xgboost
from xgboost import XGBClassifier as xgb

import torch
import torch.nn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torch.utils.data import DataLoader, random_split

In [2]:
def calculate_metrics(y_pred, y_true):

    # conf_mat = confusion_matrix(y_true, y_pred)

    # print('Confusion Matrix : \n', conf_mat)

    # total1=sum(sum(conf_mat))
    
    # accuracy1=(conf_mat[0,0]+conf_mat[1,1])/total1
    # # print('Accuracy : ', accuracy1)

    # sensitivity1 = conf_mat[0,0]/(conf_mat[0,0]+conf_mat[0,1])
    # # print('Sensitivity : ', sensitivity1 )

    # specificity1 = conf_mat[1,1]/(conf_mat[1,0]+conf_mat[1,1])
    # # print('Specificity : ', specificity1)

    bal_acc = balanced_accuracy_score(y_true, y_pred)
    # print(f'Balanced Accuracy: {bal_acc:.4f}')

    return bal_acc

In [3]:
# Root Folder
root='/home/neil/Lab_work/Jeong_Lab_Multi_Modal_MRI/Right_Temporal_Lobe/'

In [4]:
def load_train_data(root: str, node_num: str):

    train_path = os.path.join(root, 'Node_'+node_num, 'Aug_Train_Data', 'ALL_Patients')  
    x_file = f"X_train_aug"
    y_file = f"Y_train_aug"
    x_mat_name = "X_aug_train"
    y_mat_name = "Y_aug_train"  

    raw_path_x = os.path.join(train_path, f"{x_file}.mat")
    raw_path_y = os.path.join(train_path, f"{y_file}.mat")

    # Load the data from .mat files
    X_mat_l = loadmat(raw_path_x)
    X_mat = X_mat_l[x_mat_name]

    Y_mat_l = loadmat(raw_path_y)
    Y_mat = Y_mat_l[y_mat_name]
    Y_mat = Y_mat.reshape(Y_mat.shape[1],)

    X_multi_modal = X_mat
    Y_label = Y_mat

    # Load the 1D vectors (images) and binary labels
    X_multi_modal: torch.Tensor = torch.from_numpy(X_multi_modal) 
    Y_label: torch.Tensor = torch.from_numpy(Y_label) # for CrossEntropyLoss
    Y_label = Y_label.squeeze(dim=0).long()
    X_multi_modal = X_multi_modal
    X_multi_modal = X_multi_modal.squeeze(dim=0)

    return X_multi_modal, Y_label

In [5]:
def load_test_data(root: str, node_num: str):

    val_path = os.path.join(root, 'Node_'+node_num, 'Orig_Val_Data', 'ALL_Patients')  
    x_file = f"X_valid_orig"
    y_file = f"Y_valid_orig"
    x_mat_name = "X_orig_valid"
    y_mat_name = "Y_orig_valid"  

    raw_path_x = os.path.join(val_path, f"{x_file}.mat")
    raw_path_y = os.path.join(val_path, f"{y_file}.mat")

    # Load the data from .mat files
    X_mat_l = loadmat(raw_path_x)
    X_mat = X_mat_l[x_mat_name]

    Y_mat_l = loadmat(raw_path_y)
    Y_mat = Y_mat_l[y_mat_name]
    Y_mat = Y_mat.reshape(Y_mat.shape[1],)

    X_multi_modal = X_mat
    Y_label = Y_mat

    # Load the 1D vectors (images) and binary labels
    X_multi_modal: torch.Tensor = torch.from_numpy(X_multi_modal) 
    Y_label: torch.Tensor = torch.from_numpy(Y_label) # for CrossEntropyLoss
    Y_label = Y_label.squeeze(dim=0).long()
    X_multi_modal = X_multi_modal
    X_multi_modal = X_multi_modal.squeeze(dim=0)

    return X_multi_modal, Y_label

In [6]:
# Training and testing/validation dataloaders
def load_train_loader(X_train, Y_train):
   train_data = []
   for i in range(len(X_train)):
      train_data.append([X_train[i], Y_train[i]])

   train_loader = DataLoader(train_data, shuffle=True, drop_last=False, batch_size=4)  # type:ignore
   return train_loader

def load_test_loader(X_test, Y_test):
   test_data = []
   for i in range(len(X_test)):
      test_data.append([X_test[i], Y_test[i]])
   test_loader = DataLoader(test_data, shuffle=False, drop_last=False, batch_size=4)  # type:ignore
   return test_loader

In [7]:
# Define the model
class MLP_ms(nn.Module):
  def __init__(self):
    super().__init__()
    self.layers = nn.Sequential(
      nn.Linear(1899, 512),
      nn.Dropout(p=0.2),
      nn.ReLU(),
      nn.Linear(512, 512),
      nn.Dropout(p=0.2),
      nn.ReLU()
    )

    self.layers_combined = nn.Sequential(
      nn.Linear(512, 2)  # for CrossEntropyLoss
    )

  def forward(self, x):
    x_out = self.layers(x)

    x_final = self.layers_combined(x_out)

    return x_final

In [8]:
def train_one_epoch(model, train_loader, optimizer, loss_fn, device=None):    
    # Enumerate over the data
    all_preds = []
    all_labels = []
    running_loss = 0.0
    step = 0

    for _, (x,y) in enumerate(tqdm(train_loader)):
        
        # Use GPU
        x.to(device) 

        # Reset gradients
        optimizer.zero_grad()

        # Passing the node features and the connection info        
        pred = model(x)   
            
        # Calculating the loss and gradients
        loss = loss_fn(pred, y)

        loss.backward() 

        optimizer.step() 

        # Update tracking
        running_loss += loss.item()

        step += 1

        all_preds.append((pred.argmax(dim=-1)).cpu().detach().numpy()) # for CrossEntropyLoss
        all_labels.append(y.cpu().detach().numpy())

    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()

    return running_loss/step

In [13]:
def test(model, test_loader, loss_fn, device=None):
    all_preds = []
    preds_for_auc = []
    all_labels = []
    running_loss = 0.0
    step = 0

    for (x,y) in test_loader:

        x.to(device) 

        pred = model(x) 

        loss = loss_fn(pred, y)

        # Update tracking
        running_loss += loss.item()
        step += 1

        all_preds.append((pred.argmax(dim=-1)).cpu().detach().numpy())  # for CrossEntropyLoss
        all_labels.append(y.cpu().detach().numpy())
    
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()

    bal_acc = calculate_metrics(all_preds, all_labels)

    return running_loss/step, bal_acc, all_preds, all_labels

In [14]:
def train_all_epochs_per_node(model, train_loader, test_loader, epochs=100, patience=10, 
save_dir='', exp_id='', loss_fn=None, optimizer=None, device=None):
    # Start training
    best_loss = float('inf')
    best_bal_acc = 0.0
    early_stopping_counter = 0
    best_epoch = 1
    all_best_epoch_nums = []

    for epoch in range(1, epochs+1): 
        if early_stopping_counter <= patience:

            # Training
            model.train()
            loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device=device)
            print(f"Epoch {epoch} | Train Loss {loss:.4f}")

            # Validation
            model.eval()

            loss, bal_acc, _, _ = test(model, test_loader, loss_fn, device=device)
            print(f"Epoch {epoch} | Test Loss {loss:.4f} | Bal. Acc. {bal_acc:.4f}")

            # Update best (highest) Balanced Accuracy
            if float(bal_acc) > best_bal_acc:

                print(f'Validation Balanced Accuracy increases from {best_bal_acc:.4f} to {bal_acc:.4f}')

                best_bal_acc = bal_acc
                
                print('Saving best model with highest Balanced Accuracy ...')
                best_model_wts = model.state_dict()
                torch.save(best_model_wts, save_dir + "exp_node_" + str(exp_id) + "_best_model.pth")
                best_epoch = epoch
                all_best_epoch_nums.append(epoch)

                early_stopping_counter = 0
            else:
                early_stopping_counter += 1

        else:
            print("Early stopping due to no improvement.")
        
            print(f"Finishing training with highest Balanced Accuracy: {best_bal_acc}")
            break


    # Final Test with trained model
    print('\nTesting with the best model with highest Balanced Accuracy on test set ...\n')
    model.load_state_dict(torch.load(save_dir + "exp_node_" + str(exp_id) + "_best_model.pth"))

    loss, bal_acc, y_pred, y_true = test(model, test_loader, loss_fn)

    print("Printing Test Set Evaluation Metrics ....\n")
    print(f"Bal. Acc. {bal_acc:.4f}")

    print(f"\nBest scores occur at {best_epoch} epoch number")

    return bal_acc, y_pred, y_true

In [15]:
def get_list_of_node_nums():
    node_numbers_with_smote = [
        "948"
    ]

    return node_numbers_with_smote

In [16]:
save_dir = "./checkpoints/"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [17]:
# Main loop to train and test the model over all the nodes
total_nodes = 1

# choose model
model_name = 'MLP_ms'

# Define the model
if model_name == 'MLP_ms':
    # Initialize the model
    model = MLP_ms()
    model.double()
else:
    raise NotImplementedError("Unknown Model.")

# Train and test the model with these settings
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')  # type: ignore

epochs = 30
patience = 30 # needed for early stopping

# Define training loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

count = 0

node_num_list = [] # stores the node numbers for all nodes
bal_acc_list = [] # stores the balanced accuracies for all nodes
sen_list = [] # stores the sensitivity for all nodes
spec_list = [] # stores the specificity for all nodes

# list of all 94 nodes
node_numbers_with_smote = get_list_of_node_nums()

for node_num in node_numbers_with_smote:    

    exp_id = node_num # Experiment ID

    print(f'Node num as string: {node_num}')

    # load the data for the given node
    X_train, Y_train = load_train_data(root, node_num)
    train_loader = load_train_loader(X_train, Y_train)

    X_test, Y_test = load_test_data(root, node_num)
    test_loader = load_test_loader(X_test, Y_test)

    # Train the model
    print(f'Training and Evaluating on Node number: {node_num}')           

    # Evaluate Trained Model with evaluation metrics
    bal_acc, y_pred, y_true = train_all_epochs_per_node(model, train_loader, test_loader, epochs=epochs, 
    patience=patience, save_dir=save_dir, exp_id=exp_id, loss_fn=loss_fn, optimizer=optimizer, device=device)

    print(f"y-prediction: {y_pred}")
    print(f"y-true: {y_true}")
    raise ValueError("Stop here")

    # for saving balanced accuracy and confusion matrix of nodes with SMOTE in a csv file
    node_num_list.append(node_num)
    bal_acc_list.append(bal_acc)
    

Node num as string: 948
Training and Evaluating on Node number: 948


100%|██████████| 30/30 [00:01<00:00, 23.66it/s]


Epoch 1 | Train Loss 3.2640
Epoch 1 | Test Loss 0.5920 | Bal. Acc. 0.5000
Validation Balanced Accuracy increases from 0.0000 to 0.5000
Saving best model with highest Balanced Accuracy ...


100%|██████████| 30/30 [00:00<00:00, 140.97it/s]


Epoch 2 | Train Loss 0.8427
Epoch 2 | Test Loss 0.6509 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 139.20it/s]


Epoch 3 | Train Loss 0.7468
Epoch 3 | Test Loss 0.7210 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 135.24it/s]


Epoch 4 | Train Loss 0.7416
Epoch 4 | Test Loss 0.5620 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 134.04it/s]


Epoch 5 | Train Loss 0.7529
Epoch 5 | Test Loss 0.6886 | Bal. Acc. 0.5952
Validation Balanced Accuracy increases from 0.5000 to 0.5952
Saving best model with highest Balanced Accuracy ...


100%|██████████| 30/30 [00:00<00:00, 130.48it/s]


Epoch 6 | Train Loss 0.7366
Epoch 6 | Test Loss 0.7059 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 138.02it/s]


Epoch 7 | Train Loss 0.7151
Epoch 7 | Test Loss 0.8566 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 138.77it/s]


Epoch 8 | Train Loss 0.8345
Epoch 8 | Test Loss 0.6046 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 130.13it/s]


Epoch 9 | Train Loss 0.7393
Epoch 9 | Test Loss 0.6719 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 137.60it/s]


Epoch 10 | Train Loss 0.7277
Epoch 10 | Test Loss 0.6781 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 141.22it/s]


Epoch 11 | Train Loss 0.7286
Epoch 11 | Test Loss 0.6585 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 142.24it/s]


Epoch 12 | Train Loss 0.7080
Epoch 12 | Test Loss 0.6769 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 120.64it/s]


Epoch 13 | Train Loss 0.7542
Epoch 13 | Test Loss 0.7024 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 120.92it/s]


Epoch 14 | Train Loss 0.7323
Epoch 14 | Test Loss 0.6818 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 131.63it/s]


Epoch 15 | Train Loss 0.7074
Epoch 15 | Test Loss 2.0010 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 143.43it/s]


Epoch 16 | Train Loss 0.7454
Epoch 16 | Test Loss 0.6524 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 105.79it/s]


Epoch 17 | Train Loss 0.7310
Epoch 17 | Test Loss 0.6852 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 105.64it/s]


Epoch 18 | Train Loss 0.6988
Epoch 18 | Test Loss 0.6789 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 138.29it/s]


Epoch 19 | Train Loss 0.6975
Epoch 19 | Test Loss 0.6961 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 136.53it/s]


Epoch 20 | Train Loss 0.6960
Epoch 20 | Test Loss 0.6876 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 143.15it/s]


Epoch 21 | Train Loss 0.7033
Epoch 21 | Test Loss 0.6829 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 139.27it/s]


Epoch 22 | Train Loss 0.7185
Epoch 22 | Test Loss 0.6262 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 125.36it/s]


Epoch 23 | Train Loss 0.8177
Epoch 23 | Test Loss 0.6857 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 128.15it/s]


Epoch 24 | Train Loss 0.7321
Epoch 24 | Test Loss 0.6755 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 149.16it/s]


Epoch 25 | Train Loss 0.6955
Epoch 25 | Test Loss 0.6778 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 138.95it/s]


Epoch 26 | Train Loss 0.7791
Epoch 26 | Test Loss 0.7006 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 134.98it/s]


Epoch 27 | Train Loss 0.6982
Epoch 27 | Test Loss 0.6702 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 129.89it/s]


Epoch 28 | Train Loss 0.7253
Epoch 28 | Test Loss 0.6885 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 128.54it/s]


Epoch 29 | Train Loss 0.6993
Epoch 29 | Test Loss 0.6968 | Bal. Acc. 0.5000


100%|██████████| 30/30 [00:00<00:00, 140.85it/s]


Epoch 30 | Train Loss 0.7167
Epoch 30 | Test Loss 0.6986 | Bal. Acc. 0.5000

Testing with the best model with highest Balanced Accuracy on test set ...

Printing Test Set Evaluation Metrics ....

Bal. Acc. 0.5952

Best scores occur at 5 epoch number
y-prediction: [0 0 0 1 1 0 0 0 0 0]
y-true: [0 1 1 0 1 0 0 0 0 0]


ValueError: Stop here