In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.pytorch.utilities import DEVICE, set_global_seed, load_model
from hmpai.pytorch.generators import MultiXArrayProbaDataset
from hmpai.data import SAT_CLASSES_ACCURACY
from hmpai.pytorch.normalization import *
from torch.utils.data import DataLoader
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))
from hmpai.visualization import *
from hmpai.behaviour.sat2 import read_behavioural_info, SAT2_SPLITS, merge_data
from hmpai.pytorch.mamba import *
from matplotlib.lines import Line2D

In [None]:
# Set base variables
set_global_seed(42)
data_paths = [DATA_PATH / "sat2/stage_data_250hz.nc"]

splits = SAT2_SPLITS
labels = SAT_CLASSES_ACCURACY
whole_epoch = True
info_to_keep = ['event_name', 'participant', 'epochs', 'rt', 'condition']
subset_cond = None 
skip_samples = 62 # 0.25s
cut_samples = 63 # 0.25s
add_negative = True
add_pe = True

In [None]:
# Create datasets
norm_fn = norm_mad_zscore

train_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[0],
    normalization_fn=norm_fn,
    whole_epoch=whole_epoch,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_negative=add_negative,
    add_pe=add_pe,
)
norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
class_weights = train_data.statistics["class_weights"]
test_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[1] + splits[2],
    normalization_fn=norm_fn,
    norm_vars=norm_vars,
    whole_epoch=whole_epoch,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_negative=add_negative,
    add_pe=add_pe,
)

In [None]:
# Load in behavioural data
behaviour_sat2 = read_behavioural_info(DATA_PATH / "sat2/behavioural/df_full.csv")
test_loader_sat2 = DataLoader(
    test_data, batch_size=128, shuffle=True, num_workers=0, pin_memory=True
)

In [None]:
chk_path = Path("../models/final.pt")
checkpoint = load_model(chk_path)
config = {
    "n_channels": 64,
    "n_classes": len(labels),
    "n_mamba_layers": 5,
    "use_pointconv_fe": True,
    "spatial_feature_dim": 128,
    "use_conv": True,
    "conv_kernel_sizes": [3, 9],
    "conv_in_channels": [128, 128],
    "conv_out_channels": [256, 256],
    "conv_concat": True,
    "use_pos_enc": True,
}

model = build_mamba(config)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)
model.eval();

#### Combined visualization

In [None]:
# Pre-predict test/val set samples and save
data = predict_with_auc(model, test_loader_sat2, info_to_keep, labels)
data = merge_data(data, behaviour_sat2)
data.to_csv("files/visu_merged.csv", index=False)

In [None]:
def plot_all(model, merged_path=None, peak_path=None):
    if merged_path is not None:
        data = pd.read_csv("files/visu_merged.csv")
    else:
        data = predict_with_auc(model, test_loader_sat2, info_to_keep, labels)
        data = merge_data(data, behaviour_sat2)
    fig, axs = plt.subplots(3, 2, dpi=300, figsize=(7.09, 7.09), sharey='row')
    set_seaborn_style()
    # Single trial AC/SP
    # plot_single_epoch(test_data.__getitem__(0), labels, model, axs[0,0])
    # plot_single_epoch(test_data.__getitem__(1), labels, model, axs[0,1])
    
    plot_single_epoch(test_data.__getitem__(11696), labels, model, axs[0, 0])
    plot_single_epoch(test_data.__getitem__(16002), labels, model, axs[0, 1])
    axs[0, 0].set_xlabel("Time (samples)")
    axs[0, 0].set_ylabel("Probability")
    axs[0, 0].set_title("Accuracy")
    axs[0, 1].set_xlabel("Time (samples)")
    axs[0, 1].set_title("Speed")
    axs[0, 1].set_xlim(axs[0, 0].get_xlim())

    # Create custom legend
    stage_colors = dict(zip(labels[1:], sns.color_palette()[:len(labels[1:])]))
    stage_handles = [
        Line2D([0], [0], color=color, linewidth=1, label=label)
        for label, color in stage_colors.items()
    ]
    subtitle = Line2D([], [], linestyle='', alpha=0)

    model_handles = [
        Line2D([0], [0], color='gray', linestyle='-', linewidth=1, label='S4'),
        Line2D([0], [0], color='gray', linestyle='--', linewidth=1, label='HMP'),
    ]

    combined_handles = stage_handles + [subtitle] + model_handles
    combined_labels = list(stage_colors.keys()) + [''] + ['S4', 'HMP']

    axs[0, 0].legend(combined_handles, combined_labels, loc='best')

    plot_peak_timing(model, test_loader_sat2, labels, axs[1, 0], axs[1, 1], path=peak_path)
    axs[1, 0].set_ylabel("True peak timing (normalized)")
    axs[1, 0].set_xlabel("Predicted peak timing (normalized)")
    axs[1, 1].set_ylabel("")
    axs[1, 1].set_xlabel("Predicted peak timing (normalized)")

    # ACP Tertiles > Performance
    plot_tertile_split_single(data, 'confirmation_ratio', ['accuracy', 'speed'], calc_tertile_over_condition=False, normalize='time', axes=[axs[2, 0], axs[2, 1]])
    axs[2, 0].set_ylabel("Probability")
    axs[2, 0].set_xlabel("Average confirmation probability")
    axs[2, 1].set_xlabel("Average confirmation probability")
    axs[2, 0].set_ylim((0.0, 1.0))
    axs[2, 1].set_ylim((0.0, 1.0))

    # ACP Tertiles > MEG
    plot_emg_tertile_split(data, [axs[2, 0], axs[2, 1]], ['accuracy', 'speed'])
    handles, leg_labels = axs[2, 0].get_legend_handles_labels()
    unique = dict(zip(leg_labels, handles))
    axs[2, 0].legend(unique.values(), unique.keys(), title="", loc="center left")

    # Panel letters
    fig.text(0.02, 0.32, "c)", weight='bold')
    fig.text(0.02, 0.64, "b)", weight='bold')
    fig.text(0.02, 0.96, "a)", weight='bold')
    fig.tight_layout()
    fig.savefig("../img/results.svg")
    return fig

In [None]:
# Remove merged_path and/or peak_path if these have not been created yet (Visualization will take longer)
fig = plot_all(model, merged_path="visu_merged.csv", peak_path="visu_peak.csv")