In [7]:
import random
import time
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy as sk
import pickle
import mne
import tqdm
from tqdm import trange
import os

DEVICE='cuda'
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim

from torch import nn

In [8]:
def extract_data(sub_range,verbose=False,feature_func=None,**kwargs):
    first=sub_range[0]
    last=sub_range[-1]

    for sub in trange(first,last+1): 
        #read data
        if sub <=9:
            path=f'ds005540\derivatives\sub-0{sub}\ses-vid\eeg\sub-0{sub}_ses-vid_task-emotion_reorder.npy'
        else:
            path=f'ds005540\derivatives\sub-{sub}\ses-vid\eeg\sub-{sub}_ses-vid_task-emotion_reorder.npy'
        datatemp=np.load(path,allow_pickle=True)
        datatemp=np.permute_dims(datatemp,[1,0,2])

        #if we want to convert features
        if not feature_func==None:
            datatemp=feature_func(datatemp,info=info)

        #cat data along new dimension
        datatemp=torch.tensor(datatemp).unsqueeze(0)
        if sub==first:
            data=datatemp
        else:
            data=torch.cat([data,datatemp],axis=0)

        if verbose:
            print(f'subject {sub} data extracted, vector size: {data.shape}')
    return data

def shuffle_and_split_data(X,y):
    #get shuffled train/test/val in 8:1:1 ratio

    N=X.size(0)
    shuffled_indices=torch.randperm(N) #get shuffled indices
    X=X[shuffled_indices]
    y=y[shuffled_indices]

    # split 80% train set
    train_size=int(0.8*N)
    X_train=X[:train_size]
    y_train=y[:train_size]
    X_rest=X[train_size:]
    y_rest=y[train_size:]

    #split remainder by half for test and val
    splitind=int(len(X_rest)/2)
    X_val=X_rest[:splitind]
    y_val=y_rest[:splitind]
    X_test=X_rest[splitind:]
    y_test=y_rest[splitind:]

    return X_test,y_test,X_val,y_val,X_train,y_train

#create labels
#sad-dis-fear-neu-joy-ten-ins correspond to 0-6 respectively, each sample has 21 trials, for (7 emotions x 3 trials)
labels=np.array([])
for i in range(0,7):
    for n in range(0,3):
        labels=np.concatenate([labels,np.array([i])],axis=0)
labels=torch.tensor(labels)

  path=f'ds005540\derivatives\sub-0{sub}\ses-vid\eeg\sub-0{sub}_ses-vid_task-emotion_reorder.npy'
  path=f'ds005540\derivatives\sub-0{sub}\ses-vid\eeg\sub-0{sub}_ses-vid_task-emotion_reorder.npy'
  path=f'ds005540\derivatives\sub-{sub}\ses-vid\eeg\sub-{sub}_ses-vid_task-emotion_reorder.npy'
  path=f'ds005540\derivatives\sub-{sub}\ses-vid\eeg\sub-{sub}_ses-vid_task-emotion_reorder.npy'


In [9]:
data=extract_data([1,25])

100%|██████████| 25/25 [00:05<00:00,  4.72it/s]


In [10]:
def process_data(data):
    '''0-10s, 5-15s, 10-20s, 15-25, 20-30s ---> 5 samples per trial'''
    for interval in tqdm.tqdm(range(0,5)):
        datatrial=data[:,:,:,interval*1000:(interval+2)*1000]
        datatrial=torch.unsqueeze(datatrial,2)
        if interval==0:
            newdata=datatrial
        else:
            newdata=torch.concatenate((newdata,datatrial),2)
    return newdata

In [11]:
data=process_data(data)
data.shape #subject,label,trials,electrodes,timepoints

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:06<00:00,  1.20s/it]


torch.Size([25, 21, 5, 64, 2000])

