In [8]:
import numpy as np
import torch 

from braindecode.datasets.base import BaseConcatDataset
from braindecode.datasets import TUHAbnormal
from braindecode import EEGClassifier
from braindecode.models import ShallowFBCSPNet, deep4
from braindecode.preprocessing import create_fixed_length_windows

from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer

from EEGClip.clip_models import EEGClipModule

import mne
mne.set_log_level('ERROR')  # avoid messages everytime a window is extracted

n_jobs = 4
data_path = '/home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal_clip/v2.0.0/edf/'
#N_SAMPLES = 100
N_JOBS = 8 

In [9]:
tuabn = TUHAbnormal(
        path=data_path,
        preload=False,  # True
        add_physician_reports=True, 
        n_jobs=n_jobs,
        target_name = "report",
        recording_ids=range(5,11),
    )


Extracting EDF parameters from /home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal_clip/v2.0.0/edf/eval/normal/01_tcp_ar/064/00006422/s001_2009_09_10/00006422_s001_t000.edf...
EDF file detected
Extracting EDF parameters from /home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal_clip/v2.0.0/edf/train/abnormal/01_tcp_ar/059/00005928/s001_2009_09_11/00005928_s001_t000.edf...
Extracting EDF parameters from /home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal_clip/v2.0.0/edf/eval/abnormal/01_tcp_ar/045/00004526/s003_2009_09_15/00004526_s003_t001.edf...
EDF file detected
EDF file detected
Setting channel info structure...
Setting channel info structure...
Extracting EDF parameters from /home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal_clip/v2.0.0/edf/eval/normal/01_tcp_ar/062/00006201/s001_2009_09_10/00006201_s001_t000.edf...
EDF file detected
Creating raw.info structure...
Creating raw.info structure...
Setting channel info structure...
Creating raw.info structure...
Setting channel info structure...
Creating r

In [10]:
print(len(tuabn))
x, y = tuabn[-1]
print('x:', x)
print('y:', y)

793600
x: [[ 6.98874172e-08]
 [ 2.14231788e-07]
 [-1.46629139e-07]
 [-1.10543046e-07]
 [-1.82715232e-07]
 [-2.28476821e-09]
 [-1.10543046e-07]
 [-3.83708609e-08]
 [-3.83708609e-08]
 [-3.83708609e-08]
 [ 3.22490066e-07]
 [ 6.98874172e-08]
 [-2.28476821e-09]
 [-3.83708609e-08]
 [-1.10543046e-07]
 [-3.83708609e-08]
 [ 1.05973510e-07]
 [-1.82715232e-07]
 [-2.28476821e-09]
 [-3.83708609e-08]
 [-3.83708609e-08]]
y: 

HISTORY: 69 year old male with psychiatric symptoms, boarding home resident and syncope versus seizure.

MEDICATIONS: Clozapine, Metformin, Avandia, Diovan, Metformin, Amiodipine


INTRODUCTION: Digital video EEG was performed in lab/bed using standard 10-20 system of electrode placement with 1 channel of EKG.  Hyperventilation and photic simulation are preformed.  This is an awake record. The EKG lead is not functioning.

DESCRIPTION OF THE RECORD: The background EEG demonstrates rhythmic background slowing. There is a posterior dominant rhythm of 6-7 hertz, a scant amount of a

In [11]:
tuabn.description

Unnamed: 0,path,year,month,day,subject,session,segment,age,gender,report,version,train,pathological
0,/home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal...,2009,9,10,6201,1,0,54,M,CLINICAL HISTORY: 54 year old male with recurr...,v2.0.0,False,False
1,/home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal...,2009,9,10,6422,1,0,50,M,\nCLINICAL HISTORY: 49 year old left handed ...,v2.0.0,False,False
2,/home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal...,2009,9,11,5928,1,0,70,F,CLINICAL HISTORY: 71 year old woman with recu...,v2.0.0,True,True
3,/home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal...,2009,9,15,4526,3,1,71,F,\nCLINICAL HISTORY: 71 year old woman with epi...,v2.0.0,False,True
4,/home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal...,2009,9,15,5921,1,0,39,F,CLINICAL HISTORY: 38 year old right handed ...,v2.0.0,False,False
5,/home/jovyan/mne_data/TUH_PRE/tuh_eeg_abnormal...,2009,9,16,5931,1,1,69,M,\n\nHISTORY: 69 year old male with psychiatric...,v2.0.0,True,True


