In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.pytorch.utilities import DEVICE, set_global_seed, load_model
from hmpai.pytorch.generators import MultiXArrayProbaDataset
from hmpai.data import SAT_CLASSES_ACCURACY
from hmpai.pytorch.normalization import *
from hmpai.utilities import calc_ratio, format_stats_latex
from torch.utils.data import DataLoader
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))
from hmpai.visualization import *
from hmpai.behaviour.sat2 import read_behavioural_info, SAT2_SPLITS, merge_data
from hmpai.pytorch.mamba import *
from pymer4.models import Lmer


In [None]:
# Set base variables
set_global_seed(42)
data_paths = [DATA_PATH / "sat2/stage_data_250hz.nc"]

splits = SAT2_SPLITS
labels = SAT_CLASSES_ACCURACY
whole_epoch = True
info_to_keep = ['event_name', 'participant', 'epochs', 'rt', 'condition']
subset_cond = None
skip_samples = 62 # 0.25s
cut_samples = 63 # 0.25s
add_negative = True
add_pe = True

In [None]:
# Create datasets
norm_fn = norm_mad_zscore

train_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[0],
    normalization_fn=norm_fn,
    whole_epoch=whole_epoch,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_negative=add_negative,
    add_pe=add_pe,
)
norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
class_weights = train_data.statistics["class_weights"]

test_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[1] + splits[2],
    normalization_fn=norm_fn,
    norm_vars=norm_vars,
    whole_epoch=whole_epoch,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_negative=add_negative,
    add_pe=add_pe,
)

In [None]:
# Load in behavioural data
behaviour_sat2 = read_behavioural_info(DATA_PATH / "sat2/behavioural/df_full.csv")
test_loader_sat2 = DataLoader(
    test_data, batch_size=128, shuffle=True, num_workers=0, pin_memory=True
)

In [None]:
chk_path = Path("../models/final.pt")
checkpoint = load_model(chk_path)
config = {
    "n_channels": 64,
    "n_classes": len(labels),
    "n_mamba_layers": 5,
    "use_pointconv_fe": True,
    "spatial_feature_dim": 128,
    "use_conv": True,
    "conv_kernel_sizes": [3, 9],
    "conv_in_channels": [128, 128],
    "conv_out_channels": [256, 256],
    "conv_concat": True,
    "use_pos_enc": True,
}
model = build_mamba(config)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)
model.eval();

In [None]:
# Run this if visu_merged.csv does not exist
data = predict_with_auc(model, test_loader_sat2, info_to_keep, labels)
data = merge_data(data, behaviour_sat2)
data.to_csv("files/visu_merged.csv", index=False)

In [None]:
# Otherwise run this
data = pd.read_csv("files/visu_merged.csv")

### Tertiles

In [None]:
def get_tertiles(data: pd.DataFrame, column: str, conditions: list[str], rt_col: str='rt_x', cue_var='SAT'):
    data = data.copy()
    data = calc_ratio(data, column, rt_col)
    ratio_column = column + '_ratio'

    for condition in conditions:
        data_subset = data[data[cue_var] == condition]
        quantile_values = data_subset.groupby('participant')[ratio_column].quantile([1/3, 2/3]).unstack()
        low_tertiles = quantile_values.iloc[:,0]
        high_tertiles = quantile_values.iloc[:,1]
        print(f'{condition}, low. mean: {low_tertiles.mean():.2f}, std: {low_tertiles.std():.2f}')
        print(f'{condition}, high. mean: {high_tertiles.mean():.2f}, std: {high_tertiles.std():.2f}')

def add_tertiles(data: pd.DataFrame, column: str, conditions: list[str], rt_col: str='rt_x', cue_var='SAT'):
    data = calc_ratio(data, column, rt_col)
    ratio_column = column + '_ratio'
    
    # Initialize tertile column
    data['tertile'] = pd.NA

    for condition in conditions:
        condition_mask = data[cue_var] == condition

        # Directly assign back to the original dataframe using `.loc`
        data.loc[condition_mask, 'tertile'] = data[condition_mask].groupby('participant')[ratio_column].transform(
            lambda x: pd.qcut(x, q=3, labels=['Low', 'Medium', 'High'])
        )

    return data

In [None]:
data = add_tertiles(data, 'confirmation', ['accuracy', 'speed'])

### Probability correct
Generalized Linear Mixed Effects Regression

In [None]:
data_lmer = data[["participant", "condition", "confirmation_ratio", "response", "tertile"]].copy()
data_lmer["participant"] = data_lmer["participant"].astype("category")
data_lmer["condition"] = data_lmer["condition"].astype("category")
data_lmer["tertile"] = data_lmer["tertile"].astype("category")

formula = "response ~ condition * confirmation_ratio + (1|participant)"

model = Lmer(formula, data_lmer, family="binomial")
result = model.fit()
print(model.summary())

In [None]:
format_stats_latex(model)

### EMG Sequences
Generalized Linear Mixed Effects Regression

In [None]:
group_mapping = {
    "IR": 0,
    "CR": 0,
    "ICR": 1,
    "CIR": 1,
    "CCR": 1,
    "IIR": 1,
}
data["EMG_group"] = data["EMG_sequence"].map(group_mapping)
data = data[data["EMG_group"].notnull()]
data = calc_ratio(data, 'confirmation')

In [None]:
data_lmer = data[["participant", "condition", "confirmation_ratio", "EMG_group"]].copy()
data_lmer["participant"] = data_lmer["participant"].astype("category")
data_lmer["condition"] = data_lmer["condition"].astype("category")

formula = "EMG_group ~ condition * confirmation_ratio + (1|participant)"

model = Lmer(formula, data_lmer, family="binomial")

result = model.fit()
print(model.summary())

In [None]:
format_stats_latex(model)