In [2]:
import numpy as np
import matplotlib.pyplot as plt
import mne
import pandas as pd
from braindecode.datasets import TUHAbnormal
from braindecode.preprocessing import (
    preprocess, Preprocessor, create_fixed_length_windows, create_windows_from_events, scale as multiply)
import torch
from braindecode.util import set_random_seeds
from braindecode.models import ShallowFBCSPNet

mne.set_log_level('ERROR')  # avoid messages everytime a window is extracted

TUH_PATH = '/home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2.0.0'
N_JOBS = 8  # specify the number of jobs for loading and windowing
N_SAMPLES = 10

tuh = TUHAbnormal(
    path=TUH_PATH,
    recording_ids=list(range(N_SAMPLES)),
    target_name=('report'),#'pathological'),
    preload=False,
    add_physician_reports=True,
    n_jobs=N_JOBS,  # Mock dataset can't
    # be loaded in parallel
)

print("length of dataset : ", len(tuh))

#show last example 
x, y = tuh[-1]
print('x:', x)
print('y:', y[y.find('DESCRIPTION OF THE RECORD:'):])



Extracting EDF parameters from /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2.0.0/edf/eval/normal/01_tcp_ar/058/00005864/s001_2009_09_03/00005864_s001_t000.edf...
EDF file detected
Extracting EDF parameters from /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2.0.0/edf/eval/normal/01_tcp_ar/041/00004196/s003_2009_09_03/00004196_s003_t000.edf...
EDF file detected
Extracting EDF parameters from /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2.0.0/edf/train/normal/01_tcp_ar/009/00000929/s003_2009_09_04/00000929_s003_t002.edf...
EDF file detected
Extracting EDF parameters from /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2.0.0/edf/eval/normal/01_tcp_ar/062/00006201/s001_2009_09_10/00006201_s001_t000.edf...
EDF file detected
Extracting EDF parameters from /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2.0.0/edf/eval/normal/01_tcp_ar/058/00005851/s001_2009_09_04/00005851_s001_t001.edf...
Setting channel info structure...
EDF file detected
Extracting EDF parameters from /home/jovyan/mne_data/TUH/tuh_eeg_

In [16]:
print(tuh.description)

                                                path  year  month  day  \
0  /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2....  2009      9    3   
1  /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2....  2009      9    3   
2  /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2....  2009      9    4   
3  /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2....  2009      9    4   
4  /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2....  2009      9    9   
5  /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2....  2009      9   10   
6  /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2....  2009      9   10   
7  /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2....  2009      9   11   
8  /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2....  2009      9   15   
9  /home/jovyan/mne_data/TUH/tuh_eeg_abnormal/v2....  2009      9   15   

   subject  session  segment  age gender  \
0     4196        3        0   53      F   
1     5864        1        0   30      M   
2      929        3        2   39      F   
3     585

In [17]:
# create windows

window_size_samples = 1000
window_stride_samples = 1000
tuh_windows = create_fixed_length_windows(
    tuh,
    window_size_samples=window_size_samples,
    window_stride_samples=window_stride_samples,
    drop_last_window=False,
    n_jobs=N_JOBS,

)


Loading data for 351 events and 1000 original time points ...
Loading data for 297 events and 1000 original time points ...
Loading data for 287 events and 1000 original time points ...
Loading data for 756 events and 1000 original time points ...
Loading data for 753 events and 1000 original time points ...
Loading data for 366 events and 1000 original time points ...
Loading data for 384 events and 1000 original time points ...
Loading data for 312 events and 1000 original time points ...
0 bad epochs dropped
Loading data for 322 events and 1000 original time points ...
0 bad epochs dropped
Loading data for 301 events and 1000 original time points ...
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped
0 bad epochs dropped


In [18]:
len(tuh_windows)

4129

In [30]:

splitted = tuh_windows.split("train")
train_set = splitted['True']
valid_set = splitted['False']

cuda = False #torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = 'cuda' if cuda else 'cpu'
if cuda:
    torch.backends.cudnn.benchmark = True
