In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from datetime import datetime
import time
import warnings
warnings.filterwarnings('ignore')
import optuna

import sklearn
from sklearn import metrics
from sklearn.metrics import confusion_matrix, f1_score
from sklearn.utils import resample

import random, os, json
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

from joblib import Parallel, delayed
import multiprocessing

import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from einops import rearrange, repeat
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

In [None]:
### RESET Pytorch ###
def reset_pytorch(seed=42):
    """Function to ensure that results from PyTorch models
    are consistent and reproducible across different runs (CPU only)"""
    
    # 1. Set PYTHONHASHSEED environment variable at a fixed value
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    # 2. Set python built-in pseudo-random generator at a fixed value
    random.seed(seed)
    
    # 3. Set numpy pseudo-random generator at a fixed value
    np.random.seed(seed)
    
    # 4. Set torch pseudo-random generator at a fixed value
    torch.manual_seed(seed)
    
    # 5. Ensure deterministic operations in PyTorch 
    torch.backends.cudnn.deterministic = True  # Ensures deterministic results
    torch.backends.cudnn.benchmark = False  # Disables cudnn optimizations for reproducibility
    torch.use_deterministic_algorithms(True)

In [None]:
def mamba_chunk_scan_combined(x, dt_expanded, A_expanded, B_expanded, C_expanded, z_expanded, initial_states, n_heads):
    x_expanded = x.unsqueeze(1).expand(-1, n_heads, x.size(-2), -1)
    
    initial_states_expanded = initial_states if initial_states.shape[0] == x.shape[0] else repeat(initial_states, 'b h d -> b h d', b=x.shape[0])
    
    x_sum = x_expanded.sum(dim=2, keepdim=True)
    initial_states_expanded = initial_states_expanded.unsqueeze(2)
    evolved_states = initial_states_expanded + (x_sum * A_expanded) * dt_expanded
    
    evolved_states = evolved_states + B_expanded + C_expanded * z_expanded
    return evolved_states


def causal_conv1d(x, weight, bias=None):
    padding = weight.size(2) - 1
    x = F.pad(x, (padding, 0), mode='constant', value=0)
    return F.conv1d(x, weight, bias=bias)

def silu(x):
    return x * torch.sigmoid(x)

class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super(RMSNorm, self).__init__()
        self.d_model = d_model
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        norm = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        return self.weight * x / norm 
        

