In [1]:
import random
import time
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy as sk
import pickle
import mne
import tqdm
from tqdm import trange
import os

DEVICE='cuda'
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim

from torch import nn

In [2]:
def extract_data(sub_range,verbose=False,feature_func=None,**kwargs):
    first=sub_range[0]
    last=sub_range[-1]

    for sub in trange(first,last+1): 
        #read data
        if sub <=9:
            path=f'ds005540\derivatives\sub-0{sub}\ses-vid\eeg\sub-0{sub}_ses-vid_task-emotion_reorder.npy'
        else:
            path=f'ds005540\derivatives\sub-{sub}\ses-vid\eeg\sub-{sub}_ses-vid_task-emotion_reorder.npy'
        datatemp=np.load(path,allow_pickle=True)
        datatemp=np.permute_dims(datatemp,[1,0,2])

        #if we want to convert features
        if not feature_func==None:
            datatemp=feature_func(datatemp,info=info)

        #cat data along new dimension
        datatemp=torch.tensor(datatemp).unsqueeze(0)
        if sub==first:
            data=datatemp
        else:
            data=torch.cat([data,datatemp],axis=0)

        if verbose:
            print(f'subject {sub} data extracted, vector size: {data.shape}')
    return data

def shuffle_and_split_data(X,y):
    """
    Helper function to shuffle and split data

    Args:
        X: torch.tensor
        Input data
        y: torch.tensor
        Corresponding target variables
        seed: int
        Set seed for reproducibility

    Returns:
        X_test: torch.tensor
        Test data [20% of X]
        y_test: torch.tensor
        Labels corresponding to above mentioned test data
        X_train: torch.tensor
        Train data [80% of X]
        y_train: torch.tensor
        Labels corresponding to above mentioned train data
    """
    
    N=X.size(0)
    shuffled_indices=torch.randperm(N) #get shuffled indices
    X=X[shuffled_indices]
    y=y[shuffled_indices]

    # split by 20% into train-test set
    test_size=int(0.2*N)
    X_train=X[test_size:]
    y_train=y[test_size:]
    X_test=X[:test_size]
    y_test=y[:test_size]

    return X_test,y_test,X_train,y_train

#create labels
#sad-dis-fear-neu-joy-ten-ins correspond to 0-6 respectively, each sample has 21 trials, for (7 emotions x 3 trials)
labels=np.array([])
for i in range(0,7):
    for n in range(0,3):
        labels=np.concatenate([labels,np.array([i])],axis=0)
labels=torch.tensor(labels)

  path=f'ds005540\derivatives\sub-0{sub}\ses-vid\eeg\sub-0{sub}_ses-vid_task-emotion_reorder.npy'
  path=f'ds005540\derivatives\sub-0{sub}\ses-vid\eeg\sub-0{sub}_ses-vid_task-emotion_reorder.npy'
  path=f'ds005540\derivatives\sub-{sub}\ses-vid\eeg\sub-{sub}_ses-vid_task-emotion_reorder.npy'
  path=f'ds005540\derivatives\sub-{sub}\ses-vid\eeg\sub-{sub}_ses-vid_task-emotion_reorder.npy'


In [3]:
data=extract_data([1,40])

100%|██████████| 40/40 [00:15<00:00,  2.58it/s]


In [4]:
def process_data(data):
    '''further process into trials (1 second trials, 200 points)'''
    for trial in tqdm.tqdm(range(int(data.size(3)/200))):
        datatrial=data[:,:,:,trial*200:(trial+1)*200]
        datatrial=torch.unsqueeze(datatrial,2)
        if trial==0:
            newdata=datatrial
        else:
            newdata=torch.concatenate((newdata,datatrial),2)
    return newdata

In [5]:
data=process_data(data)
data.shape #subject,label,trials,electrodes,timepoints

100%|██████████| 30/30 [00:10<00:00,  2.89it/s]


torch.Size([40, 21, 30, 64, 200])

In [6]:
for label in tqdm.tqdm(range(21)): #combine labels and trials
    temp=data[:,label,:,:,:]
    if label==0:
        X=temp
        y=labels[label]*torch.ones(30,1)
    else:
        X=torch.concatenate((X,temp),dim=1)
        y=torch.concatenate((y,labels[label]*torch.ones(30,1)),dim=0)

#combine subjects
for subject in tqdm.tqdm(range(X.size(0))):
    temp=X[subject,:,:,:]
    if subject==0:
        X_full=temp
        y_full=y
    else:
        X_full=torch.concatenate((X_full,temp),dim=0)
        y_full=torch.concatenate((y_full,y),dim=0)

print(X_full.shape)
print(y_full.shape)

#split data
X_test,y_test,X_train,y_train=shuffle_and_split_data(X_full.unsqueeze(1),y_full.squeeze())

print(X_test.shape,y_test.shape,X_train.shape,y_train.shape)

100%|██████████| 21/21 [00:06<00:00,  3.41it/s]
100%|██████████| 40/40 [00:10<00:00,  3.64it/s]


torch.Size([25200, 64, 200])
torch.Size([25200, 1])
torch.Size([5040, 1, 64, 200]) torch.Size([5040]) torch.Size([20160, 1, 64, 200]) torch.Size([20160])


In [7]:
#now the model
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(1,16,3,stride=1,padding=1)
        self.conv2=nn.Conv2d(16,12,3,stride=1,padding=1)
        self.conv3=nn.Conv2d(12,2,3,stride=1,padding=1)
        self.pool1=nn.MaxPool2d(4)
        self.pool2=nn.MaxPool2d(2)
        self.fc1=nn.Linear(400,200)
        self.fc2=nn.Linear(200,128)
        self.fc3=nn.Linear(128,7)
        self.dropout1=nn.Dropout2d(0.3)
        self.dropout2=nn.Dropout(0.2)
        self.batchnorm1=nn.BatchNorm2d(16)
        self.batchnorm2=nn.BatchNorm2d(12)

    def forward(self,x):
        x=self.conv1(x)
        x=self.batchnorm1(x)
        x=self.dropout1(x)
        x=nn.functional.relu(x)
        x=self.conv2(x)
        self.batchnorm2(x)
        x=self.dropout1(x)
        x=nn.functional.relu(x)
        x=self.pool1(x)
        x=self.conv3(x)
        x=self.dropout1(x)
        x=nn.functional.relu(x)
        x=self.pool2(x)
        x=torch.flatten(x,1)
        x=self.fc1(x)
        self.dropout2(x)
        x=nn.functional.relu(x)
        x=self.fc2(x)
        x=self.dropout2(x)
        x=nn.functional.relu(x)
        x=self.fc3(x)
        x=nn.functional.relu(x)
        return x
    


