In [1]:
%load_ext autoreload
%autoreload 2

## Data preparation First Part

In [2]:
import os
import glob

import mne
from joblib import Parallel, delayed
sfreq = 100

In [3]:

n_jobs = 30

files = sorted(glob.glob('/work/dlclarge2/schirrmr-eeg-age-competition/training/*_raw.fif.gz'))

In [None]:
def do(file, sfreq):
    raw = mne.io.read_raw_fif(file, verbose='error')
    raw.resample(sfreq=sfreq)
    raw.save(file.replace('_raw.fif.gz', f'_{sfreq}_hz_raw.fif.gz'), overwrite=True)

In [None]:
for f in tqdm(files):
    do(f, sfreq)

In [None]:
# download data https://filesender.renater.fr/?s=download&token=e1de0ec4-09bc-4194-b85b-59830cb04af3
# download test data from https://codalab.lisn.upsaclay.fr/competitions/8336

# Path to training data
train_path = "/work/dlclarge2/schirrmr-eeg-age-competition/lukas/data/training/"
# Path to testing data (public test set)
test_path = "/work/dlclarge2/schirrmr-eeg-age-competition/lukas/data/testing/"
train_subj = 1200  # use 10 instead of 1200 training subjects, for demonstration purpose
test_subj = 400  # use 10 instead of 400 testing subjects, for demonstration purpose

train_raws, test_raws = {}, {}
for condition in ["EC", "EO"]:
    train_raws[condition] = []
    test_raws[condition] = []
    train_subjs = list(range(1, train_subj + 1))
    for s in tqdm(train_subjs):
        fname = f"subj{s:04}_{condition}_{sfreq}_hz_raw.fif.gz"
        raw = mne.io.read_raw(train_path + fname, preload=False, verbose='error')
        
        train_raws[condition].append(raw)
    test_subjs = list(range(1201, 1201 + test_subj))
    for s in tqdm(test_subjs):
        fname = f"subj{s:04}_{condition}_{sfreq}_hz_raw.fif.gz"
        raw = mne.io.read_raw(test_path + fname, preload=False, verbose='error')
        test_raws[condition].append(raw)


In [None]:
import pandas as pd

In [None]:
meta = pd.read_csv(train_path + "train_subjects.csv", index_col=0)
meta = pd.concat([meta, meta])
meta['condition'] = len(train_raws['EC']) * ['EC'] + len(train_raws['EO']) * ['EO']
train_raws = train_raws['EC'] + train_raws['EO']
len(train_raws), len(meta)


In [None]:
test_meta = pd.DataFrame({'condition': len(test_raws['EC']) * ['EC'] + len(test_raws['EO']) * ['EO']})
test_raws = test_raws['EC'] + test_raws['EO']
len(test_raws), len(test_meta)


In [None]:
from braindecode.datasets import BaseConcatDataset, BaseDataset

In [None]:
target_name = 'age'

In [None]:
train = BaseConcatDataset([
    BaseDataset(raw, target_name=target_name) for raw in train_raws
])
meta['subject'] = meta['id']
train.set_description(meta)
train.set_description({'path': [ds.raw.filenames[0] for ds in train.datasets]})


In [None]:
import pickle
with open(train_path + f'train_{sfreq}_hz.pkl', 'wb') as f:
    pickle.dump(train, f)

In [None]:
test = BaseConcatDataset([
    BaseDataset(raw) for raw in test_raws
])
test_meta['subject'] = test_subjs + test_subjs
test.set_description(test_meta)
test.set_description({'path': [ds.raw.filenames[0] for ds in test.datasets]})


In [None]:
with open(test_path + f'test_{sfreq}_hz.pkl', 'wb') as f:
    pickle.dump(test, f)

## Data Preparation Second Part

In [None]:
from decode_tueg import decode_tueg

import datetime
import os
exp_date = datetime.datetime.now().isoformat()

base_dir = '/work/dlclarge2/schirrmr-eeg-age-competition/results/'

import pandas as pd


