In [1]:
%load_ext autoreload
%autoreload 2

# Minimal TSMNet demo notebook for inter-session/-subject source-free (SF) offline and online unsupervised domain adaptation (UDA)

In [1]:
import torch
import sklearn
from sklearn.metrics import balanced_accuracy_score
import seaborn as sns

import numpy as np
from imblearn.under_sampling import RandomUnderSampler

import pandas as pd
from copy import deepcopy

from moabb.datasets.bnci import BNCI2015001, BNCI2014001
from moabb.paradigms import MotorImagery, CVEP

# from library.utils.torch import StratifiedDomainDataLoader
from spdnets.utils.data import StratifiedDomainDataLoader, DomainDataset
from spdnets.models import TSMNet
import spdnets.batchnorm as bn
import spdnets.functionals as fn

from spdnets.trainer import Trainer
from spdnets.callbacks import MomentumBatchNormScheduler, EarlyStopping

import sys
import matplotlib.pyplot as plt
sys.path.insert(0,"C:\\Users\\s.velut\\Documents\\These\\Protheus_PHD\\Scripts")
sys.path.insert(0,"C:\\Users\\s.velut\\Documents\\These\\Protheus_PHD")
from utils import prepare_data,get_BVEP_data,balance,get_y_pred
from _utils import make_preds_accumul_aggresive, make_preds_pvalue
sys.path.insert(0,"C:\\Users\\s.velut\\Documents\\These\\moabb\\moabb\\datasets")
from castillos2023 import CasitllosCVEP100,CasitllosCVEP40,CasitllosBurstVEP100,CasitllosBurstVEP40






c:\Users\s.velut\AppData\Local\Programs\Python\Python311\Lib\site-packages\moabb\pipelines\__init__.py:26: ModuleNotFoundError: Tensorflow is not installed. You won't be able to use these MOABB pipelines if you attempt to do so.
  warn(


In [2]:
# network and training configuration
cfg = dict(
    epochs = 20,
    batch_size_train = 64,
    domains_per_batch = 4,
    validation_size = 0.2,
    evaluation = 'inter-subject', # or 'inter-subject'
    dtype = torch.float32,
    # parameters for the TSMNet model
    mdl_kwargs = dict(
        temporal_filters=2,
        spatial_filters=30,
        subspacedims=20, 
        bnorm_dispersion=bn.BatchNormDispersion.SCALAR,
        spd_device='cpu',
        spd_dtype=torch.double,
        domain_adaptation=True
    )
)

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

## load a MOABB dataset

In [3]:
# moabb_ds = BNCI2015001()
# n_classes = 2
# moabb_paradigm = MotorImagery(n_classes=n_classes, events=['right_hand', 'feet'], fmin=4, fmax=36, tmin=0.0, tmax=4.0, resample=256)

moabb_ds = CasitllosBurstVEP100()
n_classes = 2
moabb_paradigm = CVEP()

# moabb_ds = BNCI2014001()
# n_classes = 4
# moabb_paradigm = MotorImagery(n_classes=n_classes, events=["left_hand", "right_hand", "feet", "tongue"], fmin=4, fmax=36, tmin=0.5, tmax=3.496, resample=None)

Choosing the first None classes from all possible events.


## SF offline UDA

In [4]:
def sfuda_offline(dataset : DomainDataset, model : TSMNet):
    model.eval()
    model.domainadapt_finetune(dataset.features.to(dtype=cfg['dtype'], device=device), dataset.labels.to(device=device), dataset.domains, None)

## SF online UDA

In [5]:
def _sfuda_online_init_other(domain_bn : torch.nn.Module, model : TSMNet, cfg):

    # initialize new mean with the grand average of the other domains
    M = []
    s = []
    for _, val in model.spdbnorm.batchnorm.items():
        M += [val.running_mean_test]
        s += [val.running_var_test]
    M = torch.cat(M, dim=0)
    s = torch.cat(s, dim=0)
    prev_mean = fn.spd_mean_kracher_flow(M, dim=0, return_dist=False)
    prev_var = fn.spd_mean_kracher_flow(s[...,None], dim=0, return_dist=False).squeeze(-1)

    # because we use precomputed means advance the observation number
    # so that we start with slower adaptation
    domain_bn.adapt_observation = 5 if cfg['evaluation'] == 'inter-subject' else 15
    domain_bn.running_mean_test.data = prev_mean.data.clone()
    domain_bn.running_var_test = prev_var.clone()


def sfuda_online_init(target_domains : torch.LongTensor, model : TSMNet, cfg, strategy : str = 'other'):

    assert target_domains.ndim == 1
    assert isinstance(model.spdbnorm, bn.BaseDomainBatchNorm)
    for target_domain in target_domains:
        domain_key = model.spdbnorm.domain_to_key(target_domain)
        # add domain if not yet in the model
        if domain_key not in model.spdbnorm.batchnorm:
            bncls = model.spdbnorm.domain_bn_cls
            domain_bn = bncls(
                shape=model.spdbnorm.mean.shape,
                batchdim=model.spdbnorm.batchdim, 
                learn_mean=model.spdbnorm.learn_mean,
                learn_std=model.spdbnorm.learn_std,
                dispersion=model.spdbnorm.dispersion,
                mean=model.spdbnorm.mean,
                std=model.spdbnorm.std,
                eta=model.spdbnorm.eta,
                eta_test=model.spdbnorm.eta_test
            )
        else:
            domain_bn = model.spdbnorm.batchnorm[domain_key]

        if strategy == 'other':
            _sfuda_online_init_other(domain_bn, model, cfg)
        else:
            raise NotImplementedError()

        # change the BN mode to perform online adaptation (for each batch)
        domain_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.ADAPT)
    
    for target_domain in target_domains:    
        if domain_key not in model.spdbnorm.batchnorm:
            model.spdbnorm.batchnorm[domain_key] = domain_bn

