In [40]:
import random
import time
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy as sk
import pickle
import mne
import tqdm
from tqdm import trange

import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import torch_directml
DEVICE = torch_directml.device()
from torch import nn

In [41]:
data=np.load('/home/pigmaster96/openneuroNMAdata/ds005540-download/derivatives/sub-01/ses-vid/eeg/sub-01_ses-vid_task-emotion_reorder.npy',allow_pickle=True)
data=np.permute_dims(data,[1,0,2])
data.shape


(21, 64, 6000)

In [42]:
#feature conversion from preprocessed signal to power spectrum, welch method with default params
data=np.load('/home/pigmaster96/openneuroNMAdata/ds005540-download/derivatives/sub-01/ses-vid/eeg/sub-01_ses-vid_task-emotion_reorder.npy',allow_pickle=True)
data=np.permute_dims(data,[1,0,2])
data.shape

cnames=['Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3','AF4','AF8', 'F7', 'F5','F3','F1','Fz', 'F2', 'F4', 'F6', 'F8',
'FT7', 'FC5', 'FC3', 'FC1','FCz','FC2','FC4', 'FC6', 'FT8', 'T7','C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'T8',
'TP7', 'CP5', 'CP3', 'CP1','CPz','CP2', 'CP4','CP6', 'TP8', 'P7','P5', 'P3', 'P1', 'Pz','P2', 'P4', 'P6', 'P8',
'PO7', 'PO3','POz', 'PO4','PO8', 'O1','Oz','O2', 'F9', 'F10', 'TP9', 'TP10']
info=mne.create_info(cnames,200,'eeg')

def eeg_to_spectrum(eeg,info):
    epochs=mne.EpochsArray(data,info)
    out=epochs.compute_psd(method='welch',fmin=0.1,fmax=47,verbose=None).get_data()
    return out

temp=eeg_to_spectrum(data,info=info)
temp.shape

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


(21, 64, 480)

In [43]:
cnames=['Fp1', 'Fpz', 'Fp2', 'AF7', 'AF3','AF4','AF8', 'F7', 'F5','F3','F1','Fz', 'F2', 'F4', 'F6', 'F8',
'FT7', 'FC5', 'FC3', 'FC1','FCz','FC2','FC4', 'FC6', 'FT8', 'T7','C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'T8',
'TP7', 'CP5', 'CP3', 'CP1','CPz','CP2', 'CP4','CP6', 'TP8', 'P7','P5', 'P3', 'P1', 'Pz','P2', 'P4', 'P6', 'P8',
'PO7', 'PO3','POz', 'PO4','PO8', 'O1','Oz','O2', 'F9', 'F10', 'TP9', 'TP10']
info=mne.create_info(cnames,200,'eeg')

#feature conversion from preprocessed signal to power spectrum, welch method with default params
def eeg_to_spectrum(eeg,info):
    epochs=mne.EpochsArray(eeg,info)
    out=epochs.compute_psd(method='welch',fmin=0.1,fmax=47,verbose=None).get_data()
    return out


#now create input vector
#14/7/25: Kernel keeps crashing when iterating from subject 1 to 54 in one go,
#this doesn't happen when iterating from 1-25 or 30-54... maybe the vector is too large?
def extract_data(sub_range,verbose=False,feature_func=None,**kwargs):
    first=sub_range[0]
    last=sub_range[-1]

    for sub in trange(first,last+1): 
        #read data
        if sub <=9:
            path=f'/home/pigmaster96/openneuroNMAdata/ds005540-download/derivatives/sub-0{sub}/ses-vid/eeg/sub-0{sub}_ses-vid_task-emotion_reorder.npy'
        else:
            path=f'/home/pigmaster96/openneuroNMAdata/ds005540-download/derivatives/sub-{sub}/ses-vid/eeg/sub-{sub}_ses-vid_task-emotion_reorder.npy'
        datatemp=np.load(path,allow_pickle=True)
        datatemp=np.permute_dims(datatemp,[1,0,2])

        #if we want to convert features
        if not feature_func==None:
            datatemp=feature_func(datatemp,info=info)

        #cat data along new dimension
        datatemp=torch.tensor(datatemp).unsqueeze(0)
        if sub==first:
            data=datatemp
        else:
            data=torch.cat([data,datatemp],axis=0)

        if verbose:
            print(f'subject {sub} data extracted, vector size: {data.shape}')
    return data


#create labels
#sad-dis-fear-neu-joy-ten-ins correspond to 1-7 respectively, each sample has 21 trials, for (7 emotions x 3 trials)
labels=np.array([])
for i in range(1,8):
    for n in range(0,3):
        labels=np.concatenate([labels,np.array([i])],axis=0)


#extract one CNN input:

In [44]:
#MLP with power spectrum and channels
class Net(nn.Module):
    def __init__(self, actv, input_feature_num, hidden_unit_nums, output_feature_num):
        """
        Initialize MLP Network parameters

        Args:
        actv: string
            Activation function
        input_feature_num: int
            Number of input features
        hidden_unit_nums: list
            Number of units per hidden layer. List of integers
        output_feature_num: int
            Number of output features

        Returns:
        Nothing
        """
        super(Net,self).__init__()
        self.input_feature_num=input_feature_num
        self.mlp=nn.Sequential()

        in_num=input_feature_num #input at any given point in time
        for i in range(len(hidden_unit_nums)):
            out_num=hidden_unit_nums[i]
            layer=nn.Linear(in_num,out_num) #define linear layer
            in_num=out_num #update input to next layer
            self.mlp.add_module(f"Linear_{i}",layer) #add to sequential

            actv_layer=eval(f"nn.{actv}")
            self.mlp.add_module(f"Activation_{i}",actv_layer) #append activation

        out_layer=nn.Linear(in_num,output_feature_num) #create output layer
        self.mlp.add_module(f"Output_linear", out_layer) #append output layer

    def forward(self,x):
        """
        Simulate forward pass of MLP Network

        Args:
        x: torch.tensor
            Input data

        Returns:
        logits: Instance of MLP
            Forward pass of MLP
        """
        logits=self.mlp(x)
        return logits


def train_test_classification(net, criterion, optimizer, train_loader,
                              test_loader, num_epochs=1, verbose=True,
                              training_plot=False, device='cpu'):
    """
    Accumulate training loss/Evaluate performance

    Args:
        net: Instance of Net class
        Describes the model with ReLU activation, batch size 128
        criterion: torch.nn type
        Criterion combines LogSoftmax and NLLLoss in one single class.
        optimizer: torch.optim type
        Implements Adam algorithm.
        train_loader: torch.utils.data type
        Combines the train dataset and sampler, and provides an iterable over the given dataset.
        test_loader: torch.utils.data type
        Combines the test dataset and sampler, and provides an iterable over the given dataset.
        num_epochs: int
        Number of epochs [default: 1]
        verbose: boolean
        If True, print statistics
        training_plot=False
        If True, display training plot
        device: string
        CUDA/GPU if available, CPU otherwise

    Returns:
        Nothing
    """
    net.to(device)
    net.train()
    training_losses=[]
    for epoch in trange(num_epochs): #loop over epochs
        for i, data in enumerate(train_loader, start=0):
            # Get the inputs; data is a list of [inputs, labels]
            inputs,labels=data
            inputs=inputs.to(device).float()
            labels=labels.to(device).long()

            #zero gradients
            optimizer.zero_grad()

            #forward pass, backprop, optimize
            outputs=net.forward(inputs)
            
            loss=criterion(outputs,labels)
            loss.backward()
            optimizer.step()

            # Print statistics
            if verbose:
                training_losses += [loss.item()]
    net.eval()

    def test(data_loader):
        """
        Function to gauge network performance

        Args:
        data_loader: torch.utils.data type
            Combines the test dataset and sampler, and provides an iterable over the given dataset.

        Returns:
        acc: float
            Performance of the network
        total: int
            Number of datapoints in the dataloader
        """
        correct=0
        total=0
        for data in data_loader:
            inputs,labels=data
            inputs = inputs.to(device).float()
            labels = labels.to(device).long()

            outputs=net(inputs)
            _,predicted=torch.max(outputs,1)
            total+=labels.size(0)
            correct+=(predicted==labels).sum()
        
        acc=100*correct/total
        return total,acc
    
    train_total,train_acc=test(train_loader)
    test_total,test_acc=test(test_loader)

    if verbose:
        print(f'\nAccuracy on the {train_total} training samples: {train_acc:0.2f}')
        print(f'Accuracy on the {test_total} testing samples: {test_acc:0.2f}\n')

    if training_plot:
        plt.plot(training_losses)
        plt.xlabel('Batch')
        plt.ylabel('Training loss')
        plt.show()
    return train_acc,test_acc

def shuffle_and_split_data(X,y,seed):
    """
    Helper function to shuffle and split data

    Args:
        X: torch.tensor
        Input data
        y: torch.tensor
        Corresponding target variables
        seed: int
        Set seed for reproducibility

    Returns:
        X_test: torch.tensor
        Test data [20% of X]
        y_test: torch.tensor
        Labels corresponding to above mentioned test data
        X_train: torch.tensor
        Train data [80% of X]
        y_train: torch.tensor
        Labels corresponding to above mentioned train data
    """
    torch.manual_seed(seed)
    N=X.size(0)
    shuffled_indices=torch.randperm(N) #get shuffled indices
    X=X[shuffled_indices]
    y=y[shuffled_indices]

    # split by 20% into train-test set
    test_size=int(0.2*N)
    X_train=X[test_size:]
    y_train=y[test_size:]
    X_test=X[:test_size]
    y_test=y[:test_size]

    return X_test,y_test,X_train,y_train


