<a href="https://colab.research.google.com/github/seongcho1/mnetest/blob/main/eeg_motor_imagery_001.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://github.com/mne-tools/mne-torch/blob/master/demo_eeg_csp.py

In [1]:
import sklearn.model_selection
import sklearn.ensemble
import scipy.stats
import numpy as np

In [3]:
!pip install mne

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mne
  Downloading mne-1.3.0-py3-none-any.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m33.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: mne
Successfully installed mne-1.3.0


In [4]:
from sklearn.model_selection import ShuffleSplit
from mne import Epochs, pick_types, events_from_annotations
from mne.io import concatenate_raws, read_raw_edf
from mne.datasets import eegbci

In [9]:
def get_data():
    tmin, tmax = -1., 4.
    event_id = dict(hands=2, feet=3)
    subject = 1
    runs = [6, 10, 14]  # motor imagery: hands vs feet

    raw_fnames = eegbci.load_data(subject, runs)
    raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])

    # strip channel names of "." characters
    raw.rename_channels(lambda x: x.strip('.'))

    # Apply band-pass filter
    raw.filter(7., 30., fir_design='firwin', skip_by_annotation='edge')

    events, _ = events_from_annotations(raw)

    picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                       exclude='bads')

    # Read epochs (train will be done only between 1 and 2s)
    # Testing will be done with a running classifier
    epochs = Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks,
                    baseline=None, preload=True)
    epochs.crop(tmin=1., tmax=None)
    labels = epochs.events[:, 2] - 2
    return epochs.get_data()[:, :, :256], labels


epochs_data, labels = get_data()
print(epochs_data.shape, labels.shape)

Extracting EDF parameters from /root/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R06.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /root/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R10.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /root/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R14.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Filtering raw data in 3 contiguous segments
Setting up band-pass filter from 7 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passb

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  64 out of  64 | elapsed:    0.2s finished


Used Annotations descriptions: ['T0', 'T1', 'T2']
Not setting metadata
45 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 45 events and 801 original time points ...
0 bad epochs dropped
(45, 64, 256) (45,)


In [14]:
import copy

import numpy as np

from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset  # noqa


class ConcatDataset(_ConcatDataset):
    """
    Same as torch.utils.data.dataset.ConcatDataset, but exposes an extra
    method for querying the group structure (index if dataset
    each sample comes from)
    """
    def get_groups(self):
        """Return the group index of each sample
        Returns
        -------
        groups : array of int, shape (n_samples,)
            The group indices.
        """
        groups = [k * np.ones(len(d)) for k, d in enumerate(self.datasets)]
        return np.concatenate(groups)


class EpochsDataset(Dataset):
    """Class to expose an MNE Epochs object as PyTorch dataset
    Parameters
    ----------
    epochs_data : 3d array, shape (n_epochs, n_channels, n_times)
        The epochs data.
    epochs_labels : array of int, shape (n_epochs,)
        The epochs labels.
    transform : callable | None
        The function to eventually apply to each epoch
        for preprocessing (e.g. scaling). Defaults to None.
    """
    def __init__(self, epochs_data, epochs_labels, transform=None):
        assert len(epochs_data) == len(epochs_labels)
        self.epochs_data = epochs_data
        self.epochs_labels = epochs_labels
        self.transform = transform

    def __len__(self):
        return len(self.epochs_labels)

    def __getitem__(self, idx):
        X, y = self.epochs_data[idx], self.epochs_labels[idx]
        if self.transform is not None:
            X = self.transform(X)
        X = torch.as_tensor(X[None, ...])
        return X, y


def _do_train(model, loader, optimizer, criterion, device):
    # training loop
    model.train()
    pbar = tqdm(loader)
    train_loss = np.zeros(len(loader))
    for idx_batch, (batch_x, batch_y) in enumerate(pbar):
        optimizer.zero_grad()
        batch_x = batch_x.to(device=device, dtype=torch.float32)
        batch_y = batch_y.to(device=device, dtype=torch.int64)

        output = model(batch_x)
        loss = criterion(output, batch_y)

        loss.backward()
        optimizer.step()

        train_loss[idx_batch] = loss.item()
        pbar.set_description(
            desc="avg train loss: {:.4f}".format(
                np.mean(train_loss[:idx_batch + 1])))


def _validate(model, loader, criterion, device):
    # validation loop
    pbar = tqdm(loader)
    val_loss = np.zeros(len(loader))
    accuracy = 0.
    with torch.no_grad():
        model.eval()

        for idx_batch, (batch_x, batch_y) in enumerate(pbar):
            batch_x = batch_x.to(device=device, dtype=torch.float32)
            batch_y = batch_y.to(device=device, dtype=torch.int64)
            output = model.forward(batch_x)

            loss = criterion(output, batch_y)
            val_loss[idx_batch] = loss.item()

            _, top_class = output.topk(1, dim=1)
            top_class = top_class.flatten()
            # print(top_class.shape, batch_y.shape)
            accuracy += \
                torch.sum((batch_y == top_class).to(torch.float32))

            pbar.set_description(
                desc="avg val loss: {:.4f}".format(
                    np.mean(val_loss[:idx_batch + 1])))

    accuracy = accuracy / len(loader.dataset)
    print("---  Accuracy : %s" % accuracy.item(), "\n")
    return np.mean(val_loss)


