In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_participants_into_folds
from hmpai.pytorch.training import train_and_test
from hmpai.pytorch.utilities import set_global_seed
from hmpai.pytorch.generators import MultiXArrayProbaDataset
from hmpai.data import SAT_CLASSES_ACCURACY
from hmpai.pytorch.normalization import *
from torchvision.transforms import Compose
from hmpai.pytorch.transforms import *
from hmpai.pytorch.mamba import *
from hmpai.behaviour.sat2 import SAT2_SPLITS

import os
from copy import deepcopy
import json
import time
import pandas as pd

DATA_PATH = Path(os.getenv("DATA_PATH"))

#### Visualization
Run this after running the ablation

In [None]:
logs_path = Path("../logs/model_ablation")

result_files = logs_path.glob("*/*.json")
tmp = []
ablations = [
    "Linear",
    "PointConv",
    "No Spatial",
    "2 Conv",
    "3 Conv",
    "No Pos Enc",
    "LSTM",
]
for file in result_files:
    parts = file.name.split("_")
    abl = parts[0]
    fold = parts[1].split(".")[0]
    abl_idx = int(abl[8:])
    values = {"ablation": ablations[abl_idx], "fold": fold[4:]}
    with open(file, "r") as file:
        tmp_data = json.load(file)
    values["result"] = tmp_data["test_kldiv_mean"]
    values["runtime"] = tmp_data["runtime"]
    values["n_parameters"] = tmp_data["n_parameters"]
    tmp.append(values)
data = pd.DataFrame(tmp)

In [None]:
data

In [None]:
means = data.groupby("ablation")[["result", "runtime", "n_parameters"]].mean()
means["name"] = [
    "2 convolutional layers",
    "3 convolutional layers",
    "LSTM",
    "Linear",
    "No positional encoding",
    "No spatial",
    "1-D convolution",
]
stds = data.groupby("ablation")[["result", "runtime", "n_parameters"]].std()

for i in range(len(means)):
    table_str = f"- & {means.iloc[i, 3]} & {means.iloc[i, 0]:.2f} ({stds.iloc[i, 0]:.2f}) & {means.iloc[i, 1]:.2f} ({stds.iloc[i, 1]:.2f}) & {means.iloc[i, 2]:.0f} \\\\"
    print(table_str)

#### Run ablations

In [None]:
set_global_seed(42)
N_FOLDS = len(SAT2_SPLITS[0])

data_paths = [DATA_PATH / "sat2/stage_data_250hz.nc"]

labels = SAT_CLASSES_ACCURACY
info_to_keep = ["event_name", "participant", "epochs", "rt"]
whole_epoch = True
subset_cond = None
add_negative = True
skip_samples = 0
cut_samples = 0

In [None]:
base_config = {"n_channels": 64, "n_classes": len(labels), "n_mamba_layers": 5}
configs = [
    # SPATIAL
    {
        "use_linear_fe": True,
        "spatial_feature_dim": 128,
        "use_conv": True,
        "conv_kernel_sizes": [3],
        "conv_in_channels": [128],
        "conv_out_channels": [256],
        "conv_concat": True,
        "use_pos_enc": True,
    },
    {
        "use_pointconv_fe": True,
        "spatial_feature_dim": 128,
        "use_conv": True,
        "conv_kernel_sizes": [3],
        "conv_in_channels": [128],
        "conv_out_channels": [256],
        "conv_concat": True,
        "use_pos_enc": True,
    },
    {
        "spatial_feature_dim": 64,
        "use_conv": True,
        "conv_kernel_sizes": [3],
        "conv_in_channels": [64],
        "conv_out_channels": [256],
        "conv_concat": True,
        "use_pos_enc": True,
    },
    # TEMPORAL (first one is included in spatial)
    # {
    #     # "use_pointconv_fe": True,
    #     # "spatial_feature_dim": 128,
    #     # "use_conv": True,
    #     # "conv_kernel_sizes": [3],
    #     # "conv_in_channels": [128],
    #     # "conv_out_channels": [256],
    #     # "conv_concat": True,
    #     # "use_pos_enc": True,
    # },
    {
        "use_pointconv_fe": True,
        "spatial_feature_dim": 128,
        "use_conv": True,
        "conv_kernel_sizes": [3, 9],
        "conv_in_channels": [128, 128],
        "conv_out_channels": [256, 256],
        "conv_concat": True,
        "use_pos_enc": True,
    },
    {
        "use_pointconv_fe": True,
        "spatial_feature_dim": 128,
        "use_conv": True,
        "conv_kernel_sizes": [3, 9, 27],
        "conv_in_channels": [128, 128, 128],
        "conv_out_channels": [256, 256, 256],
        "conv_concat": True,
        "use_pos_enc": True,
    },
    # POSITIONAL ENCODING
    {
        "use_pointconv_fe": True,
        "spatial_feature_dim": 128,
        "use_conv": True,
        "conv_kernel_sizes": [3, 9],
        "conv_in_channels": [128, 128],
        "conv_out_channels": [256, 256],
        "conv_concat": True,
        "use_pos_enc": False,
    },
    # LSTM
    {
        "use_pointconv_fe": True,
        "spatial_feature_dim": 128,
        "use_conv": True,
        "conv_kernel_sizes": [3, 9],
        "conv_in_channels": [128, 128],
        "conv_out_channels": [256, 256],
        "conv_concat": True,
        "use_pos_enc": True,
        "use_lstm": True,
    },
]

