In [12]:
%load_ext autoreload
%autoreload 2
import netCDF4
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants, split_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test, calculate_global_class_weights
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, save_model, load_model
from hmpai.pytorch.generators import SAT1Dataset, MultiXArrayDataset, MultiXArrayProbaDataset
from hmpai.data import SAT1_STAGES_ACCURACY, SAT_CLASSES_ACCURACY
from hmpai.visualization import plot_confusion_matrix
from hmpai.pytorch.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D, AR_SAT1_CHANNELS
from torch.utils.data import DataLoader
# from braindecode.models.eegconformer import EEGConformer
from mne.io import read_info
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))
from hmpai.visualization import plot_predictions_on_epoch, predict_with_auc, show_lmer, set_seaborn_style
import pandas as pd
from hmpai.behaviour.sat2 import read_behavioural_info
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind
from tqdm.notebook import tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


for each combination of dimensions:
	predict on test and validation data
	predict on test and validation data, but shuffle the data of each sample (ensure order is the same)
	for model_pred, model_pred_shuffled in zip(model_preds, model_preds_shuffled)
		corr_model = corr(model_pred, hmp_probas)
		corr_null = corr(model_pred_shuffled, hmp_probas)
		difference = corr_model - corr_null
		save difference

In [2]:
set_global_seed(42)
data_path_1 = DATA_PATH / "sat2/stage_data_proba_250hz_part1.nc"
data_path_2 = DATA_PATH / "sat2/stage_data_proba_250hz_part2.nc"
data_paths = [data_path_1, data_path_2]
# train_percentage=100 makes test and val 100 as well
splits = split_participants(data_paths, train_percentage=60)
labels = SAT_CLASSES_ACCURACY
whole_epoch = True
info_to_keep = ['rt', 'participant', 'epochs']
subset_cond = None # 'speed'|'accuracy'|None
skip_samples = 62

In [3]:
norm_fn = norm_mad_zscore
# Shortcut so they dont have to be recalculated each time
statistics = {
    "global_min": -0.00014557216,
    "global_max": 0.00014740844,
    "global_mean": -2.277374212336032e-18,
    "global_std": 3.3968840765876904e-06,
    "global_median": 3.4879516e-11,
    "mad_score": 3.2237037e-06,
    "class_weights": Tensor([0.0, 0.0, 0.0, 0.0, 0.0]),
}

norm_vars = get_norm_vars_from_global_statistics(statistics, norm_fn)

# train_data = MultiXArrayProbaDataset(
#     data_paths,
#     participants_to_keep=splits[0],
#     normalization_fn=norm_fn,
#     whole_epoch=whole_epoch,
#     labels=labels,
#     info_to_keep=info_to_keep,
#     subset_cond=subset_cond,
#     statistics=statistics,
# )
# class_weights = train_data.statistics["class_weights"]


In [4]:
from scipy.stats import spearmanr

def pearson_corrcoef(t1: torch.Tensor, t2: torch.Tensor):
    t1 = t1.flatten().to(DEVICE)
    t2 = t2.flatten().to(DEVICE)

    mean1 = t1.mean()
    mean2 = t2.mean()

    dev1 = t1 - mean1
    dev2 = t2 - mean2

    covariance = torch.sum(dev1 * dev2)

    std1 = torch.sqrt(torch.sum(dev1 ** 2))
    std2 = torch.sqrt(torch.sum(dev2 ** 2))

    # Case where all values in t1 or t2 are the same (no std defined)
    if std1 == 0 or std2 == 0:
        return torch.nan

    correlation = covariance / (std1 * std2)

    return correlation.item()

def spearman_corrcoef(t1: torch.Tensor, t2: torch.Tensor):
    # Assuming that t2 is hmp
    t1 = t1.flatten().to('cpu')
    t2 = t2.flatten().to('cpu')

    non_zero = t2 != 0

    spearman_corr = spearmanr(t1[non_zero], t2[non_zero])

    return spearman_corr


In [13]:
n_shuffles = 5
window_size = 50

