In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from hmpai.training import split_participants
from hmpai.pytorch.models import *
from hmpai.pytorch.utilities import DEVICE, set_global_seed, load_model
from hmpai.pytorch.generators import MultiXArrayProbaDataset
from hmpai.data import SAT_CLASSES_ACCURACY
from hmpai.pytorch.normalization import *
from hmpai.utilities import calc_ratio, format_stats_latex
from torch.utils.data import DataLoader
import os
DATA_PATH = Path(os.getenv("DATA_PATH"))
from hmpai.visualization import *
from hmpai.pytorch.mamba import *
from pymer4.models import Lmer

In [None]:
set_global_seed(42)

data_paths = [DATA_PATH / "sat1/stage_data_250hz.nc"]
splits = split_participants(data_paths, train_percentage=50)
all_participants = splits[0] + splits[1] + splits[2]

labels = SAT_CLASSES_ACCURACY

whole_epoch = True
info_to_keep = ['participant', 'epochs', 'RT', 'cue', 'movement', 'resp']
subset_cond = None
skip_samples = 62
cut_samples = 63
add_negative = True
add_pe = True

In [3]:
# Create dataset
norm_fn = norm_mad_zscore

test_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=all_participants,
    normalization_fn=norm_fn,
    whole_epoch=whole_epoch,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
    skip_samples=skip_samples,
    cut_samples=cut_samples,
    add_negative=add_negative,
    add_pe=add_pe,
)

In [4]:
# Load in behavioural data
test_loader_sat1 = DataLoader(
    test_data, batch_size=128, shuffle=True, num_workers=0, pin_memory=True
)

In [None]:
chk_path = Path("../models/boehm.pt")
checkpoint = load_model(chk_path)
config = {
    "n_channels": 30,
    "n_classes": len(labels),
    "n_mamba_layers": 5,
    "use_pointconv_fe": True,
    "spatial_feature_dim": 128,
    "use_conv": True,
    "conv_kernel_sizes": [3, 9],
    "conv_in_channels": [128, 128],
    "conv_out_channels": [256, 256],
    "conv_concat": True,
    "use_pos_enc": True,
}
model = build_mamba(config)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(DEVICE)
model.eval();

In [None]:
data = predict_with_auc(model, test_loader_sat1, info_to_keep, labels)
data['RT'] = data['RT'] / 1000

### Tertiles

In [8]:
def get_tertiles(data: pd.DataFrame, column: str, conditions: list[str], rt_col: str='rt_x', cue_var='SAT'):
    data = data.copy()
    data = calc_ratio(data, column, rt_col)
    ratio_column = column + '_ratio'

    for condition in conditions:
        data_subset = data[data[cue_var] == condition]
        quantile_values = data_subset.groupby('participant')[ratio_column].quantile([1/3, 2/3]).unstack()
        # First tertile
        low_tertiles = quantile_values.iloc[:,0]
        high_tertiles = quantile_values.iloc[:,1]
        print(f'{condition}, low. mean: {low_tertiles.mean():.2f}, std: {low_tertiles.std():.2f}')
        print(f'{condition}, high. mean: {high_tertiles.mean():.2f}, std: {high_tertiles.std():.2f}')
        # print(data_subset[ratio_column].describe())
        # print(data_subset[auc_column].describe())


def add_tertiles(data: pd.DataFrame, column: str, conditions: list[str], rt_col: str='rt_x', cue_var='SAT'):
    data = calc_ratio(data, column, rt_col)
    ratio_column = column + '_ratio'
    
    # Initialize tertile column
    data['tertile'] = pd.NA

    for condition in conditions:
        condition_mask = data[cue_var] == condition

        # Directly assign back to the original dataframe using `.loc`
        data.loc[condition_mask, 'tertile'] = data[condition_mask].groupby('participant')[ratio_column].transform(
            lambda x: pd.qcut(x, q=3, labels=['Low', 'Medium', 'High'])
        )

    return data

In [9]:
# Condition 1: Both movement and resp are non-empty
valid_comparison = (data['movement'] != '') & (data['resp'] != '')

# Condition 2: Extract last 4 chars only if non-empty (returns NaN otherwise)
movement_last4 = np.where(data['movement'] != '', data['movement'].str[-4:], np.nan)
resp_last4 = np.where(data['resp'] != '', data['resp'].str[-4:], np.nan)

# Assign correctness (True/False/NaN)
data['response'] = np.where(
    valid_comparison,
    movement_last4 == resp_last4,  # Actual comparison
    np.nan                        # Invalid -> NaN
)
data = add_tertiles(data, 'confirmation', ['AC', 'SP'], rt_col='RT', cue_var='cue')

In [None]:
data_lmer = data[["participant", "cue", "confirmation_ratio", "response", "tertile"]].copy()
data_lmer["participant"] = data_lmer["participant"].astype("category")
data_lmer["cue"] = data_lmer["cue"].astype("category")

formula = "response ~ cue * confirmation_ratio + (1|participant)"

model = Lmer(formula, data_lmer, family="binomial")
result = model.fit()
print(model.summary())



Linear mixed model fit by maximum likelihood  ['lmerMod']
Formula: response~cue*confirmation_ratio+(1|participant)

Family: binomial	 Inference: parametric

Number of observations: 4338	 Groups: {'participant': 25.0}

Log-likelihood: -2699.960 	 AIC: 5409.920

Random effects:

                    Name    Var    Std
participant  (Intercept)  0.083  0.288

No random effect correlations specified

Fixed effects:

Linear mixed model fit by maximum likelihood  ['lmerMod']
Formula: response~cue*confirmation_ratio+(1|participant)

Family: binomial	 Inference: parametric

Number of observations: 4338	 Groups: {'participant': 25.0}

Log-likelihood: -2699.960 	 AIC: 5409.920

Random effects:

                    Name    Var    Std
participant  (Intercept)  0.083  0.288

No random effect correlations specified

Fixed effects:

                          Estimate  2.5_ci  97.5_ci     SE     OR  OR_2.5_ci  \
(Intercept)                  0.913   0.760    1.065  0.078  2.491      2.139   
cueSP       

  ran_vars = ran_vars.applymap(


In [11]:
format_stats_latex(model)

(Intercept)
($\beta = 0.91$, $SE = 0.08$, $z = 11.75$, $p < 0.001$, $OR = 2.49$, $95\%\,CI\,[2.14, 2.90]$)
cueSP
($\beta = -0.35$, $SE = 0.07$, $z = -4.93$, $p < 0.001$, $OR = 0.71$, $95\%\,CI\,[0.61, 0.81]$)
confirmation_ratio
($\beta = 0.14$, $SE = 0.06$, $z = 2.32$, $p < 0.05$, $OR = 1.16$, $95\%\,CI\,[1.02, 1.31]$)
cueSP:confirmation_ratio
($\beta = 0.12$, $SE = 0.08$, $z = 1.53$, $p = 0.13$, $OR = 1.12$, $95\%\,CI\,[0.97, 1.30]$)
