In [1]:
%load_ext autoreload
%autoreload 2

# Minimal TSMNet demo notebook for inter-session/-subject source-free (SF) offline and online unsupervised domain adaptation (UDA)

In [1]:
import torch
import sklearn
from sklearn.metrics import balanced_accuracy_score
from sklearn.manifold import TSNE
import seaborn as sns

import numpy as np
from imblearn.under_sampling import RandomUnderSampler

import pandas as pd
import numpy as np
from copy import deepcopy

from moabb.datasets.bnci import BNCI2015001, BNCI2014001
from moabb.paradigms import MotorImagery, CVEP

# from library.utils.torch import StratifiedDomainDataLoader
from spdnets.utils.data import StratifiedDomainDataLoader, DomainDataset
from spdnets.models import TSMNet, SPDSMNet, SPDSMNet2, DSBNSPDBNNet, SPDSMNet_visu, DSBNSPDBNNet_Visu
import spdnets.batchnorm as bn
import spdnets.functionals as fn

from spdnets.trainer import Trainer, VisuTrainer
from spdnets.callbacks import MomentumBatchNormScheduler, EarlyStopping

import sys
from pyriemann.tangentspace import TangentSpace
from pyriemann.estimation import XdawnCovariances
import matplotlib.pyplot as plt
sys.path.insert(0,"C:\\Users\\s.velut\\Documents\\These\\Protheus_PHD\\Scripts")
sys.path.insert(0,"C:\\Users\\s.velut\\Documents\\These\\Protheus_PHD")
from utils import prepare_data,get_BVEP_data,balance,get_y_pred, get_TSNE_visu
from _utils import make_preds_accumul_aggresive, make_preds_pvalue
sys.path.insert(0,"C:\\Users\\s.velut\\Documents\\These\\moabb\\moabb\\datasets")
from castillos2023 import CasitllosCVEP100,CasitllosCVEP40,CasitllosBurstVEP100,CasitllosBurstVEP40






c:\Users\s.velut\AppData\Local\Programs\Python\Python311\Lib\site-packages\moabb\pipelines\__init__.py:26: ModuleNotFoundError: Tensorflow is not installed. You won't be able to use these MOABB pipelines if you attempt to do so.
  warn(


In [2]:
# network and training configuration
cfg_org = dict(
    epochs = 30,
    batch_size_train = 64,
    domains_per_batch = 4,
    validation_size = 0.2,
    evaluation = 'inter-subject', # or 'inter-subject'
    dtype = torch.float32,
    # parameters for the TSMNet model
    mdl_kwargs = dict(
        temporal_filters=4,
        nclasses = 2,
        spatial_filters=24,
        subspacedims=20, 
        bnorm_dispersion=bn.BatchNormDispersion.SCALAR,
        spd_device='cpu',
        spd_dtype=torch.double,
        domain_adaptation=True
    )
)

cfg_spd = dict(
    epochs = 20,
    batch_size_train = 64,
    domains_per_batch = 4,
    validation_size = 0.2,
    evaluation = 'inter-subject', # or 'inter-subject'
    dtype = torch.double,
    spd_device='cuda',
    # parameters for the TSMNet model
    mdl_kwargs = dict(
        temporal_filters=2,
        nclasses = 2,
        spatial_filters=32,
        bimap_dims = [32,28,14,7],
        subspacedims=32, 
        bnorm_dispersion=bn.BatchNormDispersion.SCALAR,
        spd_device='cuda',
        spd_dtype=torch.double,
        domain_adaptation=True
    )
)

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

## load a MOABB dataset

In [3]:
# moabb_ds = BNCI2015001()
# n_classes = 2
# moabb_paradigm = MotorImagery(n_classes=n_classes, events=['right_hand', 'feet'], fmin=4, fmax=36, tmin=0.0, tmax=4.0, resample=256)

moabb_ds = CasitllosBurstVEP100()
n_classes = 2
moabb_paradigm = CVEP()

# moabb_ds = BNCI2014001()
# n_classes = 4
# moabb_paradigm = MotorImagery(n_classes=n_classes, events=["left_hand", "right_hand", "feet", "tongue"], fmin=4, fmax=36, tmin=0.5, tmax=3.496, resample=None)

Choosing the first None classes from all possible events.


## SF offline UDA

In [4]:
def sfuda_offline(dataset : DomainDataset, model : TSMNet, cfg):
    model.eval()
    model.domainadapt_finetune(dataset.features.to(dtype=cfg['dtype'], device=device), dataset.labels.to(device=device), dataset.domains, None)

## SF online UDA

In [5]:
def _sfuda_online_init_other_spd(domain_bn : torch.nn.Module, model : TSMNet, cfg):

    # initialize new mean with the grand average of the other domains
    M = []
    s = []
    for _, val in model.spdbnorm_layers[0].batchnorm.items():
        M += [val.running_mean_test]
        s += [val.running_var_test]
    M = torch.cat(M, dim=0)
    s = torch.cat(s, dim=0)
    prev_mean = fn.spd_mean_kracher_flow(M, dim=0, return_dist=False)
    prev_var = fn.spd_mean_kracher_flow(s[...,None], dim=0, return_dist=False).squeeze(-1)

    # because we use precomputed means advance the observation number
    # so that we start with slower adaptation
    domain_bn.adapt_observation = 5 if cfg['evaluation'] == 'inter-subject' else 15
    domain_bn.running_mean_test.data = prev_mean.data.clone()
    domain_bn.running_var_test = prev_var.clone()


def sfuda_online_init_spd(target_domains : torch.LongTensor, model : TSMNet, cfg, strategy : str = 'other'):

    assert target_domains.ndim == 1
    # assert isinstance(model.spdbnorm, bn.BaseDomainBatchNorm)
    for target_domain in target_domains:
        for i in range(len(model.bimap_dims[1:])):
            domain_key = model.spdbnorm_layers[i].domain_to_key(target_domain)
            # add domain if not yet in the model
            if domain_key not in model.spdbnorm_layers[i].batchnorm:
                bncls = model.spdbnorm_layers[i].domain_bn_cls
                domain_bn = bncls(
                    shape=model.spdbnorm_layers[i].mean.shape,
                    batchdim=model.spdbnorm_layers[i].batchdim, 
                    learn_mean=model.spdbnorm_layers[i].learn_mean,
                    learn_std=model.spdbnorm_layers[i].learn_std,
                    dispersion=model.spdbnorm_layers[i].dispersion,
                    mean=model.spdbnorm_layers[i].mean,
                    std=model.spdbnorm_layers[i].std,
                    eta=model.spdbnorm_layers[i].eta,
                    eta_test=model.spdbnorm_layers[i].eta_test
                )
            else:
                domain_bn = model.spdbnorm_layers[i].batchnorm[domain_key]

        if strategy == 'other':
            _sfuda_online_init_other_spd(domain_bn, model, cfg)
        else:
            raise NotImplementedError()

        # change the BN mode to perform online adaptation (for each batch)
        domain_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.ADAPT)
    
    for target_domain in target_domains:    
        for i in range(len(model.bimap_dims[1:])):
            if domain_key not in model.spdbnorm_layers[i].batchnorm:
                model.spdbnorm_layers[i].batchnorm[domain_key] = domain_bn

def sfuda_online_step(inputs : torch.Tensor, domains : torch.LongTensor, model : TSMNet, cfg):

    model = model.to(device='cpu')
    # the model needs to be in evaluation mode so that the batch norm statistics for testing will be adapted
    model.eval()

    activations = model(inputs.to(dtype=cfg['dtype'],device=cfg['spd_device']), domains)
    # return class probabilities
    return activations


