In [4]:
%load_ext autoreload
%autoreload 2

# Minimal TSMNet demo notebook for inter-session/-subject source-free (SF) offline and online unsupervised domain adaptation (UDA)

In [1]:
import sys
sys.path.insert(0,"C:\\Users\\s.velut\\Documents\\These\\Protheus_PHD\\Scripts")
sys.path.insert(0,"C:\\Users\\s.velut\\Documents\\These\\moabb\\moabb\\datasets")
sys.path.insert(0,"C:\\Users\\s.velut\\Documents\\These\\Protheus_PHD")

import torch
import sklearn
from sklearn.metrics import balanced_accuracy_score
import seaborn as sns

import numpy as np
from imblearn.under_sampling import RandomUnderSampler

import pandas as pd
import numpy as np
from copy import deepcopy

from moabb.datasets.bnci import BNCI2015001, BNCI2014001
from moabb.paradigms import MotorImagery, CVEP

# from library.utils.torch import StratifiedDomainDataLoader
from spdnets.utils.data import StratifiedDomainDataLoader, DomainDataset
from spdnets.models import TSMNet, SPDSMNet, SPDSMNet2, DSBNSPDBNNet, SPDSMNet_visu
import spdnets.batchnorm as bn
import spdnets.functionals as fn

from spdnets.trainer import Trainer
from spdnets.callbacks import MomentumBatchNormScheduler, EarlyStopping


from pyriemann.estimation import XdawnCovariances
import matplotlib.pyplot as plt
from Scripts.utils import prepare_data,get_BVEP_data,balance,get_y_pred
from _utils import make_preds_accumul_aggresive, make_preds_pvalue
from castillos2023 import CasitllosCVEP100,CasitllosCVEP40,CasitllosBurstVEP100,CasitllosBurstVEP40






c:\Users\s.velut\AppData\Local\Programs\Python\Python311\Lib\site-packages\moabb\pipelines\__init__.py:26: ModuleNotFoundError: Tensorflow is not installed. You won't be able to use these MOABB pipelines if you attempt to do so.
  warn(


In [2]:
# network and training configuration
cfg_org = dict(
    epochs = 20,
    batch_size_train = 64,
    domains_per_batch = 4,
    validation_size = 0.2,
    evaluation = 'inter-subject', # or 'inter-subject'
    dtype = torch.float32,
    spd_device='cpu',
    # parameters for the TSMNet model
    mdl_kwargs = dict(
        temporal_filters=2,
        nclasses = 2,
        spatial_filters=32,
        subspacedims=32, 
        bnorm_dispersion=bn.BatchNormDispersion.SCALAR,
        spd_device='cpu',
        spd_dtype=torch.double,
        domain_adaptation=True
    )
)

cfg_spd = dict(
    epochs = 20,
    batch_size_train = 64,
    domains_per_batch = 4,
    validation_size = 0.2,
    evaluation = 'inter-subject', # or 'inter-subject'
    dtype = torch.double,
    spd_device='cuda',
    # parameters for the TSMNet model
    mdl_kwargs = dict(
        temporal_filters=2,
        nclasses = 2,
        spatial_filters=32,
        bimap_dims = [32,28,14,7],
        subspacedims=32, 
        bnorm_dispersion=bn.BatchNormDispersion.SCALAR,
        spd_device='cuda',
        spd_dtype=torch.double,
        domain_adaptation=True
    )
)

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

## load a MOABB dataset

In [3]:
# moabb_ds = BNCI2015001()
# n_classes = 2
# moabb_paradigm = MotorImagery(n_classes=n_classes, events=['right_hand', 'feet'], fmin=4, fmax=36, tmin=0.0, tmax=4.0, resample=256)

moabb_ds = CasitllosBurstVEP100()
n_classes = 2
moabb_paradigm = CVEP()

# moabb_ds = BNCI2014001()
# n_classes = 4
# moabb_paradigm = MotorImagery(n_classes=n_classes, events=["left_hand", "right_hand", "feet", "tongue"], fmin=4, fmax=36, tmin=0.5, tmax=3.496, resample=None)

Choosing the first None classes from all possible events.


## SF offline UDA

In [4]:
def sfuda_offline(dataset : DomainDataset, model : TSMNet, cfg):
    model.eval()
    model.domainadapt_finetune(dataset.features.to(dtype=cfg['dtype'], device=device), dataset.labels.to(device=device), dataset.domains, None)

## SF online UDA

In [5]:
def _sfuda_online_init_other_spd(domain_bn : torch.nn.Module, model : TSMNet, cfg):

    # initialize new mean with the grand average of the other domains
    M = []
    s = []
    for _, val in model.spdbnorm_layers[0].batchnorm.items():
        M += [val.running_mean_test]
        s += [val.running_var_test]
    M = torch.cat(M, dim=0)
    s = torch.cat(s, dim=0)
    prev_mean = fn.spd_mean_kracher_flow(M, dim=0, return_dist=False)
    prev_var = fn.spd_mean_kracher_flow(s[...,None], dim=0, return_dist=False).squeeze(-1)

    # because we use precomputed means advance the observation number
    # so that we start with slower adaptation
    domain_bn.adapt_observation = 5 if cfg['evaluation'] == 'inter-subject' else 15
    domain_bn.running_mean_test.data = prev_mean.data.clone()
    domain_bn.running_var_test = prev_var.clone()


def sfuda_online_init_spd(target_domains : torch.LongTensor, model : TSMNet, cfg, strategy : str = 'other'):

    assert target_domains.ndim == 1
    # assert isinstance(model.spdbnorm, bn.BaseDomainBatchNorm)
    for target_domain in target_domains:
        for i in range(len(model.bimap_dims[1:])):
            domain_key = model.spdbnorm_layers[i].domain_to_key(target_domain)
            # add domain if not yet in the model
            if domain_key not in model.spdbnorm_layers[i].batchnorm:
                bncls = model.spdbnorm_layers[i].domain_bn_cls
                domain_bn = bncls(
                    shape=model.spdbnorm_layers[i].mean.shape,
                    batchdim=model.spdbnorm_layers[i].batchdim, 
                    learn_mean=model.spdbnorm_layers[i].learn_mean,
                    learn_std=model.spdbnorm_layers[i].learn_std,
                    dispersion=model.spdbnorm_layers[i].dispersion,
                    mean=model.spdbnorm_layers[i].mean,
                    std=model.spdbnorm_layers[i].std,
                    eta=model.spdbnorm_layers[i].eta,
                    eta_test=model.spdbnorm_layers[i].eta_test
                )
            else:
                domain_bn = model.spdbnorm_layers[i].batchnorm[domain_key]

        if strategy == 'other':
            _sfuda_online_init_other_spd(domain_bn, model, cfg)
        else:
            raise NotImplementedError()

        # change the BN mode to perform online adaptation (for each batch)
        domain_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.ADAPT)
    
    for target_domain in target_domains:    
        for i in range(len(model.bimap_dims[1:])):
            if domain_key not in model.spdbnorm_layers[i].batchnorm:
                model.spdbnorm_layers[i].batchnorm[domain_key] = domain_bn

def sfuda_online_step(inputs : torch.Tensor, domains : torch.LongTensor, model : TSMNet, cfg):

    model = model.to(device='cpu')
    # the model needs to be in evaluation mode so that the batch norm statistics for testing will be adapted
    model.eval()

    activations = model(inputs.to(dtype=cfg['dtype'],device=cfg['spd_device']), domains)
    # return class probabilities
    return activations