def train_test(net,epochs,train_loader,test_loader,device):
    criterion=nn.CrossEntropyLoss()
    optimizer=optim.Adam(net.parameters(),lr=3e-4)
    train_acc=[]
    train_loss=[]
    test_acc=[]
    test_loss=[]
    net.to(device)
    for epoch in tqdm.tqdm(range(epochs)):
        net.train()
        running_loss=0.0
        correct,total=0,0
        for i,data in enumerate(train_loader,start=0):
            inputs,labels=data
            inputs=inputs.to(device).float()
            labels=labels.to(device).long()
            
            #train
            optimizer.zero_grad()
            outputs=net.forward(inputs)
            loss=criterion(outputs,labels)
            loss.backward()
            optimizer.step()

            running_loss+=loss.item()
            #training accuracy
            _,predicted=torch.max(outputs,1)
            total+=labels.size(0)
            correct+=(predicted==labels).sum()
        train_loss.append(running_loss/len(train_loader))
        train_acc.append(correct/total)
        print(f"epoch {epoch} --> TRAIN loss: {running_loss/len(train_loader):.2f}, TRAIN accuracy: {correct/total:.2f}")

        #eval on test
        net.eval()
        running_loss=0.0
        correct,total=0,0
        for inputs,labels in test_loader:
            inputs,labels=inputs.to(device).float(),labels.to(device).long()
            outputs=net.forward(inputs)
            loss=criterion(outputs,labels)
            running_loss+=loss.item()

            #test acc
            _,predicted=torch.max(outputs,1)
            total+=labels.size(0)
            correct+=(predicted==labels).sum()
        test_loss.append(running_loss/len(test_loader))
        test_acc.append(correct/total)
        print(f"epoch {epoch} --> TEST loss: {running_loss/len(train_loader):.2f}, TEST accuracy: {correct/total:.2f}")

    return train_loss,train_acc,test_loss,test_acc


batch_size=150
test_data = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_data, batch_size=batch_size,
                         shuffle=False
                         )

train_data = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_data,
                          batch_size=batch_size,
                          drop_last=False,
                          shuffle=True
                          )

In [8]:
net=Net()
train_loss,train_acc,test_loss,test_acc=train_test(net,250,train_loader=train_loader,test_loader=test_loader,device=DEVICE)

  0%|          | 0/250 [00:00<?, ?it/s]

epoch 0 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  0%|          | 1/250 [00:06<28:21,  6.83s/it]

epoch 0 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 1 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  1%|          | 2/250 [00:12<24:20,  5.89s/it]

epoch 1 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 2 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  1%|          | 3/250 [00:17<22:52,  5.56s/it]

epoch 2 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 3 --> TRAIN loss: 1.95, TRAIN accuracy: 0.15


  2%|▏         | 4/250 [00:22<22:04,  5.38s/it]

epoch 3 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 4 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  2%|▏         | 5/250 [00:27<21:31,  5.27s/it]

epoch 4 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 5 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  2%|▏         | 6/250 [00:32<21:07,  5.19s/it]

epoch 5 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 6 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  3%|▎         | 7/250 [00:37<20:49,  5.14s/it]

epoch 6 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 7 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  3%|▎         | 8/250 [00:42<20:31,  5.09s/it]

epoch 7 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 8 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  4%|▎         | 9/250 [00:47<20:18,  5.06s/it]

epoch 8 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 9 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  4%|▍         | 10/250 [00:52<20:05,  5.02s/it]

epoch 9 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 10 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  4%|▍         | 11/250 [00:57<20:14,  5.08s/it]

epoch 10 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 11 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  5%|▍         | 12/250 [01:03<20:39,  5.21s/it]

epoch 11 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 12 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  5%|▌         | 13/250 [01:08<20:42,  5.24s/it]

epoch 12 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 13 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  6%|▌         | 14/250 [01:13<20:32,  5.22s/it]

epoch 13 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 14 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  6%|▌         | 15/250 [01:18<20:20,  5.19s/it]

epoch 14 --> TEST loss: 0.49, TEST accuracy: 0.15
epoch 15 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  6%|▋         | 16/250 [01:23<20:16,  5.20s/it]

epoch 15 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 16 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  7%|▋         | 17/250 [01:29<20:04,  5.17s/it]

epoch 16 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 17 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  7%|▋         | 18/250 [01:34<19:50,  5.13s/it]

epoch 17 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 18 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  8%|▊         | 19/250 [01:39<19:42,  5.12s/it]

epoch 18 --> TEST loss: 0.49, TEST accuracy: 0.15
epoch 19 --> TRAIN loss: 1.95, TRAIN accuracy: 0.15


  8%|▊         | 20/250 [01:44<19:33,  5.10s/it]

epoch 19 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 20 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  8%|▊         | 21/250 [01:49<19:41,  5.16s/it]

epoch 20 --> TEST loss: 0.49, TEST accuracy: 0.14
epoch 21 --> TRAIN loss: 1.95, TRAIN accuracy: 0.14


  9%|▉         | 22/250 [01:54<19:34,  5.15s/it]

epoch 21 --> TEST loss: 0.49, TEST accuracy: 0.15
epoch 22 --> TRAIN loss: 1.95, TRAIN accuracy: 0.15


  9%|▉         | 23/250 [01:59<19:26,  5.14s/it]

epoch 22 --> TEST loss: 0.49, TEST accuracy: 0.15
epoch 23 --> TRAIN loss: 1.95, TRAIN accuracy: 0.15


 10%|▉         | 24/250 [02:04<19:15,  5.11s/it]