def sfuda_online_simulate_spd(dataset : DomainDataset, model : TSMNet, cfg, loss_fn):

    # do adaptation if the model is configured to do domain adaptation
    if not model.domain_adaptation_:
        test_loss, score  = None, None
    else:
        
        # extract the target domains from the dataset
        target_domains = dataset.domains.unique()

        sfuda_online_init_spd(target_domains, model, cfg)
        
        y_true = []
        y_hat = []

        test_loss = 0.
        # feed each observation through the network
        # to adapt and infer the target
        for i, (features, y) in enumerate(dataset):
            
            inputs = features['inputs'][None,...]
            domains = features['domains'][None,...]
            y = y

            pred = sfuda_online_step(inputs, domains, model, cfg)
            test_loss += loss_fn(pred, y[None,...]).item()
            y_true.append(y[None,...])
            y_hat.append(pred.argmax(1)[None,...])
            # p_class = torch.nn.functional.softmax(pred, dim=-1)

        # compute the overall score 
        score = balanced_accuracy_score(torch.cat(y_true).detach().cpu().numpy(), torch.cat(y_hat).detach().cpu().numpy())

        test_loss /= len(dataset)

        # stop adaptation for the target domains
        for target_domain in target_domains:
            for i in range(len(model.bimap_dims[1:])):
                trgt_bn = model.spdbnorm_layers[i].batchnorm[str(target_domain.item())]
                trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)  

    return test_loss, score 

In [6]:
def sfuda_online(dataset : DomainDataset, model : TSMNet, cfg, loss_fn):

    model = model.to(device='cpu')
    # the model needs to be in evaluation mode so that the batch norm statistics for testing will be adapted
    model.eval()

    # do adaptation if the model is configured to do domain adaptation
    if not model.domain_adaptation_:
        test_loss, score  = None, None
    else:
        
        # extract the target domains from the dataset
        target_domains = dataset.domains.unique()
        # initialize new mean with the grand average of the other domains
        M = []
        s = []
        for key, val in model.spdbnorm.batchnorm.items():
            if key in [f'dom {t}' for t in target_domains]:
                continue
            M += [val.running_mean_test]
            s += [val.running_var_test]
        M = torch.cat(M, dim=0)
        s = torch.cat(s, dim=0)
        prev_mean = fn.spd_mean_kracher_flow(M, dim=0, return_dist=False)
        prev_var = fn.spd_mean_kracher_flow(s[...,None], dim=0, return_dist=False).squeeze(-1)
        
        # assign the grand average mean and var to the target domains
        for target_domain in target_domains:

            trgt_bn = model.spdbnorm.batchnorm[str(target_domain.item())]
            # because we use precomputed means advance the observation number
            # so that we start with slower adaptation
            trgt_bn.adapt_observation = 5 if cfg['evaluation'] == 'inter-subject' else 15
            trgt_bn.running_mean_test.data = prev_mean.data.clone()
            trgt_bn.running_var_test = prev_var.clone()

            # change the BN mode to perform online adaptation (for each batch)
            trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.ADAPT)
        
        y_true = []
        y_hat = []

        test_loss = 0.
        # feed each observation through the network
        # to adapt and infer the target
        for i, (features, y) in enumerate(dataset):
            
            inputs = features['inputs'][None,...].to(dtype=cfg['dtype'])
            domains = features['domains'][None,...]
            y = y
            pred = model.forward(inputs=inputs, domains=domains)
            test_loss += loss_fn(pred, y[None,...]).item()
            y_true.append(y[None,...])
            y_hat.append(pred.argmax(1)[None,...])

        # compute the overall score 
        score = balanced_accuracy_score(torch.cat(y_true).detach().cpu().numpy(), torch.cat(y_hat).detach().cpu().numpy())

        test_loss /= len(dataset)

        # stop adaptation for the target domains
        for target_domain in target_domains:
            trgt_bn = model.spdbnorm.batchnorm[str(target_domain.item())]
            trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)  

    return test_loss, score 

def sfuda_online_spd(dataset : DomainDataset, model : TSMNet, cfg, loss_fn):

    model = model.to(device='cpu')
    # the model needs to be in evaluation mode so that the batch norm statistics for testing will be adapted
    model.eval()

    # do adaptation if the model is configured to do domain adaptation
    if not model.domain_adaptation_:
        test_loss, score  = None, None
    else:
        
        # extract the target domains from the dataset
        target_domains = dataset.domains.unique()
        # initialize new mean with the grand average of the other domains
        M = []
        s = []
        for key, val in model.spdbnorm_layers[0].batchnorm.items():
            if key in [f'dom {t}' for t in target_domains]:
                continue
            M += [val.running_mean_test]
            s += [val.running_var_test]
        M = torch.cat(M, dim=0)
        s = torch.cat(s, dim=0)
        prev_mean = fn.spd_mean_kracher_flow(M, dim=0, return_dist=False)
        prev_var = fn.spd_mean_kracher_flow(s[...,None], dim=0, return_dist=False).squeeze(-1)
        
        # assign the grand average mean and var to the target domains
        for target_domain in target_domains:
            trgt_bn = model.spdbnorm_layers[0].batchnorm[str(target_domain.item())]
            # because we use precomputed means advance the observation number
            # so that we start with slower adaptation
            trgt_bn.adapt_observation = 5 if cfg['evaluation'] == 'inter-subject' else 15
            trgt_bn.running_mean_test.data = prev_mean.data.clone()
            trgt_bn.running_var_test = prev_var.clone()

            # change the BN mode to perform online adaptation (for each batch)
            trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.ADAPT)
        
        y_true = []
        y_hat = []

        test_loss = 0.
        # feed each observation through the network
        # to adapt and infer the target
        for i, (features, y) in enumerate(dataset):
            
            inputs = features['inputs'][None,...].to(dtype=cfg['dtype'],device=cfg['spd_device'])
            domains = features['domains'][None,...]
            y = y
            pred = model.forward(inputs=inputs, domains=domains)
            test_loss += loss_fn(pred, y[None,...]).item()
            y_true.append(y[None,...])
            y_hat.append(pred.argmax(1)[None,...])

        # compute the overall score 
        score = balanced_accuracy_score(torch.cat(y_true).detach().cpu().numpy(), torch.cat(y_hat).detach().cpu().numpy())

        test_loss /= len(dataset)

        # stop adaptation for the target domains
        for target_domain in target_domains:
            trgt_bn = model.spdbnorm_layers[0].batchnorm[str(target_domain.item())]
            trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)  

    return test_loss, score 

In [7]:
def _sfuda_online_init_other(domain_bn : torch.nn.Module, model : TSMNet, cfg):

    # initialize new mean with the grand average of the other domains
    M = []
    s = []
    for _, val in model.spdbnorm.batchnorm.items():
        M += [val.running_mean_test]
        s += [val.running_var_test]
    M = torch.cat(M, dim=0)
    s = torch.cat(s, dim=0)
    prev_mean = fn.spd_mean_kracher_flow(M, dim=0, return_dist=False)
    prev_var = fn.spd_mean_kracher_flow(s[...,None], dim=0, return_dist=False).squeeze(-1)

    # because we use precomputed means advance the observation number
    # so that we start with slower adaptation
    domain_bn.adapt_observation = 5 if cfg['evaluation'] == 'inter-subject' else 15
    domain_bn.running_mean_test.data = prev_mean.data.clone()
    domain_bn.running_var_test = prev_var.clone()


def sfuda_online_init(target_domains : torch.LongTensor, model : TSMNet, cfg, strategy : str = 'other'):

    assert target_domains.ndim == 1
    assert isinstance(model.spdbnorm, bn.BaseDomainBatchNorm)
    for target_domain in target_domains:
        domain_key = model.spdbnorm.domain_to_key(target_domain)
        # add domain if not yet in the model
        if domain_key not in model.spdbnorm.batchnorm:
            bncls = model.spdbnorm.domain_bn_cls
            domain_bn = bncls(
                shape=model.spdbnorm.mean.shape,
                batchdim=model.spdbnorm.batchdim, 
                learn_mean=model.spdbnorm.learn_mean,
                learn_std=model.spdbnorm.learn_std,
                dispersion=model.spdbnorm.dispersion,
                mean=model.spdbnorm.mean,
                std=model.spdbnorm.std,
                eta=model.spdbnorm.eta,
                eta_test=model.spdbnorm.eta_test
            )
        else:
            domain_bn = model.spdbnorm.batchnorm[domain_key]

        if strategy == 'other':
            _sfuda_online_init_other(domain_bn, model, cfg)
        else:
            raise NotImplementedError()

        # change the BN mode to perform online adaptation (for each batch)
        domain_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.ADAPT)
    
    for target_domain in target_domains:    
        if domain_key not in model.spdbnorm.batchnorm:
            model.spdbnorm.batchnorm[domain_key] = domain_bn

