In [1]:
%load_ext autoreload
%autoreload 2
import netCDF4
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants, split_participants, split_participants_into_folds
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test, calculate_global_class_weights
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, save_model, load_model
from hmpai.pytorch.generators import SAT1Dataset, MultiXArrayDataset, MultiXArrayProbaDataset
from hmpai.data import SAT1_STAGES_ACCURACY, SAT_CLASSES_ACCURACY
from hmpai.visualization import plot_confusion_matrix
from hmpai.pytorch.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D, AR_SAT1_CHANNELS
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from hmpai.pytorch.transforms import *
from collections import defaultdict
from hmpai.pytorch.mamba import *
import os
from copy import deepcopy
import json

DATA_PATH = Path(os.getenv("DATA_PATH"))
models = []
N_FOLDS = 6

In [2]:
def base_mamba():
    embed_dim = 64
    out_channels = 128
    base_cnn = nn.Sequential(
        nn.Conv1d(
            in_channels=embed_dim,
            out_channels=out_channels,
            kernel_size=50,
            stride=1,
            padding='same',
        ),
        nn.ReLU(),
    )
    model_kwargs = {
        "embed_dim": embed_dim,
        "mamba_dim": out_channels,
        "n_channels": 19,
        "n_classes": len(labels),
        "n_mamba_layers": 5,
        "cnn_module": base_cnn,
        "dropout": 0.1,
    }
    model = ConfigurableMamba(**model_kwargs)
    return model
models.append(base_mamba)

In [3]:
def no_cnn():
    embed_dim = 64
    out_channels = 64
    model_kwargs = {
        "embed_dim": embed_dim,
        "mamba_dim": out_channels,
        "n_channels": 19,
        "n_classes": len(labels),
        "n_mamba_layers": 5,
        "cnn_module": None,
        "dropout": 0.1,
    }
    model = ConfigurableMamba(**model_kwargs)
    return model
models.append(no_cnn)

In [5]:
def two_cnn():
    embed_dim = 64
    out_channels = 128
    base_cnn = nn.Sequential(
        nn.Conv1d(
            in_channels=embed_dim,
            out_channels=96,
            kernel_size=50,
            stride=1,
            padding='same',
        ),
        nn.ReLU(),
        nn.Conv1d(
            in_channels=96,
            out_channels=out_channels,
            kernel_size=24,
            stride=1,
            padding='same',
        ),
        nn.ReLU(),
    )
    model_kwargs = {
        "embed_dim": embed_dim,
        "mamba_dim": out_channels,
        "n_channels": 19,
        "n_classes": len(labels),
        "n_mamba_layers": 5,
        "cnn_module": base_cnn,
        "dropout": 0.1,
    }
    model = ConfigurableMamba(**model_kwargs)
    return model
models.append(two_cnn)

In [6]:
def three_cnn():
    embed_dim = 64
    out_channels = 128
    base_cnn = nn.Sequential(
        nn.Conv1d(
            in_channels=embed_dim,
            out_channels=85,
            kernel_size=50,
            stride=1,
            padding='same',
        ),
        nn.ReLU(),
        nn.Conv1d(
            in_channels=85,
            out_channels=106,
            kernel_size=24,
            stride=1,
            padding='same',
        ),
        nn.ReLU(),
        nn.Conv1d(
            in_channels=106,
            out_channels=out_channels,
            kernel_size=12,
            stride=1,
            padding='same',
        ),
        nn.ReLU(),
    )
    model_kwargs = {
        "embed_dim": embed_dim,
        "mamba_dim": out_channels,
        "n_channels": 19,
        "n_classes": len(labels),
        "n_mamba_layers": 5,
        "cnn_module": base_cnn,
        "dropout": 0.1,
    }
    model = ConfigurableMamba(**model_kwargs)
    return model
models.append(three_cnn)