class Mamba2Simple(nn.Module):
    def __init__(self, input_dim, d_model=128, d_state=128, d_inner=128, n_heads=14, bias=True):
        super(Mamba2Simple, self).__init__()
        self.fc1 = nn.Linear(input_dim, d_model)  # First fully connected layer
        self.conv1d = nn.Conv1d(d_model, d_inner, kernel_size=3, padding=1)  # 1D Convolution layer
        self.norm = RMSNorm(d_inner, eps=1e-5)  # RMS normalization
        self.fc2 = nn.Linear(d_inner, 1)  # Final fully connected layer for binary classification

        self.dt_bias = nn.Parameter(torch.randn(n_heads))  
        self.A_log = nn.Parameter(torch.randn(n_heads))  
        self.D = nn.Parameter(torch.ones(n_heads))
        self.n_heads = n_heads  
        self.d_state = d_state  
        self.d_model = d_model  
        self.initial_states = nn.Parameter(torch.zeros(self.n_heads, self.d_state))  
        
        # Input projection
        self.in_proj = nn.Linear(self.d_model, 2 * self.d_model + 2 * self.n_heads * self.d_state + self.n_heads, bias=bias)
        self.B_proj = nn.Linear(self.d_model, self.n_heads * self.d_state)  # Adjusted projection to match n_heads * d_state
        self.C_proj = nn.Linear(self.d_model, self.n_heads * self.d_state)
        
    def forward(self, x):
        batch_size = x.size(0)
        batch, seqlen, dim = x.shape

        # === Step 1: Create mask for non-padding tokens ===
        mask = (x != 666).all(dim=-1)  # shape: (batch, seqlen)
        x = x * mask.unsqueeze(-1)  # zero out padded inputs

        # === Step 2: First FC layer with activation ===
        x = self.fc1(x)
        x = silu(x)

        # === Step 3: In-projection ===
        zxbcdt = self.in_proj(x)
        z, x_proj, B, C, dt = torch.split(
            zxbcdt,
            [self.d_model, self.d_model, self.n_heads * self.d_state, self.n_heads * self.d_state, self.n_heads],
            dim=-1
        )

        # === Step 4: Process dt ===
        dt = dt.mean(dim=-1)
        dt = dt.mean(dim=0)
        dt = F.softplus(dt + self.dt_bias)
        dt_expanded = dt.reshape(1, self.n_heads, 1)

        # === Step 5: Convolution ===
        x_proj = x_proj.transpose(1, 2)
        x_proj = causal_conv1d(x_proj, self.conv1d.weight, self.conv1d.bias)

        # Mask out invalid positions before max pooling
        mask_conv = mask.unsqueeze(1).expand_as(x_proj)  # (B, d_inner, L)
        x_proj[~mask_conv] = float('-inf')  # prevent max from seeing padded values
        x_conv = x_proj.max(dim=2).values  # (B, d_inner)

        # === Step 6: Mamba ops ===
        A = torch.exp(self.A_log)
        A_expanded = A.view(1, self.n_heads, 1)

        B_proj = self.B_proj(x_conv).view(batch, self.n_heads, self.d_state)
        C_proj = self.C_proj(x_conv).view(batch, self.n_heads, self.d_state)
        z_proj = z.view(batch, self.n_heads, self.d_model)

        B_expanded = B_proj.unsqueeze(2).expand(batch, self.n_heads, self.n_heads, self.d_state)
        C_expanded = C_proj.unsqueeze(2).expand(batch, self.n_heads, self.n_heads, self.d_state)
        z_expanded = z_proj.unsqueeze(2).expand(batch, self.n_heads, self.n_heads, self.d_model)

        states_expanded = repeat(self.initial_states, 'h d -> b h d', b=batch)

        y = mamba_chunk_scan_combined(
            rearrange(x_conv, "b d -> b d 1"),
            dt_expanded=dt_expanded,
            A_expanded=A_expanded,
            B_expanded=B_expanded,
            C_expanded=C_expanded,
            z_expanded=z_expanded,
            initial_states=states_expanded,
            n_heads=self.n_heads
        )

        # === Step 7: Combine and normalize ===
        x_expanded = x_conv.unsqueeze(1).unsqueeze(2)
        x_expanded_for_sum = x_expanded.expand(-1, self.n_heads, self.n_heads, -1)
        y = y.squeeze(-1)
        x_combined = x_expanded_for_sum + y

        y_norm = self.norm(x_combined)
        y_pooled = y_norm.max(dim=1).values
        y_fc = self.fc2(y_pooled).max(dim=1).values

        out = torch.sigmoid(y_fc.squeeze(-1))
        return out.view(-1, 1)


In [None]:
class EarlyStopping:
    def __init__(self, patience, mindelta, restore_best_weights=True, mode="min"):
        self.patience = patience
        self.mindelta = mindelta  
        self.restore_best_weights_flag = restore_best_weights
        self.mode = mode
        self.counter = 0
        self.best_loss = np.inf if mode == "min" else -np.inf
        self.best_model_state = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        if self.mode == "min":
            if val_loss < self.best_loss - self.mindelta:
                self.best_loss = val_loss
                self.counter = 0
                self.best_model_state = model.state_dict()
            else:
                self.counter += 1
        elif self.mode == "max":
            if val_loss > self.best_loss + self.mindelta:
                self.best_loss = val_loss
                self.counter = 0
                self.best_model_state = model.state_dict()
            else:
                self.counter += 1

        if self.counter >= self.patience:
            self.early_stop = True

    def restore_best_weights(self, model):
        if self.best_model_state is not None and self.restore_best_weights_flag:
            model.load_state_dict(self.best_model_state)