def train(model, loader_train, loader_valid, optimizer, n_epochs, patience,
          device):
    """Training function
    Parameters
    ----------
    model : instance of nn.Module
        The model.
    loader_train : instance of Sampler
        The generator of EEG samples the model has to train on.
        It contains n_train samples
    loader_valid : instance of Sampler
        The generator of EEG samples the model has to validate on.
        It contains n_val samples. The validation samples are used to
        monitor the training process and to perform early stopping
    optimizer : instance of optimizer
        The optimizer to use for training.
    n_epochs : int
        The maximum of epochs to run.
    patience : int
        The patience parameter, i.e. how long to wait for the
        validation error to go down.
    device : str | instance of torch.device
        The device to train the model on.
    Returns
    -------
    best_model : instance of nn.Module
        The model that lead to the best prediction on the validation
        dataset.
    """
    # put model on cuda if not already
    device = torch.device(device)
    # model.to(device)

    # define criterion
    criterion = F.nll_loss

    best_val_loss = + np.infty
    best_model = copy.deepcopy(model)
    waiting = 0

    for epoch in range(n_epochs):
        print("\nStarting epoch {} / {}".format(epoch + 1, n_epochs))
        _do_train(model, loader_train, optimizer, criterion, device)
        val_loss = _validate(model, loader_valid, criterion, device)

        # model saving
        if np.mean(val_loss) < best_val_loss:
            print("\nbest val loss {:.4f} -> {:.4f}".format(
                best_val_loss, np.mean(val_loss)))
            best_val_loss = np.mean(val_loss)
            best_model = copy.deepcopy(model)
            waiting = 0
        else:
            print("Waiting += 1")
            waiting += 1

        # model early stopping
        if waiting >= patience:
            print("Stop training at epoch {}".format(epoch + 1))
            print("Best val loss : {:.4f}".format(best_val_loss))
            break

    return best_model

In [17]:
# Classification with PyTorch CSP like model

import torch   # noqa
import torch.optim as optim  # noqa
from torch.utils.data import Dataset, DataLoader  # noqa
from torch.utils.data import Subset  # noqa
from torch import nn  # noqa
import torch.nn.functional as F  # noqa
from torch.utils.data import RandomSampler  # noqa
from torch.utils.data import SequentialSampler  # noqa

#from common import EpochsDataset  # noqa

cv = ShuffleSplit(10, test_size=0.2, random_state=42)
cv_split = cv.split(epochs_data)
train_idx, test_idx = next(cv_split)


def scale(X):
    """Standard scaling of data along the last dimention.
    Parameters
    ----------
    X : array, shape (n_channels, n_times)
        The input signals.
    Returns
    -------
    X_t : array, shape (n_channels, n_times)
        The scaled signals.
    """
    return X / 2e-5

dataset = EpochsDataset(epochs_data, labels, transform=scale)

ds_train, ds_valid = Subset(dataset, train_idx), Subset(dataset, test_idx)

batch_size_train = len(ds_train)
batch_size_valid = len(ds_valid)
sampler_train = RandomSampler(ds_train)
sampler_valid = SequentialSampler(ds_valid)

# create loaders
num_workers = 0
loader_train = \
    DataLoader(ds_train, batch_size=batch_size_train,
               num_workers=num_workers, sampler=sampler_train)
loader_valid = \
    DataLoader(ds_valid, batch_size=batch_size_valid,
               num_workers=num_workers, sampler=sampler_valid)

In [18]:
# Define the model


class CommonSpatialFilterModel(nn.Module):
    """The model implements a CSP-like network for BCI applications
    Parameters
    ----------
    spatial_dim : int
        Number of channels
    n_components : int
        The number of spatial filters.
    """
    def __init__(self, spatial_dim, n_components=5):
        super().__init__()
        self.spatial_dim = spatial_dim
        self.n_components = n_components

        # define model architecture
        self.spatial_filtering = nn.Conv2d(
            1, self.n_components, (self.spatial_dim, 1), bias=False)

        self.classifier = nn.Sequential(
            nn.Linear(n_components, 2),
        )

    def forward(self, x):
        x = self.spatial_filtering(x)
        x = torch.sum(x ** 2, dim=3)
        x = torch.log(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        x = F.log_softmax(x, dim=1)
        return x

# device = 'cuda'
device = 'cpu'
n_components = 30
model = CommonSpatialFilterModel(spatial_dim=epochs_data.shape[1],
                                 n_components=n_components)

# Test model works:
n_samples_test = 10
y_test = torch.randint(0, 2, (n_samples_test,))
y_pred = model.forward(torch.randn(n_samples_test, 1, *epochs_data.shape[1:]))
output = F.nll_loss(y_pred, y_test)
_, top_class = y_pred.topk(1, dim=1)


In [19]:

lr = 1e-3
n_epochs = 300
patience = 100

model.to(device=device)  # move to device before creating the optimizer
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9,
                      weight_decay=1e-4)