epoch 23 --> TEST loss: 0.49, TEST accuracy: 0.15
epoch 24 --> TRAIN loss: 1.95, TRAIN accuracy: 0.15


 10%|█         | 25/250 [02:09<19:14,  5.13s/it]

epoch 24 --> TEST loss: 0.49, TEST accuracy: 0.15
epoch 25 --> TRAIN loss: 1.94, TRAIN accuracy: 0.15


 10%|█         | 26/250 [02:15<19:15,  5.16s/it]

epoch 25 --> TEST loss: 0.49, TEST accuracy: 0.15
epoch 26 --> TRAIN loss: 1.94, TRAIN accuracy: 0.15


 11%|█         | 27/250 [02:20<19:08,  5.15s/it]

epoch 26 --> TEST loss: 0.49, TEST accuracy: 0.15
epoch 27 --> TRAIN loss: 1.94, TRAIN accuracy: 0.15


 11%|█         | 28/250 [02:25<19:08,  5.17s/it]

epoch 27 --> TEST loss: 0.49, TEST accuracy: 0.15
epoch 28 --> TRAIN loss: 1.94, TRAIN accuracy: 0.15


 12%|█▏        | 29/250 [02:30<18:57,  5.15s/it]

epoch 28 --> TEST loss: 0.49, TEST accuracy: 0.15
epoch 29 --> TRAIN loss: 1.94, TRAIN accuracy: 0.15


 12%|█▏        | 30/250 [02:35<18:53,  5.15s/it]

epoch 29 --> TEST loss: 0.49, TEST accuracy: 0.15
epoch 30 --> TRAIN loss: 1.94, TRAIN accuracy: 0.15


 12%|█▏        | 31/250 [02:41<19:12,  5.26s/it]

epoch 30 --> TEST loss: 0.49, TEST accuracy: 0.15
epoch 31 --> TRAIN loss: 1.94, TRAIN accuracy: 0.15


 13%|█▎        | 32/250 [02:47<19:47,  5.45s/it]

epoch 31 --> TEST loss: 0.49, TEST accuracy: 0.15
epoch 32 --> TRAIN loss: 1.94, TRAIN accuracy: 0.15


 13%|█▎        | 33/250 [02:53<20:15,  5.60s/it]

epoch 32 --> TEST loss: 0.49, TEST accuracy: 0.16
epoch 33 --> TRAIN loss: 1.94, TRAIN accuracy: 0.16


 14%|█▎        | 34/250 [02:58<20:10,  5.60s/it]

epoch 33 --> TEST loss: 0.49, TEST accuracy: 0.16
epoch 34 --> TRAIN loss: 1.94, TRAIN accuracy: 0.16


 14%|█▍        | 35/250 [03:03<19:32,  5.45s/it]

epoch 34 --> TEST loss: 0.49, TEST accuracy: 0.17
epoch 35 --> TRAIN loss: 1.94, TRAIN accuracy: 0.17


 14%|█▍        | 36/250 [03:09<19:23,  5.44s/it]

epoch 35 --> TEST loss: 0.49, TEST accuracy: 0.17
epoch 36 --> TRAIN loss: 1.94, TRAIN accuracy: 0.17


 15%|█▍        | 37/250 [03:14<19:11,  5.41s/it]

epoch 36 --> TEST loss: 0.49, TEST accuracy: 0.18
epoch 37 --> TRAIN loss: 1.93, TRAIN accuracy: 0.18


 15%|█▌        | 38/250 [03:19<18:44,  5.31s/it]

epoch 37 --> TEST loss: 0.49, TEST accuracy: 0.17
epoch 38 --> TRAIN loss: 1.93, TRAIN accuracy: 0.19


 16%|█▌        | 39/250 [03:24<18:23,  5.23s/it]

epoch 38 --> TEST loss: 0.49, TEST accuracy: 0.18
epoch 39 --> TRAIN loss: 1.93, TRAIN accuracy: 0.19


 16%|█▌        | 40/250 [03:29<18:13,  5.21s/it]

epoch 39 --> TEST loss: 0.49, TEST accuracy: 0.19
epoch 40 --> TRAIN loss: 1.92, TRAIN accuracy: 0.19


 16%|█▋        | 41/250 [03:35<18:05,  5.19s/it]

epoch 40 --> TEST loss: 0.49, TEST accuracy: 0.18
epoch 41 --> TRAIN loss: 1.92, TRAIN accuracy: 0.20


 17%|█▋        | 42/250 [03:40<17:49,  5.14s/it]

epoch 41 --> TEST loss: 0.48, TEST accuracy: 0.19
epoch 42 --> TRAIN loss: 1.91, TRAIN accuracy: 0.20


 17%|█▋        | 43/250 [03:45<17:44,  5.14s/it]

epoch 42 --> TEST loss: 0.48, TEST accuracy: 0.19
epoch 43 --> TRAIN loss: 1.91, TRAIN accuracy: 0.21


 18%|█▊        | 44/250 [03:50<17:37,  5.14s/it]

epoch 43 --> TEST loss: 0.48, TEST accuracy: 0.18
epoch 44 --> TRAIN loss: 1.91, TRAIN accuracy: 0.21


 18%|█▊        | 45/250 [03:55<17:32,  5.14s/it]

epoch 44 --> TEST loss: 0.48, TEST accuracy: 0.19
epoch 45 --> TRAIN loss: 1.90, TRAIN accuracy: 0.21


 18%|█▊        | 46/250 [04:00<17:30,  5.15s/it]

epoch 45 --> TEST loss: 0.48, TEST accuracy: 0.19
epoch 46 --> TRAIN loss: 1.89, TRAIN accuracy: 0.22


 19%|█▉        | 47/250 [04:05<17:21,  5.13s/it]

epoch 46 --> TEST loss: 0.48, TEST accuracy: 0.20
epoch 47 --> TRAIN loss: 1.89, TRAIN accuracy: 0.22


 19%|█▉        | 48/250 [04:10<17:10,  5.10s/it]

epoch 47 --> TEST loss: 0.48, TEST accuracy: 0.20
epoch 48 --> TRAIN loss: 1.89, TRAIN accuracy: 0.22


 20%|█▉        | 49/250 [04:15<17:05,  5.10s/it]