def sfuda_online_simulate_spd(dataset : DomainDataset, model : TSMNet, cfg, loss_fn):

    # do adaptation if the model is configured to do domain adaptation
    if not model.domain_adaptation_:
        test_loss, score  = None, None
    else:
        
        # extract the target domains from the dataset
        target_domains = dataset.domains.unique()

        sfuda_online_init_spd(target_domains, model, cfg)
        
        y_true = []
        y_hat = []

        test_loss = 0.
        # feed each observation through the network
        # to adapt and infer the target
        for i, (features, y) in enumerate(dataset):
            
            inputs = features['inputs'][None,...]
            domains = features['domains'][None,...]
            y = y

            pred = sfuda_online_step(inputs, domains, model, cfg)
            test_loss += loss_fn(pred, y[None,...]).item()
            y_true.append(y[None,...])
            y_hat.append(pred.argmax(1)[None,...])
            # p_class = torch.nn.functional.softmax(pred, dim=-1)

        # compute the overall score 
        score = balanced_accuracy_score(torch.cat(y_true).detach().cpu().numpy(), torch.cat(y_hat).detach().cpu().numpy())

        test_loss /= len(dataset)

        # stop adaptation for the target domains
        for target_domain in target_domains:
            for i in range(len(model.bimap_dims[1:])):
                trgt_bn = model.spdbnorm_layers[i].batchnorm[str(target_domain.item())]
                trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)  

    return test_loss, score 

In [6]:
def sfuda_online(dataset : DomainDataset, model : TSMNet, cfg, loss_fn):

    model = model.to(device='cpu')
    # the model needs to be in evaluation mode so that the batch norm statistics for testing will be adapted
    model.eval()

    # do adaptation if the model is configured to do domain adaptation
    if not model.domain_adaptation_:
        test_loss, score  = None, None
    else:
        
        # extract the target domains from the dataset
        target_domains = dataset.domains.unique()
        # initialize new mean with the grand average of the other domains
        M = []
        s = []
        for key, val in model.spdbnorm.batchnorm.items():
            if key in [f'dom {t}' for t in target_domains]:
                continue
            M += [val.running_mean_test]
            s += [val.running_var_test]
        M = torch.cat(M, dim=0)
        s = torch.cat(s, dim=0)
        prev_mean = fn.spd_mean_kracher_flow(M, dim=0, return_dist=False)
        prev_var = fn.spd_mean_kracher_flow(s[...,None], dim=0, return_dist=False).squeeze(-1)
        
        # assign the grand average mean and var to the target domains
        for target_domain in target_domains:

            trgt_bn = model.spdbnorm.batchnorm[str(target_domain.item())]
            # because we use precomputed means advance the observation number
            # so that we start with slower adaptation
            trgt_bn.adapt_observation = 5 if cfg['evaluation'] == 'inter-subject' else 15
            trgt_bn.running_mean_test.data = prev_mean.data.clone()
            trgt_bn.running_var_test = prev_var.clone()

            # change the BN mode to perform online adaptation (for each batch)
            trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.ADAPT)
        
        y_true = []
        y_hat = []

        test_loss = 0.
        # feed each observation through the network
        # to adapt and infer the target
        for i, (features, y) in enumerate(dataset):
            
            inputs = features['inputs'][None,...].to(dtype=cfg['dtype'])
            domains = features['domains'][None,...]
            y = y
            pred = model.forward(inputs=inputs, domains=domains)
            test_loss += loss_fn(pred, y[None,...]).item()
            y_true.append(y[None,...])
            y_hat.append(pred.argmax(1)[None,...])

        # compute the overall score 
        score = balanced_accuracy_score(torch.cat(y_true).detach().cpu().numpy(), torch.cat(y_hat).detach().cpu().numpy())

        test_loss /= len(dataset)

        # stop adaptation for the target domains
        for target_domain in target_domains:
            trgt_bn = model.spdbnorm.batchnorm[str(target_domain.item())]
            trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)  

    return test_loss, score 

def sfuda_online_spd(dataset : DomainDataset, model : TSMNet, cfg, loss_fn):

    model = model.to(device='cpu')
    # the model needs to be in evaluation mode so that the batch norm statistics for testing will be adapted
    model.eval()

    # do adaptation if the model is configured to do domain adaptation
    if not model.domain_adaptation_:
        test_loss, score  = None, None
    else:
        
        # extract the target domains from the dataset
        target_domains = dataset.domains.unique()
        # initialize new mean with the grand average of the other domains
        M = []
        s = []
        for key, val in model.spdbnorm_layers[0].batchnorm.items():
            if key in [f'dom {t}' for t in target_domains]:
                continue
            M += [val.running_mean_test]
            s += [val.running_var_test]
        M = torch.cat(M, dim=0)
        s = torch.cat(s, dim=0)
        prev_mean = fn.spd_mean_kracher_flow(M, dim=0, return_dist=False)
        prev_var = fn.spd_mean_kracher_flow(s[...,None], dim=0, return_dist=False).squeeze(-1)
        
        # assign the grand average mean and var to the target domains
        for target_domain in target_domains:
            trgt_bn = model.spdbnorm_layers[0].batchnorm[str(target_domain.item())]
            # because we use precomputed means advance the observation number
            # so that we start with slower adaptation
            trgt_bn.adapt_observation = 5 if cfg['evaluation'] == 'inter-subject' else 15
            trgt_bn.running_mean_test.data = prev_mean.data.clone()
            trgt_bn.running_var_test = prev_var.clone()

            # change the BN mode to perform online adaptation (for each batch)
            trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.ADAPT)
        
        y_true = []
        y_hat = []

        test_loss = 0.
        # feed each observation through the network
        # to adapt and infer the target
        for i, (features, y) in enumerate(dataset):
            
            inputs = features['inputs'][None,...].to(dtype=cfg['dtype'],device=cfg['spd_device'])
            domains = features['domains'][None,...]
            y = y
            pred = model.forward(inputs=inputs, domains=domains)
            test_loss += loss_fn(pred, y[None,...]).item()
            y_true.append(y[None,...])
            y_hat.append(pred.argmax(1)[None,...])

        # compute the overall score 
        score = balanced_accuracy_score(torch.cat(y_true).detach().cpu().numpy(), torch.cat(y_hat).detach().cpu().numpy())

        test_loss /= len(dataset)

        # stop adaptation for the target domains
        for target_domain in target_domains:
            trgt_bn = model.spdbnorm_layers[0].batchnorm[str(target_domain.item())]
            trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)  

    return test_loss, score 

In [7]:
def _sfuda_online_init_other(domain_bn : torch.nn.Module, model : TSMNet, cfg):

    # initialize new mean with the grand average of the other domains
    M = []
    s = []
    for _, val in model.spdbnorm.batchnorm.items():
        M += [val.running_mean_test]
        s += [val.running_var_test]
    M = torch.cat(M, dim=0)
    s = torch.cat(s, dim=0)
    prev_mean = fn.spd_mean_kracher_flow(M, dim=0, return_dist=False)
    prev_var = fn.spd_mean_kracher_flow(s[...,None], dim=0, return_dist=False).squeeze(-1)

    # because we use precomputed means advance the observation number
    # so that we start with slower adaptation
    domain_bn.adapt_observation = 5 if cfg['evaluation'] == 'inter-subject' else 15
    domain_bn.running_mean_test.data = prev_mean.data.clone()
    domain_bn.running_var_test = prev_var.clone()