In [12]:
for label in tqdm.tqdm(range(21)): #combine labels and trials
    temp=data[:,label,:,:,:]
    if label==0:
        X=temp
        y=labels[label]*torch.ones(5,1)
    else:
        X=torch.concatenate((X,temp),dim=1)
        y=torch.concatenate((y,labels[label]*torch.ones(5,1)),dim=0)

#combine subjects
for subject in tqdm.tqdm(range(X.size(0))):
    temp=X[subject,:,:,:]
    if subject==0:
        X_full=temp
        y_full=y
    else:
        X_full=torch.concatenate((X_full,temp),dim=0)
        y_full=torch.concatenate((y_full,y),dim=0)

print(X_full.shape)
print(y_full.shape)

#split data
X_test,y_test,X_val,y_val,X_train,y_train=shuffle_and_split_data(X_full.unsqueeze(1),y_full.squeeze())

print(X_test.shape,y_test.shape,
      X_val.shape,y_val.shape,
      X_train.shape,y_train.shape)

100%|██████████| 21/21 [00:06<00:00,  3.37it/s]
100%|██████████| 25/25 [00:08<00:00,  2.88it/s]


torch.Size([2625, 64, 2000])
torch.Size([2625, 1])
torch.Size([263, 1, 64, 2000]) torch.Size([263]) torch.Size([262, 1, 64, 2000]) torch.Size([262]) torch.Size([2100, 1, 64, 2000]) torch.Size([2100])


In [13]:
#now the model
class Net(nn.Module):
    def __init__(self,dropout):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(1,32,3,stride=1,padding=1)
        self.conv2=nn.Conv2d(32,12,3,stride=1,padding=1)
        self.conv3=nn.Conv2d(12,16,5,stride=1,padding=2)
        self.conv4=nn.Conv2d(16,12,3,stride=1,padding=1)
        self.conv5=nn.Conv2d(12,4,3,stride=1,padding=1)
        self.pool1=nn.MaxPool2d(4)
        self.fc1=nn.Linear(2000,7)
        self.dropout1=nn.Dropout2d(dropout)
        self.batchnorm1=nn.BatchNorm2d(32)
        self.batchnorm2=nn.BatchNorm2d(12)
        self.batchnorm3=nn.BatchNorm2d(16)
        self.batchnorm4=nn.BatchNorm2d(4)

    def forward(self,x):
        x=self.conv1(x)
        x=self.batchnorm1(x)
        x=self.dropout1(x)
        x=nn.functional.relu(x)
        x=self.conv2(x)
        x=self.batchnorm2(x)
        x=self.dropout1(x)
        x=nn.functional.relu(x)
        x=self.pool1(x)
        x=self.conv3(x)
        x=self.batchnorm3(x)
        x=self.dropout1(x)
        x=nn.functional.relu(x)
        x=self.conv4(x)
        x=self.batchnorm2(x)
        x=self.dropout1(x)
        x=nn.functional.relu(x)
        x=self.conv5(x)
        x=self.batchnorm4(x)
        x=self.dropout1(x)
        x=nn.functional.relu(x)
        x=self.pool1(x)
        x=torch.flatten(x,1)
        x=self.fc1(x)
        return x
    