In [2]:
def twostream_cnn():
    # TODO: Rethink whatever this is, I think putting features into channel/time dimension is not good, channels are of course a feature as well but idk
    embed_dim = 64
    out_channels = 128
    space_conv = nn.Sequential(
        nn.Conv2d(
            in_channels=1,
            out_channels=out_channels,
            kernel_size=(1, 1),  # Convolution over the channel dimension (n_channels)
            stride=1,
            padding='same',
        ),
        nn.ReLU(),
    )

    time_conv = nn.Sequential(
        nn.Conv2d(
            in_channels=1,
            out_channels=out_channels,
            kernel_size=(50, 1),
            stride=1,
            padding='same',
        ),
        nn.ReLU(),
    )

    model_kwargs = {
        "embed_dim": embed_dim,
        "mamba_dim": 2 * out_channels,
        "n_channels": 19,
        "n_classes": len(labels),
        "n_mamba_layers": 5,
        "space_cnn_module": space_conv,
        "time_cnn_module": time_conv,
        "dropout": 0.1,
    }
    model = ConfigurableMamba(**model_kwargs)
    return model
models.append(twostream_cnn)

In [3]:
data_path_1 = DATA_PATH / "sat2/stage_data_proba_250hz_part1.nc"
data_path_2 = DATA_PATH / "sat2/stage_data_proba_250hz_part2.nc"
data_paths = [data_path_1, data_path_2]

logs_path = Path("../../logs/architecture_validation")

set_global_seed(42)
folds = split_participants_into_folds(data_paths, N_FOLDS)

results = defaultdict(list)
torch.cuda.empty_cache()

for i_fold in range(len(folds)):
    train_folds = deepcopy(folds)
    test_fold = train_folds.pop(i_fold)
    train_fold = np.concatenate(train_folds, axis=0)
    print(f"Fold {i_fold + 1}: test fold: {test_fold}")

    labels = SAT_CLASSES_ACCURACY
    whole_epoch = True
    # Maybe 'accuracy'? probably not necessary
    subset_cond = None
    add_negative = True
    skip_samples = 0
    norm_fn = norm_mad_zscore

    train_data = MultiXArrayProbaDataset(
        data_paths,
        participants_to_keep=train_fold,
        normalization_fn=norm_fn,
        whole_epoch=whole_epoch,
        labels=labels,
        subset_cond=subset_cond,
        add_negative=add_negative,
    )
    norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
    class_weights = train_data.statistics["class_weights"]
    testval_data = MultiXArrayProbaDataset(
        data_paths,
        participants_to_keep=test_fold,
        normalization_fn=norm_fn,
        norm_vars=norm_vars,
        whole_epoch=whole_epoch,
        labels=labels,
        subset_cond=subset_cond,
        add_negative=add_negative,
    )
    for i_model, model_fn in enumerate(models):
        model = model_fn()
        test_result = train_and_test(
            model,
            train_data,
            testval_data,
            testval_data,
            logs_path=logs_path,
            workers=0,
            batch_size=64,
            labels=labels,
            lr=0.0001,
            use_class_weights=False,
            class_weights=class_weights,
            whole_epoch=whole_epoch,
            epochs=20,
            additional_name=f"model_fn-{model_fn.__name__}_fold-{i_fold}",
        )
        print(f"Fold {i_fold + 1}, model: {model_fn.__name__}: KLDivLoss: {test_result[0]['KLDivLoss']}")
        for i, result in enumerate(test_result):
            results[model_fn.__name__].append(result)
for model_fn in models:
    with open(logs_path / f"results_{model_fn.__name__}.json", "w") as f:
        json.dump(results[model_fn.__name__], f, indent=4)

Fold 1: test fold: ['S1' 'S10' 'S18']


  0%|          | 0/499 [00:00<?, ? batch/s]

  return F.conv2d(input, weight, bias, self.stride,


In [5]:
! tensorboard --logdir ../../logs/architecture_validation

TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.16.2 at http://localhost:6006/ (Press CTRL+C to quit)
E0928 14:07:20.577132 140497037293120 _internal.py:97] Error on request:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/werkzeug/serving.py", line 370, in run_wsgi
    execute(self.server.app)
  File "/opt/conda/lib/python3.10/site-packages/werkzeug/serving.py", line 331, in execute
    application_iter = app(environ, start_response)
  File "/opt/conda/lib/python3.10/site-packages/tensorboard/backend/application.py", line 528, in __call__
    return self._app(environ, start_response)
  File "/opt/conda/lib/python3.10/site-packages/tens