def sfuda_online_step(inputs : torch.Tensor, domains : torch.LongTensor, model : TSMNet, cfg):

    model = model.to(device='cpu')
    # the model needs to be in evaluation mode so that the batch norm statistics for testing will be adapted
    model.eval()

    activations = model(inputs.to(dtype=cfg['dtype']), domains)
    # return class probabilities
    return activations


def sfuda_online_simulate(dataset : DomainDataset, model : TSMNet, cfg, loss_fn):

    # do adaptation if the model is configured to do domain adaptation
    if not model.domain_adaptation_:
        test_loss, score  = None, None
    else:
        
        # extract the target domains from the dataset
        target_domains = dataset.domains.unique()

        sfuda_online_init(target_domains, model, cfg)
        
        y_true = []
        y_hat = []

        test_loss = 0.
        # feed each observation through the network
        # to adapt and infer the target
        for i, (features, y) in enumerate(dataset):
            
            inputs = features['inputs'][None,...]
            domains = features['domains'][None,...]
            y = y

            pred = sfuda_online_step(inputs, domains, model, cfg)
            test_loss += loss_fn(pred, y[None,...]).item()
            y_true.append(y[None,...])
            y_hat.append(pred.argmax(1)[None,...])
            # p_class = torch.nn.functional.softmax(pred, dim=-1)

        # compute the overall score 
        score = balanced_accuracy_score(torch.cat(y_true).detach().cpu().numpy(), torch.cat(y_hat).detach().cpu().numpy())

        test_loss /= len(dataset)

        # stop adaptation for the target domains
        for target_domain in target_domains:
            trgt_bn = model.spdbnorm.batchnorm[str(target_domain.item())]
            trgt_bn.set_test_stats_mode(bn.BatchNormTestStatsMode.BUFFER)  

    return test_loss, score 

## fit and evaluat the model for all domains

In [8]:
%matplotlib Qt

In [9]:
# subjects = [1,2,3]
subjects = [1,2,3,4,5,6,7,8,9,10,11,12]
n_channels = 32
on_frame = True
if on_frame:
    freq = 60
else:
    freq = 500

raw_data,labels,codes,labels_codes = get_BVEP_data(subjects,on_frame)

Choosing the first None classes from all possible events.


None
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad ep

  raw = mne.io.read_raw_eeglab(file_path_list[0], preload=True, verbose=False)


C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs 

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped


## original TSMNET

In [10]:
records = []