def train_test(net,epochs,train_loader,test_loader,device):
    criterion=nn.CrossEntropyLoss()
    optimizer=optim.Adam(net.parameters(),lr=1e-3)
    train_acc=[]
    train_loss=[]
    test_acc=[]
    test_loss=[]
    net.to(device)
    for epoch in tqdm.tqdm(range(epochs)):
        net.train()
        running_loss=0.0
        correct,total=0,0
        for i,data in enumerate(train_loader,start=0):
            inputs,labels=data
            inputs=inputs.to(device).float()
            labels=labels.to(device).long()
            
            #train
            optimizer.zero_grad()
            outputs=net.forward(inputs)
            loss=criterion(outputs,labels)
            loss.backward()
            optimizer.step()

            running_loss+=loss.item()
            #training accuracy
            _,predicted=torch.max(outputs,1)
            total+=labels.size(0)
            correct+=(predicted==labels).sum()
        train_loss.append(running_loss/len(train_loader))
        train_acc.append(correct/total)
        print(f"epoch {epoch} --> TRAIN loss: {running_loss/len(train_loader):.6f}, TRAIN accuracy: {correct/total:.2f}")

        #eval on test
        net.eval()
        running_loss=0.0
        correct,total=0,0
        for inputs,labels in test_loader:
            inputs,labels=inputs.to(device).float(),labels.to(device).long()
            outputs=net.forward(inputs)
            loss=criterion(outputs,labels)
            running_loss+=loss.item()

            #test acc
            _,predicted=torch.max(outputs,1)
            total+=labels.size(0)
            correct+=(predicted==labels).sum()
        test_loss.append(running_loss/len(test_loader))
        test_acc.append(correct/total)
        print(f"epoch {epoch} --> TEST loss: {running_loss/len(train_loader):.2f}, TEST accuracy: {correct/total:.2f}")

    return train_loss,train_acc,test_loss,test_acc


batch_size=150
test_data = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_data, batch_size=batch_size,
                         shuffle=False,num_workers=0
                         )

train_data = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_data,
                          batch_size=batch_size,
                          drop_last=False,
                          shuffle=True,num_workers=0
                          )

In [14]:
net=Net(dropout=0.4)
train_loss,train_acc,test_loss,test_acc=train_test(net,150,train_loader=train_loader,test_loader=test_loader,device=DEVICE)

  0%|          | 0/150 [00:00<?, ?it/s]

epoch 0 --> TRAIN loss: 2.271819, TRAIN accuracy: 0.14


  1%|          | 1/150 [00:09<24:00,  9.67s/it]

epoch 0 --> TEST loss: 0.28, TEST accuracy: 0.14
epoch 1 --> TRAIN loss: 2.048028, TRAIN accuracy: 0.14


  1%|▏         | 2/150 [00:56<1:17:27, 31.40s/it]

epoch 1 --> TEST loss: 0.28, TEST accuracy: 0.17
epoch 2 --> TRAIN loss: 1.981871, TRAIN accuracy: 0.16


  2%|▏         | 3/150 [01:35<1:25:31, 34.91s/it]

epoch 2 --> TEST loss: 0.28, TEST accuracy: 0.13
epoch 3 --> TRAIN loss: 1.972720, TRAIN accuracy: 0.17


  3%|▎         | 4/150 [02:14<1:28:39, 36.43s/it]

epoch 3 --> TEST loss: 0.28, TEST accuracy: 0.13
epoch 4 --> TRAIN loss: 1.973667, TRAIN accuracy: 0.16


  3%|▎         | 5/150 [02:53<1:30:29, 37.44s/it]

epoch 4 --> TEST loss: 0.28, TEST accuracy: 0.14
epoch 5 --> TRAIN loss: 1.953385, TRAIN accuracy: 0.16


  4%|▍         | 6/150 [03:32<1:31:09, 37.98s/it]

epoch 5 --> TEST loss: 0.28, TEST accuracy: 0.13
epoch 6 --> TRAIN loss: 1.953198, TRAIN accuracy: 0.16


  5%|▍         | 7/150 [04:10<1:30:59, 38.18s/it]

epoch 6 --> TEST loss: 0.28, TEST accuracy: 0.16
epoch 7 --> TRAIN loss: 1.957938, TRAIN accuracy: 0.17


  5%|▌         | 8/150 [04:49<1:30:37, 38.29s/it]

epoch 7 --> TEST loss: 0.28, TEST accuracy: 0.14
epoch 8 --> TRAIN loss: 1.949055, TRAIN accuracy: 0.16


  6%|▌         | 9/150 [05:27<1:30:04, 38.33s/it]

epoch 8 --> TEST loss: 0.28, TEST accuracy: 0.14
epoch 9 --> TRAIN loss: 1.929253, TRAIN accuracy: 0.16


  7%|▋         | 10/150 [06:06<1:29:30, 38.36s/it]