epoch 48 --> TEST loss: 0.48, TEST accuracy: 0.20
epoch 49 --> TRAIN loss: 1.88, TRAIN accuracy: 0.23


 20%|██        | 50/250 [04:21<17:01,  5.11s/it]

epoch 49 --> TEST loss: 0.48, TEST accuracy: 0.20
epoch 50 --> TRAIN loss: 1.88, TRAIN accuracy: 0.23


 20%|██        | 51/250 [04:26<16:53,  5.09s/it]

epoch 50 --> TEST loss: 0.48, TEST accuracy: 0.20
epoch 51 --> TRAIN loss: 1.87, TRAIN accuracy: 0.23


 21%|██        | 52/250 [04:31<16:55,  5.13s/it]

epoch 51 --> TEST loss: 0.48, TEST accuracy: 0.20
epoch 52 --> TRAIN loss: 1.87, TRAIN accuracy: 0.24


 21%|██        | 53/250 [04:36<17:05,  5.20s/it]

epoch 52 --> TEST loss: 0.48, TEST accuracy: 0.20
epoch 53 --> TRAIN loss: 1.86, TRAIN accuracy: 0.24


 22%|██▏       | 54/250 [04:42<17:09,  5.25s/it]

epoch 53 --> TEST loss: 0.48, TEST accuracy: 0.21
epoch 54 --> TRAIN loss: 1.86, TRAIN accuracy: 0.25


 22%|██▏       | 55/250 [04:47<16:56,  5.21s/it]

epoch 54 --> TEST loss: 0.48, TEST accuracy: 0.20
epoch 55 --> TRAIN loss: 1.85, TRAIN accuracy: 0.25


 22%|██▏       | 56/250 [04:52<16:49,  5.21s/it]

epoch 55 --> TEST loss: 0.48, TEST accuracy: 0.21
epoch 56 --> TRAIN loss: 1.84, TRAIN accuracy: 0.25


 23%|██▎       | 57/250 [04:57<16:41,  5.19s/it]

epoch 56 --> TEST loss: 0.48, TEST accuracy: 0.21
epoch 57 --> TRAIN loss: 1.84, TRAIN accuracy: 0.26


 23%|██▎       | 58/250 [05:02<16:36,  5.19s/it]

epoch 57 --> TEST loss: 0.48, TEST accuracy: 0.20
epoch 58 --> TRAIN loss: 1.84, TRAIN accuracy: 0.26


 24%|██▎       | 59/250 [05:07<16:34,  5.21s/it]

epoch 58 --> TEST loss: 0.48, TEST accuracy: 0.21
epoch 59 --> TRAIN loss: 1.83, TRAIN accuracy: 0.27


 24%|██▍       | 60/250 [05:13<16:26,  5.19s/it]

epoch 59 --> TEST loss: 0.48, TEST accuracy: 0.21
epoch 60 --> TRAIN loss: 1.83, TRAIN accuracy: 0.27


 24%|██▍       | 61/250 [05:18<16:13,  5.15s/it]

epoch 60 --> TEST loss: 0.48, TEST accuracy: 0.21
epoch 61 --> TRAIN loss: 1.82, TRAIN accuracy: 0.27


 25%|██▍       | 62/250 [05:23<16:10,  5.16s/it]

epoch 61 --> TEST loss: 0.48, TEST accuracy: 0.21
epoch 62 --> TRAIN loss: 1.81, TRAIN accuracy: 0.27


 25%|██▌       | 63/250 [05:28<16:03,  5.15s/it]

epoch 62 --> TEST loss: 0.48, TEST accuracy: 0.20
epoch 63 --> TRAIN loss: 1.80, TRAIN accuracy: 0.28


 26%|██▌       | 64/250 [05:33<16:00,  5.17s/it]

epoch 63 --> TEST loss: 0.49, TEST accuracy: 0.20
epoch 64 --> TRAIN loss: 1.80, TRAIN accuracy: 0.28


 26%|██▌       | 65/250 [05:38<15:51,  5.15s/it]

epoch 64 --> TEST loss: 0.49, TEST accuracy: 0.21
epoch 65 --> TRAIN loss: 1.80, TRAIN accuracy: 0.28


 26%|██▋       | 66/250 [05:43<15:43,  5.13s/it]

epoch 65 --> TEST loss: 0.48, TEST accuracy: 0.21
epoch 66 --> TRAIN loss: 1.79, TRAIN accuracy: 0.28


 27%|██▋       | 67/250 [05:48<15:34,  5.11s/it]

epoch 66 --> TEST loss: 0.48, TEST accuracy: 0.20
epoch 67 --> TRAIN loss: 1.78, TRAIN accuracy: 0.29


 27%|██▋       | 68/250 [05:54<15:33,  5.13s/it]

epoch 67 --> TEST loss: 0.49, TEST accuracy: 0.21
epoch 68 --> TRAIN loss: 1.78, TRAIN accuracy: 0.29


 28%|██▊       | 69/250 [05:59<15:25,  5.11s/it]

epoch 68 --> TEST loss: 0.48, TEST accuracy: 0.21
epoch 69 --> TRAIN loss: 1.77, TRAIN accuracy: 0.29


 28%|██▊       | 70/250 [06:04<15:15,  5.09s/it]

epoch 69 --> TEST loss: 0.48, TEST accuracy: 0.21
epoch 70 --> TRAIN loss: 1.77, TRAIN accuracy: 0.29


 28%|██▊       | 71/250 [06:09<15:06,  5.07s/it]

epoch 70 --> TEST loss: 0.49, TEST accuracy: 0.22
epoch 71 --> TRAIN loss: 1.76, TRAIN accuracy: 0.30


 29%|██▉       | 72/250 [06:14<15:01,  5.06s/it]

epoch 71 --> TEST loss: 0.48, TEST accuracy: 0.21
epoch 72 --> TRAIN loss: 1.75, TRAIN accuracy: 0.30


 29%|██▉       | 73/250 [06:19<14:52,  5.04s/it]