from datetime import datetime
params = {
    'model_name': ['deep'],  # 'shallow', 'deep', 'tcn'
    'subset': ['normal'],  # 'normal', 'abnormal', 'mixed'
    'target_name': ['age'],  # age, gender, pathological, age_clf

    'valid_set_i': [0],  # 0, 1, 2, 3, 4
    'n_epochs': [35],  # 35, 105, 210
    'n_restarts': [0],  # 0, 2, 5
    'augment': ['0'],  # dropout, flipfb, fliplr, noise, mask, reverse, shuffle, sign, random, identity, '0'
    'fast_mode': [1],
    'loss': ['mae'],  # mse, mae, log_cosh, huber, nll

    'condition': ['all'],  # 'all', 'EC', 'EO', TODO: implement using both, prevent subject leakage in both sets
    'n_train_recordings': [-1],  # -1: None
    'tmax': [-1],  # 4*60done, 6*60done, 11*60done, -1,  00# -1: None
    'min_age': [-1],
    'max_age': [-1],
    # 'data_path': ['/home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal/v2.0.0/edf/'],
    'data_path': ['/work/dlclarge2/schirrmr-eeg-age-competition/lukas/data/training/'],
    'squash_outs': [1],  # force output to be in [0, 1] through sigmoid

    'final_eval': [0],
    'debug': [1],
    'seed': [20221116],  # default 20220429
    'date': [exp_date],  # sometimes, need to restart some of the cv runs, due to cluster failure. do not reset exp date then
    'intuitive_training_scores': [1],  # 1: add slow callbacks that track age decodnig loss intuitively as mae
    'out_dir': [os.path.join(base_dir, 'competition/results/')],
    'n_jobs': [2],  # faster than 1, 3, and 4 on tmax=2*60, n_recordings=-1, subset=normal, n_epochs=5, preload=0
    'preload': [1],

    'batch_size': [64],  # 64. does CroppedTrialEpochStoring increase GPU memory consumption? 256 works fine in notebook but fails as pipeline. 128 works with shallow fails with deep
    'tmin': [-1],
    'standardize_data': [0],  # TODO: needs to be implemented. sclaing to microvolts is done anyways
    'standardize_targets': [1],
    'window_size_samples': [1500],  # EC condition is ~40s, EO only ~20s
    'shuffle_data_before_split': [0],
}


params = {k: params[k][0] for k in params}
params['config'] = pd.Series(params)

In [None]:
locals().update(**params)

In [None]:
from decode_tueg import add_file_logger, check_input_args

In [None]:
import math
import json
import glob
import pickle
import logging
import warnings
import argparse
from datetime import datetime
from functools import partial
from collections import OrderedDict
from io import StringIO as StringBuffer

import mne
mne.set_log_level('ERROR')
#mne.set_config("MNE_LOGGING_LEVEL", "ERROR")
import torch
import numpy as np
import pandas as pd
import seaborn as sns
sns.set_color_codes('deep')
import matplotlib.pyplot as plt
plt.style.use('seaborn')
plt.set_loglevel('ERROR')
from sklearn.metrics import mean_absolute_error, balanced_accuracy_score, mean_absolute_percentage_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.compose import TransformedTargetRegressor
from skorch.helper import predefined_split
from skorch.callbacks import LRScheduler, Checkpoint, TrainEndCheckpoint, ProgressBar, BatchScoring
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from skorch.utils import valid_loss_score, noop

from braindecode.datasets import BaseDataset, BaseConcatDataset
from braindecode.datasets.tuh import TUHAbnormal
from braindecode.preprocessing import Preprocessor, preprocess
from braindecode.preprocessing.windowers import create_fixed_length_windows
from braindecode.util import set_random_seeds
from braindecode.models import ShallowFBCSPNet, to_dense_prediction_model, Deep4Net, TCN
from braindecode.models.modules import Expression
from braindecode.regressor import EEGRegressor
from braindecode.classifier import EEGClassifier
from braindecode.training import CroppedLoss, CroppedTrialEpochScoring

In [None]:
import logging
logging.basicConfig(
    format="%(asctime)s %(levelname)s : %(message)s",
    level=logging.DEBUG,
)
logger = logging.getLogger()
logger.setLevel("DEBUG")