def sfuda_online_init(target_domains : torch.LongTensor, model : TSMNet, cfg, strategy : str = 'other'):

    assert target_domains.ndim == 1
    assert isinstance(model.spdbnorm, bn.BaseDomainBatchNorm)
    for target_domain in target_domains:
        domain_key = model.spdbnorm.domain_to_key(target_domain)
        # add domain if not yet in the model
        if domain_key not in model.spdbnorm.batchnorm:
            bncls = model.spdbnorm.domain_bn_cls
            domain_bn = bncls(
                shape=model.spdbnorm.mean.shape,
                batchdim=model.spdbnorm.batchdim, 
                learn_mean=model.spdbnorm.learn_mean,
                learn_std=model.spdbnorm.learn_std,
                dispersion=model.spdbnorm.dispersion,
                mean=model.spdbnorm.mean,
                std=model.spdbnorm.std,
                eta=model.spdbnorm.eta,
                eta_test=model.spdbnorm.eta_test
            )
        else:
            domain_bn = model.spdbnorm.batchnorm[domain_key]

        if strategy == 'other':
            _sfuda_online_init_other(domain_bn, model, cfg)
        else:
            raise NotImplementedError()

        # change the BN mode to perform online adaptation (for each batch)
        domain_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.ADAPT)
    
    for target_domain in target_domains:    
        if domain_key not in model.spdbnorm.batchnorm:
            model.spdbnorm.batchnorm[domain_key] = domain_bn

def sfuda_online_step(inputs : torch.Tensor, domains : torch.LongTensor, model : TSMNet, cfg):

    model = model.to(device='cpu')
    # the model needs to be in evaluation mode so that the batch norm statistics for testing will be adapted
    model.eval()

    activations = model(inputs.to(dtype=cfg['dtype']), domains)
    # return class probabilities
    return activations


def sfuda_online_simulate(dataset : DomainDataset, model : TSMNet, cfg, loss_fn):

    # do adaptation if the model is configured to do domain adaptation
    if not model.domain_adaptation_:
        test_loss, score  = None, None
    else:
        
        # extract the target domains from the dataset
        target_domains = dataset.domains.unique()

        sfuda_online_init(target_domains, model, cfg)
        
        y_true = []
        y_hat = []

        test_loss = 0.
        # feed each observation through the network
        # to adapt and infer the target
        for i, (features, y) in enumerate(dataset):
            
            inputs = features['inputs'][None,...]
            domains = features['domains'][None,...]
            y = y

            pred = sfuda_online_step(inputs, domains, model, cfg)
            test_loss += loss_fn(pred, y[None,...]).item()
            y_true.append(y[None,...])
            y_hat.append(pred.argmax(1)[None,...])
            # p_class = torch.nn.functional.softmax(pred, dim=-1)

        # compute the overall score 
        score = balanced_accuracy_score(torch.cat(y_true).detach().cpu().numpy(), torch.cat(y_hat).detach().cpu().numpy())

        test_loss /= len(dataset)

        # stop adaptation for the target domains
        for target_domain in target_domains:
            trgt_bn = model.spdbnorm.batchnorm[str(target_domain.item())]
            trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)  

    return test_loss, score 

## fit and evaluat the model for all domains

In [8]:
%matplotlib Qt

In [10]:
# subjects = [1,2,3]
subjects = [1,2,3,4,5,6,7,8,9,10,11,12]
n_channels = 32
on_frame = True
if on_frame:
    freq = 60
else:
    freq = 500

raw_data,labels,codes,labels_codes = get_BVEP_data(subjects,on_frame)
X_parent, Y_parent, domains_parent = prepare_data(subjects, raw_data, labels, on_frame,True,codes)
metadata_parent = pd.DataFrame({"subject":np.repeat(list(map(str,subjects)),X_parent[0].shape[0]),"session":["0"]*len(subjects)*X_parent[0].shape[0]})

Choosing the first None classes from all possible events.


None
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad ep

  raw = mne.io.read_raw_eeglab(file_path_list[0], preload=True, verbose=False)


C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs 

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped


## original TSMNET

In [57]:
records = []