epoch 72 --> TEST loss: 0.49, TEST accuracy: 0.22
epoch 73 --> TRAIN loss: 1.75, TRAIN accuracy: 0.30


 30%|██▉       | 74/250 [06:24<14:46,  5.04s/it]

epoch 73 --> TEST loss: 0.48, TEST accuracy: 0.21
epoch 74 --> TRAIN loss: 1.74, TRAIN accuracy: 0.31


 30%|███       | 75/250 [06:29<14:47,  5.07s/it]

epoch 74 --> TEST loss: 0.49, TEST accuracy: 0.21
epoch 75 --> TRAIN loss: 1.74, TRAIN accuracy: 0.31


 30%|███       | 76/250 [06:34<14:40,  5.06s/it]

epoch 75 --> TEST loss: 0.49, TEST accuracy: 0.21
epoch 76 --> TRAIN loss: 1.73, TRAIN accuracy: 0.31


 31%|███       | 77/250 [06:39<14:40,  5.09s/it]

epoch 76 --> TEST loss: 0.49, TEST accuracy: 0.22
epoch 77 --> TRAIN loss: 1.72, TRAIN accuracy: 0.32


 31%|███       | 78/250 [06:44<14:35,  5.09s/it]

epoch 77 --> TEST loss: 0.49, TEST accuracy: 0.21
epoch 78 --> TRAIN loss: 1.72, TRAIN accuracy: 0.32


 32%|███▏      | 79/250 [06:49<14:34,  5.11s/it]

epoch 78 --> TEST loss: 0.49, TEST accuracy: 0.22
epoch 79 --> TRAIN loss: 1.72, TRAIN accuracy: 0.32


 32%|███▏      | 80/250 [06:55<14:33,  5.14s/it]

epoch 79 --> TEST loss: 0.49, TEST accuracy: 0.21
epoch 80 --> TRAIN loss: 1.71, TRAIN accuracy: 0.32


 32%|███▏      | 81/250 [07:00<14:33,  5.17s/it]

epoch 80 --> TEST loss: 0.49, TEST accuracy: 0.22
epoch 81 --> TRAIN loss: 1.70, TRAIN accuracy: 0.32


 33%|███▎      | 82/250 [07:05<14:28,  5.17s/it]

epoch 81 --> TEST loss: 0.49, TEST accuracy: 0.22
epoch 82 --> TRAIN loss: 1.70, TRAIN accuracy: 0.33


 33%|███▎      | 83/250 [07:10<14:26,  5.19s/it]

epoch 82 --> TEST loss: 0.50, TEST accuracy: 0.21
epoch 83 --> TRAIN loss: 1.70, TRAIN accuracy: 0.33


 34%|███▎      | 84/250 [07:15<14:18,  5.17s/it]

epoch 83 --> TEST loss: 0.49, TEST accuracy: 0.21
epoch 84 --> TRAIN loss: 1.68, TRAIN accuracy: 0.34


 34%|███▍      | 85/250 [07:21<14:14,  5.18s/it]

epoch 84 --> TEST loss: 0.50, TEST accuracy: 0.22
epoch 85 --> TRAIN loss: 1.68, TRAIN accuracy: 0.33


 34%|███▍      | 86/250 [07:26<14:07,  5.17s/it]

epoch 85 --> TEST loss: 0.50, TEST accuracy: 0.21
epoch 86 --> TRAIN loss: 1.68, TRAIN accuracy: 0.33


 35%|███▍      | 87/250 [07:31<14:01,  5.16s/it]

epoch 86 --> TEST loss: 0.49, TEST accuracy: 0.22
epoch 87 --> TRAIN loss: 1.67, TRAIN accuracy: 0.34


 35%|███▌      | 88/250 [07:36<13:55,  5.16s/it]

epoch 87 --> TEST loss: 0.50, TEST accuracy: 0.22
epoch 88 --> TRAIN loss: 1.67, TRAIN accuracy: 0.34


 36%|███▌      | 89/250 [07:41<13:46,  5.13s/it]

epoch 88 --> TEST loss: 0.50, TEST accuracy: 0.22
epoch 89 --> TRAIN loss: 1.67, TRAIN accuracy: 0.34


 36%|███▌      | 90/250 [07:46<13:39,  5.12s/it]

epoch 89 --> TEST loss: 0.50, TEST accuracy: 0.22
epoch 90 --> TRAIN loss: 1.66, TRAIN accuracy: 0.34


 36%|███▋      | 91/250 [07:52<14:04,  5.31s/it]

epoch 90 --> TEST loss: 0.49, TEST accuracy: 0.21
epoch 91 --> TRAIN loss: 1.65, TRAIN accuracy: 0.34


 37%|███▋      | 92/250 [07:57<13:58,  5.30s/it]

epoch 91 --> TEST loss: 0.50, TEST accuracy: 0.21
epoch 92 --> TRAIN loss: 1.64, TRAIN accuracy: 0.35


 37%|███▋      | 93/250 [08:02<13:47,  5.27s/it]

epoch 92 --> TEST loss: 0.50, TEST accuracy: 0.22
epoch 93 --> TRAIN loss: 1.64, TRAIN accuracy: 0.35


 38%|███▊      | 94/250 [08:07<13:33,  5.21s/it]

epoch 93 --> TEST loss: 0.50, TEST accuracy: 0.22
epoch 94 --> TRAIN loss: 1.64, TRAIN accuracy: 0.35


 38%|███▊      | 95/250 [08:13<13:23,  5.18s/it]

epoch 94 --> TEST loss: 0.50, TEST accuracy: 0.22
epoch 95 --> TRAIN loss: 1.63, TRAIN accuracy: 0.35


 38%|███▊      | 96/250 [08:18<13:17,  5.18s/it]

epoch 95 --> TEST loss: 0.50, TEST accuracy: 0.22
epoch 96 --> TRAIN loss: 1.63, TRAIN accuracy: 0.36


 39%|███▉      | 97/250 [08:23<13:04,  5.13s/it]

epoch 96 --> TEST loss: 0.50, TEST accuracy: 0.22
epoch 97 --> TRAIN loss: 1.63, TRAIN accuracy: 0.35


 39%|███▉      | 98/250 [08:28<13:03,  5.15s/it]