def set_seed(seed=None):
  """
  Function that controls randomness. NumPy and random modules must be imported.

  Args:
    seed : Integer
      A non-negative integer that defines the random state. Default is `None`.
    seed_torch : Boolean
      If `True` sets the random seed for pytorch tensors, so pytorch module
      must be imported. Default is `True`.

  Returns:
    Nothing.
  """
  if seed is None:
    seed = np.random.choice(2 ** 32)
  random.seed(seed)
  np.random.seed(seed)

  print(f'Random seed {seed} has been set.')

# In case that `DataLoader` is used
def seed_worker(worker_id):
  """
  DataLoader will reseed workers following randomness in
  multi-process data loading algorithm.

  Args:
    worker_id: integer
      ID of subprocess to seed. 0 means that
      the data will be loaded in the main process
      Refer: https://pytorch.org/docs/stable/data.html#data-loading-randomness for more details

  Returns:
    Nothing
  """
  worker_seed = torch.initial_seed() % 2**32
  np.random.seed(worker_seed)
  random.seed(worker_seed)

In [45]:
data_1=extract_data([1,40],feature_func=eeg_to_spectrum,info=info)

  0%|          | 0/40 [00:00<?, ?it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


  2%|▎         | 1/40 [00:00<00:17,  2.28it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


  5%|▌         | 2/40 [00:00<00:17,  2.15it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


  8%|▊         | 3/40 [00:01<00:18,  2.02it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 10%|█         | 4/40 [00:01<00:18,  1.99it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 12%|█▎        | 5/40 [00:02<00:17,  2.05it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 15%|█▌        | 6/40 [00:02<00:16,  2.08it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 18%|█▊        | 7/40 [00:03<00:16,  1.98it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 20%|██        | 8/40 [00:04<00:16,  1.93it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 22%|██▎       | 9/40 [00:04<00:16,  1.90it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 25%|██▌       | 10/40 [00:05<00:15,  1.90it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 28%|██▊       | 11/40 [00:05<00:16,  1.73it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 30%|███       | 12/40 [00:06<00:16,  1.72it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 32%|███▎      | 13/40 [00:06<00:15,  1.71it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 35%|███▌      | 14/40 [00:07<00:14,  1.76it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 38%|███▊      | 15/40 [00:08<00:14,  1.73it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 40%|████      | 16/40 [00:08<00:13,  1.79it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 42%|████▎     | 17/40 [00:09<00:12,  1.83it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 45%|████▌     | 18/40 [00:09<00:11,  1.85it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 48%|████▊     | 19/40 [00:10<00:11,  1.89it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 50%|█████     | 20/40 [00:10<00:10,  1.92it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 52%|█████▎    | 21/40 [00:11<00:09,  1.92it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 55%|█████▌    | 22/40 [00:11<00:09,  1.91it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 57%|█████▊    | 23/40 [00:12<00:09,  1.88it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 60%|██████    | 24/40 [00:12<00:08,  1.91it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 62%|██████▎   | 25/40 [00:13<00:07,  1.88it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 65%|██████▌   | 26/40 [00:13<00:07,  1.90it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 68%|██████▊   | 27/40 [00:14<00:06,  1.88it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 70%|███████   | 28/40 [00:14<00:06,  1.89it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 72%|███████▎  | 29/40 [00:15<00:05,  1.87it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 75%|███████▌  | 30/40 [00:15<00:05,  1.89it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 78%|███████▊  | 31/40 [00:16<00:04,  1.87it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 80%|████████  | 32/40 [00:17<00:04,  1.85it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 82%|████████▎ | 33/40 [00:17<00:03,  1.86it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 85%|████████▌ | 34/40 [00:18<00:03,  1.86it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 88%|████████▊ | 35/40 [00:18<00:02,  1.87it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 90%|█████████ | 36/40 [00:19<00:02,  1.74it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 92%|█████████▎| 37/40 [00:19<00:01,  1.75it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 95%|█████████▌| 38/40 [00:20<00:01,  1.78it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


 98%|█████████▊| 39/40 [00:20<00:00,  1.77it/s]

Not setting metadata
21 matching events found
No baseline correction applied
0 projection items activated
Effective window size : 10.240 (s)


100%|██████████| 40/40 [00:21<00:00,  1.86it/s]


In [46]:
channels=['AF3','AF4','F7','F3','F4','F8','FC5','FC6','T7','T8','P7','P8','O1','O2']
chaninds=[]
for chan in channels:
    chaninds.append(cnames.index(chan))
data_1=data_1[:,:,chaninds,:]

In [47]:
for patient in range(data_1.size(0)):
    if patient==0:
        X=data_1[patient,:,:,:]
        y=labels
    else:
        X=np.concatenate((X,data_1[patient,:,:,:]),axis=0)
        y=np.concatenate((y,labels),axis=0)

In [48]:
print(X.shape)
print(y.shape)

(840, 14, 480)
(840,)


In [49]:
#data for one patient
SEED=2025
set_seed(SEED)

X=torch.tensor(X).squeeze().flatten(1,2)
y=torch.tensor(y)
X_test, y_test, X_train, y_train = shuffle_and_split_data(X,y,seed=SEED)
print(f"{X.size(1)} inputs")
print(X_test.shape)
print(X_train.shape)

Random seed 2025 has been set.
6720 inputs
torch.Size([168, 6720])
torch.Size([672, 6720])


In [50]:
batch_size=400
g_seed = torch.Generator()
g_seed.manual_seed(SEED)

test_data = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_data, batch_size=batch_size,
                         shuffle=False, num_workers=0,
                         worker_init_fn=seed_worker,
                         generator=g_seed,
                         )

train_data = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_data,
                          batch_size=batch_size,
                          drop_last=False,
                          shuffle=True,
                          worker_init_fn=seed_worker,
                          generator=g_seed,
                          )

In [51]:
def main(epochs):
    train_acc_vec=[]
    test_acc_vec=[]
    for epoch in epochs:
        net=Net('ReLU()',6720,[6720,6720,6720],7).to(DEVICE)
        criterion=nn.CrossEntropyLoss()
        optimizer=optim.Adam(net.parameters(),lr=3e-4)

        train_acc,test_acc=train_test_classification(net, criterion, optimizer,
                                    train_loader, test_loader,
                                    num_epochs=epoch, device=DEVICE)
        train_acc_vec.append(train_acc)
        test_acc_vec.append(test_acc)
    return train_acc_vec,test_acc_vec

In [52]:
epochs_to_test=np.arange(5,200,5)
train_acc_vec,test_acc_vec=main(epochs_to_test)
print(train_acc_vec)
print(test_acc_vec)

100%|██████████| 5/5 [00:26<00:00,  5.21s/it]



Accuracy on the 672 training samples: 17.71
Accuracy on the 168 testing samples: 11.90



100%|██████████| 10/10 [00:49<00:00,  4.91s/it]



Accuracy on the 672 training samples: 37.65
Accuracy on the 168 testing samples: 8.93



100%|██████████| 15/15 [01:14<00:00,  4.97s/it]



Accuracy on the 672 training samples: 67.11
Accuracy on the 168 testing samples: 12.50



100%|██████████| 20/20 [01:41<00:00,  5.08s/it]



Accuracy on the 672 training samples: 77.53
Accuracy on the 168 testing samples: 13.69



100%|██████████| 25/25 [01:59<00:00,  4.78s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 11.90



100%|██████████| 30/30 [02:16<00:00,  4.55s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 10.12



100%|██████████| 35/35 [02:40<00:00,  4.59s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 14.88



100%|██████████| 40/40 [03:08<00:00,  4.70s/it]



Accuracy on the 672 training samples: 77.23
Accuracy on the 168 testing samples: 20.24



100%|██████████| 45/45 [03:44<00:00,  4.98s/it]



Accuracy on the 672 training samples: 84.82
Accuracy on the 168 testing samples: 14.29



100%|██████████| 50/50 [04:06<00:00,  4.93s/it]



Accuracy on the 672 training samples: 84.52
Accuracy on the 168 testing samples: 17.26



100%|██████████| 55/55 [04:37<00:00,  5.04s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 13.10



100%|██████████| 60/60 [04:48<00:00,  4.81s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 20.83



100%|██████████| 65/65 [05:15<00:00,  4.85s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 14.88



100%|██████████| 70/70 [05:29<00:00,  4.71s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 16.07



100%|██████████| 75/75 [05:48<00:00,  4.65s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 13.69



100%|██████████| 80/80 [06:12<00:00,  4.66s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 11.90



100%|██████████| 85/85 [06:38<00:00,  4.69s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 14.88



100%|██████████| 90/90 [06:52<00:00,  4.59s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 11.31



100%|██████████| 95/95 [07:08<00:00,  4.51s/it]



Accuracy on the 672 training samples: 84.23
Accuracy on the 168 testing samples: 14.88



100%|██████████| 100/100 [07:31<00:00,  4.52s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 12.50



100%|██████████| 105/105 [07:51<00:00,  4.49s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 13.69



100%|██████████| 110/110 [08:18<00:00,  4.53s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 6.55



100%|██████████| 115/115 [08:48<00:00,  4.59s/it]



Accuracy on the 672 training samples: 85.42
Accuracy on the 168 testing samples: 8.93



  1%|          | 1/120 [00:08<16:48,  8.48s/it]


KeyboardInterrupt: 