epoch 9 --> TEST loss: 0.28, TEST accuracy: 0.15
epoch 10 --> TRAIN loss: 1.946913, TRAIN accuracy: 0.16


  7%|▋         | 11/150 [06:44<1:28:56, 38.39s/it]

epoch 10 --> TEST loss: 0.28, TEST accuracy: 0.14
epoch 11 --> TRAIN loss: 1.936640, TRAIN accuracy: 0.17


  8%|▊         | 12/150 [07:23<1:28:25, 38.45s/it]

epoch 11 --> TEST loss: 0.28, TEST accuracy: 0.12
epoch 12 --> TRAIN loss: 1.915481, TRAIN accuracy: 0.18


  9%|▊         | 13/150 [08:01<1:27:30, 38.33s/it]

epoch 12 --> TEST loss: 0.28, TEST accuracy: 0.15
epoch 13 --> TRAIN loss: 1.917199, TRAIN accuracy: 0.20


  9%|▉         | 14/150 [08:40<1:27:07, 38.44s/it]

epoch 13 --> TEST loss: 0.28, TEST accuracy: 0.14
epoch 14 --> TRAIN loss: 1.929544, TRAIN accuracy: 0.18


 10%|█         | 15/150 [09:18<1:26:29, 38.44s/it]

epoch 14 --> TEST loss: 0.28, TEST accuracy: 0.15
epoch 15 --> TRAIN loss: 1.912499, TRAIN accuracy: 0.19


 11%|█         | 16/150 [09:56<1:25:49, 38.43s/it]

epoch 15 --> TEST loss: 0.28, TEST accuracy: 0.15
epoch 16 --> TRAIN loss: 1.908287, TRAIN accuracy: 0.19


 11%|█▏        | 17/150 [10:35<1:25:13, 38.44s/it]

epoch 16 --> TEST loss: 0.28, TEST accuracy: 0.14
epoch 17 --> TRAIN loss: 1.897423, TRAIN accuracy: 0.20


 12%|█▏        | 18/150 [11:14<1:25:17, 38.77s/it]

epoch 17 --> TEST loss: 0.28, TEST accuracy: 0.18
epoch 18 --> TRAIN loss: 1.891055, TRAIN accuracy: 0.21


 13%|█▎        | 19/150 [11:54<1:25:07, 38.99s/it]

epoch 18 --> TEST loss: 0.28, TEST accuracy: 0.17
epoch 19 --> TRAIN loss: 1.895894, TRAIN accuracy: 0.21


 13%|█▎        | 20/150 [12:32<1:24:03, 38.79s/it]

epoch 19 --> TEST loss: 0.28, TEST accuracy: 0.19
epoch 20 --> TRAIN loss: 1.887508, TRAIN accuracy: 0.22


 14%|█▍        | 21/150 [13:11<1:23:12, 38.70s/it]

epoch 20 --> TEST loss: 0.28, TEST accuracy: 0.15
epoch 21 --> TRAIN loss: 1.881059, TRAIN accuracy: 0.22


 15%|█▍        | 22/150 [13:50<1:22:34, 38.70s/it]

epoch 21 --> TEST loss: 0.28, TEST accuracy: 0.17
epoch 22 --> TRAIN loss: 1.888076, TRAIN accuracy: 0.21


 15%|█▌        | 23/150 [14:28<1:21:49, 38.65s/it]

epoch 22 --> TEST loss: 0.29, TEST accuracy: 0.15
epoch 23 --> TRAIN loss: 1.883987, TRAIN accuracy: 0.21


 16%|█▌        | 24/150 [15:07<1:21:05, 38.62s/it]

epoch 23 --> TEST loss: 0.28, TEST accuracy: 0.17
epoch 24 --> TRAIN loss: 1.862268, TRAIN accuracy: 0.22


 17%|█▋        | 25/150 [15:45<1:20:14, 38.51s/it]