train(model, loader_train, loader_valid, optimizer, n_epochs, patience, device)


Starting epoch 1 / 300


avg train loss: 1.8085: 100%|██████████| 1/1 [00:00<00:00,  4.75it/s]
avg val loss: 2.4237: 100%|██████████| 1/1 [00:00<00:00, 97.12it/s]


---  Accuracy : 0.3333333432674408 


best val loss inf -> 2.4237

Starting epoch 2 / 300


avg train loss: 1.5866: 100%|██████████| 1/1 [00:00<00:00, 27.51it/s]
avg val loss: 1.7573: 100%|██████████| 1/1 [00:00<00:00, 132.49it/s]


---  Accuracy : 0.3333333432674408 


best val loss 2.4237 -> 1.7573

Starting epoch 3 / 300


avg train loss: 1.1969: 100%|██████████| 1/1 [00:00<00:00, 30.63it/s]
avg val loss: 0.9919: 100%|██████████| 1/1 [00:00<00:00, 119.60it/s]


---  Accuracy : 0.3333333432674408 


best val loss 1.7573 -> 0.9919

Starting epoch 4 / 300


avg train loss: 0.7999: 100%|██████████| 1/1 [00:00<00:00, 24.39it/s]
avg val loss: 0.5936: 100%|██████████| 1/1 [00:00<00:00, 110.86it/s]


---  Accuracy : 0.6666666865348816 


best val loss 0.9919 -> 0.5936

Starting epoch 5 / 300