In [5]:
len([45, 26, 37, 40, 69, 53, 71, 67, 40, 28, 28, 67, 28, 40, 28, 22, 62, 81, 53, 44, 37, 24, 71, 21, 46, 71, 62, 46, 38, 71, 69, 49, 28, 49, 42, 81, 28, 37, 26, 62, 42, 71, 67, 67, 26, 62, 34, 40, 42, 24, 71, 30, 38, 55, 67, 42, 71, 45, 45, 42, 28, 34, 37, 53, 89, 37, 44, 42, 39, 69, 67, 53, 37, 89, 28, 55, 89, 67, 54, 71, 53, 34, 37, 54, 49, 48, 48, 28, 40, 46, 62, 30, 81, 53, 67, 34, 54, 59, 89, 37, 89, 38, 45, 80, 24, 37, 26, 67, 53, 28, 71, 28, 57, 22, 69, 28, 45, 37, 24, 53, 21, 46, 37, 69, 67, 39, 42, 53, 49, 28, 71, 44, 44, 24, 81, 54, 38, 42, 71, 44, 55, 22, 48, 49, 80, 40, 54, 37, 42, 67, 42, 49, 48, 81, 47, 53, 38, 24, 53, 37, 45, 24, 48, 24, 42, 67, 30, 26, 80, 89, 40, 62, 28, 53, 46, 62, 38, 71, 46, 67, 62, 40, 28, 38, 38, 71, 71, 47, 69, 62, 26, 37, 39, 37, 89, 42, 42, 24, 69, 54, 38, 71, 42, 22, 28, 71, 30, 71, 30, 57, 38, 40, 57, 45, 54, 62, 39, 54, 44, 71, 69, 44, 42, 22, 44, 22, 69, 81, 71, 55, 62, 37, 54, 81, 28, 54, 59, 26, 28, 57, 21, 80, 81, 34, 81, 38, 57, 46, 53, 39, 67, 24, 22, 81, 67, 28, 57, 57, 45, 38, 62, 37, 48, 89, 54, 21, 54, 21, 49, 71, 37, 57, 80, 30, 21, 34, 28, 89, 38, 44, 89, 54, 69, 40, 80, 22, 42, 49, 38, 24, 37, 54, 55, 47, 69, 24, 71, 22, 71, 28, 37, 47, 47, 28, 67, 28, 46, 38, 44, 57, 89, 62, 55, 55, 54, 55, 44, 53, 71, 46, 28, 22, 28, 42, 28, 69, 40, 62, 28, 59, 69, 54, 71, 67, 22, 26, 28, 38, 55, 42, 44, 53, 40, 54, 26, 81, 22, 53, 57, 45, 40, 62, 69, 22, 42, 28, 46, 57, 37, 46, 28, 42, 42, 28, 71, 28, 34, 28, 81, 54, 38, 59, 30, 69, 28, 57, 44, 54, 71, 53, 62, 69, 53, 21, 39, 22, 69, 44, 62, 28, 54, 44, 40, 42, 37, 62, 26, 38, 62, 54, 37, 55, 71, 30, 47, 62, 28, 22, 89, 55, 28, 69, 80, 22, 46, 55, 71, 69, 28, 37, 54, 37, 30, 80, 39, 21, 22, 81, 48, 53, 49, 40, 53, 42, 22, 37, 39, 37, 49, 62, 54, 69, 28, 37, 47, 37, 28, 69, 22, 55, 28, 42, 22, 48, 22, 71, 53, 39, 54, 49, 67, 57, 24, 37, 38, 46, 22, 81, 26, 48, 71, 47, 67, 81, 45, 69, 62, 62, 38, 37, 71, 42, 44, 28, 39, 46, 39, 42, 28, 28, 62, 71, 39, 28, 24, 54, 28, 39, 71, 53, 71, 71, 26, 44, 62, 69, 46, 47, 53, 67, 45, 54, 71, 28, 34, 57, 21, 37, 42, 28, 69, 71, 22, 39, 49, 22, 26, 47, 54, 53, 38, 57, 42, 28, 54, 26, 30, 40, 53, 26, 28, 34, 28, 71, 80, 54, 57, 42, 42, 46, 39, 48, 55, 34, 28, 71, 62, 21, 81, 55, 53, 89, 71, 24, 42, 21, 89, 42, 28, 54, 40, 24, 30, 38, 42, 80, 37, 55, 53, 62, 81, 54, 44, 30, 22, 44, 89, 37, 62, 28, 55, 22, 67, 21, 71, 42, 81, 34, 53, 37, 81, 62, 34, 40, 34, 26, 30, 69, 37, 54, 71, 38, 28, 71, 46, 28, 67, 38, 48, 39, 54, 53, 67, 42, 49, 71, 22, 53, 53, 28, 34, 24, 44, 69, 37, 22, 47, 71, 47, 71, 81, 71, 38, 55, 59, 22, 42, 40, 37, 28, 46, 42, 38, 38, 57, 38, 28, 54, 38, 42, 22, 40, 28, 42, 28, 80, 39, 47, 71, 28, 22, 38, 26, 89, 57, 89, 71, 34, 71, 38, 53, 54, 39, 22, 37, 53, 89, 53, 22, 21, 67, 62, 69, 40, 48, 59, 53, 45, 44, 34, 59, 53, 46, 59, 42, 55, 26, 71, 42, 46, 49, 22, 69, 67, 53, 69, 69, 34, 53, 71, 69, 67, 28, 28, 54, 71, 80, 71])