epoch 24 --> TEST loss: 0.28, TEST accuracy: 0.20
epoch 25 --> TRAIN loss: 1.856043, TRAIN accuracy: 0.24


 17%|█▋        | 26/150 [16:23<1:19:36, 38.52s/it]

epoch 25 --> TEST loss: 0.29, TEST accuracy: 0.19
epoch 26 --> TRAIN loss: 1.864437, TRAIN accuracy: 0.23


 18%|█▊        | 27/150 [17:02<1:18:59, 38.53s/it]

epoch 26 --> TEST loss: 0.28, TEST accuracy: 0.18
epoch 27 --> TRAIN loss: 1.851648, TRAIN accuracy: 0.24


 19%|█▊        | 28/150 [17:40<1:18:15, 38.49s/it]

epoch 27 --> TEST loss: 0.29, TEST accuracy: 0.19
epoch 28 --> TRAIN loss: 1.860507, TRAIN accuracy: 0.24


 19%|█▉        | 29/150 [18:18<1:17:20, 38.35s/it]

epoch 28 --> TEST loss: 0.28, TEST accuracy: 0.17
epoch 29 --> TRAIN loss: 1.856949, TRAIN accuracy: 0.24


 20%|██        | 30/150 [18:57<1:16:47, 38.40s/it]

epoch 29 --> TEST loss: 0.28, TEST accuracy: 0.19
epoch 30 --> TRAIN loss: 1.874816, TRAIN accuracy: 0.23


 21%|██        | 31/150 [19:35<1:16:04, 38.36s/it]

epoch 30 --> TEST loss: 0.29, TEST accuracy: 0.19
epoch 31 --> TRAIN loss: 1.834100, TRAIN accuracy: 0.25


 21%|██▏       | 32/150 [20:14<1:15:43, 38.51s/it]

epoch 31 --> TEST loss: 0.29, TEST accuracy: 0.19
epoch 32 --> TRAIN loss: 1.834514, TRAIN accuracy: 0.26


 22%|██▏       | 33/150 [20:53<1:15:12, 38.57s/it]

epoch 32 --> TEST loss: 0.29, TEST accuracy: 0.18
epoch 33 --> TRAIN loss: 1.839765, TRAIN accuracy: 0.26


 23%|██▎       | 34/150 [21:32<1:15:02, 38.82s/it]

epoch 33 --> TEST loss: 0.28, TEST accuracy: 0.18
epoch 34 --> TRAIN loss: 1.843725, TRAIN accuracy: 0.25


 23%|██▎       | 35/150 [22:11<1:14:39, 38.95s/it]

epoch 34 --> TEST loss: 0.29, TEST accuracy: 0.19
epoch 35 --> TRAIN loss: 1.836202, TRAIN accuracy: 0.27


 24%|██▍       | 36/150 [22:52<1:14:45, 39.34s/it]

epoch 35 --> TEST loss: 0.29, TEST accuracy: 0.18
epoch 36 --> TRAIN loss: 1.808443, TRAIN accuracy: 0.27


 25%|██▍       | 37/150 [23:30<1:13:43, 39.15s/it]

epoch 36 --> TEST loss: 0.29, TEST accuracy: 0.19
epoch 37 --> TRAIN loss: 1.808096, TRAIN accuracy: 0.27


 25%|██▌       | 38/150 [24:09<1:12:51, 39.03s/it]

epoch 37 --> TEST loss: 0.29, TEST accuracy: 0.17
epoch 38 --> TRAIN loss: 1.812246, TRAIN accuracy: 0.29


 26%|██▌       | 39/150 [24:48<1:12:00, 38.92s/it]

epoch 38 --> TEST loss: 0.29, TEST accuracy: 0.16
epoch 39 --> TRAIN loss: 1.794885, TRAIN accuracy: 0.29


 27%|██▋       | 40/150 [25:27<1:11:21, 38.92s/it]

epoch 39 --> TEST loss: 0.29, TEST accuracy: 0.17
epoch 40 --> TRAIN loss: 1.805074, TRAIN accuracy: 0.30


 27%|██▋       | 41/150 [26:06<1:10:47, 38.97s/it]