if 'inter-session' in cfg_org['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg_org['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_org['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_org['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_org['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_org['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_org['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = TSMNet(**mdl_kwargs).to(device=device, dtype=cfg_org['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_org['epochs']-10,
            bs0=cfg_org['batch_size_train'],
            bs=cfg_org['batch_size_train']/cfg_org['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = Trainer(
            max_epochs=cfg_org['epochs'],
            min_epochs=cfg_org['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_org['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred = trainer.pred(net,dataloader=loader_test)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))


        # extract model parameters
        state_dict = deepcopy(net.state_dict())

        # create new model and perform offline SF UDA
        print("Perform UDA offline")
        sfuda_offline_net = TSMNet(**mdl_kwargs).to(device=device)
        sfuda_offline_net.load_state_dict(state_dict)
        sfuda_offline(ds_test, sfuda_offline_net, cfg_org)
        res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # create a new model and perform online SF UDA
        sfuda_online_net = TSMNet(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online(ds_test, sfuda_online_net, cfg_org, trainer.loss_fn)
        records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # create a new model and perform online SF UDA
        sfuda_online_net = TSMNet(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg_org, trainer.loss_fn)
        records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))

resdf = pd.DataFrame(records)    
resdf.to_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/TSMNet_code.csv")



C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs 

  raw = mne.io.read_raw_eeglab(file_path_list[0], preload=True, verbose=False)


C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs 

  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.25
Perform UDA offline
test domain=['2/0']
epoch=  0 gd-step=  327 trn_loss= 0.6905 trn_score=0.5393 val_loss= 0.6925 val_score=0.5216 
epoch= 10 gd-step= 3569 trn_loss= 0.6463 trn_score=0.6260 val_loss= 0.6880 val_score=0.5642 
epoch= 19 gd-step= 6496 trn_loss= 0.6474 trn_score=0.6249 val_loss= 0.7108 val_score=0.5589 
ES best epoch=10
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.25
Perform UDA offline
test domain=['3/0']
epoch=  0 gd-step=  323 trn_loss= 0.6883 trn_score=0.5603 val_loss= 0.6902 val_score=0.5275 
epoch= 10 gd-step= 3559 trn_loss= 0.6540 trn_score=0.6103 val_loss= 0.6934 val_score=0.5530 


## Modified TSMNET

In [22]:
records = []

if 'inter-session' in cfg_spd['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg_spd['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    # X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=[1,2,3], return_epochs=False)
    # X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)
    X = np.concatenate(X_parent)
    labels = np.concatenate(Y_parent)
    domains = np.concatenate(domains_parent)
    metadata = metadata_parent.copy()

    X_std = X.std(axis=0)
    X /= X_std + 1e-8

    # xdawncov = XdawnCovariances(estimator="lwf",xdawn_estimator="lwf",nfilter=8)
    # X = xdawncov.fit_transform(X,labels)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_spd['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_spd['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_spd['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_spd['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_spd['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = SPDSMNet(**mdl_kwargs).to(device=device, dtype=cfg_spd['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_spd['epochs']-10,
            bs0=cfg_spd['batch_size_train'],
            bs=cfg_spd['batch_size_train']/cfg_spd['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = Trainer(
            max_epochs=cfg_spd['epochs'],
            min_epochs=cfg_spd['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_spd['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred = trainer.pred(net,dataloader=loader_test)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))


        # extract model parameters
        state_dict = deepcopy(net.state_dict())

        # create new model and perform offline SF UDA
        print("Perform UDA offline")
        sfuda_offline_net = SPDSMNet(**mdl_kwargs).to(device=device)
        sfuda_offline_net.load_state_dict(state_dict)
        sfuda_offline(ds_test, sfuda_offline_net, cfg_spd)
        res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # create a new model and perform online SF UDA
        # sfuda_online_net = SPDSMNet(**mdl_kwargs)
        # sfuda_online_net.load_state_dict(state_dict)
        # loss, score = sfuda_online(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        # records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # # # create a new model and perform online SF UDA
        # sfuda_online_net = SPDSMNet(**mdl_kwargs)
        # sfuda_online_net.load_state_dict(state_dict)
        # loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        # records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))

resdf = pd.DataFrame(records)    
resdf.to_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet_test/SPDSMNet_code.csv")

test domain=['1/0']
epoch=  0 gd-step=  326 trn_loss= 0.6625 trn_score=0.6032 val_loss= 0.6650 val_score=0.5987 
epoch= 10 gd-step= 3564 trn_loss= 0.5822 trn_score=0.6939 val_loss= 0.5883 val_score=0.6900 
epoch= 19 gd-step= 6485 trn_loss= 0.5801 trn_score=0.6950 val_loss= 0.5842 val_score=0.6934 
ES best epoch=19
evaluate the estimator
 accuracy score of the participant 0.55
Perform UDA offline
test domain=['10/0']
epoch=  0 gd-step=  325 trn_loss= 0.6639 trn_score=0.6122 val_loss= 0.6670 val_score=0.6036 
epoch= 10 gd-step= 3559 trn_loss= 0.6235 trn_score=0.6567 val_loss= 0.6295 val_score=0.6468 
epoch= 19 gd-step= 6463 trn_loss= 0.6223 trn_score=0.6554 val_loss= 0.6292 val_score=0.6491 
ES best epoch=18
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.32
Perform UDA offline
test domain=['11/0']
epoch=  0 gd-step=  323 trn_loss= 0.6764 trn_score=0.5828 val_loss= 0.6770 val_score=0.5777 
epoch= 10 gd-step= 3559 trn_loss= 0.6032 trn_score=0.6771 val_loss= 0.6020 val_score=0.6756 
epoch= 19 gd-step= 6466 trn_loss= 0.6016 trn_score=0.6773 val_loss= 0.6011 val_score=0.6756 
ES best epoch=15
evaluate the estimator
 accuracy score of the participant 0.37
Perform UDA offline
test domain=['12/0']
epoch=  0 gd-step=  324 trn_loss= 0.6678 trn_score=0.5927 val_loss= 0.6686 val_score=0.5979 
epoch= 10 gd-step= 3570 trn_loss= 0.6265 trn_score=0.6455 val_loss= 0.6304 val_score=0.6403 
epoch= 19 gd-step= 6480 trn_loss= 0.6255 trn_score=0.6468 val_loss= 0.6298 val_score=0.6453 
ES best epoch=19
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.22
Perform UDA offline
test domain=['2/0']
epoch=  0 gd-step=  323 trn_loss= 0.6675 trn_score=0.5947 val_loss= 0.6700 val_score=0.5907 
epoch= 10 gd-step= 3563 trn_loss= 0.6440 trn_score=0.6260 val_loss= 0.6513 val_score=0.6180 
epoch= 19 gd-step= 6486 trn_loss= 0.6425 trn_score=0.6256 val_loss= 0.6509 val_score=0.6176 
ES best epoch=17
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.3
Perform UDA offline
test domain=['3/0']
epoch=  0 gd-step=  319 trn_loss= 0.6874 trn_score=0.5446 val_loss= 0.6866 val_score=0.5466 
epoch= 10 gd-step= 3559 trn_loss= 0.6268 trn_score=0.6439 val_loss= 0.6316 val_score=0.6411 
epoch= 19 gd-step= 6478 trn_loss= 0.6262 trn_score=0.6403 val_loss= 0.6296 val_score=0.6411 
ES best epoch=18
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.22
Perform UDA offline
test domain=['4/0']
epoch=  0 gd-step=  324 trn_loss= 0.6445 trn_score=0.6443 val_loss= 0.6466 val_score=0.6403 
epoch= 10 gd-step= 3560 trn_loss= 0.5941 trn_score=0.6811 val_loss= 0.5914 val_score=0.6875 
epoch= 19 gd-step= 6474 trn_loss= 0.5931 trn_score=0.6814 val_loss= 0.5903 val_score=0.6892 
ES best epoch=19
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.32
Perform UDA offline
test domain=['5/0']
epoch=  0 gd-step=  327 trn_loss= 0.6684 trn_score=0.5968 val_loss= 0.6723 val_score=0.5845 
epoch= 10 gd-step= 3558 trn_loss= 0.5925 trn_score=0.6861 val_loss= 0.5910 val_score=0.6886 
epoch= 19 gd-step= 6468 trn_loss= 0.5890 trn_score=0.6864 val_loss= 0.5851 val_score=0.6926 
ES best epoch=19
evaluate the estimator
 accuracy score of the participant 0.25
Perform UDA offline
test domain=['6/0']
epoch=  0 gd-step=  320 trn_loss= 0.6920 trn_score=0.5264 val_loss= 0.6906 val_score=0.5312 
epoch= 10 gd-step= 3549 trn_loss= 0.6557 trn_score=0.6102 val_loss= 0.6571 val_score=0.6089 
epoch= 19 gd-step= 6464 trn_loss= 0.6540 trn_score=0.6117 val_loss= 0.6560 val_score=0.6087 
ES best epoch=19
evaluate the estimator
 accuracy score of the participant 0.33
Perform UDA offline
test domain=['7/0']
epoch=  0 gd-step=  327 trn_loss= 0.6512 trn_score=0.6292 val_loss= 0.6499 val_score=0.6398 
epoch= 10 gd-step= 3583 trn_l

  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.23
Perform UDA offline
test domain=['8/0']
epoch=  0 gd-step=  325 trn_loss= 0.6648 trn_score=0.5995 val_loss= 0.6683 val_score=0.5892 
epoch= 10 gd-step= 3564 trn_loss= 0.6186 trn_score=0.6518 val_loss= 0.6302 val_score=0.6339 
epoch= 19 gd-step= 6482 trn_loss= 0.6202 trn_score=0.6520 val_loss= 0.6300 val_score=0.6339 
ES best epoch=13
evaluate the estimator
 accuracy score of the participant 0.23
Perform UDA offline
test domain=['9/0']
epoch=  0 gd-step=  326 trn_loss= 0.6475 trn_score=0.6288 val_loss= 0.6440 val_score=0.6326 
epoch= 10 gd-step= 3575 trn_loss= 0.6214 trn_score=0.6528 val_loss= 0.6157 val_score=0.6610 
epoch= 19 gd-step= 6498 trn_loss= 0.6226 trn_score=0.6549 val_loss= 0.6155 val_score=0.6572 
ES best epoch=13
evaluate the estimator
 accuracy score of the participant 0.23
Perform UDA offline


## Modified TSMNET2

In [23]:
records = []

if 'inter-session' in cfg_spd['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg_spd['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    # X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=[1,2,3], return_epochs=False)
    # X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)
    X = np.concatenate(X_parent)
    labels = np.concatenate(Y_parent)
    domains = np.concatenate(domains_parent)
    metadata = metadata_parent.copy()

    X_std = X.std(axis=0)
    X /= X_std + 1e-8

    # xdawncov = XdawnCovariances(estimator="lwf",xdawn_estimator="lwf",nfilter=8)
    # X = xdawncov.fit_transform(X,labels)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_spd['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_spd['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_spd['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_spd['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_spd['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = SPDSMNet2(**mdl_kwargs).to(device=device, dtype=cfg_spd['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_spd['epochs']-10,
            bs0=cfg_spd['batch_size_train'],
            bs=cfg_spd['batch_size_train']/cfg_spd['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = Trainer(
            max_epochs=cfg_spd['epochs'],
            min_epochs=cfg_spd['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_spd['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred = trainer.pred(net,dataloader=loader_test)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))


        # extract model parameters
        state_dict = deepcopy(net.state_dict())

        # create new model and perform offline SF UDA
        print("Perform UDA offline")
        sfuda_offline_net = SPDSMNet2(**mdl_kwargs).to(device=device)
        sfuda_offline_net.load_state_dict(state_dict)
        sfuda_offline(ds_test, sfuda_offline_net, cfg_spd)
        res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # create a new model and perform online SF UDA
        # sfuda_online_net = SPDSMNet2(**mdl_kwargs)
        # sfuda_online_net.load_state_dict(state_dict)
        # loss, score = sfuda_online(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        # records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # # # create a new model and perform online SF UDA
        # sfuda_online_net = SPDSMNet2(**mdl_kwargs)
        # sfuda_online_net.load_state_dict(state_dict)
        # loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        # records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))

resdf = pd.DataFrame(records)    
resdf.to_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet_test/SPDSMNet2_code.csv")

test domain=['1/0']
epoch=  0 gd-step=  326 trn_loss= 0.6477 trn_score=0.6318 val_loss= 0.6517 val_score=0.6229 
epoch= 10 gd-step= 3574 trn_loss= 0.6114 trn_score=0.6700 val_loss= 0.6119 val_score=0.6661 
epoch= 19 gd-step= 6487 trn_loss= 0.6068 trn_score=0.6766 val_loss= 0.6075 val_score=0.6737 
ES best epoch=19
evaluate the estimator
 accuracy score of the participant 0.73
Perform UDA offline
test domain=['10/0']
epoch=  0 gd-step=  323 trn_loss= 0.6787 trn_score=0.5710 val_loss= 0.6812 val_score=0.5610 
epoch= 10 gd-step= 3564 trn_loss= 0.6455 trn_score=0.6279 val_loss= 0.6393 val_score=0.6265 
epoch= 19 gd-step= 6495 trn_loss= 0.6435 trn_score=0.6291 val_loss= 0.6377 val_score=0.6275 
ES best epoch=19
evaluate the estimator
 accuracy score of the participant 0.33
Perform UDA offline
test domain=['11/0']
epoch=  0 gd-step=  326 trn_loss= 0.6604 trn_score=0.6170 val_loss= 0.6618 val_score=0.6091 
epoch= 10 gd-step= 3571 trn_loss= 0.5994 trn_score=0.6797 val_loss= 0.6061 val_score=0.

  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.25
Perform UDA offline
test domain=['2/0']
epoch=  0 gd-step=  325 trn_loss= 0.6323 trn_score=0.6631 val_loss= 0.6327 val_score=0.6545 
epoch= 10 gd-step= 3573 trn_loss= 0.5318 trn_score=0.7306 val_loss= 0.5372 val_score=0.7252 
epoch= 19 gd-step= 6485 trn_loss= 0.5240 trn_score=0.7356 val_loss= 0.5321 val_score=0.7286 
ES best epoch=19
evaluate the estimator
 accuracy score of the participant 0.28
Perform UDA offline
test domain=['3/0']
epoch=  0 gd-step=  327 trn_loss= 0.6899 trn_score=0.5408 val_loss= 0.6913 val_score=0.5362 
epoch= 10 gd-step= 3560 trn_loss= 0.6221 trn_score=0.6457 val_loss= 0.6129 val_score=0.6532 
epoch= 19 gd-step= 6479 trn_loss= 0.6201 trn_score=0.6505 val_loss= 0.6111 val_score=0.6540 
ES best epoch=19
evaluate the estimator
 accuracy score of the participant 0.23
Perform UDA offline
test domain=['4/0']
epoch=  0 gd-step=  320 trn_loss= 0.6503 trn_score=0.6243 val_loss= 0.6492 val_score=0.6301 
epoch= 10 gd-step= 3562 trn_l

  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.28
Perform UDA offline
test domain=['6/0']
epoch=  0 gd-step=  325 trn_loss= 0.6214 trn_score=0.6635 val_loss= 0.6188 val_score=0.6557 
epoch= 10 gd-step= 3562 trn_loss= 0.5271 trn_score=0.7364 val_loss= 0.5259 val_score=0.7295 
epoch= 19 gd-step= 6483 trn_loss= 0.5247 trn_score=0.7375 val_loss= 0.5208 val_score=0.7341 
ES best epoch=19
evaluate the estimator
 accuracy score of the participant 0.27
Perform UDA offline
test domain=['7/0']
epoch=  0 gd-step=  322 trn_loss= 0.6552 trn_score=0.6158 val_loss= 0.6569 val_score=0.6131 
epoch= 10 gd-step= 3561 trn_loss= 0.5958 trn_score=0.6755 val_loss= 0.5948 val_score=0.6765 
epoch= 19 gd-step= 6484 trn_loss= 0.5923 trn_score=0.6797 val_loss= 0.5922 val_score=0.6816 
ES best epoch=19
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.22
Perform UDA offline
test domain=['8/0']
epoch=  0 gd-step=  325 trn_loss= 0.5936 trn_score=0.6975 val_loss= 0.5958 val_score=0.6937 
epoch= 10 gd-step= 3571 trn_loss= 0.5447 trn_score=0.7237 val_loss= 0.5450 val_score=0.7250 
epoch= 19 gd-step= 6478 trn_loss= 0.5407 trn_score=0.7268 val_loss= 0.5429 val_score=0.7284 
ES best epoch=19
evaluate the estimator
 accuracy score of the participant 0.23
Perform UDA offline
test domain=['9/0']
epoch=  0 gd-step=  326 trn_loss= 0.6372 trn_score=0.6533 val_loss= 0.6442 val_score=0.6458 
epoch= 10 gd-step= 3569 trn_loss= 0.6053 trn_score=0.6785 val_loss= 0.6156 val_score=0.6727 
epoch= 19 gd-step= 6485 trn_loss= 0.6044 trn_score=0.6801 val_loss= 0.6150 val_score=0.6742 
ES best epoch=16
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.23
Perform UDA offline


## DSBNSDPBNNet

In [24]:
records = []

if 'inter-session' in cfg_spd['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg_spd['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    # X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)
    X = np.concatenate(X_parent)
    labels = np.concatenate(Y_parent)
    domains = np.concatenate(domains_parent)
    metadata = metadata_parent.copy()

    X_std = X.std(axis=0)
    X /= X_std + 1e-8

    # xdawncov = XdawnCovariances(estimator="lwf",xdawn_estimator="lwf",nfilter=8)
    # X = xdawncov.fit_transform(X,labels)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_spd['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_spd['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_spd['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_spd['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_spd['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = DSBNSPDBNNet(**mdl_kwargs).to(device=device, dtype=cfg_spd['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_spd['epochs']-10,
            bs0=cfg_spd['batch_size_train'],
            bs=cfg_spd['batch_size_train']/cfg_spd['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = Trainer(
            max_epochs=cfg_spd['epochs'],
            min_epochs=cfg_spd['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_spd['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred = trainer.pred(net,dataloader=loader_test)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))


        # extract model parameters
        state_dict = deepcopy(net.state_dict())

        # # create new model and perform offline SF UDA
        # print("Perform UDA offline")
        # sfuda_offline_net = SPDSMNet2(**mdl_kwargs).to(device=device)
        # sfuda_offline_net.load_state_dict(state_dict)
        # sfuda_offline(ds_test, sfuda_offline_net)
        # res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        # records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # create a new model and perform online SF UDA
        # sfuda_online_net = SPDSMNet2(**mdl_kwargs)
        # sfuda_online_net.load_state_dict(state_dict)
        # loss, score = sfuda_online(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        # records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # # # create a new model and perform online SF UDA
        # sfuda_online_net = SPDSMNet2(**mdl_kwargs)
        # sfuda_online_net.load_state_dict(state_dict)
        # loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        # records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))

resdf = pd.DataFrame(records)    
resdf.to_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet_test/DSBMSDPBNNet_code.csv")

test domain=['1/0']
epoch=  0 gd-step=  324 trn_loss= 0.6895 trn_score=0.5462 val_loss= 0.6894 val_score=0.5449 
epoch= 10 gd-step= 3578 trn_loss= 0.6559 trn_score=0.6079 val_loss= 0.6557 val_score=0.6076 
epoch= 19 gd-step= 6494 trn_loss= 0.6534 trn_score=0.6117 val_loss= 0.6537 val_score=0.6148 
ES best epoch=18
evaluate the estimator
 accuracy score of the participant 0.25
test domain=['10/0']
epoch=  0 gd-step=  327 trn_loss= 0.6830 trn_score=0.5864 val_loss= 0.6832 val_score=0.5903 
epoch= 10 gd-step= 3572 trn_loss= 0.6293 trn_score=0.6390 val_loss= 0.6308 val_score=0.6434 
epoch= 19 gd-step= 6486 trn_loss= 0.6338 trn_score=0.6372 val_loss= 0.6336 val_score=0.6371 
ES best epoch=9
evaluate the estimator
 accuracy score of the participant 0.2
test domain=['11/0']
epoch=  0 gd-step=  324 trn_loss= 0.6846 trn_score=0.5709 val_loss= 0.6842 val_score=0.5617 
epoch= 10 gd-step= 3570 trn_loss= 0.6645 trn_score=0.5923 val_loss= 0.6646 val_score=0.5968 
epoch= 19 gd-step= 6497 trn_loss= 0.

  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.27
test domain=['2/0']
epoch=  0 gd-step=  324 trn_loss= 0.6901 trn_score=0.5422 val_loss= 0.6897 val_score=0.5538 
epoch= 10 gd-step= 3579 trn_loss= 0.6374 trn_score=0.6346 val_loss= 0.6436 val_score=0.6348 
epoch= 19 gd-step= 6505 trn_loss= 0.6347 trn_score=0.6418 val_loss= 0.6410 val_score=0.6394 
ES best epoch=13
evaluate the estimator
 accuracy score of the participant 0.15
test domain=['3/0']
epoch=  0 gd-step=  321 trn_loss= 0.6753 trn_score=0.6090 val_loss= 0.6756 val_score=0.6055 
epoch= 10 gd-step= 3564 trn_loss= 0.6297 trn_score=0.6390 val_loss= 0.6322 val_score=0.6388 
epoch= 19 gd-step= 6485 trn_loss= 0.6309 trn_score=0.6406 val_loss= 0.6334 val_score=0.6422 
ES best epoch=7
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.28
test domain=['4/0']
epoch=  0 gd-step=  320 trn_loss= 0.6884 trn_score=0.5522 val_loss= 0.6878 val_score=0.5572 
epoch= 10 gd-step= 3570 trn_loss= 0.6452 trn_score=0.6224 val_loss= 0.6418 val_score=0.6305 
epoch= 19 gd-step= 6483 trn_loss= 0.6440 trn_score=0.6312 val_loss= 0.6423 val_score=0.6307 
ES best epoch=12
evaluate the estimator
 accuracy score of the participant 0.18
test domain=['5/0']
epoch=  0 gd-step=  324 trn_loss= 0.6746 trn_score=0.6172 val_loss= 0.6743 val_score=0.6218 
epoch= 10 gd-step= 3561 trn_loss= 0.6257 trn_score=0.6507 val_loss= 0.6253 val_score=0.6576 
epoch= 19 gd-step= 6470 trn_loss= 0.6256 trn_score=0.6530 val_loss= 0.6255 val_score=0.6598 
ES best epoch=14
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.23
test domain=['6/0']
epoch=  0 gd-step=  322 trn_loss= 0.6762 trn_score=0.6006 val_loss= 0.6747 val_score=0.6076 
epoch= 10 gd-step= 3564 trn_loss= 0.6342 trn_score=0.6353 val_loss= 0.6281 val_score=0.6428 
epoch= 19 gd-step= 6479 trn_loss= 0.6397 trn_score=0.6322 val_loss= 0.6374 val_score=0.6299 
ES best epoch=9
evaluate the estimator
 accuracy score of the participant 0.22
test domain=['7/0']
epoch=  0 gd-step=  326 trn_loss= 0.6676 trn_score=0.6237 val_loss= 0.6689 val_score=0.6152 
epoch= 10 gd-step= 3560 trn_loss= 0.6229 trn_score=0.6509 val_loss= 0.6262 val_score=0.6477 
epoch= 19 gd-step= 6474 trn_loss= 0.6223 trn_score=0.6495 val_loss= 0.6208 val_score=0.6553 
ES best epoch=19
evaluate the estimator
 accuracy score of the participant 0.18
test domain=['8/0']
epoch=  0 gd-step=  328 trn_loss= 0.6873 trn_score=0.5591 val_loss= 0.6869 val_score=0.5612 
epoch= 10 gd-step= 3571 trn_loss= 0.6598 trn_score=0.6049 val_loss= 0.6605 val_score=0.603

In [12]:
X.shape

torch.Size([21060, 32, 32])

## Modified TSMNET_visu

In [None]:
records = []

if 'inter-session' in cfg_spd['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg_spd['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)

    X_std = X.std(axis=0)
    X /= X_std + 1e-8

    xdawncov = XdawnCovariances(estimator="lwf",xdawn_estimator="lwf",nfilter=8)
    X = xdawncov.fit_transform(X,labels)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_spd['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_spd['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_spd['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_spd['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_spd['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = SPDSMNet_visu(**mdl_kwargs).to(device=device, dtype=cfg_spd['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_spd['epochs']-10,
            bs0=cfg_spd['batch_size_train'],
            bs=cfg_spd['batch_size_train']/cfg_spd['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = VisuTrainer(
            max_epochs=cfg_spd['epochs'],
            min_epochs=cfg_spd['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_spd['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred, inter_layer_out = trainer.pred(net,dataloader=loader_test)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))


        # extract model parameters
        state_dict = deepcopy(net.state_dict())

        # create new model and perform offline SF UDA
        print("Perform UDA offline")
        sfuda_offline_net = SPDSMNet2(**mdl_kwargs).to(device=device)
        sfuda_offline_net.load_state_dict(state_dict)
        sfuda_offline(ds_test, sfuda_offline_net)
        res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # create a new model and perform online SF UDA
        sfuda_online_net = SPDSMNet2(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # # create a new model and perform online SF UDA
        sfuda_online_net = SPDSMNet2(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))


## original TSMNET on samples

In [17]:
subjects = [1,2,3]
# subjects = [1,2,3,4,5,6,7,8,9,10,11,12]
n_channels = 32
on_frame = False
if on_frame:
    freq = 60
else:
    freq = 500

raw_data,labels,codes,labels_codes = get_BVEP_data(subjects,on_frame)
X_parent, labels_parent, domains_parent = prepare_data(subjects,raw_data, labels, on_frame,False,codes)
metadata = pd.DataFrame({"subject":np.repeat(list(map(str,subjects)),X_parent[0].shape[0]),"session":["0"]*len(subjects)*X_parent[0].shape[0]})

Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 50 - 50 Hz

IIR filter parameters
---------------------
Butterworth bandstop zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 49.90, 50.10 Hz: -6.02, -6.02 dB

Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 50 - 50 Hz

IIR filter parameters
---------------------
Butterworth bandstop zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 49.90, 50.10 Hz: -6.02, -6.02 dB

Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 50 - 50 Hz

IIR filter parameters
---------------------
Butterworth bandstop zero-phase (two-pas

In [18]:
records = []

if 'inter-session' in cfg_org['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg_org['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    # X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)
    X = np.concatenate(X_parent)
    labels = np.concatenate(labels_parent)
    domains = np.concatenate(domains_parent)

    X_std = X.std(axis=0)
    X /= X_std + 1e-8

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_org['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_org['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_org['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_org['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_org['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = TSMNet(**mdl_kwargs).to(device=device, dtype=cfg_org['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_org['epochs']-10,
            bs0=cfg_org['batch_size_train'],
            bs=cfg_org['batch_size_train']/cfg_org['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = Trainer(
            max_epochs=cfg_org['epochs'],
            min_epochs=cfg_org['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_org['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred = trainer.pred(net,dataloader=loader_test)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))


        # extract model parameters
        state_dict = deepcopy(net.state_dict())

        # create new model and perform offline SF UDA
        print("Perform UDA offline")
        sfuda_offline_net = TSMNet(**mdl_kwargs).to(device=device)
        sfuda_offline_net.load_state_dict(state_dict)
        sfuda_offline(ds_test, sfuda_offline_net)
        res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # create a new model and perform online SF UDA
        sfuda_online_net = TSMNet(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online(ds_test, sfuda_online_net, cfg_org, trainer.loss_fn)
        records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # create a new model and perform online SF UDA
        sfuda_online_net = TSMNet(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg_org, trainer.loss_fn)
        records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))


test domain=['1/0']
epoch=  0 gd-step=  498 trn_loss= 0.6509 trn_score=0.6513 val_loss= 0.6569 val_score=0.6383 
epoch= 10 gd-step= 5478 trn_loss= 0.4250 trn_score=0.8138 val_loss= 0.4748 val_score=0.7852 
epoch= 19 gd-step= 9960 trn_loss= 0.4401 trn_score=0.7983 val_loss= 0.4919 val_score=0.7658 
ES best epoch=18
evaluate the estimator


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.49 GiB. GPU 0 has a total capacty of 6.00 GiB of which 0 bytes is free. Of the allocated memory 5.15 GiB is allocated by PyTorch, and 32.46 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

## Modified TSMNET on samples

In [19]:
subjects = [1,2,3]
# subjects = [1,2,3,4,5,6,7,8,9,10,11,12]
n_channels = 32
on_frame = False
if on_frame:
    freq = 60
else:
    freq = 500

raw_data,labels,codes,labels_codes = get_BVEP_data(subjects,on_frame)
X_parent, labels_parent, domains_parent = prepare_data(subjects,raw_data, labels, on_frame,True,codes)
metadata = pd.DataFrame({"subject":np.repeat(list(map(str,subjects)),X_parent[0].shape[0]),"session":["0"]*len(subjects)*X_parent[0].shape[0]})

Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 50 - 50 Hz

IIR filter parameters
---------------------
Butterworth bandstop zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 49.90, 50.10 Hz: -6.02, -6.02 dB

Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 50 - 50 Hz

IIR filter parameters
---------------------
Butterworth bandstop zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 49.90, 50.10 Hz: -6.02, -6.02 dB

Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 50 - 50 Hz

IIR filter parameters
---------------------
Butterworth bandstop zero-phase (two-pas

In [20]:
# metadata = pd.DataFrame({"subject":np.repeat(list(map(str,subjects)),X[0].shape[0]),"session":["0"]*len(subjects)*X[0].shape[0]})
metadata

Unnamed: 0,subject,session
0,1,0
1,1,0
2,1,0
3,1,0
4,1,0
...,...,...
175495,3,0
175496,3,0
175497,3,0
175498,3,0


In [21]:
records = []

if 'inter-session' in cfg_spd['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg_spd['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    # X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)
    X = np.concatenate(X_parent)
    labels = np.concatenate(labels_parent)
    domains = np.concatenate(domains_parent)

    X_std = X.std(axis=0)
    X /= X_std + 1e-8

    # xdawncov = XdawnCovariances(estimator="lwf",xdawn_estimator="lwf",nfilter=8)
    # X = xdawncov.fit_transform(X,labels)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_spd['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_spd['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_spd['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_spd['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_spd['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = SPDSMNet(**mdl_kwargs).to(device=device, dtype=cfg_spd['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_spd['epochs']-10,
            bs0=cfg_spd['batch_size_train'],
            bs=cfg_spd['batch_size_train']/cfg_spd['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = Trainer(
            max_epochs=cfg_spd['epochs'],
            min_epochs=cfg_spd['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_spd['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred = trainer.pred(net,dataloader=loader_test)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))


        # extract model parameters
        state_dict = deepcopy(net.state_dict())

        # create new model and perform offline SF UDA
        print("Perform UDA offline")
        sfuda_offline_net = SPDSMNet(**mdl_kwargs).to(device=device)
        sfuda_offline_net.load_state_dict(state_dict)
        sfuda_offline(ds_test, sfuda_offline_net)
        res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # create a new model and perform online SF UDA
        # sfuda_online_net = SPDSMNet(**mdl_kwargs)
        # sfuda_online_net.load_state_dict(state_dict)
        # loss, score = sfuda_online(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        # records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # # create a new model and perform online SF UDA
        # sfuda_online_net = SPDSMNet(**mdl_kwargs)
        # sfuda_online_net.load_state_dict(state_dict)
        # loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        # records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))


test domain=['1/0']


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)

## original TSMNET train/test

In [9]:
# subjects = [1,2,3]
subjects = [1,2,3,4,5,6,7,8,9,10,11,12]
n_channels = 32
on_frame = True
if on_frame:
    freq = 60
else:
    freq = 500

raw_data,labels,codes,labels_codes = get_BVEP_data(subjects,on_frame)
X_parent, labels_parent, domains_parent = prepare_data(subjects,raw_data, labels, on_frame,False,codes)

Choosing the first None classes from all possible events.


None
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad ep

  raw = mne.io.read_raw_eeglab(file_path_list[0], preload=True, verbose=False)


C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs 

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped


In [30]:
records = []
n_cal = 7
nb_frame = int(4*n_cal*(2.2-0.25)*60)
nb_frame_train = int(4*(n_cal-1)*(2.2-0.25)*60)
nb_frame_val = int(4*1*(2.2-0.25)*60)


fit_records = []

# iterate over subject groups
for ix_sub, sub in enumerate(subjects):

    # get the data from the MOABB paradigm/dataset
    # X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)
    X = X_parent[ix_sub]
    labels = labels_parent[ix_sub]
    domains = domains_parent[ix_sub]
    metadata = pd.DataFrame({"subject":np.repeat(list(map(str,[sub])),X.shape[0]),"session":["0"]*X.shape[0]})

    train = np.arange(0,nb_frame_train,1)
    val = np.arange(nb_frame_train,nb_frame,1)
    test = np.arange(nb_frame, len(X),1)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(domains)
    
    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # Balancing the data
    rus = RandomUnderSampler()
    counter=np.array(range(0,len(y[train]))).reshape(-1,1)
    index,_ = rus.fit_resample(counter,y[train][:])
    index = np.sort(index,axis=0)
    # X_train = np.squeeze(X[train][index,:,:], axis=1)
    # y_train = np.squeeze(y[train][index])
    # domain_train = np.squeeze(domain[train][index])
    # metadata_train = metadata.loc[train].iloc[np.concatenate(index)]
    train = np.squeeze(train[index], axis=1)

    X_std = X[train].std(axis=0)
    X[train] /= X_std + 1e-8
    X_std = X[test].std(axis=0)
    X[test] /= X_std + 1e-8

    # adjust number of 
    du = np.unique(domains[train])
    if cfg_org['domains_per_batch'] > len(du):
        domains_per_batch = len(du)
    else:
        domains_per_batch = cfg_org['domains_per_batch']

    
    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_org['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # split entire dataset into train/validation/test
    ds_train = DomainDataset(X[train], y[train], domain[train], metadata.iloc[train])
    ds_val = DomainDataset(X[val], y[val], domain[val], metadata.iloc[val])
    ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

    # create dataloaders
    # for training use specific loader/sampler so taht 
    # batches contain a specific number of domains with equal observations per domain
    # and stratified labels
    loader_train = StratifiedDomainDataLoader(ds_train, cfg_org['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
    loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
    loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

    # extract domains in the test dataset
    test_domain = np.unique(domains[test])

    # create the model
    net = TSMNet(**mdl_kwargs).to(device=device, dtype=cfg_org['dtype'])

    # create the momentum scheduler    
    bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_org['epochs']-10,
            bs0=cfg_org['batch_size_train'],
            bs=cfg_org['batch_size_train']/cfg_org['domains_per_batch'], 
            tau0=0.85
    )

    es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
    # create the trainer
    trainer = Trainer(
            max_epochs=cfg_org['epochs'],
            min_epochs=cfg_org['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_org['dtype']
    )

    # fit the model
    print(f"test domain={test_domain}")
    trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

    print(f'ES best epoch={es.best_epoch}')

    fit_df = pd.DataFrame(trainer.records)
    fit_df['fold'] = 1
    fit_df['subset'] = 1
    fit_records.append(fit_df)

    # evaluation
    print("evaluate the estimator")
    res = trainer.test(net, dataloader=loader_train)
    records.append(dict(mode='train', domain=test_domain, **res))
    res = trainer.test(net, dataloader=loader_test)
    records.append(dict(mode='test(noUDA)', domain=test_domain, **res))

    y_pred = trainer.pred(net,dataloader=loader_test)
    labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
        y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
    )
    accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_sub][4*n_cal:][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
    print(" accuracy score of the participant",accuracy_code)
    records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))

resdf = pd.DataFrame(records)
resdf.to_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet_test/SPDSMNet_traintest_code.csv")



test domain=['1/0']
epoch=  0 gd-step=   17 trn_loss= 0.6866 trn_score=0.5558 val_loss= 0.6869 val_score=0.5527 
epoch= 10 gd-step=  187 trn_loss= 0.5855 trn_score=0.8646 val_loss= 0.5852 val_score=0.8661 
epoch= 19 gd-step=  340 trn_loss= 0.4530 trn_score=0.9173 val_loss= 0.4537 val_score=0.9170 
ES best epoch=19
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.22
test domain=['2/0']
epoch=  0 gd-step=   17 trn_loss= 0.6868 trn_score=0.6206 val_loss= 0.6868 val_score=0.6241 
epoch= 10 gd-step=  187 trn_loss= 0.6020 trn_score=0.8336 val_loss= 0.6028 val_score=0.8313 
epoch= 19 gd-step=  340 trn_loss= 0.4903 trn_score=0.8796 val_loss= 0.4903 val_score=0.8795 
ES best epoch=19
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.28
test domain=['3/0']
epoch=  0 gd-step=   17 trn_loss= 0.6862 trn_score=0.6344 val_loss= 0.6863 val_score=0.6321 
epoch= 10 gd-step=  187 trn_loss= 0.5969 trn_score=0.8125 val_loss= 0.5967 val_score=0.8152 


KeyboardInterrupt: 

In [27]:
y[train][:50]

tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
        1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1,
        1, 0])

In [56]:
y_pred[0].cpu().numpy().shape

(3744,)

## Get the results

In [54]:
# report the results
resdf = pd.DataFrame(records)

resdf.groupby(['mode']).agg(['mean', 'std']).round(2)

KeyError: 'mode'

In [36]:
resdf.to_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/SPDSMNet_code.csv")
# resdf

In [37]:
df = pd.read_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/SPDSMNet_code.csv")
print()
df_score_code = np.array(df["score"])[df["mode"]=="test(noUDA)_code"]
df_score = np.array(df["score"])[df["mode"]=="test(noUDA)"]

np.save("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/SPDSMNet_score_code",df_score_code)
np.save("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/SPDSMNet_score",df_score)




In [38]:
fit_df = pd.concat(fit_records)
fit_df = fit_df.groupby(['subset', 'fold', 'epoch']).mean()
fit_df.columns = fit_df.columns.str.split('_', expand=True)
fit_df.columns.names = ['set', 'metric']
fit_df = fit_df.stack('set').reset_index()


sns.relplot(data=fit_df, x='epoch', y='loss', hue='set', col='subset', col_wrap=5, kind='line', n_boot=100)

<seaborn.axisgrid.FacetGrid at 0x23432e9ee10>