if 'inter-session' in cfg_org['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg_org['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_org['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        X_std = X_fit.std(axis=0)
        X_fit /= X_std + 1e-8
        X_std = X[test].std(axis=0)
        X[test] /= X_std + 1e-8

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_org['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_org['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_org['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_org['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = TSMNet(**mdl_kwargs).to(device=device, dtype=cfg_org['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_org['epochs']-10,
            bs0=cfg_org['batch_size_train'],
            bs=cfg_org['batch_size_train']/cfg_org['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = Trainer(
            max_epochs=cfg_org['epochs'],
            min_epochs=cfg_org['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_org['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred = trainer.pred(net,dataloader=loader_test)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))


        # extract model parameters
        state_dict = deepcopy(net.state_dict())

        # create new model and perform offline SF UDA
        print("Perform UDA offline")
        sfuda_offline_net = TSMNet(**mdl_kwargs).to(device=device)
        sfuda_offline_net.load_state_dict(state_dict)
        sfuda_offline(ds_test, sfuda_offline_net, cfg_org)
        res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # create a new model and perform online SF UDA
        sfuda_online_net = TSMNet(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online(ds_test, sfuda_online_net, cfg_org, trainer.loss_fn)
        records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # create a new model and perform online SF UDA
        sfuda_online_net = TSMNet(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg_org, trainer.loss_fn)
        records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))

resdf = pd.DataFrame(records)    
resdf.to_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/TSMNet_code.csv")



C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs 

  raw = mne.io.read_raw_eeglab(file_path_list[0], preload=True, verbose=False)


C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs 

  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.33
Perform UDA offline
test domain=['2/0']
epoch=  0 gd-step=  323 trn_loss= 0.6854 trn_score=0.5622 val_loss= 0.6885 val_score=0.5383 
epoch= 10 gd-step= 3565 trn_loss= 0.6096 trn_score=0.6624 val_loss= 0.6847 val_score=0.5824 
epoch= 20 gd-step= 6794 trn_loss= 0.5827 trn_score=0.6920 val_loss= 0.6955 val_score=0.5941 
epoch= 29 gd-step= 9709 trn_loss= 0.5753 trn_score=0.6991 val_loss= 0.7049 val_score=0.5989 
ES best epoch=6
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.28
Perform UDA offline
test domain=['3/0']
epoch=  0 gd-step=  326 trn_loss= 0.6877 trn_score=0.5515 val_loss= 0.6905 val_score=0.5322 
epoch= 10 gd-step= 3570 trn_loss= 0.6239 trn_score=0.6456 val_loss= 0.6939 val_score=0.5659 
epoch= 20 gd-step= 6800 trn_loss= 0.5898 trn_score=0.6845 val_loss= 0.7055 val_score=0.5900 
epoch= 29 gd-step= 9715 trn_loss= 0.5783 trn_score=0.6936 val_loss= 0.7161 val_score=0.5907 
ES best epoch=4
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.32
Perform UDA offline
test domain=['4/0']
epoch=  0 gd-step=  327 trn_loss= 0.6839 trn_score=0.5736 val_loss= 0.6878 val_score=0.5509 
epoch= 10 gd-step= 3570 trn_loss= 0.6151 trn_score=0.6584 val_loss= 0.6939 val_score=0.5644 
epoch= 20 gd-step= 6803 trn_loss= 0.5927 trn_score=0.6794 val_loss= 0.7104 val_score=0.5769 
epoch= 29 gd-step= 9712 trn_loss= 0.5780 trn_score=0.6904 val_loss= 0.7184 val_score=0.5712 
ES best epoch=4
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.33
Perform UDA offline
test domain=['5/0']
epoch=  0 gd-step=  325 trn_loss= 0.6878 trn_score=0.5477 val_loss= 0.6922 val_score=0.5167 
epoch= 10 gd-step= 3560 trn_loss= 0.6257 trn_score=0.6542 val_loss= 0.6966 val_score=0.5621 
epoch= 20 gd-step= 6801 trn_loss= 0.5970 trn_score=0.6787 val_loss= 0.7057 val_score=0.5742 
epoch= 29 gd-step= 9731 trn_loss= 0.5836 trn_score=0.6891 val_loss= 0.7038 val_score=0.5773 
ES best epoch=2
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.35
Perform UDA offline
test domain=['6/0']
epoch=  0 gd-step=  323 trn_loss= 0.6872 trn_score=0.5529 val_loss= 0.6912 val_score=0.5222 
epoch= 10 gd-step= 3562 trn_loss= 0.6234 trn_score=0.6485 val_loss= 0.6905 val_score=0.5756 
epoch= 20 gd-step= 6801 trn_loss= 0.5936 trn_score=0.6833 val_loss= 0.7067 val_score=0.5824 
epoch= 29 gd-step= 9712 trn_loss= 0.5824 trn_score=0.6890 val_loss= 0.7210 val_score=0.5831 
ES best epoch=6
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.18
Perform UDA offline
test domain=['7/0']
epoch=  0 gd-step=  327 trn_loss= 0.6915 trn_score=0.5311 val_loss= 0.6946 val_score=0.5174 
epoch= 10 gd-step= 3562 trn_loss= 0.6246 trn_score=0.6521 val_loss= 0.6833 val_score=0.5811 
epoch= 20 gd-step= 6800 trn_loss= 0.5924 trn_score=0.6843 val_loss= 0.6980 val_score=0.5926 
epoch= 29 gd-step= 9717 trn_loss= 0.5818 trn_score=0.6890 val_loss= 0.7065 val_score=0.5867 
ES best epoch=8
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.25
Perform UDA offline
test domain=['8/0']
epoch=  0 gd-step=  323 trn_loss= 0.6879 trn_score=0.5501 val_loss= 0.6898 val_score=0.5362 
epoch= 10 gd-step= 3553 trn_loss= 0.6262 trn_score=0.6471 val_loss= 0.6863 val_score=0.5725 
epoch= 20 gd-step= 6790 trn_loss= 0.6017 trn_score=0.6684 val_loss= 0.7035 val_score=0.5835 
epoch= 29 gd-step= 9711 trn_loss= 0.5892 trn_score=0.6788 val_loss= 0.7104 val_score=0.5799 
ES best epoch=8
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.25
Perform UDA offline
test domain=['9/0']
epoch=  0 gd-step=  324 trn_loss= 0.6866 trn_score=0.5577 val_loss= 0.6921 val_score=0.5284 
epoch= 10 gd-step= 3561 trn_loss= 0.6149 trn_score=0.6609 val_loss= 0.7027 val_score=0.5670 
epoch= 20 gd-step= 6783 trn_loss= 0.5962 trn_score=0.6786 val_loss= 0.7268 val_score=0.5670 
epoch= 29 gd-step= 9704 trn_loss= 0.5882 trn_score=0.6896 val_loss= 0.7394 val_score=0.5693 
ES best epoch=1
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.25
Perform UDA offline
test domain=['10/0']
epoch=  0 gd-step=  324 trn_loss= 0.6886 trn_score=0.5571 val_loss= 0.6916 val_score=0.5301 
epoch= 10 gd-step= 3547 trn_loss= 0.6216 trn_score=0.6521 val_loss= 0.6860 val_score=0.5811 
epoch= 20 gd-step= 6790 trn_loss= 0.6010 trn_score=0.6754 val_loss= 0.7055 val_score=0.5877 
epoch= 29 gd-step= 9723 trn_loss= 0.5853 trn_score=0.6900 val_loss= 0.7157 val_score=0.5801 
ES best epoch=7
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.5
Perform UDA offline
test domain=['11/0']
epoch=  0 gd-step=  320 trn_loss= 0.6887 trn_score=0.5373 val_loss= 0.6927 val_score=0.5170 
epoch= 10 gd-step= 3554 trn_loss= 0.6355 trn_score=0.6349 val_loss= 0.7099 val_score=0.5583 
epoch= 20 gd-step= 6781 trn_loss= 0.5948 trn_score=0.6805 val_loss= 0.7220 val_score=0.5786 
epoch= 29 gd-step= 9699 trn_loss= 0.5914 trn_score=0.6843 val_loss= 0.7381 val_score=0.5693 
ES best epoch=2
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.23
Perform UDA offline
test domain=['12/0']
epoch=  0 gd-step=  325 trn_loss= 0.6847 trn_score=0.5557 val_loss= 0.6892 val_score=0.5339 
epoch= 10 gd-step= 3549 trn_loss= 0.6166 trn_score=0.6602 val_loss= 0.7021 val_score=0.5758 
epoch= 20 gd-step= 6783 trn_loss= 0.5892 trn_score=0.6825 val_loss= 0.7130 val_score=0.5705 
epoch= 29 gd-step= 9704 trn_loss= 0.5868 trn_score=0.6895 val_loss= 0.7358 val_score=0.5661 
ES best epoch=2
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.23
Perform UDA offline


## Modified TSMNET

In [34]:
records = []

if 'inter-session' in cfg_spd['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg_spd['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)

    X_std = X.std(axis=0)
    X /= X_std + 1e-8

    xdawncov = XdawnCovariances(estimator="lwf",xdawn_estimator="lwf",nfilter=8)
    X = xdawncov.fit_transform(X,labels)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_spd['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_spd['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_spd['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_spd['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_spd['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = SPDSMNet(**mdl_kwargs).to(device=device, dtype=cfg_spd['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_spd['epochs']-10,
            bs0=cfg_spd['batch_size_train'],
            bs=cfg_spd['batch_size_train']/cfg_spd['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = Trainer(
            max_epochs=cfg_spd['epochs'],
            min_epochs=cfg_spd['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_spd['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred = trainer.pred(net,dataloader=loader_test)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))


        # extract model parameters
        state_dict = deepcopy(net.state_dict())

        # create new model and perform offline SF UDA
        print("Perform UDA offline")
        sfuda_offline_net = SPDSMNet(**mdl_kwargs).to(device=device)
        sfuda_offline_net.load_state_dict(state_dict)
        sfuda_offline(ds_test, sfuda_offline_net, cfg_spd)
        res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # create a new model and perform online SF UDA
        sfuda_online_net = SPDSMNet(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # # create a new model and perform online SF UDA
        sfuda_online_net = SPDSMNet(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))

resdf = pd.DataFrame(records)    
resdf.to_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/SDPSMNet_code.csv")

C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs 

  raw = mne.io.read_raw_eeglab(file_path_list[0], preload=True, verbose=False)


C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs 

  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.83
Perform UDA offline
test domain=['11/0']
epoch=  0 gd-step=  322 trn_loss= 0.6593 trn_score=0.6129 val_loss= 0.6609 val_score=0.6074 
epoch= 10 gd-step= 3561 trn_loss= 0.5942 trn_score=0.6832 val_loss= 0.6035 val_score=0.6718 
epoch= 19 gd-step= 6465 trn_loss= 0.5906 trn_score=0.6882 val_loss= 0.6013 val_score=0.6727 
ES best epoch=19
evaluate the estimator
 accuracy score of the participant 0.7
Perform UDA offline
test domain=['12/0']
epoch=  0 gd-step=  324 trn_loss= 0.6697 trn_score=0.6019 val_loss= 0.6736 val_score=0.5947 
epoch= 10 gd-step= 3564 trn_loss= 0.6437 trn_score=0.6329 val_loss= 0.6536 val_score=0.6229 
epoch= 19 gd-step= 6484 trn_loss= 0.6416 trn_score=0.6366 val_loss= 0.6520 val_score=0.6271 
ES best epoch=19
evaluate the estimator


  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.6
Perform UDA offline


## Modified TSMNET2

In [53]:
records = []

if 'inter-session' in cfg_spd['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg_spd['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)

    X_std = X.std(axis=0)
    X /= X_std + 1e-8

    xdawncov = XdawnCovariances(estimator="lwf",xdawn_estimator="lwf",nfilter=8)
    X = xdawncov.fit_transform(X,labels)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_spd['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_spd['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_spd['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_spd['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_spd['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = SPDSMNet2(**mdl_kwargs).to(device=device, dtype=cfg_spd['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_spd['epochs']-10,
            bs0=cfg_spd['batch_size_train'],
            bs=cfg_spd['batch_size_train']/cfg_spd['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = Trainer(
            max_epochs=cfg_spd['epochs'],
            min_epochs=cfg_spd['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_spd['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred = trainer.pred(net,dataloader=loader_test)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))


        # extract model parameters
        state_dict = deepcopy(net.state_dict())

        # create new model and perform offline SF UDA
        print("Perform UDA offline")
        sfuda_offline_net = SPDSMNet2(**mdl_kwargs).to(device=device)
        sfuda_offline_net.load_state_dict(state_dict)
        sfuda_offline(ds_test, sfuda_offline_ne, cfg_spdt)
        res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # create a new model and perform online SF UDA
        sfuda_online_net = SPDSMNet2(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # # create a new model and perform online SF UDA
        sfuda_online_net = SPDSMNet2(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))

resdf = pd.DataFrame(records)    
resdf.to_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/SDPSMNet2_code.csv")

C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\


KeyboardInterrupt: 

## DSBNSDPBNNet

In [None]:
records = []

if 'inter-session' in cfg_spd['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg_spd['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)

    X_std = X.std(axis=0)
    X /= X_std + 1e-8

    xdawncov = XdawnCovariances(estimator="lwf",xdawn_estimator="lwf",nfilter=8)
    X = xdawncov.fit_transform(X,labels)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_spd['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_spd['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_spd['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_spd['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_spd['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = DSBNSPDBNNet(**mdl_kwargs).to(device=device, dtype=cfg_spd['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_spd['epochs']-10,
            bs0=cfg_spd['batch_size_train'],
            bs=cfg_spd['batch_size_train']/cfg_spd['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = Trainer(
            max_epochs=cfg_spd['epochs'],
            min_epochs=cfg_spd['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_spd['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred = trainer.pred(net,dataloader=loader_test)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))


        # extract model parameters
        state_dict = deepcopy(net.state_dict())

        # create new model and perform offline SF UDA
        print("Perform UDA offline")
        sfuda_offline_net = SPDSMNet2(**mdl_kwargs).to(device=device)
        sfuda_offline_net.load_state_dict(state_dict)
        sfuda_offline(ds_test, sfuda_offline_net, cfg_spd)
        res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # create a new model and perform online SF UDA
        sfuda_online_net = SPDSMNet2(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # # create a new model and perform online SF UDA
        sfuda_online_net = SPDSMNet2(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))

resdf = pd.DataFrame(records)    
resdf.to_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/DSBMSDPBNNet_code.csv")

## Modified TSMNET_visu

In [17]:
records = []

if 'inter-session' in cfg_spd['evaluation']:
    # subset_iter = iter([[s] for s in moabb_ds.subject_list])
    subset_iter = iter([[s] for s in [1,2,3]])
    groupvarname = 'session'
elif 'inter-subject' in cfg_spd['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=[1,2,3], return_epochs=False)

    X_std = X.std(axis=0)
    X /= X_std + 1e-8

    xdawncov = XdawnCovariances(estimator="lwf",xdawn_estimator="lwf",nfilter=8)
    X = xdawncov.fit_transform(X,labels)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_spd['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()
    
    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):
        visualisation_layers = []
        visualisation_layers_val = []

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_spd['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_spd['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_spd['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_spd['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = SPDSMNet_visu(**mdl_kwargs).to(device=device, dtype=cfg_spd['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_spd['epochs']-10,
            bs0=cfg_spd['batch_size_train'],
            bs=cfg_spd['batch_size_train']/cfg_spd['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = VisuTrainer(
            max_epochs=cfg_spd['epochs'],
            min_epochs=cfg_spd['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_spd['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred, inter_layer_out = trainer.pred(net,dataloader=loader_test, return_interbn=True)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))

        # Get visual of the data between the layers
        visu_layer = []
        ind_1 = np.random.choice(np.where(y[test].numpy()==1)[0],size=400,replace=False)
        ind_0 = np.random.choice(np.where(y[test].numpy()==0)[0],size=400,replace=False)
        for l in inter_layer_out[0]:
            cov1 = np.array([l[i].cpu().numpy() for i in ind_1])
            cov0 = np.array([l[i].cpu().numpy() for i in ind_0])
            visu_layer.append(get_TSNE_visu(cov0,400))
            visu_layer.append(get_TSNE_visu(cov1,400))

        visualisation_layers.append(visu_layer)
        # extract model parameters
        val_pred, val_inter_layer_out = trainer.pred(net,dataloader=loader_val, return_interbn=True)
        # Get visual of the data between the layers
        visu_layer_val = []
        ind_1 = np.where(y_fit[val].numpy()==1)[0]
        ind_0 = np.where(y_fit[val].numpy()==0)[0]
        for l in val_inter_layer_out[0]:
            for d in np.unique(domain_fit[train].numpy()):
                ind_dom1 = np.random.choice(np.where(domain_fit[val][ind_1].numpy()==d)[0],size=200,replace=False)
                ind_dom0 = np.random.choice(np.where(domain_fit[val][ind_0].numpy()==d)[0],size=200,replace=False)
                cov1 = np.array([l[i].cpu().numpy() for i in ind_dom1])
                cov0 = np.array([l[i].cpu().numpy() for i in ind_dom0])
                visu_layer_val.append(get_TSNE_visu(cov0,200))
                visu_layer_val.append(get_TSNE_visu(cov1,200))

        visualisation_layers_val.append(visu_layer_val)
        

        # # create new model and perform offline SF UDA
        # print("Perform UDA offline")
        # sfuda_offline_net = SPDSMNet2(**mdl_kwargs).to(device=device)
        # sfuda_offline_net.load_state_dict(state_dict)
        # sfuda_offline(ds_test, sfuda_offline_net, cfg_spd)
        # res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        # records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # # create a new model and perform online SF UDA
        # sfuda_online_net = SPDSMNet2(**mdl_kwargs)
        # sfuda_online_net.load_state_dict(state_dict)
        # loss, score = sfuda_online(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        # records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # # # create a new model and perform online SF UDA
        # sfuda_online_net = SPDSMNet2(**mdl_kwargs)
        # sfuda_online_net.load_state_dict(state_dict)
        # loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        # records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))

        np.save("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet_test/SDPSMNet_visu_{}.npy".format(ix_fold), np.array(visualisation_layers))
        np.save("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet_test/SDPSMNet_val_visu_{}.npy".format(ix_fold), np.array(visualisation_layers_val))

C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\


KeyboardInterrupt: 

## Modified DSBNSPDBNNet_visu

In [28]:
records = []

if 'inter-session' in cfg_spd['evaluation']:
    # subset_iter = iter([[s] for s in moabb_ds.subject_list])
    subset_iter = iter([[s] for s in [[4,5,61,2,3]]])
    groupvarname = 'session'
elif 'inter-subject' in cfg_spd['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    _,_, metadata = moabb_paradigm.get_data(moabb_ds, subjects=[1,2,3], return_epochs=False)
    X = []
    labels = []
    for s in [1,2,3]:
        X_temp, labels_temp, _ = moabb_paradigm.get_data(moabb_ds, subjects=[s], return_epochs=False)

        X_std = X_temp.std(axis=0)
        X_temp /= X_std + 1e-8

        xdawncov = XdawnCovariances(estimator="lwf",xdawn_estimator="lwf",nfilter=8)
        X_temp = xdawncov.fit_transform(X_temp,labels_temp)
        X.append(X_temp)
        labels.append(labels_temp)

    X = np.concatenate(np.array(X))
    labels = np.concatenate(np.array(labels))

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_spd['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()
    
    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):
        visualisation_layers = []
        visualisation_layers_val = []

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_spd['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_spd['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_spd['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_spd['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = DSBNSPDBNNet_Visu(**mdl_kwargs).to(device=device, dtype=cfg_spd['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_spd['epochs']-10,
            bs0=cfg_spd['batch_size_train'],
            bs=cfg_spd['batch_size_train']/cfg_spd['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = VisuTrainer(
            max_epochs=cfg_spd['epochs'],
            min_epochs=cfg_spd['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_spd['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred, inter_layer_out = trainer.pred(net,dataloader=loader_test, return_interbn=True)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))

        # Get visual of the data between the layers
        val_pred, val_inter_layer_out = trainer.pred(net,dataloader=loader_val, return_interbn=True)
        train_pred, train_inter_layer_out = trainer.pred(net,dataloader=torch.utils.data.DataLoader(ds_train, batch_size=len(ds_train),shuffle=False),
                                                          return_interbn=True)
        visu_layer = []
        ind_1 = np.random.choice(np.where(y[test].numpy()==1)[0],size=150,replace=False)
        ind_0 = np.random.choice(np.where(y[test].numpy()==0)[0],size=150,replace=False)
        ind_1_val = np.where(y_fit[val].numpy()==1)[0]
        ind_0_val = np.where(y_fit[val].numpy()==0)[0]
        ind_1_train = np.where(y_fit[train].numpy()==1)[0]
        ind_0_train = np.where(y_fit[train].numpy()==0)[0]
        for l, l_val,l_train in zip(inter_layer_out[0],val_inter_layer_out[0],train_inter_layer_out[0]):
            visu_layer = []
            for d in np.unique(domain_fit[val].numpy()):
                ind_dom1 = np.random.choice(np.where(domain_fit[val][ind_1_val].numpy()==d)[0],size=150,replace=False)
                ind_dom0 = np.random.choice(np.where(domain_fit[val][ind_0_val].numpy()==d)[0],size=150,replace=False)
                cov1_val = np.array([l_val[ind_1_val][i].cpu().numpy() for i in ind_dom1])
                cov0_val = np.array([l_val[ind_0_val][i].cpu().numpy() for i in ind_dom0])
                visu_layer.append(cov1_val)
                visu_layer.append(cov0_val)
            for d in np.unique(domain_fit[train].numpy()):
                ind_dom1 = np.random.choice(np.where(domain_fit[train][ind_1_train].numpy()==d)[0],size=150,replace=False)
                ind_dom0 = np.random.choice(np.where(domain_fit[train][ind_0_train].numpy()==d)[0],size=150,replace=False)
                cov1_train = np.array([l_train[ind_1_train][i].cpu().numpy() for i in ind_dom1])
                cov0_train = np.array([l_train[ind_0_train][i].cpu().numpy() for i in ind_dom0])
                visu_layer.append(cov1_train)
                visu_layer.append(cov0_train)
            cov1 = np.array([l[i].cpu().numpy() for i in ind_1])
            cov0 = np.array([l[i].cpu().numpy() for i in ind_0])
            visu_layer.append(cov0)
            visu_layer.append(cov1)

            visu_layer = np.concatenate(visu_layer)
            to_visualize = TangentSpace().fit_transform(visu_layer)
            X_embedded = TSNE(n_components=2, learning_rate='auto',init='random', perplexity=300).fit_transform(to_visualize)
            visualisation_layers.append(X_embedded)
        # visualisation_layers.append(get_TSNE_visu(visu_layer,visu_layer.shape[0]))
        

        np.save("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet_test/DSBNSPDBNNet_visu_{}.npy".format(ix_fold), np.array(visualisation_layers))

C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped


  raw = mne.io.read_raw_eeglab(file_path_list[0], preload=True, verbose=False)


C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped


  raw = mne.io.read_raw_eeglab(file_path_list[0], preload=True, verbose=False)


test domain=['4/0']
epoch=  0 gd-step=   59 trn_loss= 0.6849 trn_score=0.5871 val_loss= 0.6865 val_score=0.5771 
epoch= 10 gd-step=  649 trn_loss= 0.6194 trn_score=0.6925 val_loss= 0.6254 val_score=0.6937 
epoch= 19 gd-step= 1180 trn_loss= 0.5842 trn_score=0.6992 val_loss= 0.5924 val_score=0.6958 
ES best epoch=19
evaluate the estimator
 accuracy score of the participant 0.25
test domain=['5/0']
epoch=  0 gd-step=   59 trn_loss= 0.6912 trn_score=0.5265 val_loss= 0.6926 val_score=0.5083 
epoch= 10 gd-step=  649 trn_loss= 0.6660 trn_score=0.6176 val_loss= 0.6714 val_score=0.5865 
epoch= 19 gd-step= 1180 trn_loss= 0.6472 trn_score=0.6297 val_loss= 0.6546 val_score=0.6094 
ES best epoch=17
evaluate the estimator
 accuracy score of the participant 0.3
test domain=['6/0']
epoch=  0 gd-step=   59 trn_loss= 0.6900 trn_score=0.5339 val_loss= 0.6910 val_score=0.5198 
epoch= 10 gd-step=  649 trn_loss= 0.6707 trn_score=0.6017 val_loss= 0.6759 val_score=0.5833 
epoch= 19 gd-step= 1180 trn_loss= 0.6

  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.25


In [13]:
fit

array([    0,     1,     2, ..., 14037, 14038, 14039])

## Modified DSBNSPDBNNet_visu LOOA

In [12]:
subjects = [1,2,3]
# subjects = [1,2,3,4,5,6,7,8,9,10,11,12]
n_channels = 32
on_frame = True
if on_frame:
    freq = 60
else:
    freq = 500

raw_data,labels,codes,labels_codes = get_BVEP_data(subjects,on_frame)
X_parent, labels_parent, domains_parent = prepare_data(subjects,raw_data, labels, on_frame,True,False,codes)
metadata = pd.DataFrame({"subject":np.repeat(list(map(str,subjects)),X_parent[0].shape[0]),"session":["0"]*len(subjects)*X_parent[0].shape[0]})

Choosing the first None classes from all possible events.


None
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
C:\Users\s.velut\mne_data\MNE-4class-vep-data\records\8255618\files\
Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
[1, 2, 3]
[1, 2, 3]
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandw

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 25 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 25.00 Hz
- Upper transition bandwidth: 6.25 Hz (-6 dB cutoff frequency: 28.12 Hz)
- Filter length: 1651 samples (3.302 s)

7020 events found
Event IDs: [100 101]
Not setting metadata
7020 matching events found
Setting baseline interval to [-0.01, 0.0] s
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 7020 events and 131 original time points ...


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


0 bad epochs dropped


In [27]:
records = []
n_cal = 5

if 'inter-session' in cfg_spd['evaluation']:
    # subset_iter = iter([[s] for s in moabb_ds.subject_list])
    subset_iter = iter([[s] for s in [4,5,6]])
    groupvarname = 'session'
elif 'inter-subject' in cfg_spd['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    X = np.concatenate(X_parent)
    labels = np.concatenate(labels_parent)
    domains = np.concatenate(domains_parent)

    X_std = X.std(axis=0)
    X /= X_std + 1e-8

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_spd['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()


    
    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):
        fit = np.concatenate([fit,test[:int(1.95*n_cal*4*60)]])
        test = test[int(1.95*n_cal*4*60):]
        visualisation_layers = []
        visualisation_layers_val = []

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_spd['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_spd['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_spd['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_spd['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = DSBNSPDBNNet_Visu(**mdl_kwargs).to(device=device, dtype=cfg_spd['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_spd['epochs']-10,
            bs0=cfg_spd['batch_size_train'],
            bs=cfg_spd['batch_size_train']/cfg_spd['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = VisuTrainer(
            max_epochs=cfg_spd['epochs'],
            min_epochs=cfg_spd['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_spd['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred, inter_layer_out = trainer.pred(net,dataloader=loader_test, return_interbn=True)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][n_cal*4:][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))

        # Get visual of the data between the layers
        val_pred, val_inter_layer_out = trainer.pred(net,dataloader=loader_val, return_interbn=True)
        train_pred, train_inter_layer_out = trainer.pred(net,dataloader=torch.utils.data.DataLoader(ds_train, batch_size=len(ds_train),shuffle=False),
                                                          return_interbn=True)
        visu_layer = []
        ind_1 = np.random.choice(np.where(y[test].numpy()==1)[0],size=50,replace=False)
        ind_0 = np.random.choice(np.where(y[test].numpy()==0)[0],size=50,replace=False)
        ind_1_val = np.where(y_fit[val].numpy()==1)[0]
        ind_0_val = np.where(y_fit[val].numpy()==0)[0]
        ind_1_train = np.where(y_fit[train].numpy()==1)[0]
        ind_0_train = np.where(y_fit[train].numpy()==0)[0]
        for l, l_val,l_train in zip(inter_layer_out[0],val_inter_layer_out[0],train_inter_layer_out[0]):
            visu_layer = []
            for d in np.unique(domain_fit[val].numpy()):
                ind_dom1 = np.random.choice(np.where(domain_fit[val][ind_1_val].numpy()==d)[0],size=50,replace=False)
                ind_dom0 = np.random.choice(np.where(domain_fit[val][ind_0_val].numpy()==d)[0],size=50,replace=False)
                cov1_val = np.array([l_val[ind_1_val][i].cpu().numpy() for i in ind_dom1])
                cov0_val = np.array([l_val[ind_0_val][i].cpu().numpy() for i in ind_dom0])
                visu_layer.append(cov1_val)
                visu_layer.append(cov0_val)
            for d in np.unique(domain_fit[train].numpy()):
                ind_dom1 = np.random.choice(np.where(domain_fit[train][ind_1_train].numpy()==d)[0],size=50,replace=False)
                ind_dom0 = np.random.choice(np.where(domain_fit[train][ind_0_train].numpy()==d)[0],size=50,replace=False)
                cov1_train = np.array([l_train[ind_1_train][i].cpu().numpy() for i in ind_dom1])
                cov0_train = np.array([l_train[ind_0_train][i].cpu().numpy() for i in ind_dom0])
                visu_layer.append(cov1_train)
                visu_layer.append(cov0_train)
            cov1 = np.array([l[i].cpu().numpy() for i in ind_1])
            cov0 = np.array([l[i].cpu().numpy() for i in ind_0])
            visu_layer.append(cov0)
            visu_layer.append(cov1)

            visu_layer = np.concatenate(visu_layer)
            to_visualize = TangentSpace().fit_transform(visu_layer)
            X_embedded = TSNE(n_components=2, learning_rate='auto',init='random', perplexity=100).fit_transform(to_visualize)
            visualisation_layers.append(X_embedded)
        # visualisation_layers.append(get_TSNE_visu(visu_layer,visu_layer.shape[0]))
        

        np.save("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet_test/DSBNSPDBNNet_visu_{}.npy".format(ix_fold), np.array(visualisation_layers))

test domain=['1/0']
epoch=  0 gd-step=   30 trn_loss= 0.6901 trn_score=0.5439 val_loss= 0.6901 val_score=0.5509 
epoch= 10 gd-step=  330 trn_loss= 0.6803 trn_score=0.5931 val_loss= 0.6835 val_score=0.5670 
epoch= 19 gd-step=  600 trn_loss= 0.6712 trn_score=0.6153 val_loss= 0.6799 val_score=0.5527 
ES best epoch=16
evaluate the estimator
 accuracy score of the participant 0.4
test domain=['2/0']
epoch=  0 gd-step=   29 trn_loss= 0.6932 trn_score=0.5041 val_loss= 0.6945 val_score=0.4821 
epoch= 10 gd-step=  319 trn_loss= 0.6732 trn_score=0.6196 val_loss= 0.6751 val_score=0.6161 
epoch= 19 gd-step=  580 trn_loss= 0.6570 trn_score=0.6244 val_loss= 0.6549 val_score=0.6366 
ES best epoch=19
evaluate the estimator
 accuracy score of the participant 0.6
test domain=['3/0']
epoch=  0 gd-step=   30 trn_loss=    nan trn_score=0.5000 val_loss=    nan val_score=0.5000 
epoch= 10 gd-step=  330 trn_loss=    nan trn_score=0.5000 val_loss=    nan val_score=0.5000 
epoch= 19 gd-step=  600 trn_loss=    n

  c /= stddev[:, None]
  c /= stddev[None, :]


 accuracy score of the participant 0.25


ValueError: Matrices must be positive definite. Add regularization to avoid this error.

In [None]:
np.unique(domain_fit[val].numpy())

(29,)

## original TSMNET on samples

In [17]:
subjects = [1,2,3]
# subjects = [1,2,3,4,5,6,7,8,9,10,11,12]
n_channels = 32
on_frame = False
if on_frame:
    freq = 60
else:
    freq = 500

raw_data,labels,codes,labels_codes = get_BVEP_data(subjects,on_frame)
X_parent, labels_parent, domains_parent = prepare_data(subjects,raw_data, labels, on_frame,False,codes)
metadata = pd.DataFrame({"subject":np.repeat(list(map(str,subjects)),X_parent[0].shape[0]),"session":["0"]*len(subjects)*X_parent[0].shape[0]})

Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 50 - 50 Hz

IIR filter parameters
---------------------
Butterworth bandstop zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 49.90, 50.10 Hz: -6.02, -6.02 dB

Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 50 - 50 Hz

IIR filter parameters
---------------------
Butterworth bandstop zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 49.90, 50.10 Hz: -6.02, -6.02 dB

Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 50 - 50 Hz

IIR filter parameters
---------------------
Butterworth bandstop zero-phase (two-pas

In [18]:
records = []

if 'inter-session' in cfg_org['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg_org['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    # X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)
    X = np.concatenate(X_parent)
    labels = np.concatenate(labels_parent)
    domains = np.concatenate(domains_parent)

    X_std = X.std(axis=0)
    X /= X_std + 1e-8

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_org['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_org['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_org['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_org['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_org['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = TSMNet(**mdl_kwargs).to(device=device, dtype=cfg_org['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_org['epochs']-10,
            bs0=cfg_org['batch_size_train'],
            bs=cfg_org['batch_size_train']/cfg_org['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = Trainer(
            max_epochs=cfg_org['epochs'],
            min_epochs=cfg_org['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_org['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred = trainer.pred(net,dataloader=loader_test)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))


        # extract model parameters
        state_dict = deepcopy(net.state_dict())

        # create new model and perform offline SF UDA
        print("Perform UDA offline")
        sfuda_offline_net = TSMNet(**mdl_kwargs).to(device=device)
        sfuda_offline_net.load_state_dict(state_dict)
        sfuda_offline(ds_test, sfuda_offline_net)
        res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # create a new model and perform online SF UDA
        sfuda_online_net = TSMNet(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online(ds_test, sfuda_online_net, cfg_org, trainer.loss_fn)
        records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # create a new model and perform online SF UDA
        sfuda_online_net = TSMNet(**mdl_kwargs)
        sfuda_online_net.load_state_dict(state_dict)
        loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg_org, trainer.loss_fn)
        records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))


test domain=['1/0']
epoch=  0 gd-step=  498 trn_loss= 0.6509 trn_score=0.6513 val_loss= 0.6569 val_score=0.6383 
epoch= 10 gd-step= 5478 trn_loss= 0.4250 trn_score=0.8138 val_loss= 0.4748 val_score=0.7852 
epoch= 19 gd-step= 9960 trn_loss= 0.4401 trn_score=0.7983 val_loss= 0.4919 val_score=0.7658 
ES best epoch=18
evaluate the estimator


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.49 GiB. GPU 0 has a total capacty of 6.00 GiB of which 0 bytes is free. Of the allocated memory 5.15 GiB is allocated by PyTorch, and 32.46 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

## Modified TSMNET on samples

In [19]:
subjects = [1,2,3]
# subjects = [1,2,3,4,5,6,7,8,9,10,11,12]
n_channels = 32
on_frame = False
if on_frame:
    freq = 60
else:
    freq = 500

raw_data,labels,codes,labels_codes = get_BVEP_data(subjects,on_frame)
X_parent, labels_parent, domains_parent = prepare_data(subjects,raw_data, labels, on_frame,True,codes)
metadata = pd.DataFrame({"subject":np.repeat(list(map(str,subjects)),X_parent[0].shape[0]),"session":["0"]*len(subjects)*X_parent[0].shape[0]})

Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 50 - 50 Hz

IIR filter parameters
---------------------
Butterworth bandstop zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 49.90, 50.10 Hz: -6.02, -6.02 dB

Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 50 - 50 Hz

IIR filter parameters
---------------------
Butterworth bandstop zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 49.90, 50.10 Hz: -6.02, -6.02 dB

Using data from preloaded Raw for 60 events and 1101 original time points ...
0 bad epochs dropped
Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 50 - 50 Hz

IIR filter parameters
---------------------
Butterworth bandstop zero-phase (two-pas

In [20]:
# metadata = pd.DataFrame({"subject":np.repeat(list(map(str,subjects)),X[0].shape[0]),"session":["0"]*len(subjects)*X[0].shape[0]})
metadata

Unnamed: 0,subject,session
0,1,0
1,1,0
2,1,0
3,1,0
4,1,0
...,...,...
175495,3,0
175496,3,0
175497,3,0
175498,3,0


In [21]:
records = []

if 'inter-session' in cfg_spd['evaluation']:
    subset_iter = iter([[s] for s in moabb_ds.subject_list])
    groupvarname = 'session'
elif 'inter-subject' in cfg_spd['evaluation']:
    subset_iter = iter([None])
    groupvarname = 'subject'
else:
    raise NotImplementedError()

fit_records = []

# iterate over subject groups
for ix_subset, subjects in enumerate(subset_iter):

    # get the data from the MOABB paradigm/dataset
    # X, labels, metadata = moabb_paradigm.get_data(moabb_ds, subjects=subjects, return_epochs=False)
    X = np.concatenate(X_parent)
    labels = np.concatenate(labels_parent)
    domains = np.concatenate(domains_parent)

    X_std = X.std(axis=0)
    X /= X_std + 1e-8

    # xdawncov = XdawnCovariances(estimator="lwf",xdawn_estimator="lwf",nfilter=8)
    # X = xdawncov.fit_transform(X,labels)

    # extract domains = subject/session
    metadata['label'] = labels
    metadata['domain'] = metadata.apply(lambda row: f'{row.subject}/{row.session}',  axis=1)
    domain = sklearn.preprocessing.LabelEncoder().fit_transform(metadata['domain'])

    # convert to torch tensors
    domain = torch.from_numpy(domain)
    X = torch.from_numpy(X)
    y = sklearn.preprocessing.LabelEncoder().fit_transform(labels)
    y = torch.from_numpy(y)

    # leave one subject or session out
    cv_outer = sklearn.model_selection.LeaveOneGroupOut()
    cv_outer_group = metadata[groupvarname]

    # train/validation split stratified across domains and labels
    cv_inner_group = metadata.apply(lambda row: f'{row.domain}/{row.label}',  axis=1)
    cv_inner_group = sklearn.preprocessing.LabelEncoder().fit_transform(cv_inner_group)

    # add datadependen model kwargs
    mdl_kwargs = deepcopy(cfg_spd['mdl_kwargs'])
    mdl_kwargs['nclasses'] = n_classes
    mdl_kwargs['nchannels'] = X.shape[1]
    mdl_kwargs['nsamples'] = X.shape[2]
    mdl_kwargs['domains'] = domain.unique()

    # perform outer CV
    for ix_fold, (fit, test) in enumerate(cv_outer.split(X, y, cv_outer_group)):

        # Balancing the data
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(y[fit]))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,y[fit][:])
        index = np.sort(index,axis=0)
        X_fit = np.squeeze(X[fit][index,:,:], axis=1)
        y_fit = np.squeeze(y[fit][index])
        domain_fit = np.squeeze(domain[fit][index])
        metadata_fit = metadata.loc[fit].iloc[np.concatenate(index)]

        # split fitting data into train and validation 
        cv_inner = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, test_size=cfg_spd['validation_size'])
        train, val = next(cv_inner.split(X_fit, y_fit, np.squeeze(cv_inner_group[fit][index])))

        # adjust number of 
        du = domain_fit[train].unique()
        if cfg_spd['domains_per_batch'] > len(du):
            domains_per_batch = len(du)
        else:
            domains_per_batch = cfg_spd['domains_per_batch']

        # split entire dataset into train/validation/test
        ds_train = DomainDataset(X_fit[train], y_fit[train], domain_fit[train], metadata_fit.iloc[train,:])
        ds_val = DomainDataset(X_fit[val], y_fit[val], domain_fit[val], metadata_fit.iloc[val,:])
        ds_test = DomainDataset(X[test], y[test], domain[test], metadata.iloc[test,:])

        # create dataloaders
        # for training use specific loader/sampler so taht 
        # batches contain a specific number of domains with equal observations per domain
        # and stratified labels
        loader_train = StratifiedDomainDataLoader(ds_train, cfg_spd['batch_size_train'], domains_per_batch=domains_per_batch, shuffle=True)
        loader_val = torch.utils.data.DataLoader(ds_val, batch_size=len(ds_val))
        loader_test = torch.utils.data.DataLoader(ds_test, batch_size=len(ds_test))

        # extract domains in the test dataset
        test_domain = metadata['domain'].iloc[test].unique()

        # create the model
        net = SPDSMNet(**mdl_kwargs).to(device=device, dtype=cfg_spd['dtype'])

        # create the momentum scheduler
        bn_sched = MomentumBatchNormScheduler(
            epochs=cfg_spd['epochs']-10,
            bs0=cfg_spd['batch_size_train'],
            bs=cfg_spd['batch_size_train']/cfg_spd['domains_per_batch'], 
            tau0=0.85
        )

        es = EarlyStopping(metric='val_loss', higher_is_better=False, patience=15, verbose=False)
        
        # create the trainer
        trainer = Trainer(
            max_epochs=cfg_spd['epochs'],
            min_epochs=cfg_spd['epochs'],
            callbacks=[bn_sched, es],
            loss=torch.nn.CrossEntropyLoss(),
            device=device, 
            dtype=cfg_spd['dtype']
        )

        # fit the model
        print(f"test domain={test_domain}")
        trainer.fit(net, train_dataloader=loader_train, val_dataloader=loader_val)

        print(f'ES best epoch={es.best_epoch}')

        fit_df = pd.DataFrame(trainer.records)
        fit_df['fold'] = ix_fold
        fit_df['subset'] = ix_subset
        fit_records.append(fit_df)

        # evaluation
        print("evaluate the estimator")
        res = trainer.test(net, dataloader=loader_train)
        records.append(dict(mode='train', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_val)
        records.append(dict(mode='validation', domain=test_domain, **res))
        res = trainer.test(net, dataloader=loader_test)
        records.append(dict(mode='test(noUDA)', domain=test_domain, **res))


        y_pred = trainer.pred(net,dataloader=loader_test)
        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred[0].cpu(), codes, min_len=30, sfreq=60, consecutive=50, window_size=0.25
        )
        accuracy_code = np.round(balanced_accuracy_score(labels_codes[ix_fold][labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        print(" accuracy score of the participant",accuracy_code)
        records.append(dict(mode='test(noUDA)_code', domain=test_domain, score=accuracy_code, loss=None))


        # extract model parameters
        state_dict = deepcopy(net.state_dict())

        # create new model and perform offline SF UDA
        print("Perform UDA offline")
        sfuda_offline_net = SPDSMNet(**mdl_kwargs).to(device=device)
        sfuda_offline_net.load_state_dict(state_dict)
        sfuda_offline(ds_test, sfuda_offline_net)
        res = trainer.test(sfuda_offline_net, dataloader=loader_test)
        records.append(dict(mode='test(SFUDA)', domain=test_domain, **res))

        # create a new model and perform online SF UDA
        # sfuda_online_net = SPDSMNet(**mdl_kwargs)
        # sfuda_online_net.load_state_dict(state_dict)
        # loss, score = sfuda_online(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        # records.append(dict(mode='test(onlineSFUDA)', domain=test_domain, loss=loss, score=score))

        # # create a new model and perform online SF UDA
        # sfuda_online_net = SPDSMNet(**mdl_kwargs)
        # sfuda_online_net.load_state_dict(state_dict)
        # loss, score = sfuda_online_simulate(ds_test, sfuda_online_net, cfg_spd, trainer.loss_fn)
        # records.append(dict(mode='test(online_sim_SFUDA)', domain=test_domain, loss=loss, score=score))


test domain=['1/0']


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)

## Get the results

In [54]:
# report the results
resdf = pd.DataFrame(records)

resdf.groupby(['mode']).agg(['mean', 'std']).round(2)

KeyError: 'mode'

In [36]:
resdf.to_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/SPDSMNet_code.csv")
# resdf

In [37]:
df = pd.read_csv("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/SPDSMNet_code.csv")
print()
df_score_code = np.array(df["score"])[df["mode"]=="test(noUDA)_code"]
df_score = np.array(df["score"])[df["mode"]=="test(noUDA)"]

np.save("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/SPDSMNet_score_code",df_score_code)
np.save("C:/Users/s.velut/Documents/These/Protheus_PHD/results/results/Score_TF/TSMNet/SPDSMNet_score",df_score)




In [38]:
fit_df = pd.concat(fit_records)
fit_df = fit_df.groupby(['subset', 'fold', 'epoch']).mean()
fit_df.columns = fit_df.columns.str.split('_', expand=True)
fit_df.columns.names = ['set', 'metric']
fit_df = fit_df.stack('set').reset_index()


sns.relplot(data=fit_df, x='epoch', y='loss', hue='set', col='subset', col_wrap=5, kind='line', n_boot=100)

<seaborn.axisgrid.FacetGrid at 0x23432e9ee10>