epoch 40 --> TEST loss: 0.29, TEST accuracy: 0.18
epoch 41 --> TRAIN loss: 1.776436, TRAIN accuracy: 0.30


 28%|██▊       | 42/150 [26:44<1:09:47, 38.78s/it]

epoch 41 --> TEST loss: 0.29, TEST accuracy: 0.18
epoch 42 --> TRAIN loss: 1.781550, TRAIN accuracy: 0.30


 29%|██▊       | 43/150 [27:22<1:08:53, 38.63s/it]

epoch 42 --> TEST loss: 0.29, TEST accuracy: 0.18
epoch 43 --> TRAIN loss: 1.770055, TRAIN accuracy: 0.31


 29%|██▉       | 44/150 [28:01<1:08:05, 38.54s/it]

epoch 43 --> TEST loss: 0.29, TEST accuracy: 0.15
epoch 44 --> TRAIN loss: 1.774204, TRAIN accuracy: 0.31


 30%|███       | 45/150 [28:39<1:07:16, 38.44s/it]

epoch 44 --> TEST loss: 0.29, TEST accuracy: 0.17
epoch 45 --> TRAIN loss: 1.749674, TRAIN accuracy: 0.32


 31%|███       | 46/150 [29:17<1:06:26, 38.34s/it]

epoch 45 --> TEST loss: 0.29, TEST accuracy: 0.16
epoch 46 --> TRAIN loss: 1.748939, TRAIN accuracy: 0.30


 31%|███▏      | 47/150 [29:55<1:05:39, 38.25s/it]

epoch 46 --> TEST loss: 0.29, TEST accuracy: 0.14
epoch 47 --> TRAIN loss: 1.749330, TRAIN accuracy: 0.31


 32%|███▏      | 48/150 [30:34<1:05:14, 38.37s/it]

epoch 47 --> TEST loss: 0.29, TEST accuracy: 0.15
epoch 48 --> TRAIN loss: 1.759582, TRAIN accuracy: 0.31


 33%|███▎      | 49/150 [31:12<1:04:48, 38.50s/it]

epoch 48 --> TEST loss: 0.29, TEST accuracy: 0.16
epoch 49 --> TRAIN loss: 1.745487, TRAIN accuracy: 0.33


 33%|███▎      | 50/150 [31:52<1:04:53, 38.93s/it]

epoch 49 --> TEST loss: 0.30, TEST accuracy: 0.16
epoch 50 --> TRAIN loss: 1.726506, TRAIN accuracy: 0.32


 34%|███▍      | 51/150 [32:32<1:04:26, 39.06s/it]

epoch 50 --> TEST loss: 0.30, TEST accuracy: 0.11
epoch 51 --> TRAIN loss: 1.705094, TRAIN accuracy: 0.33


 35%|███▍      | 52/150 [33:11<1:03:52, 39.11s/it]

epoch 51 --> TEST loss: 0.30, TEST accuracy: 0.13
epoch 52 --> TRAIN loss: 1.709896, TRAIN accuracy: 0.34


 35%|███▌      | 53/150 [33:50<1:02:58, 38.96s/it]

epoch 52 --> TEST loss: 0.30, TEST accuracy: 0.15
epoch 53 --> TRAIN loss: 1.695126, TRAIN accuracy: 0.35


 36%|███▌      | 54/150 [34:28<1:02:07, 38.83s/it]

epoch 53 --> TEST loss: 0.30, TEST accuracy: 0.13
epoch 54 --> TRAIN loss: 1.682635, TRAIN accuracy: 0.35


 37%|███▋      | 55/150 [35:07<1:01:17, 38.72s/it]

epoch 54 --> TEST loss: 0.30, TEST accuracy: 0.14
epoch 55 --> TRAIN loss: 1.675242, TRAIN accuracy: 0.34


 37%|███▋      | 56/150 [35:45<1:00:42, 38.74s/it]