728

In [6]:
tuabn.description["path"].duplicated().sum()

0

In [7]:
import pandas as pd
full_desc = pd.read_csv('/home/jovyan/EEGClip/data/TUH_Abnormal_EEG_rep.csv')

In [8]:
set(np.unique(tuabn.description["subject"].values)).issubset(set(np.unique(full_desc["SUBJECT"].values)))

True

In [9]:
full_desc.duplicated(subset=['SUBJECT'])

0       False
1       False
2       False
3       False
4       False
        ...  
2989     True
2990     True
2991    False
2992    False
2993    False
Length: 2994, dtype: bool

In [10]:
tuabn.description[tuabn.description.subject.isin([4933,6523])]["pathological"].tolist()

[]

In [11]:
full_desc[full_desc["SUBJECT"]==subject_id].iloc[0]["DESCRIPTION OF THE RECORD"]

NameError: name 'subject_id' is not defined

In [None]:
sfreq  = 100
n_minutes = 20

subject_datasets = tuabn.split('subject')
n_subjects = len(subject_datasets)

n_split = int(np.round(n_subjects * 0.75))
keys = list(subject_datasets.keys())
train_sets = [d for i in range(n_split) for d in subject_datasets[keys[i]].datasets]
train_set = BaseConcatDataset(train_sets)
valid_sets = [d for i in range(n_split, n_subjects) for d in subject_datasets[keys[i]].datasets]
valid_set = BaseConcatDataset(valid_sets)



In [None]:
window_size_samples = 1000
window_stride_samples = 1000

window_train_set = create_fixed_length_windows(
    train_set,
    window_size_samples=window_size_samples,
    window_stride_samples=window_stride_samples,
    drop_last_window=True,
    n_jobs=N_JOBS,

)
window_valid_set = create_fixed_length_windows(
    valid_set,
    window_size_samples=window_size_samples,
    window_stride_samples=window_stride_samples,
    drop_last_window=False,
    n_jobs=N_JOBS,

)


Loading data for 120 events and 1000 original time points ...
Loading data for 140 events and 1000 original time points ...
Loading data for 120 events and 1000 original time points ...
Loading data for 120 events and 1000 original time points ...
Loading data for 125 events and 1000 original time points ...
Loading data for 120 events and 1000 original time points ...
Loading data for 118 events and 1000 original time points ...
Loading data for 125 events and 1000 original time points ...
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
Loading data for 138 events and 1000 original time points ...
Loading data for 122 events and 1000 original time points ...
Loading data for 119 events and 1000 original time points ...
Loading data for 121 events and 1000 original time points ...
Loading data for 120 events and 1000 original time points ...
Loading data for 115 event

In [None]:

batch_size = 32
num_workers = 32
n_epochs = 50

train_loader = torch.utils.data.DataLoader(
    window_train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    drop_last=True)
valid_loader = torch.utils.data.DataLoader(
    window_valid_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    drop_last=False)





In [None]:


n_classes = 128
# Extract number of chans and time steps from dataset
n_chans = window_train_set[0][0].shape[0]
input_window_samples = window_train_set[0][0].shape[1]

eeg_classifier_model = ShallowFBCSPNet(
    n_chans,
    n_classes,
    input_window_samples=input_window_samples,
    final_conv_length='auto',
)


# These values we found good for shallow network:
lr = 0.0625 * 0.01
weight_decay = 0

# For deep4 they should be:
# lr = 1 * 0.01
# weight_decay = 0.5 * 0.001


logger = TensorBoardLogger("results/tb_logs", name="EEG_Clip")

trainer = Trainer(
    devices=1,
    accelerator="gpu",
    max_epochs=n_epochs,
    #callbacks=[TQDMProgressBar(refresh_rate=20)],
    logger=logger,
    profiler="advanced"
)

trainer.fit(EEGClipModule(eeg_classifier_model=eeg_classifier_model, lr = lr), train_loader, valid_loader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


NameError: name 'valid_loader' is not defined