def run_network(X_train, X_val, y_train, y_val, hyperparameters, seed):
    input_dim = X_train.shape[2]
    model = Mamba2Simple(
        input_dim=input_dim,
        d_model=hyperparameters['d_model'],
        d_state=hyperparameters['d_state'],
        d_inner=hyperparameters['d_inner'],
        n_heads=hyperparameters['n_heads'],
        bias=True
    )

    loss_fn = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=hyperparameters['lr_scheduler'])

    # Convert data to tensors
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train.values.reshape(-1, 1), dtype=torch.float32)

    X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
    y_val_tensor = torch.tensor(y_val.values.reshape(-1, 1), dtype=torch.float32)

    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)

    train_loader = DataLoader(train_dataset, batch_size=hyperparameters['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=hyperparameters['batch_size'], shuffle=False)

    n_epochs_max = hyperparameters['n_epochs_max']
    patience = hyperparameters['patience']
    verbose = hyperparameters['verbose']
    mindelta = hyperparameters['mindelta']

    earlystopping = EarlyStopping(patience=patience, mindelta=mindelta, restore_best_weights=True, mode="min")

    train_loss_history = []
    val_loss_history = []

    for epoch in range(n_epochs_max):
        model.train()
        epoch_train_loss = 0.0

        for batch_X, batch_y in train_loader:
            optimizer.zero_grad()
            y_pred = model(batch_X)

            # Ensure y_pred shape is (batch_size, 1)
            y_pred = y_pred.view(-1, 1)
            assert y_pred.shape == batch_y.shape, f"Shape mismatch: {y_pred.shape} vs {batch_y.shape}"

            loss = loss_fn(y_pred, batch_y)
            loss.backward()
            optimizer.step()
            epoch_train_loss += loss.item()

        epoch_train_loss /= len(train_loader)
        train_loss_history.append(epoch_train_loss)

        model.eval()
        epoch_val_loss = 0.0
        with torch.no_grad():
            for batch_X, batch_y in val_loader:
                y_val_pred = model(batch_X)
                y_val_pred = y_val_pred.view(-1, 1)
                val_loss = loss_fn(y_val_pred, batch_y)
                epoch_val_loss += val_loss.item()

        epoch_val_loss /= len(val_loader)
        val_loss_history.append(epoch_val_loss)

        earlystopping(epoch_val_loss, model)

        if earlystopping.early_stop:
            print(f"Early stopping at epoch {epoch + 1}")
            earlystopping.restore_best_weights(model)
            break

        if verbose:
            print(f"Epoch {epoch + 1}/{n_epochs_max} - Train Loss: {epoch_train_loss:.4f} - Val Loss: {epoch_val_loss:.4f}")

    history = {
        'loss': train_loss_history,
        'val_loss': val_loss_history
    }

    return model, history, earlystopping


In [None]:
def objective(trial, hyperparameters, seed, X_train, y_train, X_val, y_val, split, norm, n_time_steps):
    """
    Objective function for hyperparameter optimization with Optuna.    
    """

    print(f"Trial {trial.number} started")
    hyperparameters_copy = hyperparameters.copy()

    hyperparameters_copy["dropout"] = trial.suggest_float('dropout', 0.0, 0.3)
    hyperparameters_copy["lr_scheduler"] = trial.suggest_loguniform('lr_scheduler', 1e-3, 1e-1)
    hyperparameters_copy['activation'] = trial.suggest_categorical("activation", ['LeakyReLU', 'tanh'])
    hyperparameters_copy['patience'] = trial.suggest_int('patience', 1, 50)
    hyperparameters_copy['mindelta'] = trial.suggest_loguniform('mindelta', 1e-10, 1e-3)
    hyperparameters_copy['weight_decay'] = trial.suggest_loguniform('weight_decay', 1e-5, 0)
    common_value = trial.suggest_categorical('d_common', [32, 64])
    hyperparameters_copy['d_model'] = common_value
    hyperparameters_copy['d_state'] = common_value
    hyperparameters_copy['d_inner'] = common_value
    hyperparameters_copy['batch_size'] = hyperparameters['batch_size']
    hyperparameters_copy['n_epochs_max'] = hyperparameters['n_epochs_max']
    
    v_val_loss = []

    model, hist, earlystopping = run_network(
            X_train, X_val,
            y_train,
            y_val,
            hyperparameters_copy,
            seed
    )

    v_val_loss.append(np.min(hist["val_loss"]))

    metric_dev = np.mean(v_val_loss)
    return metric_dev

def optuna_study(hyperparameters, seed, X_train, y_train, X_val, y_val, split, norm, n_time_steps):
    """
    Find the best hyperparameters.
    """
    
    study = optuna.create_study(direction='minimize') 
    study.optimize(lambda trial: objective(trial, hyperparameters, seed, X_train, y_train , X_val, y_val, split, norm, n_time_steps), n_trials=30)
    
    best_params = study.best_params
    best_metric = study.best_value
    
    best_hyperparameters = {
                'dropout': best_params['dropout'],
                'lr_scheduler': best_params['lr_scheduler'],
                'activation': best_params['activation'],
                'batch_size': hyperparameters['batch_size'],
                'n_epochs_max': hyperparameters['n_epochs_max'],
                'patience': best_params['patience'],
                'mindelta': best_params['mindelta'],
                'd_common': best_params['d_common'],
        
                'weight_decay': best_params['weight_decay']
            }
    print(f"Best Hyperparameters: {best_params}")
    print(f"Best Validation Metric: {best_metric}")
    
    return best_hyperparameters

In [None]:
seeds = [42, 76, 124]
input_dim = 80  
n_time_steps = 14
batch_size = 32
n_epochs_max = 1000

norm = "robustNorm" 

# Mamba 
d_model = 64 
d_state = 64  
d_inner = 64  
n_heads = 14  
patience = 15  # Epochs for Early Stopping
monitor = "val_loss"  

hyperparameters = {
    "n_time_steps": n_time_steps,
    "batch_size": batch_size,
    "n_epochs_max": n_epochs_max,
    "patience": patience,
    "monitor": monitor,
    "mindelta": 0,  # Minimum delta for Early Stopping
    "dropout": 0.0, 
    "verbose": 0, 
    "d_model": d_model,
    "d_state": d_state,
    "d_inner": d_inner,
    "n_heads": n_heads
}

In [None]:
run_model = True
debug = True
tab='\t'
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    roc_auc_score,
    average_precision_score,
)

if run_model:
    import time

    loss_train = []
    loss_dev = []
    v_models = []
    training_times = []
    optimization_times = []
    inference_times = []

    v_accuracy_test = []
    v_specificity = []
    v_precision = []
    v_recall = []
    v_f1score = []
    v_roc = []
    v_aucpr = []

    bestHyperparameters_bySplit = {}
    y_pred_by_split = {}
    results = ""

    for i in [1, 2, 3]:
        print("====================>", i)


        X_train = np.load(f"../../DATA/s{i}/X_train_tensor_robustNorm.npy")
        X_val = np.load(f"../../DATA/s{i}/X_val_tensor_robustNorm.npy")

        y_train = pd.read_csv(f"../../DATA/s{i}/y_train_robustNorm.csv")[['individualMRGerm_stac']]
        y_train = y_train.iloc[0:y_train.shape[0]:hyperparameters["n_time_steps"]].reset_index(drop=True)

        y_val = pd.read_csv(f"../../DATA/s{i}/y_val_robustNorm.csv")[['individualMRGerm_stac']]
        y_val = y_val.iloc[0:y_val.shape[0]:hyperparameters["n_time_steps"]].reset_index(drop=True)

        X_test = np.load(f"../../DATA/s{i}/X_test_tensor_robustNorm.npy")
        
        y_test = pd.read_csv(f"../../DATA/s{i}/y_test_robustNorm.csv")[['individualMRGerm_stac']]
        y_test = y_test.iloc[0:y_test.shape[0]:hyperparameters["n_time_steps"]].reset_index(drop=True)
        y_test = y_test.values.reshape(-1, 1)

        start_opt = time.time()
        bestHyperparameters = optuna_study(
            hyperparameters,
            seeds[i-1],
            X_train, y_train,  
            X_val, y_val,
            f"s{i}",
            norm,
            n_time_steps
        )
        end_opt = time.time()
        optimization_times.append(end_opt - start_opt)

        bestHyperparameters_bySplit[str(i)] = bestHyperparameters


        split_directory = f'./Results_Mamba_optuna/split_{i}'
        os.makedirs(split_directory, exist_ok=True)
        with open(os.path.join(split_directory, f"bestHyperparameters_split_{i}.pkl"), 'wb') as f:
            pickle.dump(bestHyperparameters, f)

        hyperparameters.update({
            "dropout": bestHyperparameters["dropout"],
            "lr_scheduler": bestHyperparameters["lr_scheduler"],
            "batch_size": bestHyperparameters["batch_size"],
            "n_epochs_max": bestHyperparameters["n_epochs_max"],
            "patience": bestHyperparameters["patience"],
            'd_model': bestHyperparameters['d_common'],
            'd_state': bestHyperparameters['d_common'],
            'd_inner': bestHyperparameters['d_common'],
        })


        reset_pytorch()
        print(hyperparameters)


        X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
        y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

        t0_train = time.time()

        model, history, earlystopping = run_network(
            X_train, X_val,
            y_train, y_val,
            hyperparameters,
            seeds[i - 1]
        )

        t1_train = time.time()
        training_times.append(t1_train - t0_train)

        v_models.append(model)
        loss_train.append(history['loss'])
        loss_dev.append(history['val_loss'])

        # ----- INFERENCE -----
        t0_inf = time.time()

        y_pred = model(X_test_tensor).detach().cpu().numpy()

        t1_inf = time.time()
        inference_times.append(t1_inf - t0_inf)

        y_pred_by_split[str(i)] = y_pred

        with open(os.path.join(split_directory, f"y_pred_split_{i}.pkl"), 'wb') as f:
            pickle.dump(y_pred, f)

        model_filename = os.path.join(split_directory, f"model_split_{i}.h5")
        torch.save(model, model_filename)

        # ----- METRICS -----
        accuracy_test = accuracy_score(y_test, np.round(y_pred))
        tn, fp, fn, tp = confusion_matrix(y_test, np.round(y_pred)).ravel()
        roc = roc_auc_score(y_test, y_pred)
        aucpr = average_precision_score(y_test, y_pred)

        v_accuracy_test.append(accuracy_test)
        v_specificity.append(tn / (tn + fp))
        v_precision.append(tp / (tp + fp))
        v_recall.append(tp / (tp + fn))
        v_f1score.append((2 * v_recall[i - 1] * v_precision[i - 1]) / (v_recall[i - 1] + v_precision[i - 1]))
        v_roc.append(roc)
        v_aucpr.append(aucpr)

        if debug:
            results += tab + f"Split {i} - Timing (s):\n"
            results += tab + f"{tab}Optimization: {optimization_times[-1]:.2f}\n"
            results += tab + f"{tab}Training: {training_times[-1]:.2f}\n"
            results += tab + f"{tab}Inference: {inference_times[-1]:.2f}\n"
            results += tab + f"\tTP: {tp} | FP: {fp} | TN: {tn} | FN: {fn}\n"
            results += tab + f"\tAccuracy: {accuracy_test:.4f} | ROC-AUC: {roc:.4f} | AUC-PR: {aucpr:.4f}\n"

    # SAVE
    directory = './Results_Mamba_optuna'
    os.makedirs(directory, exist_ok=True)
    summary_df = pd.DataFrame({
        "Split": [i for i in range(1, len(v_accuracy_test) + 1)],
        "OptimizationTime": optimization_times,
        "TrainingTime": training_times,
        "InferenceTime": inference_times,
        "Accuracy": v_accuracy_test,
        "Specificity": v_specificity,
        "Precision": v_precision,
        "Recall": v_recall,
        "F1Score": v_f1score,
        "ROC_AUC": v_roc,
        "AUC_PR": v_aucpr
    })

    summary_path = os.path.join(directory, "summary_metrics.csv")
    summary_df.to_csv(summary_path, index=False)

    if debug:
        print("\n--- SUMMARY ---")
        print(summary_df)


In [2]:
directory = './Results_Mamba_optuna'
summary_path = os.path.join(directory, "summary_metrics.csv")
summary_df = pd.read_csv(summary_path)


def calculateKPI(parameter):
    """
    This function calculate the mean and deviation of a set of values of
    a given performance indicator.
    
    Returns: Mean and std (float)
    """
    mean = round(np.mean(parameter)*100, 2)
    deviation = round(np.sqrt(np.sum(np.power(parameter - np.mean(parameter), 2) / len(parameter)))*100, 2)
    return mean, deviation

def format_metric_line(metric_name, mean_value, deviation_value):
    return f"{metric_name}: {mean_value:.2f} +- {deviation_value:.2f}\n"

mean_test, deviation_test = calculateKPI(summary_df["Accuracy"])
mean_specificity, deviation_specificity = calculateKPI(summary_df["Specificity"])
mean_recall, deviation_recall = calculateKPI(summary_df["Recall"])
mean_f1, deviation_f1 = calculateKPI(summary_df["F1Score"])
mean_precision, deviation_precision = calculateKPI(summary_df["Precision"])
mean_roc, deviation_roc = calculateKPI(summary_df["ROC_AUC"])
mean_aucpr, deviation_aucpr = calculateKPI(summary_df["AUC_PR"])  

results = ""
results += format_metric_line("Test Accuracy", mean_test, deviation_test)
results += format_metric_line("Specificity", mean_specificity, deviation_specificity)
results += format_metric_line("Sensitivity", mean_recall, deviation_recall)
results += format_metric_line("Precision", mean_precision, deviation_precision)
results += format_metric_line("F1-score", mean_f1, deviation_f1)
results += format_metric_line("ROC-AUC", mean_roc, deviation_roc)
results += format_metric_line("AUC-PR", mean_aucpr, deviation_aucpr) 

final_results = (
    f"Sensitivity: {mean_recall:.2f} +- {deviation_recall:.2f}\n"
    f"Specificity: {mean_specificity:.2f} +- {deviation_specificity:.2f}\n"
    f"Precision: {mean_precision:.2f} +- {deviation_precision:.2f}\n"
    f"F1-score: {mean_f1:.2f} +- {deviation_f1:.2f}\n"
    f"ROC-AUC: {mean_roc:.2f} +- {deviation_roc:.2f}\n"
    f"AUC-PR: {mean_aucpr:.2f} +- {deviation_aucpr:.2f}\n" 
    f"Test Accuracy: {mean_test:.2f} +- {deviation_test:.2f}\n"
)

print(final_results)

Sensitivity: 64.15 +- 4.08
Specificity: 87.77 +- 0.16
Precision: 48.30 +- 1.69
F1-score: 55.09 +- 2.59
ROC-AUC: 79.40 +- 3.58
AUC-PR: 49.09 +- 3.61
Test Accuracy: 84.19 +- 0.67