avg train loss: 0.7693: 100%|██████████| 1/1 [00:00<00:00, 26.75it/s]
avg val loss: 0.6587: 100%|██████████| 1/1 [00:00<00:00, 94.19it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 6 / 300


avg train loss: 1.0971: 100%|██████████| 1/1 [00:00<00:00, 25.16it/s]
avg val loss: 0.7543: 100%|██████████| 1/1 [00:00<00:00, 68.53it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 7 / 300


avg train loss: 1.3041: 100%|██████████| 1/1 [00:00<00:00, 26.05it/s]
avg val loss: 0.7153: 100%|██████████| 1/1 [00:00<00:00, 118.20it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 8 / 300


avg train loss: 1.2188: 100%|██████████| 1/1 [00:00<00:00, 23.34it/s]
avg val loss: 0.6030: 100%|██████████| 1/1 [00:00<00:00, 91.26it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 9 / 300


avg train loss: 0.9275: 100%|██████████| 1/1 [00:00<00:00, 29.32it/s]
avg val loss: 0.6480: 100%|██████████| 1/1 [00:00<00:00, 71.56it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 10 / 300


avg train loss: 0.7125: 100%|██████████| 1/1 [00:00<00:00, 25.10it/s]
avg val loss: 0.9845: 100%|██████████| 1/1 [00:00<00:00, 120.83it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 11 / 300


avg train loss: 0.7834: 100%|██████████| 1/1 [00:00<00:00, 26.01it/s]
avg val loss: 1.3426: 100%|██████████| 1/1 [00:00<00:00, 83.21it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 12 / 300


avg train loss: 0.9533: 100%|██████████| 1/1 [00:00<00:00, 20.61it/s]
avg val loss: 1.4904: 100%|██████████| 1/1 [00:00<00:00, 114.52it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 13 / 300


avg train loss: 1.0306: 100%|██████████| 1/1 [00:00<00:00, 25.40it/s]
avg val loss: 1.3845: 100%|██████████| 1/1 [00:00<00:00, 156.60it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 14 / 300


avg train loss: 0.9724: 100%|██████████| 1/1 [00:00<00:00, 23.63it/s]
avg val loss: 1.0913: 100%|██████████| 1/1 [00:00<00:00, 97.58it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 15 / 300


avg train loss: 0.8240: 100%|██████████| 1/1 [00:00<00:00, 28.13it/s]
avg val loss: 0.7720: 100%|██████████| 1/1 [00:00<00:00, 76.36it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 16 / 300


avg train loss: 0.7031: 100%|██████████| 1/1 [00:00<00:00, 23.98it/s]
avg val loss: 0.6065: 100%|██████████| 1/1 [00:00<00:00, 119.33it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 17 / 300


avg train loss: 0.7259: 100%|██████████| 1/1 [00:00<00:00, 25.16it/s]
avg val loss: 0.5889: 100%|██████████| 1/1 [00:00<00:00, 113.67it/s]


---  Accuracy : 0.6666666865348816 


best val loss 0.5936 -> 0.5889

Starting epoch 18 / 300


avg train loss: 0.8320: 100%|██████████| 1/1 [00:00<00:00, 29.05it/s]
avg val loss: 0.5963: 100%|██████████| 1/1 [00:00<00:00, 120.16it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 19 / 300


avg train loss: 0.8712: 100%|██████████| 1/1 [00:00<00:00, 23.18it/s]
avg val loss: 0.5866: 100%|██████████| 1/1 [00:00<00:00, 115.32it/s]


---  Accuracy : 0.6666666865348816 


best val loss 0.5889 -> 0.5866

Starting epoch 20 / 300


avg train loss: 0.8023: 100%|██████████| 1/1 [00:00<00:00, 21.90it/s]
avg val loss: 0.6155: 100%|██████████| 1/1 [00:00<00:00, 118.52it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 21 / 300


avg train loss: 0.7066: 100%|██████████| 1/1 [00:00<00:00, 26.58it/s]
avg val loss: 0.7423: 100%|██████████| 1/1 [00:00<00:00, 114.59it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 22 / 300


avg train loss: 0.6860: 100%|██████████| 1/1 [00:00<00:00, 26.43it/s]
avg val loss: 0.9144: 100%|██████████| 1/1 [00:00<00:00, 122.58it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 23 / 300


avg train loss: 0.7355: 100%|██████████| 1/1 [00:00<00:00, 23.86it/s]
avg val loss: 1.0189: 100%|██████████| 1/1 [00:00<00:00, 85.59it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 24 / 300


avg train loss: 0.7776: 100%|██████████| 1/1 [00:00<00:00, 24.73it/s]
avg val loss: 0.9993: 100%|██████████| 1/1 [00:00<00:00, 115.07it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 25 / 300


avg train loss: 0.7675: 100%|██████████| 1/1 [00:00<00:00, 27.05it/s]
avg val loss: 0.8781: 100%|██████████| 1/1 [00:00<00:00, 104.25it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 26 / 300


avg train loss: 0.7171: 100%|██████████| 1/1 [00:00<00:00, 25.65it/s]
avg val loss: 0.7309: 100%|██████████| 1/1 [00:00<00:00, 101.59it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 27 / 300


avg train loss: 0.6748: 100%|██████████| 1/1 [00:00<00:00, 24.01it/s]
avg val loss: 0.6324: 100%|██████████| 1/1 [00:00<00:00, 112.20it/s]


---  Accuracy : 0.7777777910232544 

Waiting += 1

Starting epoch 28 / 300


avg train loss: 0.6790: 100%|██████████| 1/1 [00:00<00:00, 23.22it/s]
avg val loss: 0.5955: 100%|██████████| 1/1 [00:00<00:00, 75.98it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 29 / 300


avg train loss: 0.7114: 100%|██████████| 1/1 [00:00<00:00, 29.55it/s]
avg val loss: 0.5894: 100%|██████████| 1/1 [00:00<00:00, 93.35it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 30 / 300


avg train loss: 0.7238: 100%|██████████| 1/1 [00:00<00:00, 23.35it/s]
avg val loss: 0.5987: 100%|██████████| 1/1 [00:00<00:00, 99.45it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 31 / 300


avg train loss: 0.7001: 100%|██████████| 1/1 [00:00<00:00, 23.33it/s]
avg val loss: 0.6363: 100%|██████████| 1/1 [00:00<00:00, 94.05it/s]


---  Accuracy : 0.7777777910232544 

Waiting += 1

Starting epoch 32 / 300


avg train loss: 0.6685: 100%|██████████| 1/1 [00:00<00:00, 32.64it/s]
avg val loss: 0.7096: 100%|██████████| 1/1 [00:00<00:00, 96.87it/s]


---  Accuracy : 0.4444444477558136 

Waiting += 1

Starting epoch 33 / 300


avg train loss: 0.6611: 100%|██████████| 1/1 [00:00<00:00, 18.21it/s]
avg val loss: 0.7905: 100%|██████████| 1/1 [00:00<00:00, 106.79it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 34 / 300


avg train loss: 0.6761: 100%|██████████| 1/1 [00:00<00:00, 19.65it/s]
avg val loss: 0.8348: 100%|██████████| 1/1 [00:00<00:00, 115.80it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 35 / 300


avg train loss: 0.6883: 100%|██████████| 1/1 [00:00<00:00, 19.39it/s]
avg val loss: 0.8203: 100%|██████████| 1/1 [00:00<00:00, 93.93it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 36 / 300


avg train loss: 0.6821: 100%|██████████| 1/1 [00:00<00:00, 20.68it/s]
avg val loss: 0.7602: 100%|██████████| 1/1 [00:00<00:00, 76.50it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 37 / 300


avg train loss: 0.6635: 100%|██████████| 1/1 [00:00<00:00, 28.47it/s]
avg val loss: 0.6895: 100%|██████████| 1/1 [00:00<00:00, 158.21it/s]


---  Accuracy : 0.4444444477558136 

Waiting += 1

Starting epoch 38 / 300


avg train loss: 0.6508: 100%|██████████| 1/1 [00:00<00:00, 23.79it/s]
avg val loss: 0.6380: 100%|██████████| 1/1 [00:00<00:00, 129.22it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 39 / 300


avg train loss: 0.6536: 100%|██████████| 1/1 [00:00<00:00, 23.64it/s]
avg val loss: 0.6133: 100%|██████████| 1/1 [00:00<00:00, 132.99it/s]


---  Accuracy : 0.8888888955116272 

Waiting += 1

Starting epoch 40 / 300


avg train loss: 0.6628: 100%|██████████| 1/1 [00:00<00:00, 24.38it/s]
avg val loss: 0.6092: 100%|██████████| 1/1 [00:00<00:00, 119.17it/s]


---  Accuracy : 0.7777777910232544 

Waiting += 1

Starting epoch 41 / 300


avg train loss: 0.6636: 100%|██████████| 1/1 [00:00<00:00, 20.19it/s]
avg val loss: 0.6220: 100%|██████████| 1/1 [00:00<00:00, 82.83it/s]


---  Accuracy : 0.7777777910232544 

Waiting += 1

Starting epoch 42 / 300


avg train loss: 0.6536: 100%|██████████| 1/1 [00:00<00:00, 28.58it/s]
avg val loss: 0.6525: 100%|██████████| 1/1 [00:00<00:00, 102.57it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 43 / 300


avg train loss: 0.6433: 100%|██████████| 1/1 [00:00<00:00, 22.84it/s]
avg val loss: 0.6950: 100%|██████████| 1/1 [00:00<00:00, 84.84it/s]


---  Accuracy : 0.4444444477558136 

Waiting += 1

Starting epoch 44 / 300


avg train loss: 0.6415: 100%|██████████| 1/1 [00:00<00:00, 27.92it/s]
avg val loss: 0.7328: 100%|██████████| 1/1 [00:00<00:00, 101.93it/s]


---  Accuracy : 0.4444444477558136 

Waiting += 1

Starting epoch 45 / 300


avg train loss: 0.6459: 100%|██████████| 1/1 [00:00<00:00, 24.76it/s]
avg val loss: 0.7482: 100%|██████████| 1/1 [00:00<00:00, 95.07it/s]


---  Accuracy : 0.3333333432674408 

Waiting += 1

Starting epoch 46 / 300


avg train loss: 0.6480: 100%|██████████| 1/1 [00:00<00:00, 25.57it/s]
avg val loss: 0.7350: 100%|██████████| 1/1 [00:00<00:00, 86.38it/s]


---  Accuracy : 0.4444444477558136 

Waiting += 1

Starting epoch 47 / 300


avg train loss: 0.6437: 100%|██████████| 1/1 [00:00<00:00, 27.05it/s]
avg val loss: 0.7021: 100%|██████████| 1/1 [00:00<00:00, 80.10it/s]


---  Accuracy : 0.4444444477558136 

Waiting += 1

Starting epoch 48 / 300


avg train loss: 0.6366: 100%|██████████| 1/1 [00:00<00:00, 19.85it/s]
avg val loss: 0.6657: 100%|██████████| 1/1 [00:00<00:00, 99.37it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 49 / 300


avg train loss: 0.6327: 100%|██████████| 1/1 [00:00<00:00, 23.64it/s]
avg val loss: 0.6385: 100%|██████████| 1/1 [00:00<00:00, 131.01it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 50 / 300


avg train loss: 0.6336: 100%|██████████| 1/1 [00:00<00:00, 18.76it/s]
avg val loss: 0.6249: 100%|██████████| 1/1 [00:00<00:00, 85.87it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 51 / 300


avg train loss: 0.6354: 100%|██████████| 1/1 [00:00<00:00, 27.84it/s]
avg val loss: 0.6246: 100%|██████████| 1/1 [00:00<00:00, 110.57it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 52 / 300


avg train loss: 0.6337: 100%|██████████| 1/1 [00:00<00:00, 21.45it/s]
avg val loss: 0.6358: 100%|██████████| 1/1 [00:00<00:00, 126.72it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 53 / 300


avg train loss: 0.6292: 100%|██████████| 1/1 [00:00<00:00, 23.55it/s]
avg val loss: 0.6559: 100%|██████████| 1/1 [00:00<00:00, 100.21it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 54 / 300


avg train loss: 0.6255: 100%|██████████| 1/1 [00:00<00:00, 27.43it/s]
avg val loss: 0.6788: 100%|██████████| 1/1 [00:00<00:00, 107.75it/s]


---  Accuracy : 0.4444444477558136 

Waiting += 1

Starting epoch 55 / 300


avg train loss: 0.6247: 100%|██████████| 1/1 [00:00<00:00, 18.81it/s]
avg val loss: 0.6954: 100%|██████████| 1/1 [00:00<00:00, 119.70it/s]


---  Accuracy : 0.4444444477558136 

Waiting += 1

Starting epoch 56 / 300


avg train loss: 0.6253: 100%|██████████| 1/1 [00:00<00:00, 25.18it/s]
avg val loss: 0.6989: 100%|██████████| 1/1 [00:00<00:00, 123.84it/s]


---  Accuracy : 0.4444444477558136 

Waiting += 1

Starting epoch 57 / 300


avg train loss: 0.6245: 100%|██████████| 1/1 [00:00<00:00, 24.74it/s]
avg val loss: 0.6884: 100%|██████████| 1/1 [00:00<00:00, 110.61it/s]


---  Accuracy : 0.4444444477558136 

Waiting += 1

Starting epoch 58 / 300


avg train loss: 0.6218: 100%|██████████| 1/1 [00:00<00:00, 23.18it/s]
avg val loss: 0.6694: 100%|██████████| 1/1 [00:00<00:00, 87.63it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 59 / 300


avg train loss: 0.6187: 100%|██████████| 1/1 [00:00<00:00, 23.23it/s]
avg val loss: 0.6497: 100%|██████████| 1/1 [00:00<00:00, 97.35it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 60 / 300


avg train loss: 0.6170: 100%|██████████| 1/1 [00:00<00:00, 27.59it/s]
avg val loss: 0.6353: 100%|██████████| 1/1 [00:00<00:00, 98.15it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 61 / 300


avg train loss: 0.6167: 100%|██████████| 1/1 [00:00<00:00, 22.76it/s]
avg val loss: 0.6287: 100%|██████████| 1/1 [00:00<00:00, 152.27it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 62 / 300


avg train loss: 0.6161: 100%|██████████| 1/1 [00:00<00:00, 18.67it/s]
avg val loss: 0.6300: 100%|██████████| 1/1 [00:00<00:00, 133.12it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 63 / 300


avg train loss: 0.6144: 100%|██████████| 1/1 [00:00<00:00, 21.93it/s]
avg val loss: 0.6380: 100%|██████████| 1/1 [00:00<00:00, 114.23it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 64 / 300


avg train loss: 0.6120: 100%|██████████| 1/1 [00:00<00:00, 23.03it/s]
avg val loss: 0.6498: 100%|██████████| 1/1 [00:00<00:00, 111.40it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 65 / 300


avg train loss: 0.6101: 100%|██████████| 1/1 [00:00<00:00, 23.26it/s]
avg val loss: 0.6611: 100%|██████████| 1/1 [00:00<00:00, 129.27it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 66 / 300


avg train loss: 0.6091: 100%|██████████| 1/1 [00:00<00:00, 25.74it/s]
avg val loss: 0.6676: 100%|██████████| 1/1 [00:00<00:00, 98.25it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 67 / 300


avg train loss: 0.6084: 100%|██████████| 1/1 [00:00<00:00, 24.56it/s]
avg val loss: 0.6666: 100%|██████████| 1/1 [00:00<00:00, 103.13it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 68 / 300


avg train loss: 0.6071: 100%|██████████| 1/1 [00:00<00:00, 23.82it/s]
avg val loss: 0.6588: 100%|██████████| 1/1 [00:00<00:00, 93.02it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 69 / 300


avg train loss: 0.6052: 100%|██████████| 1/1 [00:00<00:00, 18.27it/s]
avg val loss: 0.6476: 100%|██████████| 1/1 [00:00<00:00, 74.12it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 70 / 300


avg train loss: 0.6033: 100%|██████████| 1/1 [00:00<00:00, 23.01it/s]
avg val loss: 0.6367: 100%|██████████| 1/1 [00:00<00:00, 153.54it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 71 / 300


avg train loss: 0.6021: 100%|██████████| 1/1 [00:00<00:00, 23.16it/s]
avg val loss: 0.6292: 100%|██████████| 1/1 [00:00<00:00, 138.19it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 72 / 300


avg train loss: 0.6011: 100%|██████████| 1/1 [00:00<00:00, 25.97it/s]
avg val loss: 0.6262: 100%|██████████| 1/1 [00:00<00:00, 94.39it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 73 / 300


avg train loss: 0.5999: 100%|██████████| 1/1 [00:00<00:00, 25.16it/s]
avg val loss: 0.6278: 100%|██████████| 1/1 [00:00<00:00, 115.51it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 74 / 300


avg train loss: 0.5983: 100%|██████████| 1/1 [00:00<00:00, 24.02it/s]
avg val loss: 0.6328: 100%|██████████| 1/1 [00:00<00:00, 132.10it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 75 / 300


avg train loss: 0.5966: 100%|██████████| 1/1 [00:00<00:00, 23.05it/s]
avg val loss: 0.6391: 100%|██████████| 1/1 [00:00<00:00, 130.54it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 76 / 300


avg train loss: 0.5952: 100%|██████████| 1/1 [00:00<00:00, 23.94it/s]
avg val loss: 0.6442: 100%|██████████| 1/1 [00:00<00:00, 107.89it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 77 / 300


avg train loss: 0.5940: 100%|██████████| 1/1 [00:00<00:00, 21.84it/s]
avg val loss: 0.6460: 100%|██████████| 1/1 [00:00<00:00, 121.63it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 78 / 300


avg train loss: 0.5928: 100%|██████████| 1/1 [00:00<00:00, 24.58it/s]
avg val loss: 0.6438: 100%|██████████| 1/1 [00:00<00:00, 132.88it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 79 / 300


avg train loss: 0.5914: 100%|██████████| 1/1 [00:00<00:00, 24.90it/s]
avg val loss: 0.6384: 100%|██████████| 1/1 [00:00<00:00, 144.88it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 80 / 300


avg train loss: 0.5899: 100%|██████████| 1/1 [00:00<00:00, 27.99it/s]
avg val loss: 0.6317: 100%|██████████| 1/1 [00:00<00:00, 118.58it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 81 / 300


avg train loss: 0.5884: 100%|██████████| 1/1 [00:00<00:00, 24.75it/s]
avg val loss: 0.6257: 100%|██████████| 1/1 [00:00<00:00, 115.83it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 82 / 300


avg train loss: 0.5871: 100%|██████████| 1/1 [00:00<00:00, 25.48it/s]
avg val loss: 0.6218: 100%|██████████| 1/1 [00:00<00:00, 109.29it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 83 / 300


avg train loss: 0.5859: 100%|██████████| 1/1 [00:00<00:00, 25.66it/s]
avg val loss: 0.6206: 100%|██████████| 1/1 [00:00<00:00, 103.87it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 84 / 300


avg train loss: 0.5846: 100%|██████████| 1/1 [00:00<00:00, 24.68it/s]
avg val loss: 0.6219: 100%|██████████| 1/1 [00:00<00:00, 134.77it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 85 / 300


avg train loss: 0.5831: 100%|██████████| 1/1 [00:00<00:00, 19.78it/s]
avg val loss: 0.6248: 100%|██████████| 1/1 [00:00<00:00, 101.85it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 86 / 300


avg train loss: 0.5817: 100%|██████████| 1/1 [00:00<00:00, 24.00it/s]
avg val loss: 0.6278: 100%|██████████| 1/1 [00:00<00:00, 123.64it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 87 / 300


avg train loss: 0.5804: 100%|██████████| 1/1 [00:00<00:00, 23.95it/s]
avg val loss: 0.6297: 100%|██████████| 1/1 [00:00<00:00, 137.31it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 88 / 300


avg train loss: 0.5791: 100%|██████████| 1/1 [00:00<00:00, 22.09it/s]
avg val loss: 0.6296: 100%|██████████| 1/1 [00:00<00:00, 127.69it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 89 / 300


avg train loss: 0.5779: 100%|██████████| 1/1 [00:00<00:00, 25.69it/s]
avg val loss: 0.6274: 100%|██████████| 1/1 [00:00<00:00, 108.44it/s]


---  Accuracy : 0.5555555820465088 

Waiting += 1

Starting epoch 90 / 300


avg train loss: 0.5765: 100%|██████████| 1/1 [00:00<00:00, 27.57it/s]
avg val loss: 0.6237: 100%|██████████| 1/1 [00:00<00:00, 109.88it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 91 / 300


avg train loss: 0.5751: 100%|██████████| 1/1 [00:00<00:00, 24.27it/s]
avg val loss: 0.6196: 100%|██████████| 1/1 [00:00<00:00, 143.44it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 92 / 300


avg train loss: 0.5738: 100%|██████████| 1/1 [00:00<00:00, 25.80it/s]
avg val loss: 0.6163: 100%|██████████| 1/1 [00:00<00:00, 133.90it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 93 / 300


avg train loss: 0.5726: 100%|██████████| 1/1 [00:00<00:00, 25.26it/s]
avg val loss: 0.6143: 100%|██████████| 1/1 [00:00<00:00, 139.82it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 94 / 300


avg train loss: 0.5713: 100%|██████████| 1/1 [00:00<00:00, 22.53it/s]
avg val loss: 0.6138: 100%|██████████| 1/1 [00:00<00:00, 143.59it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 95 / 300


avg train loss: 0.5700: 100%|██████████| 1/1 [00:00<00:00, 24.18it/s]
avg val loss: 0.6146: 100%|██████████| 1/1 [00:00<00:00, 139.16it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 96 / 300


avg train loss: 0.5687: 100%|██████████| 1/1 [00:00<00:00, 26.29it/s]
avg val loss: 0.6159: 100%|██████████| 1/1 [00:00<00:00, 115.73it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 97 / 300


avg train loss: 0.5674: 100%|██████████| 1/1 [00:00<00:00, 25.26it/s]
avg val loss: 0.6170: 100%|██████████| 1/1 [00:00<00:00, 139.70it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 98 / 300


avg train loss: 0.5662: 100%|██████████| 1/1 [00:00<00:00, 25.14it/s]
avg val loss: 0.6173: 100%|██████████| 1/1 [00:00<00:00, 136.59it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 99 / 300


avg train loss: 0.5649: 100%|██████████| 1/1 [00:00<00:00, 25.96it/s]
avg val loss: 0.6163: 100%|██████████| 1/1 [00:00<00:00, 101.81it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 100 / 300


avg train loss: 0.5637: 100%|██████████| 1/1 [00:00<00:00, 22.48it/s]
avg val loss: 0.6143: 100%|██████████| 1/1 [00:00<00:00, 122.59it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 101 / 300


avg train loss: 0.5624: 100%|██████████| 1/1 [00:00<00:00, 23.36it/s]
avg val loss: 0.6117: 100%|██████████| 1/1 [00:00<00:00, 113.16it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 102 / 300


avg train loss: 0.5611: 100%|██████████| 1/1 [00:00<00:00, 23.98it/s]
avg val loss: 0.6091: 100%|██████████| 1/1 [00:00<00:00, 111.64it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 103 / 300


avg train loss: 0.5599: 100%|██████████| 1/1 [00:00<00:00, 27.27it/s]
avg val loss: 0.6071: 100%|██████████| 1/1 [00:00<00:00, 128.53it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 104 / 300


avg train loss: 0.5587: 100%|██████████| 1/1 [00:00<00:00, 23.89it/s]
avg val loss: 0.6059: 100%|██████████| 1/1 [00:00<00:00, 100.83it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 105 / 300


avg train loss: 0.5575: 100%|██████████| 1/1 [00:00<00:00, 21.49it/s]
avg val loss: 0.6055: 100%|██████████| 1/1 [00:00<00:00, 98.19it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 106 / 300


avg train loss: 0.5562: 100%|██████████| 1/1 [00:00<00:00, 21.77it/s]
avg val loss: 0.6056: 100%|██████████| 1/1 [00:00<00:00, 136.94it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 107 / 300


avg train loss: 0.5550: 100%|██████████| 1/1 [00:00<00:00, 24.87it/s]
avg val loss: 0.6059: 100%|██████████| 1/1 [00:00<00:00, 139.94it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 108 / 300


avg train loss: 0.5538: 100%|██████████| 1/1 [00:00<00:00, 24.91it/s]
avg val loss: 0.6060: 100%|██████████| 1/1 [00:00<00:00, 106.90it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 109 / 300


avg train loss: 0.5525: 100%|██████████| 1/1 [00:00<00:00, 21.98it/s]
avg val loss: 0.6054: 100%|██████████| 1/1 [00:00<00:00, 134.84it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 110 / 300


avg train loss: 0.5513: 100%|██████████| 1/1 [00:00<00:00, 24.31it/s]
avg val loss: 0.6042: 100%|██████████| 1/1 [00:00<00:00, 136.30it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 111 / 300


avg train loss: 0.5501: 100%|██████████| 1/1 [00:00<00:00, 23.82it/s]
avg val loss: 0.6025: 100%|██████████| 1/1 [00:00<00:00, 106.88it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 112 / 300


avg train loss: 0.5489: 100%|██████████| 1/1 [00:00<00:00, 24.76it/s]
avg val loss: 0.6005: 100%|██████████| 1/1 [00:00<00:00, 116.25it/s]


---  Accuracy : 0.6666666865348816 

Waiting += 1

Starting epoch 113 / 300


avg train loss: 0.5477: 100%|██████████| 1/1 [00:00<00:00, 24.82it/s]
avg val loss: 0.5987: 100%|██████████| 1/1 [00:00<00:00, 84.14it/s]


---  Accuracy : 0.7777777910232544 

Waiting += 1

Starting epoch 114 / 300


avg train loss: 0.5465: 100%|██████████| 1/1 [00:00<00:00, 26.19it/s]
avg val loss: 0.5973: 100%|██████████| 1/1 [00:00<00:00, 97.52it/s]


---  Accuracy : 0.7777777910232544 

Waiting += 1

Starting epoch 115 / 300


avg train loss: 0.5453: 100%|██████████| 1/1 [00:00<00:00, 20.87it/s]
avg val loss: 0.5963: 100%|██████████| 1/1 [00:00<00:00, 131.53it/s]


---  Accuracy : 0.7777777910232544 

Waiting += 1

Starting epoch 116 / 300


avg train loss: 0.5441: 100%|██████████| 1/1 [00:00<00:00, 22.47it/s]
avg val loss: 0.5958: 100%|██████████| 1/1 [00:00<00:00, 117.86it/s]


---  Accuracy : 0.7777777910232544 

Waiting += 1

Starting epoch 117 / 300


avg train loss: 0.5429: 100%|██████████| 1/1 [00:00<00:00, 23.92it/s]
avg val loss: 0.5955: 100%|██████████| 1/1 [00:00<00:00, 108.76it/s]


---  Accuracy : 0.7777777910232544 

Waiting += 1

Starting epoch 118 / 300


avg train loss: 0.5417: 100%|██████████| 1/1 [00:00<00:00, 24.34it/s]
avg val loss: 0.5952: 100%|██████████| 1/1 [00:00<00:00, 123.32it/s]


---  Accuracy : 0.7777777910232544 

Waiting += 1

Starting epoch 119 / 300


avg train loss: 0.5405: 100%|██████████| 1/1 [00:00<00:00, 22.24it/s]
avg val loss: 0.5946: 100%|██████████| 1/1 [00:00<00:00, 112.10it/s]

---  Accuracy : 0.7777777910232544 

Waiting += 1
Stop training at epoch 119
Best val loss : 0.5866





CommonSpatialFilterModel(
  (spatial_filtering): Conv2d(1, 30, kernel_size=(64, 1), stride=(1, 1), bias=False)
  (classifier): Sequential(
    (0): Linear(in_features=30, out_features=2, bias=True)
  )
)