epoch 55 --> TEST loss: 0.30, TEST accuracy: 0.14
epoch 56 --> TRAIN loss: 1.685150, TRAIN accuracy: 0.37


 38%|███▊      | 57/150 [36:24<1:00:01, 38.73s/it]

epoch 56 --> TEST loss: 0.30, TEST accuracy: 0.13
epoch 57 --> TRAIN loss: 1.703214, TRAIN accuracy: 0.34


 39%|███▊      | 58/150 [37:03<59:16, 38.66s/it]  

epoch 57 --> TEST loss: 0.30, TEST accuracy: 0.14
epoch 58 --> TRAIN loss: 1.667397, TRAIN accuracy: 0.36


 39%|███▉      | 59/150 [37:41<58:28, 38.55s/it]

epoch 58 --> TEST loss: 0.30, TEST accuracy: 0.15
epoch 59 --> TRAIN loss: 1.657791, TRAIN accuracy: 0.36


 40%|████      | 60/150 [38:20<57:57, 38.64s/it]

epoch 59 --> TEST loss: 0.30, TEST accuracy: 0.15
epoch 60 --> TRAIN loss: 1.679106, TRAIN accuracy: 0.37


 41%|████      | 61/150 [38:58<57:19, 38.64s/it]

epoch 60 --> TEST loss: 0.31, TEST accuracy: 0.14
epoch 61 --> TRAIN loss: 1.656876, TRAIN accuracy: 0.36


 41%|████▏     | 62/150 [39:37<56:38, 38.62s/it]

epoch 61 --> TEST loss: 0.30, TEST accuracy: 0.14
epoch 62 --> TRAIN loss: 1.619414, TRAIN accuracy: 0.39


 42%|████▏     | 63/150 [40:16<56:03, 38.66s/it]

epoch 62 --> TEST loss: 0.30, TEST accuracy: 0.14
epoch 63 --> TRAIN loss: 1.638388, TRAIN accuracy: 0.37


 43%|████▎     | 64/150 [40:55<55:29, 38.72s/it]

epoch 63 --> TEST loss: 0.30, TEST accuracy: 0.14
epoch 64 --> TRAIN loss: 1.610759, TRAIN accuracy: 0.40


 43%|████▎     | 65/150 [41:33<54:50, 38.71s/it]

epoch 64 --> TEST loss: 0.30, TEST accuracy: 0.13
epoch 65 --> TRAIN loss: 1.603018, TRAIN accuracy: 0.40


 44%|████▍     | 66/150 [42:12<54:02, 38.60s/it]

epoch 65 --> TEST loss: 0.31, TEST accuracy: 0.14
epoch 66 --> TRAIN loss: 1.593655, TRAIN accuracy: 0.40


 45%|████▍     | 67/150 [42:50<53:13, 38.48s/it]

epoch 66 --> TEST loss: 0.30, TEST accuracy: 0.15
epoch 67 --> TRAIN loss: 1.587658, TRAIN accuracy: 0.41


 45%|████▌     | 68/150 [43:28<52:33, 38.46s/it]

epoch 67 --> TEST loss: 0.30, TEST accuracy: 0.12
epoch 68 --> TRAIN loss: 1.579430, TRAIN accuracy: 0.39


 46%|████▌     | 69/150 [44:06<51:50, 38.40s/it]

epoch 68 --> TEST loss: 0.31, TEST accuracy: 0.13
epoch 69 --> TRAIN loss: 1.571046, TRAIN accuracy: 0.41


 47%|████▋     | 70/150 [44:45<51:08, 38.36s/it]

epoch 69 --> TEST loss: 0.31, TEST accuracy: 0.13
epoch 70 --> TRAIN loss: 1.534807, TRAIN accuracy: 0.43


 47%|████▋     | 71/150 [45:23<50:31, 38.38s/it]