epoch 97 --> TEST loss: 0.50, TEST accuracy: 0.22
epoch 98 --> TRAIN loss: 1.62, TRAIN accuracy: 0.36


 40%|███▉      | 99/250 [08:33<13:00,  5.17s/it]

epoch 98 --> TEST loss: 0.50, TEST accuracy: 0.21
epoch 99 --> TRAIN loss: 1.62, TRAIN accuracy: 0.36


 40%|████      | 100/250 [08:38<12:58,  5.19s/it]

epoch 99 --> TEST loss: 0.50, TEST accuracy: 0.21
epoch 100 --> TRAIN loss: 1.60, TRAIN accuracy: 0.37


 40%|████      | 101/250 [08:44<12:53,  5.19s/it]

epoch 100 --> TEST loss: 0.50, TEST accuracy: 0.21
epoch 101 --> TRAIN loss: 1.61, TRAIN accuracy: 0.37


 41%|████      | 102/250 [08:49<12:44,  5.16s/it]

epoch 101 --> TEST loss: 0.51, TEST accuracy: 0.22
epoch 102 --> TRAIN loss: 1.60, TRAIN accuracy: 0.37


 41%|████      | 103/250 [08:54<12:39,  5.17s/it]

epoch 102 --> TEST loss: 0.51, TEST accuracy: 0.22
epoch 103 --> TRAIN loss: 1.59, TRAIN accuracy: 0.37


 42%|████▏     | 104/250 [08:59<12:37,  5.19s/it]

epoch 103 --> TEST loss: 0.51, TEST accuracy: 0.21
epoch 104 --> TRAIN loss: 1.59, TRAIN accuracy: 0.37


 42%|████▏     | 105/250 [09:04<12:27,  5.15s/it]

epoch 104 --> TEST loss: 0.51, TEST accuracy: 0.21
epoch 105 --> TRAIN loss: 1.59, TRAIN accuracy: 0.37


 42%|████▏     | 106/250 [09:09<12:19,  5.14s/it]

epoch 105 --> TEST loss: 0.51, TEST accuracy: 0.21
epoch 106 --> TRAIN loss: 1.58, TRAIN accuracy: 0.37


 43%|████▎     | 107/250 [09:14<12:10,  5.11s/it]

epoch 106 --> TEST loss: 0.52, TEST accuracy: 0.22
epoch 107 --> TRAIN loss: 1.57, TRAIN accuracy: 0.38


 43%|████▎     | 108/250 [09:19<12:03,  5.09s/it]

epoch 107 --> TEST loss: 0.51, TEST accuracy: 0.21
epoch 108 --> TRAIN loss: 1.58, TRAIN accuracy: 0.37


 44%|████▎     | 109/250 [09:25<12:02,  5.12s/it]

epoch 108 --> TEST loss: 0.51, TEST accuracy: 0.21
epoch 109 --> TRAIN loss: 1.58, TRAIN accuracy: 0.37


 44%|████▍     | 110/250 [09:30<12:04,  5.17s/it]

epoch 109 --> TEST loss: 0.51, TEST accuracy: 0.22
epoch 110 --> TRAIN loss: 1.58, TRAIN accuracy: 0.37


 44%|████▍     | 111/250 [09:35<11:58,  5.17s/it]

epoch 110 --> TEST loss: 0.51, TEST accuracy: 0.21
epoch 111 --> TRAIN loss: 1.56, TRAIN accuracy: 0.38


 45%|████▍     | 112/250 [09:40<11:49,  5.14s/it]

epoch 111 --> TEST loss: 0.51, TEST accuracy: 0.22
epoch 112 --> TRAIN loss: 1.56, TRAIN accuracy: 0.38


 45%|████▌     | 113/250 [09:45<11:40,  5.11s/it]

epoch 112 --> TEST loss: 0.51, TEST accuracy: 0.22
epoch 113 --> TRAIN loss: 1.56, TRAIN accuracy: 0.38


 46%|████▌     | 114/250 [09:50<11:33,  5.10s/it]

epoch 113 --> TEST loss: 0.52, TEST accuracy: 0.21
epoch 114 --> TRAIN loss: 1.55, TRAIN accuracy: 0.38


 46%|████▌     | 115/250 [09:55<11:28,  5.10s/it]

epoch 114 --> TEST loss: 0.52, TEST accuracy: 0.22
epoch 115 --> TRAIN loss: 1.54, TRAIN accuracy: 0.39


 46%|████▋     | 116/250 [10:00<11:23,  5.10s/it]

epoch 115 --> TEST loss: 0.52, TEST accuracy: 0.21
epoch 116 --> TRAIN loss: 1.55, TRAIN accuracy: 0.38


 47%|████▋     | 117/250 [10:06<11:18,  5.10s/it]

epoch 116 --> TEST loss: 0.52, TEST accuracy: 0.21
epoch 117 --> TRAIN loss: 1.54, TRAIN accuracy: 0.38


 47%|████▋     | 118/250 [10:11<11:17,  5.13s/it]

epoch 117 --> TEST loss: 0.52, TEST accuracy: 0.21
epoch 118 --> TRAIN loss: 1.53, TRAIN accuracy: 0.39


 48%|████▊     | 119/250 [10:16<11:11,  5.13s/it]

epoch 118 --> TEST loss: 0.52, TEST accuracy: 0.21
epoch 119 --> TRAIN loss: 1.54, TRAIN accuracy: 0.38


 48%|████▊     | 120/250 [10:21<11:05,  5.12s/it]

epoch 119 --> TEST loss: 0.53, TEST accuracy: 0.21
epoch 120 --> TRAIN loss: 1.53, TRAIN accuracy: 0.39


 48%|████▊     | 121/250 [10:27<11:22,  5.29s/it]

epoch 120 --> TEST loss: 0.53, TEST accuracy: 0.22
epoch 121 --> TRAIN loss: 1.53, TRAIN accuracy: 0.39


 49%|████▉     | 122/250 [10:32<11:31,  5.40s/it]

