In [2]:
import numpy as np
import torch 

from braindecode.datasets.base import BaseConcatDataset
from braindecode.datasets import TUHAbnormal
from braindecode import EEGClassifier
from braindecode.models import ShallowFBCSPNet, deep4
from braindecode.preprocessing import create_fixed_length_windows

from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer

from EEGClip.clip_models import EEGClipModule

import mne
mne.set_log_level('ERROR')  # avoid messages everytime a window is extracted

n_jobs = 4
data_path = '/home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal/v2.0.0/edf/'
N_SAMPLES = 100
N_JOBS = 8 

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tuabn = TUHAbnormal(
        path=data_path,
        preload=False,  # True
        #add_physician_reports=True, 
        n_jobs=n_jobs,
        target_name = ('subject'),
        recording_ids=list(range(N_SAMPLES)),
    )


Extracting EDF parameters from /home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal/v2.0.0/edf/eval/normal/01_tcp_ar/058/00005864/s001_2009_09_03/00005864_s001_t000.edf...
EDF file detected
Extracting EDF parameters from /home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal/v2.0.0/edf/train/normal/01_tcp_ar/009/00000929/s003_2009_09_04/00000929_s003_t002.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal/v2.0.0/edf/eval/normal/01_tcp_ar/058/00005851/s001_2009_09_04/00005851_s001_t001.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal/v2.0.0/edf/eval/normal/01_tcp_ar/041/00004196/s003_2009_09_03/00004196_s003_t000.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Ext

In [9]:
print(len(tuabn))
x, y = tuabn[-1]
print('x:', x)
print('y:', y)

13249100
x: [[ 2.30026551e-09]
 [ 1.75603656e-08]
 [-7.87313456e-09]
 [ 2.30026551e-09]
 [ 2.30026551e-09]
 [ 1.24736656e-08]
 [ 2.30026551e-09]
 [ 2.30026551e-09]
 [-7.87313456e-09]
 [-2.78643452e-09]
 [ 1.24736656e-08]
 [ 1.75603656e-08]
 [ 2.30026551e-09]
 [-2.78643452e-09]
 [-2.78643452e-09]
 [-2.78643452e-09]
 [ 2.30026551e-09]
 [ 2.30026551e-09]
 [ 2.30026551e-09]
 [ 2.30026551e-09]
 [-2.78643452e-09]]
y: 6523


In [4]:
np.unique(tuabn.description["subject"].values) 

array([ 647,  929, 1355, 1402, 2272, 2508, 2705, 2740, 2775, 2849, 2940,
       3623, 3884, 4196, 4209, 4392, 4526, 4586, 4806, 4933, 5071, 5221,
       5372, 5390, 5394, 5395, 5398, 5400, 5412, 5655, 5721, 5851, 5864,
       5897, 5909, 5921, 5928, 5931, 5942, 6005, 6038, 6056, 6061, 6076,
       6081, 6088, 6095, 6096, 6099, 6101, 6103, 6111, 6113, 6117, 6140,
       6144, 6146, 6155, 6156, 6162, 6163, 6183, 6186, 6187, 6188, 6201,
       6212, 6213, 6215, 6218, 6219, 6222, 6224, 6227, 6236, 6238, 6240,
       6290, 6297, 6311, 6312, 6316, 6317, 6319, 6322, 6328, 6422, 6488,
       6490, 6523, 6926, 6927, 6936, 6955, 6964, 6965, 7026])

In [8]:
import pandas as pd
full_desc = pd.read_csv('/home/jovyan/EEGClip/data/TUH_Abnormal_EEG_rep.csv')

In [13]:
set(np.unique(tuabn.description["subject"].values)).issubset(set(np.unique(full_desc["SUBJECT"].values)))

True

In [14]:
full_desc.duplicated(subset=['SUBJECT'])

0       False
1       False
2       False
3       False
4       False
        ...  
2989     True
2990     True
2991    False
2992    False
2993    False
Length: 2994, dtype: bool

In [36]:
tuabn.description[tuabn.description.subject.isin([4933,6523])]["pathological"].tolist()

[False, False]

In [30]:
full_desc[full_desc["SUBJECT"]==subject_id].iloc[0]["DESCRIPTION OF THE RECORD"]

'In wakefulness, the background EEG is relatively low voltage, but there is a discernible 10.5 Hz alpha rhythm in a background of generous beta. Much of the record includes drowsiness with attenuation of the alpha rhythm. Hyperventilation produces an increase in amplitude of the record. Photic stimulation does elicit a driving response.'

In [37]:
sfreq  = 100
n_minutes = 20

subject_datasets = tuabn.split('subject')
n_subjects = len(subject_datasets)

n_split = int(np.round(n_subjects * 0.75))
keys = list(subject_datasets.keys())
train_sets = [d for i in range(n_split) for d in subject_datasets[keys[i]].datasets]
train_set = BaseConcatDataset(train_sets)
valid_sets = [d for i in range(n_split, n_subjects) for d in subject_datasets[keys[i]].datasets]
valid_set = BaseConcatDataset(valid_sets)



In [39]:
window_size_samples = 1000
window_stride_samples = 1000

window_train_set = create_fixed_length_windows(
    train_set,
    window_size_samples=window_size_samples,
    window_stride_samples=window_stride_samples,
    drop_last_window=True,
    n_jobs=N_JOBS,

)
window_valid_set = create_fixed_length_windows(
    valid_set,
    window_size_samples=window_size_samples,
    window_stride_samples=window_stride_samples,
    drop_last_window=False,
    n_jobs=N_JOBS,

)


Loading data for 120 events and 1000 original time points ...
Loading data for 140 events and 1000 original time points ...
Loading data for 120 events and 1000 original time points ...
Loading data for 120 events and 1000 original time points ...
Loading data for 125 events and 1000 original time points ...
Loading data for 120 events and 1000 original time points ...
Loading data for 118 events and 1000 original time points ...
Loading data for 125 events and 1000 original time points ...
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
Loading data for 138 events and 1000 original time points ...
Loading data for 122 events and 1000 original time points ...
Loading data for 119 events and 1000 original time points ...
Loading data for 121 events and 1000 original time points ...
Loading data for 120 events and 1000 original time points ...
Loading data for 115 event

In [40]:

batch_size = 32
num_workers = 32
n_epochs = 50

train_loader = torch.utils.data.DataLoader(
    window_train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    drop_last=True)
valid_loader = torch.utils.data.DataLoader(
    window_valid_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    drop_last=False)





In [41]:


n_classes = 128
# Extract number of chans and time steps from dataset
n_chans = window_train_set[0][0].shape[0]
input_window_samples = window_train_set[0][0].shape[1]

eeg_classifier_model = ShallowFBCSPNet(
    n_chans,
    n_classes,
    input_window_samples=input_window_samples,
    final_conv_length='auto',
)


# These values we found good for shallow network:
lr = 0.0625 * 0.01
weight_decay = 0

# For deep4 they should be:
# lr = 1 * 0.01
# weight_decay = 0.5 * 0.001


logger = TensorBoardLogger("results/tb_logs", name="EEG_Clip")

trainer = Trainer(
    devices=1,
    accelerator="gpu",
    max_epochs=n_epochs,
    #callbacks=[TQDMProgressBar(refresh_rate=20)],
    logger=logger,
    profiler="advanced"
)

trainer.fit(EEGClipModule(eeg_classifier_model=eeg_classifier_model, lr = lr), train_loader, valid_loader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


NameError: name 'valid_loader' is not defined