In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_participants
from hmpai.pytorch.utilities import DEVICE, set_global_seed, load_model
from hmpai.pytorch.generators import MultiXArrayProbaDataset
from hmpai.data import SAT_CLASSES_ACCURACY
from hmpai.pytorch.normalization import *
from torch.utils.data import DataLoader
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))
from hmpai.visualization import *
from hmpai.pytorch.mamba import *
from matplotlib.lines import Line2D

In [None]:
# Set base variables
set_global_seed(42)

data_paths = [DATA_PATH / "sat1/stage_data_250hz.nc"]
splits = split_participants(data_paths, train_percentage=50)
all_participants = splits[0] + splits[1] + splits[2]

labels = SAT_CLASSES_ACCURACY

whole_epoch = True
info_to_keep = ['participant', 'epochs', 'RT', 'cue', 'movement', 'resp']
subset_cond = None
skip_samples = 62
cut_samples = 63
add_negative = True
add_pe = True

In [None]:
# Create dataset
norm_fn = norm_mad_zscore

test_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=all_participants,
    normalization_fn=norm_fn,
    whole_epoch=whole_epoch,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_negative=add_negative,
    add_pe=add_pe,
)

In [None]:
test_loader_sat1 = DataLoader(
    test_data, batch_size=128, shuffle=True, num_workers=0, pin_memory=True
)

In [None]:
chk_path = Path("../models/boehm.pt")
checkpoint = load_model(chk_path)
config = {
    "n_channels": 30,
    "n_classes": len(labels),
    "n_mamba_layers": 5,
    "use_pointconv_fe": True,
    "spatial_feature_dim": 128,
    "use_conv": True,
    "conv_kernel_sizes": [3, 9],
    "conv_in_channels": [128, 128],
    "conv_out_channels": [256, 256],
    "conv_concat": True,
    "use_pos_enc": True,
}
model = build_mamba(config)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)
model.eval();

In [None]:
def plot_all(model):
    data = predict_with_auc(model, test_loader_sat1, info_to_keep, labels)
    fig, axs = plt.subplots(3, 2, dpi=300, figsize=(7.09, 7.09), sharey='row')
    
    set_seaborn_style()
    # Single trial AC/SP
    # plot_single_epoch(test_data.__getitem__(1), labels, model, axs[0,0])
    # plot_single_epoch(test_data.__getitem__(0), labels, model, axs[0,1])

    plot_single_epoch(test_data.__getitem__(1109), labels, model, axs[0,0])
    plot_single_epoch(test_data.__getitem__(1000), labels, model, axs[0,1])
    axs[0, 0].set_xlabel("Time (samples)")
    axs[0, 0].set_ylabel("Probability")
    axs[0, 0].set_title("Accuracy")
    axs[0, 1].set_xlabel("Time (samples)")
    axs[0, 1].set_title("Speed")
    axs[0, 1].set_xlim(axs[0, 0].get_xlim())

    # Create custom legend
    stage_colors = dict(zip(labels[1:], sns.color_palette()[:len(labels[1:])]))
    stage_handles = [
        Line2D([0], [0], color=color, linewidth=1, label=label)
        for label, color in stage_colors.items()
    ]
    subtitle = Line2D([], [], linestyle='', alpha=0)

    model_handles = [
        Line2D([0], [0], color='gray', linestyle='-', linewidth=1, label='S4'),
        Line2D([0], [0], color='gray', linestyle='--', linewidth=1, label='HMP'),
    ]

    combined_handles = stage_handles + [subtitle] + model_handles
    combined_labels = list(stage_colors.keys()) + [''] + ['S4', 'HMP']

    axs[0, 0].legend(combined_handles, combined_labels, loc='best')

    plot_peak_timing(model, test_loader_sat1, labels, axs[1, 0], axs[1, 1], cue_var='cue', sample=False)
    axs[1, 0].set_ylabel("True peak timing (normalized)")
    axs[1, 0].set_xlabel("Predicted peak timing (normalized)")
    axs[1, 1].set_ylabel("")
    axs[1, 1].set_xlabel("Predicted peak timing (normalized)")

    # ACP Tertiles > Performance
    # Add response to data
    # Condition 1: Both movement and resp are non-empty
    valid_comparison = (data['movement'] != '') & (data['resp'] != '')

    # Condition 2: Extract last 4 chars only if non-empty (returns NaN otherwise)
    movement_last4 = np.where(data['movement'] != '', data['movement'].str[-4:], np.nan)
    resp_last4 = np.where(data['resp'] != '', data['resp'].str[-4:], np.nan)

    # Assign correctness (True/False/NaN)
    data['response'] = np.where(
        valid_comparison,
        movement_last4 == resp_last4,  # Actual comparison
        np.nan                        # Invalid -> NaN
    )
    
    plot_tertile_split_single(data, 'confirmation_ratio', ['AC', 'SP'], calc_tertile_over_condition=False, normalize='time', axes=[axs[2, 0], axs[2, 1]], cue_var='cue')
    axs[2, 0].set_ylabel("Probability of correct response")
    axs[2, 0].set_xlabel("Average confirmation probability")
    axs[2, 1].set_xlabel("Average confirmation probability")
    axs[2, 0].set_ylim((0.0, 1.0))
    axs[2, 1].set_ylim((0.0, 1.0))

    # Panel letters
    fig.text(0.015, 0.32, "c)", weight='bold')
    fig.text(0.015, 0.64, "b)", weight='bold')
    fig.text(0.015, 0.96, "a)", weight='bold')
    fig.tight_layout()
    fig.savefig("../img/results_sat1.svg")
    plt.show()

In [None]:
plot_all(model)