Find transformations with base model that maximize the difference between prediction and shuffled prediction

In [1]:
%load_ext autoreload
%autoreload 2
import netCDF4
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants, split_participants, split_participants_into_folds
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test, calculate_global_class_weights
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, save_model, load_model
from hmpai.pytorch.generators import SAT1Dataset, MultiXArrayDataset, MultiXArrayProbaDataset
from hmpai.data import SAT1_STAGES_ACCURACY, SAT_CLASSES_ACCURACY
from hmpai.visualization import plot_confusion_matrix
from hmpai.pytorch.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D, AR_SAT1_CHANNELS
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from hmpai.pytorch.transforms import *
from collections import defaultdict
from hmpai.pytorch.mamba import *
import os
from copy import deepcopy
import json

DATA_PATH = Path(os.getenv("DATA_PATH"))
models = []
N_FOLDS = 2

In [2]:
def base_mamba():
    embed_dim = 64
    out_channels = 128
    base_cnn = nn.Sequential(
        nn.Conv1d(
            in_channels=embed_dim,
            out_channels=out_channels,
            kernel_size=50,
            stride=1,
            padding='same',
        ),
        nn.ReLU(),
    )
    model_kwargs = {
        "embed_dim": embed_dim,
        "mamba_dim": out_channels,
        "n_channels": 19,
        "n_classes": len(labels),
        "n_mamba_layers": 5,
        "cnn_module": base_cnn,
        "dropout": 0.1,
    }
    model = ConfigurableMamba(**model_kwargs)
    return model


Define transform configurations, separately for train and testval, probably dont want most transforms in case of testval but should still be configurable

Can probably create dataset per fold, then set transform by doing `train_data.transform=transform`

In [3]:
# First index = train set, second index is testval set
# transforms = [(Compose([]), Compose([])),
#               (Compose([]), Compose([])),
#               (Compose([]), Compose([])),
#               (Compose([]), Compose([])),
#               (Compose([]), Compose([])),
#               (Compose([]), Compose([])),
#               (Compose([]), Compose([])),
#               (Compose([]), Compose([])),
#               (Compose([]), Compose([])),]
# StartJitterTransform(62)
# EndJitterTransform(63)
# ReverseTimeTransform()
# GaussianNoise()
# TimeMaskTransform()
# TimeDropoutTransform()
# ChannelsDropout()
#     (Compose([StartJitterTransform(62), EndJitterTransform(63)]), None),
#     (Compose([StartJitterTransform(62), EndJitterTransform(63), ReverseTimeTransform()]), None),
# First test if transforms by themselves make a difference, assume additive behaviour
transforms = [
    (None, None),
    (Compose([StartJitterTransform(62)]), None),
    # (Compose([EndJitterTransform(63)]), None),
    # (Compose([ReverseTimeTransform()]), None),
    # (Compose([GaussianNoise()]), None),
    # (Compose([TimeMaskTransform()]), None),
    # (Compose([TimeDropoutTransform()]), None),
    # (Compose([ChannelsDropout()]), None),
    # (Compose([ConcatenateTransform(0.5)]), None),

]

In [14]:
transforms[1][0].transforms[0].__class__.__name__

'StartJitterTransform'

In [15]:
from hmpai.pytorch.utilities import save_tensor


data_path_1 = DATA_PATH / "sat2/stage_data_proba_250hz_part1.nc"
data_path_2 = DATA_PATH / "sat2/stage_data_proba_250hz_part2.nc"
# data_paths = [data_path_1, data_path_2]
data_paths = [data_path_1] # TODO: Both paths

logs_path = Path("../../logs/transformation_validation")

set_global_seed(42)
folds = split_participants_into_folds(data_paths, N_FOLDS)

results = defaultdict(list)
torch.cuda.empty_cache()