def sfuda_online_step(inputs : torch.Tensor, domains : torch.LongTensor, model : TSMNet, cfg):

    model = model.to(device='cpu')
    # the model needs to be in evaluation mode so that the batch norm statistics for testing will be adapted
    model.eval()

    activations = model(inputs.to(dtype=cfg['dtype']), domains)
    # return class probabilities
    return activations


def sfuda_online_simulate(dataset : DomainDataset, model : TSMNet, cfg, loss_fn):

    # do adaptation if the model is configured to do domain adaptation
    if not model.domain_adaptation_:
        test_loss, score  = None, None
    else:
        
        # extract the target domains from the dataset
        target_domains = dataset.domains.unique()

        sfuda_online_init(target_domains, model, cfg)
        
        y_true = []
        y_hat = []

        test_loss = 0.
        # feed each observation through the network
        # to adapt and infer the target
        for i, (features, y) in enumerate(dataset):
            
            inputs = features['inputs'][None,...]
            domains = features['domains'][None,...]
            y = y

            pred = sfuda_online_step(inputs, domains, model, cfg)
            test_loss += loss_fn(pred, y[None,...]).item()
            y_true.append(y[None,...])
            y_hat.append(pred.argmax(1)[None,...])
            # p_class = torch.nn.functional.softmax(pred, dim=-1)

        # compute the overall score 
        score = balanced_accuracy_score(torch.cat(y_true).detach().cpu().numpy(), torch.cat(y_hat).detach().cpu().numpy())

        test_loss /= len(dataset)

        # stop adaptation for the target domains
        for target_domain in target_domains:
            trgt_bn = model.spdbnorm.batchnorm[str(target_domain.item())]
            trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)  

    return test_loss, score 

In [6]:
def sfuda_online(dataset : DomainDataset, model : TSMNet, cfg, loss_fn):

    model = model.to(device='cpu')
    # the model needs to be in evaluation mode so that the batch norm statistics for testing will be adapted
    model.eval()

    # do adaptation if the model is configured to do domain adaptation
    if not model.domain_adaptation_:
        test_loss, score  = None, None
    else:
        
        # extract the target domains from the dataset
        target_domains = dataset.domains.unique()
        # initialize new mean with the grand average of the other domains
        M = []
        s = []
        for key, val in model.spdbnorm.batchnorm.items():
            if key in [f'dom {t}' for t in target_domains]:
                continue
            M += [val.running_mean_test]
            s += [val.running_var_test]
        M = torch.cat(M, dim=0)
        s = torch.cat(s, dim=0)
        prev_mean = fn.spd_mean_kracher_flow(M, dim=0, return_dist=False)
        prev_var = fn.spd_mean_kracher_flow(s[...,None], dim=0, return_dist=False).squeeze(-1)
        
        # assign the grand average mean and var to the target domains
        for target_domain in target_domains:

            trgt_bn = model.spdbnorm.batchnorm[str(target_domain.item())]
            # because we use precomputed means advance the observation number
            # so that we start with slower adaptation
            trgt_bn.adapt_observation = 5 if cfg['evaluation'] == 'inter-subject' else 15
            trgt_bn.running_mean_test.data = prev_mean.data.clone()
            trgt_bn.running_var_test = prev_var.clone()

            # change the BN mode to perform online adaptation (for each batch)
            trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.ADAPT)
        
        y_true = []
        y_hat = []

        test_loss = 0.
        # feed each observation through the network
        # to adapt and infer the target
        for i, (features, y) in enumerate(dataset):
            
            inputs = features['inputs'][None,...].to(dtype=cfg['dtype'])
            domains = features['domains'][None,...]
            y = y
            pred = model.forward(inputs=inputs, domains=domains)
            test_loss += loss_fn(pred, y[None,...]).item()
            y_true.append(y[None,...])
            y_hat.append(pred.argmax(1)[None,...])

        # compute the overall score 
        score = balanced_accuracy_score(torch.cat(y_true).detach().cpu().numpy(), torch.cat(y_hat).detach().cpu().numpy())

        test_loss /= len(dataset)

        # stop adaptation for the target domains
        for target_domain in target_domains:
            trgt_bn = model.spdbnorm.batchnorm[str(target_domain.item())]
            trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)  

    return test_loss, score 