epoch 121 --> TEST loss: 0.53, TEST accuracy: 0.21
epoch 122 --> TRAIN loss: 1.52, TRAIN accuracy: 0.39


 49%|████▉     | 123/250 [10:38<11:35,  5.48s/it]

epoch 122 --> TEST loss: 0.54, TEST accuracy: 0.21
epoch 123 --> TRAIN loss: 1.52, TRAIN accuracy: 0.39


 50%|████▉     | 124/250 [10:44<11:37,  5.54s/it]

epoch 123 --> TEST loss: 0.53, TEST accuracy: 0.21
epoch 124 --> TRAIN loss: 1.51, TRAIN accuracy: 0.39


 50%|█████     | 125/250 [10:49<11:39,  5.59s/it]

epoch 124 --> TEST loss: 0.53, TEST accuracy: 0.21
epoch 125 --> TRAIN loss: 1.51, TRAIN accuracy: 0.40


 50%|█████     | 126/250 [10:55<11:49,  5.72s/it]

epoch 125 --> TEST loss: 0.53, TEST accuracy: 0.21
epoch 126 --> TRAIN loss: 1.51, TRAIN accuracy: 0.40


 51%|█████     | 127/250 [11:01<11:46,  5.74s/it]

epoch 126 --> TEST loss: 0.53, TEST accuracy: 0.21
epoch 127 --> TRAIN loss: 1.50, TRAIN accuracy: 0.40


 51%|█████     | 128/250 [11:07<11:38,  5.72s/it]

epoch 127 --> TEST loss: 0.53, TEST accuracy: 0.22
epoch 128 --> TRAIN loss: 1.52, TRAIN accuracy: 0.39


 52%|█████▏    | 129/250 [11:13<11:31,  5.71s/it]

epoch 128 --> TEST loss: 0.54, TEST accuracy: 0.21
epoch 129 --> TRAIN loss: 1.50, TRAIN accuracy: 0.40


 52%|█████▏    | 130/250 [11:19<11:35,  5.79s/it]

epoch 129 --> TEST loss: 0.54, TEST accuracy: 0.21
epoch 130 --> TRAIN loss: 1.50, TRAIN accuracy: 0.40


 52%|█████▏    | 131/250 [11:24<11:24,  5.75s/it]

epoch 130 --> TEST loss: 0.54, TEST accuracy: 0.21
epoch 131 --> TRAIN loss: 1.49, TRAIN accuracy: 0.40


 53%|█████▎    | 132/250 [11:30<11:15,  5.73s/it]

epoch 131 --> TEST loss: 0.54, TEST accuracy: 0.21
epoch 132 --> TRAIN loss: 1.49, TRAIN accuracy: 0.41


 53%|█████▎    | 133/250 [11:36<11:07,  5.71s/it]

epoch 132 --> TEST loss: 0.54, TEST accuracy: 0.21
epoch 133 --> TRAIN loss: 1.49, TRAIN accuracy: 0.40


 54%|█████▎    | 134/250 [11:41<11:01,  5.70s/it]

epoch 133 --> TEST loss: 0.55, TEST accuracy: 0.21
epoch 134 --> TRAIN loss: 1.48, TRAIN accuracy: 0.41


 54%|█████▍    | 135/250 [11:47<10:51,  5.67s/it]

epoch 134 --> TEST loss: 0.54, TEST accuracy: 0.21
epoch 135 --> TRAIN loss: 1.49, TRAIN accuracy: 0.40


 54%|█████▍    | 136/250 [11:52<10:43,  5.65s/it]

epoch 135 --> TEST loss: 0.54, TEST accuracy: 0.21
epoch 136 --> TRAIN loss: 1.48, TRAIN accuracy: 0.40


 55%|█████▍    | 137/250 [11:58<10:42,  5.69s/it]

epoch 136 --> TEST loss: 0.55, TEST accuracy: 0.21
epoch 137 --> TRAIN loss: 1.48, TRAIN accuracy: 0.41


 55%|█████▌    | 138/250 [12:04<10:50,  5.81s/it]

epoch 137 --> TEST loss: 0.54, TEST accuracy: 0.21
epoch 138 --> TRAIN loss: 1.48, TRAIN accuracy: 0.40


 56%|█████▌    | 139/250 [12:10<10:39,  5.76s/it]

epoch 138 --> TEST loss: 0.55, TEST accuracy: 0.21
epoch 139 --> TRAIN loss: 1.47, TRAIN accuracy: 0.41


 56%|█████▌    | 140/250 [12:16<10:29,  5.73s/it]

epoch 139 --> TEST loss: 0.57, TEST accuracy: 0.21
epoch 140 --> TRAIN loss: 1.48, TRAIN accuracy: 0.41


 56%|█████▋    | 141/250 [12:21<10:19,  5.68s/it]

epoch 140 --> TEST loss: 0.56, TEST accuracy: 0.21
epoch 141 --> TRAIN loss: 1.47, TRAIN accuracy: 0.41


 57%|█████▋    | 142/250 [12:27<10:12,  5.67s/it]

epoch 141 --> TEST loss: 0.55, TEST accuracy: 0.21
epoch 142 --> TRAIN loss: 1.47, TRAIN accuracy: 0.41


 57%|█████▋    | 143/250 [12:33<10:12,  5.73s/it]

epoch 142 --> TEST loss: 0.54, TEST accuracy: 0.21
epoch 143 --> TRAIN loss: 1.46, TRAIN accuracy: 0.42


 58%|█████▊    | 144/250 [12:39<10:11,  5.77s/it]

epoch 143 --> TEST loss: 0.56, TEST accuracy: 0.21
epoch 144 --> TRAIN loss: 1.46, TRAIN accuracy: 0.42


 58%|█████▊    | 145/250 [12:44<10:05,  5.76s/it]

epoch 144 --> TEST loss: 0.56, TEST accuracy: 0.21
epoch 145 --> TRAIN loss: 1.45, TRAIN accuracy: 0.41


 58%|█████▊    | 146/250 [12:50<10:01,  5.79s/it]