for i_fold in range(len(folds)):
    train_folds = deepcopy(folds)
    test_fold = train_folds.pop(i_fold)
    train_fold = np.concatenate(train_folds, axis=0)
    print(f"Fold {i_fold + 1}: test fold: {test_fold}")

    labels = SAT_CLASSES_ACCURACY
    whole_epoch = True
    # Maybe 'accuracy'? probably not necessary
    subset_cond = 'accuracy'
    add_negative = True
    norm_fn = norm_mad_zscore

    train_data = MultiXArrayProbaDataset(
        data_paths,
        participants_to_keep=train_fold,
        normalization_fn=norm_fn
        whole_epoch=whole_epoch,
        labels=labels,
        subset_cond=subset_cond,
        add_negative=add_negative,
    )
    norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
    class_weights = train_data.statistics["class_weights"]
    testval_data = MultiXArrayProbaDataset(
        data_paths,
        participants_to_keep=test_fold,
        normalization_fn=norm_fn,
        norm_vars=norm_vars,
        whole_epoch=whole_epoch,
        labels=labels,
        subset_cond=subset_cond,
        add_negative=add_negative,
    )

    for i_t, (t_train, t_test) in enumerate(transforms):
        # Set transforms
        train_data.transform = t_train
        testval_data.transform = t_test

        model = base_mamba()
        additional_name = 'None' if t_train is None else f"transform-{t_train.transforms[0].__class__.__name__}_fold-{i_fold}"
        test_result = train_and_test(
            model,
            train_data,
            testval_data,
            testval_data,
            logs_path=logs_path,
            workers=0,
            batch_size=64,
            labels=labels,
            lr=0.001, # 0.0001
            use_class_weights=False,
            class_weights=class_weights,
            whole_epoch=whole_epoch,
            epochs=1,
            additional_name=additional_name,
            do_test_shuffled=True
        )
        print(f"Fold {i_fold + 1}, transform: {str(t_train)}: EMD: {test_result[0]['EMD']}, EMD_raw: {test_result[0]['EMD_raw'].shape}")
        for i, result in enumerate(test_result):
            results[i_t].append(result)

for i_t, (t_train, t_test) in enumerate(transforms):
    if isinstance(type(results[i_t][0]), dict):
        with open(logs_path / f"results_{str(t_train)}.json", "w") as f:
            json.dump(results[i_t], f, indent=4)
    else:
        for i_fold, fold in enumerate(results[i_t]):
            tensors = fold["EMD_raw"]
            name = 'None' if t_train is None else t_train.transforms[0].__class__.__name__
            save_tensor(tensors, logs_path / f"results_{name}_fold_{i_fold}.csv")


Fold 1: test fold: ['S17' 'S10' 'S15' 'S1' 'S18']


  0%|          | 0/70 [00:00<?, ? batch/s]

Fold 1, transform: None: EMD: tensor([-1.1785e-04, -1.0405e-04, -1.3804e-04, -2.0141e-05],
       dtype=torch.float64), EMD_raw: torch.Size([5790, 4])


  0%|          | 0/70 [00:00<?, ? batch/s]

Fold 1, transform: Compose(
    <hmpai.pytorch.transforms.StartJitterTransform object at 0x7f4929c87f10>
): EMD: tensor([-0.0001, -0.0002, -0.0002, -0.0001], dtype=torch.float64), EMD_raw: torch.Size([5790, 4])
Fold 2: test fold: ['S11' 'S13' 'S12' 'S16']


  0%|          | 0/91 [00:00<?, ? batch/s]

Fold 2, transform: None: EMD: tensor([-4.7935e-05, -7.0328e-05, -2.3023e-04, -2.0602e-04],
       dtype=torch.float64), EMD_raw: torch.Size([4435, 4])


  0%|          | 0/91 [00:00<?, ? batch/s]

Fold 2, transform: Compose(
    <hmpai.pytorch.transforms.StartJitterTransform object at 0x7f4929c87f10>
): EMD: tensor([-2.4292e-04, -2.7995e-04, -3.8946e-05,  1.3507e-05],
       dtype=torch.float64), EMD_raw: torch.Size([4435, 4])
