In [1]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from hmpai.training import split_participants, split_participants_custom
from hmpai.pytorch.training import train_and_test
from hmpai.pytorch.utilities import DEVICE, set_global_seed, load_model
from hmpai.pytorch.generators import MultiXArrayProbaDataset
from hmpai.data import SAT_CLASSES_ACCURACY
from hmpai.pytorch.normalization import *
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from hmpai.pytorch.transforms import *
from hmpai.pytorch.mamba import *
from hmpai.behaviour.sat2 import SAT2_SPLITS
from hmpai.visualization import predict_with_auc, set_seaborn_style, plot_peak_timing
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import mne
from pymer4.models import Lmer
DATA_PATH = Path(os.getenv("DATA_PATH"))

labels_t1 = ["negative", "t1_1", "t1_2", "t1_3"]
labels_t2 = ["negative", "t2_1", "t2_2", "t2_3"]


In [None]:
def save_peak_timing(model, loader, labels, path: Path, cue_var="condition"):
    if not path.exists():
        output = []
        torch.cuda.empty_cache()

        with torch.no_grad():
            for batch in loader:
                info = batch[2][0]  # Contains RT

                pred = model(batch[0].to(DEVICE))
                pred = torch.nn.Softmax(dim=2)(pred).to("cpu")

                true = batch[1]

                lengths = get_masking_indices(batch[0])

                pred_peaks = pred[..., 1:].argmax(dim=1).float()
                true_peaks = true[..., 1:].argmax(dim=1).float()

                # Get activations at the peak positions
                batch_indices = torch.arange(batch[0].shape[0]).unsqueeze(1).expand(-1, pred_peaks.shape[-1])
                # [batch_size, classes, channels]
                peak_values = batch[0][batch_indices, pred_peaks.int(), :]

                # pred_peaks /= lengths.unsqueeze(1)
                # true_peaks /= lengths.unsqueeze(1)
                data = {"condition": info[cue_var], "epoch": info["trial_index"], "participant": info["participant"], "interval": info["interval"] if "interval" in info else None, "rt_samples": lengths}
                for i, label in enumerate(labels):
                    if i == 0:
                        continue
                    label_pred_peaks = pred_peaks[:, i - 1]
                    label_true_peaks = true_peaks[:, i - 1]
                    data[f"{label}_pred"] = label_pred_peaks
                    data[f"{label}_true"] = label_true_peaks
                    data[f"{label}_peak_values"] = peak_values[:, i - 1].tolist()
                output.append(data)
        df = pd.concat([pd.DataFrame(data) for data in output])
        df.to_csv(path, index=False)
    else:
        df = pd.read_csv(path)

### Analysis task 1

In [13]:
set_global_seed(42)

data_paths = [DATA_PATH / "prp/stage_data_250hz_t1.nc"]

# train_percentage=100 makes test and val 100 as well
# splits = split_participants(data_paths, train_percentage=60)
splits = split_participants_custom(data_paths, 0)
# info_to_keep = ['event_name', 'participant', 'epochs', 'rt'] #TODO: Might not be present
info_to_keep = ['rt', 'participant', 'epochs', 'condition', 'trial_index', 'interval']
whole_epoch = True
subset_cond = None
# subset_cond = ('condition', 'equal', 'long')
add_negative = True
skip_samples = 62 # 62
cut_samples = 63 # 63
add_pe = True

In [14]:
norm_fn = norm_mad_zscore
test_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[0],
    normalization_fn=norm_fn,
    labels=labels_t1,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    add_negative=add_negative,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_pe=add_pe,
)

In [15]:
test_loader = DataLoader(
    test_data, batch_size=128, shuffle=True, num_workers=0, pin_memory=True
)

In [16]:
# info_path = DATA_PATH / "sat2/preprocessed_500hz/preprocessed_S1_raw.fif"
info_path = DATA_PATH / "prp/epoched/VP1-t1-epo.fif"

epoch = mne.read_epochs(info_path)
epoch.set_montage("biosemi64")
positions = epoch.info

Reading /workspace/data_local/prp/epoched/VP1-t1-epo.fif ...


    Found the data of interest:
        t =    -250.00 ...    2000.00 ms
        0 CTF compensation matrices available
Adding metadata with 4 columns
771 matching events found
No baseline correction applied
0 projection items activated


In [17]:
chk_path = Path("../models/t1_pe.pt")
checkpoint = load_model(chk_path)
config = {
    "n_channels": 64,
    "n_classes": len(labels_t1),
    "n_mamba_layers": 5,
    "use_pointconv_fe": True,
    "spatial_feature_dim": 128,
    "use_conv": True,
    "conv_kernel_sizes": [3, 9],
    "conv_in_channels": [128, 128],
    "conv_out_channels": [256, 256],
    "conv_concat": True,
    "use_pos_enc": add_pe,
}

model = build_mamba(config)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)
model.eval();

In [19]:
save_peak_timing(model, test_loader, labels_t1, Path("files/t1_peaks.csv"), cue_var="condition")

### Analysis Task 2

In [27]:
set_global_seed(42)

data_paths = [DATA_PATH / "prp/stage_data_250hz_t2.nc"]

# train_percentage=100 makes test and val 100 as well
# splits = split_participants(data_paths, train_percentage=60)
splits = split_participants_custom(data_paths, 0)
# info_to_keep = ['event_name', 'participant', 'epochs', 'rt'] #TODO: Might not be present
info_to_keep = ['rt', 'participant', 'epochs', 'condition', 'trial_index']
whole_epoch = True
subset_cond = None
# subset_cond = ('condition', 'equal', 'long')
add_negative = True
skip_samples = 62 # 62
cut_samples = 63 # 63
add_pe = True

In [28]:
norm_fn = norm_mad_zscore
test_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[0],
    normalization_fn=norm_fn,
    labels=labels_t2,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    add_negative=add_negative,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_pe=add_pe,
)

In [29]:
test_loader = DataLoader(
    test_data, batch_size=128, shuffle=True, num_workers=0, pin_memory=True
)

In [30]:
chk_path = Path("../models/t2_pe.pt")
checkpoint = load_model(chk_path)
config = {
    "n_channels": 64,
    "n_classes": len(labels_t2),
    "n_mamba_layers": 5,
    "use_pointconv_fe": True,
    "spatial_feature_dim": 128,
    "use_conv": True,
    "conv_kernel_sizes": [3, 9],
    "conv_in_channels": [128, 128],
    "conv_out_channels": [256, 256],
    "conv_concat": True,
    "use_pos_enc": add_pe,
}

model = build_mamba(config)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)
model.eval();

In [31]:
# info_path = DATA_PATH / "sat2/preprocessed_500hz/preprocessed_S1_raw.fif"
info_path = DATA_PATH / "prp/epoched/VP1-t1-epo.fif"

epoch = mne.read_epochs(info_path)
epoch.set_montage("biosemi64")
positions = epoch.info

Reading /workspace/data_local/prp/epoched/VP1-t1-epo.fif ...
    Found the data of interest:
        t =    -250.00 ...    2000.00 ms
        0 CTF compensation matrices available
Adding metadata with 4 columns
771 matching events found
No baseline correction applied
0 projection items activated


In [36]:
save_peak_timing(model, test_loader, labels_t2, Path("files/t2_peaks.csv"), cue_var="condition")