seed = 20200220
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 2
# Extract number of chans and time steps from dataset
n_chans = train_set[0][0].shape[0]
input_window_samples = train_set[0][0].shape[1]

model = ShallowFBCSPNet(
    n_chans,
    n_classes,
    input_window_samples=input_window_samples,
    final_conv_length='auto',
)

# Send model to GPU
if cuda:
    model.cuda()


# These values we found good for shallow network:
lr = 0.0625 * 0.01
weight_decay = 0

# For deep4 they should be:
# lr = 1 * 0.01
# weight_decay = 0.5 * 0.001

batch_size = 64
n_epochs = 50


In [20]:
#Use EEGClassifier class for training
"""

from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split

from braindecode import EEGClassifier
clf = EEGClassifier(
    model,
    criterion=torch.nn.NLLLoss,
    optimizer=torch.optim.AdamW,
    train_split=predefined_split(valid_set),  # using valid_set for validation
    optimizer__lr=lr,
    optimizer__weight_decay=weight_decay,
    batch_size=batch_size,
    callbacks=[
        "accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
    ],
    device=device,
)
# Model training for a specified number of epochs. `y` is None as it is already supplied
# in the dataset.
clf.fit(train_set, y=None, epochs=n_epochs)
"""

'\n\nfrom skorch.callbacks import LRScheduler\nfrom skorch.helper import predefined_split\n\nfrom braindecode import EEGClassifier\nclf = EEGClassifier(\n    model,\n    criterion=torch.nn.NLLLoss,\n    optimizer=torch.optim.AdamW,\n    train_split=predefined_split(valid_set),  # using valid_set for validation\n    optimizer__lr=lr,\n    optimizer__weight_decay=weight_decay,\n    batch_size=batch_size,\n    callbacks=[\n        "accuracy", ("lr_scheduler", LRScheduler(\'CosineAnnealingLR\', T_max=n_epochs - 1)),\n    ],\n    device=device,\n)\n# Model training for a specified number of epochs. `y` is None as it is already supplied\n# in the dataset.\nclf.fit(train_set, y=None, epochs=n_epochs)\n'

In [21]:
## use pytorch 

In [42]:
import torch

dl = torch.utils.data.DataLoader(
    dataset=tuh_windows,
    batch_size=4,
)

loss_fn = torch.nn.NLLLoss()
optimizer = torch.optim.AdamW(model.parameters())

In [44]:
for x, y, z in dl:
    #print("x : ",x)
    x.to(device)
    y.to(device)
    y_hat = model(x)
    print(y.shape, y_hat.shape)
    loss = loss_fn(y_hat, y.long())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(loss)


torch.Size([4]) torch.Size([4, 2])
tensor(33.8681, grad_fn=<NllLossBackward0>)
torch.Size([4]) torch.Size([4, 2])
tensor(0., grad_fn=<NllLossBackward0>)
torch.Size([4]) torch.Size([4, 2])
tensor(0., grad_fn=<NllLossBackward0>)
torch.Size([4]) torch.Size([4, 2])
tensor(0., grad_fn=<NllLossBackward0>)
torch.Size([4]) torch.Size([4, 2])
tensor(0., grad_fn=<NllLossBackward0>)
torch.Size([4]) torch.Size([4, 2])
tensor(0., grad_fn=<NllLossBackward0>)
torch.Size([4]) torch.Size([4, 2])
tensor(7.1526e-07, grad_fn=<NllLossBackward0>)
torch.Size([4]) torch.Size([4, 2])
tensor(1.1206e-05, grad_fn=<NllLossBackward0>)
torch.Size([4]) torch.Size([4, 2])
tensor(0.0010, grad_fn=<NllLossBackward0>)
torch.Size([4]) torch.Size([4, 2])
tensor(0.6835, grad_fn=<NllLossBackward0>)
torch.Size([4]) torch.Size([4, 2])


KeyboardInterrupt: 