epoch 145 --> TEST loss: 0.56, TEST accuracy: 0.21
epoch 146 --> TRAIN loss: 1.45, TRAIN accuracy: 0.42


 59%|█████▉    | 147/250 [12:56<10:01,  5.84s/it]

epoch 146 --> TEST loss: 0.55, TEST accuracy: 0.21
epoch 147 --> TRAIN loss: 1.46, TRAIN accuracy: 0.41


 59%|█████▉    | 148/250 [13:02<10:10,  5.99s/it]

epoch 147 --> TEST loss: 0.55, TEST accuracy: 0.21
epoch 148 --> TRAIN loss: 1.45, TRAIN accuracy: 0.42


 60%|█████▉    | 149/250 [13:09<10:07,  6.01s/it]

epoch 148 --> TEST loss: 0.56, TEST accuracy: 0.21
epoch 149 --> TRAIN loss: 1.45, TRAIN accuracy: 0.42


 60%|██████    | 150/250 [13:14<09:48,  5.89s/it]

epoch 149 --> TEST loss: 0.57, TEST accuracy: 0.21
epoch 150 --> TRAIN loss: 1.44, TRAIN accuracy: 0.42


 60%|██████    | 151/250 [13:20<09:52,  5.99s/it]

epoch 150 --> TEST loss: 0.56, TEST accuracy: 0.21
epoch 151 --> TRAIN loss: 1.45, TRAIN accuracy: 0.42


 61%|██████    | 152/250 [13:26<09:51,  6.03s/it]

epoch 151 --> TEST loss: 0.57, TEST accuracy: 0.21
epoch 152 --> TRAIN loss: 1.45, TRAIN accuracy: 0.42


 61%|██████    | 153/250 [13:32<09:44,  6.03s/it]

epoch 152 --> TEST loss: 0.56, TEST accuracy: 0.21
epoch 153 --> TRAIN loss: 1.43, TRAIN accuracy: 0.43


 62%|██████▏   | 154/250 [13:38<09:30,  5.94s/it]

epoch 153 --> TEST loss: 0.57, TEST accuracy: 0.21
epoch 154 --> TRAIN loss: 1.43, TRAIN accuracy: 0.42


 62%|██████▏   | 155/250 [13:44<09:19,  5.88s/it]

epoch 154 --> TEST loss: 0.56, TEST accuracy: 0.21
epoch 155 --> TRAIN loss: 1.44, TRAIN accuracy: 0.42


 62%|██████▏   | 156/250 [13:50<09:15,  5.91s/it]

epoch 155 --> TEST loss: 0.55, TEST accuracy: 0.21
epoch 156 --> TRAIN loss: 1.43, TRAIN accuracy: 0.42


 63%|██████▎   | 157/250 [13:56<09:05,  5.86s/it]

epoch 156 --> TEST loss: 0.57, TEST accuracy: 0.21
epoch 157 --> TRAIN loss: 1.42, TRAIN accuracy: 0.43


 63%|██████▎   | 158/250 [14:01<08:57,  5.84s/it]

epoch 157 --> TEST loss: 0.57, TEST accuracy: 0.21
epoch 158 --> TRAIN loss: 1.43, TRAIN accuracy: 0.42


 64%|██████▎   | 159/250 [14:08<09:01,  5.95s/it]

epoch 158 --> TEST loss: 0.57, TEST accuracy: 0.21
epoch 159 --> TRAIN loss: 1.43, TRAIN accuracy: 0.42


 64%|██████▍   | 160/250 [14:14<08:54,  5.94s/it]

epoch 159 --> TEST loss: 0.56, TEST accuracy: 0.21
epoch 160 --> TRAIN loss: 1.43, TRAIN accuracy: 0.43


 64%|██████▍   | 161/250 [14:19<08:42,  5.87s/it]

epoch 160 --> TEST loss: 0.58, TEST accuracy: 0.21
epoch 161 --> TRAIN loss: 1.42, TRAIN accuracy: 0.43


 65%|██████▍   | 162/250 [14:25<08:29,  5.79s/it]

epoch 161 --> TEST loss: 0.58, TEST accuracy: 0.21
epoch 162 --> TRAIN loss: 1.42, TRAIN accuracy: 0.43


 65%|██████▌   | 163/250 [14:31<08:20,  5.75s/it]

epoch 162 --> TEST loss: 0.56, TEST accuracy: 0.21
epoch 163 --> TRAIN loss: 1.41, TRAIN accuracy: 0.43


 66%|██████▌   | 164/250 [14:36<08:18,  5.80s/it]

epoch 163 --> TEST loss: 0.58, TEST accuracy: 0.21
epoch 164 --> TRAIN loss: 1.42, TRAIN accuracy: 0.43


 66%|██████▌   | 165/250 [14:42<08:12,  5.79s/it]

epoch 164 --> TEST loss: 0.57, TEST accuracy: 0.21
epoch 165 --> TRAIN loss: 1.41, TRAIN accuracy: 0.43


 66%|██████▋   | 166/250 [14:48<08:06,  5.80s/it]

epoch 165 --> TEST loss: 0.57, TEST accuracy: 0.21
epoch 166 --> TRAIN loss: 1.41, TRAIN accuracy: 0.43


 67%|██████▋   | 167/250 [14:53<07:44,  5.59s/it]

epoch 166 --> TEST loss: 0.58, TEST accuracy: 0.21
epoch 167 --> TRAIN loss: 1.40, TRAIN accuracy: 0.43


 67%|██████▋   | 168/250 [14:58<07:25,  5.43s/it]

epoch 167 --> TEST loss: 0.58, TEST accuracy: 0.21
epoch 168 --> TRAIN loss: 1.41, TRAIN accuracy: 0.43


 68%|██████▊   | 169/250 [15:03<07:14,  5.36s/it]

epoch 168 --> TEST loss: 0.57, TEST accuracy: 0.21
epoch 169 --> TRAIN loss: 1.41, TRAIN accuracy: 0.43


 68%|██████▊   | 170/250 [15:09<07:03,  5.29s/it]

epoch 169 --> TEST loss: 0.56, TEST accuracy: 0.21


 68%|██████▊   | 170/250 [15:13<07:09,  5.37s/it]


KeyboardInterrupt: 