# paths = [f"../models/paper1_m{i}.pt" for i in range(1, 11)]
# model_labels = [
#     "negative_both",
#     "negative_ac",
#     "no_negative_both",
#     "no_negative_ac",
#     "sliding_window_categorical_both",
#     "sliding_window_categorical_ac",
#     "sliding_window_proba_both",
#     "sliding_window_proba_ac",
#     "sliding_window_proba_negative_both",
#     "sliding_window_proba_negative_ac",
# ]
# paths = ["../models/mamba_prestim_jitter.pt", "../models/mamba_prestim_jitter_ac.pt"]
# model_labels = ["prestim-jitter", "prestim-jitter-ac"]
paths = ["../models/mamba_crop_jitter.pt"]
model_labels = ["crop_jitter"]
model_kwargs = {
    "embed_dim": 256,
    "n_channels": 19,
    "n_classes": len(labels),
    "n_layers": 5,
    "global_pool": False,
    "dropout": 0.1,
}
corr_fun = spearman_corrcoef
# TODO: Save all results
# TODO: Check if hmp values are actually 0
for path, label in zip(paths, model_labels):
    if path is None:
        continue

    use_sliding_window = 'crop' in path
    add_negative = 'negative' in path

    testval_data = MultiXArrayProbaDataset(
        data_paths,
        participants_to_keep=splits[1] + splits[2],
        normalization_fn=norm_fn,
        norm_vars=norm_vars,
        whole_epoch=whole_epoch,
        labels=labels,
        info_to_keep=info_to_keep,
        subset_cond=subset_cond,
        add_negative=add_negative,
        skip_samples=skip_samples
    )

    print(f'Now testing model: {label}')
    # Load model
    model = MambaModel(**model_kwargs)
    checkpoint = load_model(path)
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(DEVICE)
    model.pretraining = False
    model.global_pool = False
    model.eval()

    # Create new loader so start is the same for each model
    test_loader = DataLoader(
        testval_data, batch_size=128, shuffle=False, num_workers=0, pin_memory=True
    )
    torch.cuda.empty_cache()
    torch.set_grad_enabled(False)
    corr_diffs = []
    for batch in tqdm(test_loader):
        hmp_probas = batch[1]
        # Predict on data
        data_batch = batch[0]
        lengths = torch.sum(data_batch[:, :, 0] != MASKING_VALUE, dim=1)
        if use_sliding_window:
            all_window_preds = []
            
            for trial_idx in range(data_batch.shape[0]):
                length = lengths[trial_idx].item()
                if length < window_size:
                    valid_data = data_batch[trial_idx, :length, :].unsqueeze(0)
                    window_pred = model(valid_data.to(DEVICE))
                    window_pred = torch.nn.Softmax(dim=2)(window_pred)
                    all_window_preds.append(window_pred)
                else:
                    trial_preds_unfolded = torch.full((length, length, len(labels)), float('nan'))
                    for start in range(0, length - window_size + 1):
                        end = start + window_size
                        window_data = data_batch[trial_idx, start:end, :].unsqueeze(0)
                        window_pred = model(window_data.to(DEVICE))
                        window_pred = torch.nn.Softmax(dim=2)(window_pred)
                        trial_preds_unfolded[start, start:end, :] = window_pred
                    trial_preds_unfolded = torch.nanmean(trial_preds_unfolded, dim=0)
                    trial_preds_unfolded = torch.nn.functional.pad(trial_preds_unfolded, (0, 0, 0, data_batch.shape[1] - trial_preds_unfolded.shape[0]))
                    all_window_preds.append(trial_preds_unfolded)
                    model_pred = torch.stack(all_window_preds, dim=0)
        else:
            model_pred = model(batch[0].to(DEVICE))
            model_pred = torch.nn.Softmax(dim=2)(model_pred)
            pred_shape = model_pred.shape
    
        # Create (batch_size, n_shuffles, seq_len, n_classes) shape for storing shuffled predictions
        shuffled_preds = torch.zeros((pred_shape[0], n_shuffles, pred_shape[1], pred_shape[2]))

        for i in range(n_shuffles):
            # Shuffle data up to rt_idx, predict again (n times?)
            batch_data = batch[0].clone()

            for trial in range(batch_data.shape[0]):
                length = lengths[trial]
                shuffle_section = batch_data[trial, :length]
                shuffle_section = shuffle_section[torch.randperm(length)]
                batch_data[trial, :length] = shuffle_section
            
            shuffled_pred = model(batch_data.to(DEVICE))
            shuffled_pred = torch.nn.Softmax(dim=2)(shuffled_pred)
            shuffled_preds[:, i, ...] = shuffled_pred
        
        # Take mean over n_shuffles dim
        shuffled_preds = torch.mean(shuffled_preds, dim=1)

        model_corrs = []
        shuffled_corrs = []
        for i in range(len(labels)):
            model_corrs.append(corr_fun(model_pred[..., i], hmp_probas[..., i]))
            shuffled_corrs.append(corr_fun(shuffled_preds[..., i], hmp_probas[..., i]))
        corr_diffs.append(torch.Tensor(model_corrs) - torch.Tensor(shuffled_corrs))
    results = torch.stack(corr_diffs).nanmean(dim=0)
    print(f'{label} achieved: {results}')



Now testing model: crop_jitter


  0%|          | 0/141 [00:00<?, ?it/s]

RuntimeError: The expanded size of the tensor (128) must match the existing size (10) at non-singleton dimension 0.  Target sizes: [128, 572, 5].  Tensor sizes: [10, 572, 5]

In [9]:
shuffled_corrs

[0.5652056932449341,
 0.8104925155639648,
 0.7749770283699036,
 0.5627233982086182,
 0.19457444548606873]