## fit and evaluat the model for all domains

In [27]:
X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=[1,2], return_epochs=False)

C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\


KeyboardInterrupt: 

In [28]:
%matplotlib Qt

In [8]:
subjects = [1,2,3,4,5,6,7,8,9,10,11,12]
n_channels = 32
on_frame = True
if on_frame:
    freq = 60
else:
    freq = 500

raw_data,labels,codes,labels_codes = get_BVEP_data(subjects,on_frame)

Choosing the first None classes from all possible events.


None
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad ep

  raw = mne.io.read_raw_eeglab(file_path_list[0], preload=True, verbose=False)


C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs 

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...
0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 p

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Using data from preloaded Raw for 7020 events and 131 original time points ...
0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...
0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 p

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...
0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 p

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped


In [9]:
records = []

if 'inter-session' in cfg['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = TSMNet(**mdl_kwargs).to(device=device, dtype=cfg['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg['epochs']-10,
            bs0=cfg['batch_size_train'],
            bs=cfg['batch_size_train']/cfg['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = Trainer(
            max_epochs=cfg['epochs'],
            min_epochs=cfg['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred = trainer.pred(net,dataloader=loader_test)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))


        # extract model parameters
        state_dict = deepcopy(net.state_dict())

        # create new model and perform offline SF UDA
        print("Perform UDA offline")
        sfuda_offline_net = TSMNet(**mdl_kwargs).to(device=device)
        sfuda_offline_net.load_state_dict(state_dict)
        sfuda_offline(ds_test, sfuda_offline_net)
        res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # create a new model and perform online SF UDA
        sfuda_online_net = TSMNet(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online(ds_test, sfuda_online_net, cfg, trainer.loss_fn)
        records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # create a new model and perform online SF UDA
        sfuda_online_net = TSMNet(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg, trainer.loss_fn)
        records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))


C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs 

  raw = mne.io.read_raw_eeglab(file_path_list[0], preload=True, verbose=False)


C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs 

  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.43
Perform UDA offline
test domain=['2/0']
epoch=  0 gd-step=  327 trn_loss= 0.6896 trn_score=0.5410 val_loss= 0.6920 val_score=0.5152 
epoch= 10 gd-step= 3569 trn_loss= 0.6490 trn_score=0.6235 val_loss= 0.6836 val_score=0.5884 
epoch= 19 gd-step= 6489 trn_loss= 0.6420 trn_score=0.6244 val_loss= 0.6957 val_score=0.5661 
ES best epoch=5
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.28
Perform UDA offline
test domain=['3/0']
epoch=  0 gd-step=  327 trn_loss= 0.6861 trn_score=0.5608 val_loss= 0.6889 val_score=0.5439 
epoch= 10 gd-step= 3578 trn_loss= 0.6388 trn_score=0.6326 val_loss= 0.6712 val_score=0.5873 


In [48]:
# report the results
resdf = pd.DataFrame(records)

resdf.groupby(['mode']).agg(['mean', 'std']).round(2)

  resdf.groupby(['mode']).agg(['mean', 'std']).round(2)


Unnamed: 0_level_0,loss,loss,score,score
Unnamed: 0_level_1,mean,std,mean,std
mode,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
test(SFUDA),0.71,0.02,0.54,0.03
test(noUDA),5.59,3.64,0.51,0.01
test(noUDA)_code,,,0.29,0.07
test(onlineSFUDA),0.7,0.03,0.54,0.03
test(online_sim_SFUDA),0.7,0.03,0.54,0.03
train,0.65,0.03,0.62,0.03
validation,0.68,0.0,0.57,0.02


In [50]:
resdf.to_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/TSMNet_code.csv")
# resdf

In [20]:
df = pd.read_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/TSMNet_code.csv")
print()
df_score_code = np.array(df["score"])[df["mode"]=="test(noUDA)_code"]
df_score = np.array(df["score"])[df["mode"]=="test(noUDA)"]

np.save("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/TSMNet_score_code",df_score_code)
np.save("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/TSMNet_score",df_score)




In [12]:
fit_df = pd.concat(fit_records)
fit_df = fit_df.groupby(['subset', 'fold', 'epoch']).mean()
fit_df.columns = fit_df.columns.str.split('_', expand=True)
fit_df.columns.names = ['set', 'metric']
fit_df = fit_df.stack('set').reset_index()


sns.relplot(data=fit_df, x='epoch', y='loss', hue='set', col='subset', col_wrap=5, kind='line', n_boot=100)

<seaborn.axisgrid.FacetGrid at 0x16dbc746e90>