for config in configs:
    config.update(base_config)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def test_model(config: dict, i: int, from_fold: int = -1) -> None:
    data_paths = [DATA_PATH / "sat2/stage_data_250hz.nc"]

    logs_path = Path("../logs/model_ablation")

    set_global_seed(42)
    folds = split_participants_into_folds(
        data_paths, N_FOLDS, participants_to_use=SAT2_SPLITS[0], shuffle=False
    )
    ablation_name = f"ablation{i}"
    torch.cuda.empty_cache()
    print(f"Model {i}")
    for i_fold in range(len(folds)):
        if i_fold <= from_fold:
            continue
        train_folds = deepcopy(folds)
        test_fold = train_folds.pop(i_fold)
        train_fold = np.concatenate(train_folds, axis=0)
        print(f"\tFold {i_fold + 1}: test fold: {test_fold}")

        whole_epoch = True
        subset_cond = None
        add_negative = True
        norm_fn = norm_mad_zscore

        run_name = f"{ablation_name}_fold{i_fold}"

        train_data = MultiXArrayProbaDataset(
            data_paths,
            participants_to_keep=train_fold,
            normalization_fn=norm_fn,
            whole_epoch=whole_epoch,
            labels=labels,
            subset_cond=subset_cond,
            add_negative=add_negative,
            transform=Compose(
                [StartJitterTransform(62, 1.0), EndJitterTransform(63, 1.0)]
            ),
            skip_samples=skip_samples,
            cut_samples=cut_samples,
            add_pe=config["use_pos_enc"],
        )

        norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
        class_weights = train_data.statistics["class_weights"]
        testval_data = MultiXArrayProbaDataset(
            data_paths,
            participants_to_keep=test_fold,
            normalization_fn=norm_fn,
            norm_vars=norm_vars,
            whole_epoch=whole_epoch,
            labels=labels,
            subset_cond=subset_cond,
            add_negative=add_negative,
            transform=None,
            skip_samples=skip_samples,
            cut_samples=cut_samples,
            add_pe=config["use_pos_enc"],
        )

        model = build_mamba(config)
        n_parameters = count_parameters(model)
        start_time = time.time()
        test_result = train_and_test(
            model,
            train_data,
            testval_data,
            testval_data,
            logs_path=logs_path / ablation_name,
            workers=12,
            batch_size=32,
            labels=labels,
            lr=0.0001,
            use_class_weights=False,
            class_weights=class_weights,
            whole_epoch=whole_epoch,
            epochs=50,
            additional_name=run_name,
        )
        end_time = time.time()

        config["runtime"] = end_time - start_time
        config["test_kldiv_mean"] = test_result[0]["test_kldiv_mean"]
        config["test_kldiv_list"] = test_result[0]["test_kldiv_list"]
        config["n_parameters"] = n_parameters
        with open(logs_path / ablation_name / f"{run_name}.json", "w") as out:
            json.dump(config, out)

In [None]:
for i, config in enumerate(configs):
    test_model(config, i)