In [None]:
out_dir = os.path.join(out_dir, date, str(seed), str(valid_set_i))
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
else:
    raise RuntimeError(f'Directory already exists {out_dir}')
add_file_logger(
    logger=logger,
    out_dir=out_dir,
)

warnings.filterwarnings("ignore", message="'pathological' not in description.")
warnings.filterwarnings("ignore", message="torch.backends.cudnn.benchmark was set to True which")
warnings.filterwarnings("ignore", message="You are using an callback that overrides on_batch_begin or on_batc")
warnings.filterwarnings("ignore", message="This function was designed to predict trials from cropped datasets")
#warnings.filterwarnings("ignore", message="UserWarning: y_pred contains classes not in y_true")

check_input_args(
    batch_size, condition, config, data_path, debug, final_eval, intuitive_training_scores,
    max_age, min_age, model_name, n_epochs, n_jobs, n_restarts, n_train_recordings, 
    out_dir, preload, seed, shuffle_data_before_split, squash_outs, 
    standardize_data, standardize_targets, subset, target_name, tmax, tmin, 
    valid_set_i, window_size_samples, augment, loss, logger,
)


In [None]:
#log_capture_string = get_log_capturer(logger, debug)
level = logging.DEBUG if debug == 1 else logging.INFO
logger.setLevel(level)
logger.info(f'\n{config.sort_index()}')

In [None]:
# check if GPU is available, if True chooses to use it
cuda = torch.cuda.is_available()
if not cuda:
    raise RuntimeError('no gpu found')
torch.backends.cudnn.benchmark = True
logger.debug(f"cuda: {cuda}")
cropped = True
logger.debug(f"cropped: {cropped}")

In [None]:
from decode_tueg import get_competition_datasets
from decode_tueg import test_name

In [None]:
tuabn_train, tuabn_valid, mapping, valid_rest, valid_rest_name = get_competition_datasets(
            data_path,
            target_name,
            subset,
            n_train_recordings,
            tmin,
            tmax,
            n_jobs,
            final_eval,
            valid_set_i,
            seed,
            min_age,
            max_age,
            condition,
        )

In [None]:
from decode_tueg import save_input
save_input(
        config,
        out_dir,
        tuabn_train.description,
        tuabn_valid.description,
        test_name(final_eval),
    )

In [None]:
from decode_tueg import get_model


ch_names = tuabn_train.datasets[0].raw.ch_names
sfreq = tuabn_train.datasets[0].raw.info['sfreq']
n_channels = len(ch_names)
model, lr, weight_decay = get_model(
    n_channels,
    seed,
    cuda,
    target_name,
    model_name,
    cropped,
    window_size_samples,
    squash_outs,
)

In [None]:
from decode_tueg import create_windows
from decode_tueg import get_n_preds_per_input

In [None]:
%%time
n_preds_per_input = get_n_preds_per_input(
    model,
    n_channels,
    window_size_samples,
)
tuabn_train, tuabn_valid = create_windows(
    mapping, 
    tuabn_train,
    tuabn_valid,
    window_size_samples,
    n_jobs,
    preload,
    n_preds_per_input,
    test_name(final_eval),
)

In [None]:
from decode_tueg import standardize


tuabn_train, tuabn_valid = standardize(
    standardize_data, 
    standardize_targets,
    tuabn_train,
    tuabn_valid,
    target_name,
)

In [None]:
import pickle
pickle.dump(tuabn_train, open('tuabn_train.pkl', 'wb'))
pickle.dump(tuabn_valid, open('tuabn_valid.pkl', 'wb'))

In [None]:
with open(os.path.join(data_path.replace('training', 'testing'), f'test_{int(sfreq):d}_hz.pkl'), 'rb') as f:
    tuabn_eval = pickle.load(f)

In [None]:
from decode_tueg import _create_windows
tuabn_eval = _create_windows(
                tuabn_eval,
                window_size_samples,
                n_jobs, 
                preload,
                n_preds_per_input,
                mapping,
            )

In [None]:
pickle.dump(tuabn_eval, open('tuabn_eval.pkl', 'wb'))