epoch 70 --> TEST loss: 0.31, TEST accuracy: 0.14
epoch 71 --> TRAIN loss: 1.553725, TRAIN accuracy: 0.42


 48%|████▊     | 72/150 [46:01<49:50, 38.34s/it]

epoch 71 --> TEST loss: 0.31, TEST accuracy: 0.12
epoch 72 --> TRAIN loss: 1.557586, TRAIN accuracy: 0.42


 49%|████▊     | 73/150 [46:40<49:26, 38.53s/it]

epoch 72 --> TEST loss: 0.31, TEST accuracy: 0.13
epoch 73 --> TRAIN loss: 1.556828, TRAIN accuracy: 0.41


 49%|████▉     | 74/150 [47:20<49:12, 38.85s/it]

epoch 73 --> TEST loss: 0.31, TEST accuracy: 0.12
epoch 74 --> TRAIN loss: 1.519343, TRAIN accuracy: 0.43


 50%|█████     | 75/150 [48:01<49:12, 39.37s/it]

epoch 74 --> TEST loss: 0.31, TEST accuracy: 0.14
epoch 75 --> TRAIN loss: 1.505902, TRAIN accuracy: 0.45


 51%|█████     | 76/150 [48:40<48:30, 39.33s/it]

epoch 75 --> TEST loss: 0.30, TEST accuracy: 0.12
epoch 76 --> TRAIN loss: 1.490651, TRAIN accuracy: 0.46


 51%|█████▏    | 77/150 [49:19<47:51, 39.33s/it]

epoch 76 --> TEST loss: 0.32, TEST accuracy: 0.12
epoch 77 --> TRAIN loss: 1.514487, TRAIN accuracy: 0.44


 52%|█████▏    | 78/150 [49:59<47:15, 39.39s/it]

epoch 77 --> TEST loss: 0.30, TEST accuracy: 0.13
epoch 78 --> TRAIN loss: 1.499596, TRAIN accuracy: 0.44


 53%|█████▎    | 79/150 [50:38<46:26, 39.25s/it]

epoch 78 --> TEST loss: 0.32, TEST accuracy: 0.13
epoch 79 --> TRAIN loss: 1.449144, TRAIN accuracy: 0.46


 53%|█████▎    | 80/150 [51:16<45:36, 39.09s/it]

epoch 79 --> TEST loss: 0.32, TEST accuracy: 0.11
epoch 80 --> TRAIN loss: 1.473734, TRAIN accuracy: 0.46


 54%|█████▍    | 81/150 [51:55<44:53, 39.04s/it]

epoch 80 --> TEST loss: 0.31, TEST accuracy: 0.12
epoch 81 --> TRAIN loss: 1.439157, TRAIN accuracy: 0.47


 55%|█████▍    | 82/150 [52:34<44:17, 39.07s/it]

epoch 81 --> TEST loss: 0.31, TEST accuracy: 0.13
epoch 82 --> TRAIN loss: 1.433481, TRAIN accuracy: 0.47


 55%|█████▌    | 83/150 [53:14<43:49, 39.25s/it]

epoch 82 --> TEST loss: 0.31, TEST accuracy: 0.13
epoch 83 --> TRAIN loss: 1.444054, TRAIN accuracy: 0.47


 56%|█████▌    | 84/150 [53:54<43:17, 39.35s/it]

epoch 83 --> TEST loss: 0.31, TEST accuracy: 0.13
epoch 84 --> TRAIN loss: 1.436960, TRAIN accuracy: 0.48


 57%|█████▋    | 85/150 [54:32<42:25, 39.17s/it]

epoch 84 --> TEST loss: 0.32, TEST accuracy: 0.13
epoch 85 --> TRAIN loss: 1.409953, TRAIN accuracy: 0.48


 57%|█████▋    | 86/150 [55:11<41:31, 38.93s/it]

epoch 85 --> TEST loss: 0.32, TEST accuracy: 0.13


 57%|█████▋    | 86/150 [55:50<41:33, 38.95s/it]


KeyboardInterrupt: 