In [1]:
%load_ext autoreload
%autoreload 2
import netCDF4
import xarray as xr
from pathlib import Path
from hmpai.pytorch.models import *
from hmpai.training import split_data_on_participants, split_participants
from hmpai.pytorch.training import train, validate, calculate_class_weights, train_and_test, k_fold_cross_validate, test, calculate_global_class_weights
from hmpai.pytorch.utilities import DEVICE, set_global_seed, get_summary_str, save_model, load_model
from hmpai.pytorch.generators import SAT1Dataset, MultiXArrayDataset, MultiXArrayProbaDataset
from hmpai.data import SAT1_STAGES_ACCURACY, SAT_CLASSES_ACCURACY
from hmpai.visualization import plot_confusion_matrix
from hmpai.pytorch.normalization import *
from torchinfo import summary
from hmpai.utilities import print_results, CHANNELS_2D, AR_SAT1_CHANNELS
from torch.utils.data import DataLoader
# from braindecode.models.eegconformer import EEGConformer
from mne.io import read_info
import os
from ray import train as ray_train, tune
from ray.tune.schedulers import ASHAScheduler
from ray.train import Checkpoint
import tempfile
from ray.tune.tune_config import TuneConfig
from ray.train import ScalingConfig
DATA_PATH = Path(os.getenv("DATA_PATH"))

In [2]:
set_global_seed(42)
data_path_1 = DATA_PATH / "sat2/stage_data_proba_250hz_part1.nc"
data_path_2 = DATA_PATH / "sat2/stage_data_proba_250hz_part2.nc"
data_paths = [data_path_1, data_path_2]
# train_percentage=100 makes test and val 100 as well
splits = split_participants(data_paths, train_percentage=60)
labels = SAT_CLASSES_ACCURACY
info_to_keep = ["event_name", "rt"]
whole_epoch = True
subset_cond = None
batch_size = 64

In [3]:
norm_fn = norm_mad_zscore
train_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[0],
    normalization_fn=norm_fn,
    whole_epoch=whole_epoch,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
)
norm_vars = get_norm_vars_from_global_statistics(train_data.statistics, norm_fn)
class_weights = train_data.statistics["class_weights"]
test_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[1],
    normalization_fn=norm_fn,
    norm_vars=norm_vars,
    whole_epoch=whole_epoch,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
)
val_data = MultiXArrayProbaDataset(
    data_paths,
    participants_to_keep=splits[2],
    normalization_fn=norm_fn,
    norm_vars=norm_vars,
    whole_epoch=whole_epoch,
    labels=labels,
    info_to_keep=info_to_keep,
    subset_cond=subset_cond,
)

In [20]:
torch.cuda.empty_cache()

In [7]:
def tune_sat2(config):
    train_loader = DataLoader(
        train_data, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True
    )
    val_loader = DataLoader(
        val_data, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True
    )
    test_loader = DataLoader(
        test_data, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True
    )

    model = MambaModel(config["embed_dim"], 19, len(labels), config["n_layers"], global_pool=False, dropout=config["dropout"])
    model = model.to(DEVICE)

    loss_fn = torch.nn.KLDivLoss(reduction='batchmean', log_target=False)

    opt = torch.optim.NAdam(model.parameters(), weight_decay=config["weight_decay"], lr=config["lr"])

    # Epochs
    for i in range(20):
        batch_losses = train(model, train_loader, opt, loss_fn, whole_epoch=True)
        loss = np.mean(batch_losses)

        val_losses, val_accuracy = validate(model, val_loader, loss_fn, whole_epoch=True)
        val_loss = np.mean(val_losses)

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            checkpoint = None
            if (i + 1) % 5 == 0:
                torch.save(
                    model.state_dict(),
        	        os.path.join(temp_checkpoint_dir, "model.pth")
                )
                checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
            
            ray_train.report({"loss": loss, "val_loss": val_loss}, checkpoint=checkpoint)


In [11]:
torch.cuda.empty_cache()
search_space = {
    "embed_dim": tune.choice([32, 64, 128, 256, 512]),
    "n_layers": tune.randint(2, 9),
    "dropout": tune.choice([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]),
    "weight_decay": tune.choice([0.0, 0.1, 0.01, 0.001]),
    "lr": tune.choice([0.1, 0.01, 0.001, 0.0001, 0.0001]),
}
scheduler = ASHAScheduler(
        max_t=20,
        grace_period=2,
        reduction_factor=2)

tuner = tune.Tuner(
    tune.with_resources(
        tune.with_parameters(tune_sat2),
        resources={"cpu": 12, "gpu": 1}
    ),
    tune_config=tune.TuneConfig(
        metric="val_loss",
        mode="min",
        scheduler=scheduler,
        num_samples=20
    ),
    param_space=search_space
)

results = tuner.fit()

0,1
Current time:,2024-09-13 12:02:11
Running for:,00:38:31.10
Memory:,7.8/94.3 GiB

Trial name,# failures,error file
tune_sat2_dd2ac_00013,1,"/tmp/ray/session_2024-09-13_11-17-18_694579_894671/artifacts/2024-09-13_11-23-40/tune_sat2_2024-09-13_11-23-40/driver_artifacts/tune_sat2_dd2ac_00013_13_dropout=0.5000,embed_dim=512,lr=0.0010,n_layers=7,weight_decay=0.0000_2024-09-13_11-23-41/error.txt"

Trial name,status,loc,dropout,embed_dim,lr,n_layers,weight_decay,iter,total time (s),loss,val_loss
tune_sat2_dd2ac_00000,TERMINATED,172.18.0.2:907835,0.2,64,0.0001,8,0.001,20.0,154.648,4.87775,4.04353
tune_sat2_dd2ac_00001,TERMINATED,172.18.0.2:911132,0.2,64,0.0001,7,0.001,2.0,14.8014,175.073,24.9735
tune_sat2_dd2ac_00002,TERMINATED,172.18.0.2:911719,0.1,32,0.01,2,0.0,20.0,116.501,1.7104,2.63804
tune_sat2_dd2ac_00003,TERMINATED,172.18.0.2:914537,0.5,64,0.1,2,0.0,20.0,116.022,,
tune_sat2_dd2ac_00004,TERMINATED,172.18.0.2:917337,0.0,128,0.01,5,0.01,20.0,167.85,,
tune_sat2_dd2ac_00005,TERMINATED,172.18.0.2:920661,0.0,512,0.0001,2,0.001,20.0,278.233,1.63706,2.48537
tune_sat2_dd2ac_00006,TERMINATED,172.18.0.2:925249,0.4,512,0.1,3,0.001,2.0,40.2864,21.1414,10.6765
tune_sat2_dd2ac_00007,TERMINATED,172.18.0.2:926066,0.2,256,0.1,7,0.001,20.0,386.534,,
tune_sat2_dd2ac_00008,TERMINATED,172.18.0.2:931587,0.3,64,0.1,8,0.01,20.0,154.014,,
tune_sat2_dd2ac_00009,TERMINATED,172.18.0.2:934747,0.0,512,0.0001,6,0.0,4.0,156.04,2.80493,3.98093


2024-09-13 11:53:34,856	ERROR tune_controller.py:1331 -- Trial task failed for trial tune_sat2_dd2ac_00013
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/opt/conda/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/ray/_private/worker.py", line 2661, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/opt/conda/lib/python3.10/site-packages/ray/_private/worker.py", line 871, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(OutOfMemoryError): [36mray::ImplicitFunc.train()[39m (pid=939978, ip=172.18.0.2, actor_