# EEG-TEA: A Novel Application of Test-time Energy Adaptation in Electroencephalogram Decoding

In [None]:
from eeg_otta.utils.config_setup import setup_config
from eeg_otta.utils.get_accuracy import calculate_accuracy
from eeg.analysis.plots import plot_energy_accuracy_loss, plot_accuracy
from eeg_otta.utils.embedding_eval import plot_embeddings

### 1. Run adaptation
Experiment loso:
- dataset_name: 2a
- dataset_setup: loso
- corruption_level: None

Experiment corruption:
- dataset_name: 2b
- dataset_setup: within
- corruption_level: 1, 2, 3, 4, 5

In [None]:
dataset_name = '2a' # ['2a', '2b']
dataset_setup = 'loso' # ['within', 'loso']
corruption_level = 1 # [None, 1, 2, 3, 4, 5] (None = no corruption)

In [None]:
seeds = [0]
save_dir = {}
tea_model_dict = {}
adaptation_methods = [ 'source',  'entropy_minimization', 'tea']

for seed in seeds:
    for adaptation_method in adaptation_methods:
        print(f"Evaluating adaptation method: {adaptation_method}")
        model_cls, tta_cls, datamodule, config = setup_config(dataset_name, dataset_setup, adaptation_method, seed, corruption_level)
        test_accuracy, model_dict = calculate_accuracy(model_cls, tta_cls, datamodule, config, get_model_dict=True)

        if adaptation_method == 'tea':
            tea_model_dict = model_dict
        save_dir[adaptation_method] = config['tta_config']['save_dir']

### 2. Plot results

In [None]:
plot_energy_accuracy_loss(save_dir['tea'])

In [None]:
plot_accuracy(save_dir)

### 3. Plot PCA embeddings

In [None]:
SUBJECT_ID = 1
corruption_level = None
adapted = False

if not adapted:
    model, _, _, _ = setup_config(dataset_name, dataset_setup, 'tea', 0, corruption_level)
else:
    model = tea_model_dict[SUBJECT_ID]

datamodule.subject_id = SUBJECT_ID
datamodule.prepare_data()
datamodule.setup()
datamodule.corruption_level = corruption_level

plot_embeddings(tea_model_dict[SUBJECT_ID], datamodule.train_dataloader(), datamodule.test_dataloader())