In [1]:
import os
import torch
import skimage
import pywt
import scipy.io
import scipy.signal
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from scipy import stats
from einops import reduce, rearrange, repeat
from npeet import entropy_estimators as ee
from torch.optim.lr_scheduler import StepLR
from scipy.fft import rfft, rfftfreq, ifft
from einops import rearrange
from torch_geometric.data import InMemoryDataset, Data, DataLoader
from Electrodes import Electrodes
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

In [2]:
class DEAPDatasetEEGFeatures(InMemoryDataset):
  def __init__(self, root, raw_dir, processed_dir, feature='de',target='participant_id', transform=None, pre_transform=None,include_edge_attr = True, undirected_graphs = True, add_global_connections=True, participant_from=1, participant_to=32, n_videos=40):
      self._raw_dir = raw_dir
      self._processed_dir = processed_dir
      self.participant_from = participant_from
      self.participant_to = participant_to
      self.n_videos = n_videos
      self.feature = feature
      self.target = target
      # Whether or not to include edge_attr in the dataset
      self.include_edge_attr = include_edge_attr
      # If true there will be 1024 links as opposed to 528
      self.undirected_graphs = undirected_graphs
      # Instantiate class to handle electrode positions
      print('Using global connections' if add_global_connections else 'Not using global connections')
      self.electrodes = Electrodes(add_global_connections, expand_3d = False)
      super(DEAPDatasetEEGFeatures, self).__init__(root, transform, pre_transform)
      self.data, self.slices = torch.load(self.processed_paths[0])
      
  @property
  def raw_dir(self):
      return f'{self.root}/{self._raw_dir}'

  @property
  def processed_dir(self):
      return f'{self.root}/{self._processed_dir}'

  @property
  def raw_file_names(self):
      raw_names = [f for f in os.listdir(self.raw_dir)]
      raw_names.sort()
      return raw_names

  @property
  def processed_file_names(self):
      if not os.path.exists(self.processed_dir):
        os.makedirs(self.processed_dir)
      file_name = f'{self.participant_from}-{self.participant_to}' if self.participant_from is not self.participant_to else f'{self.participant_from}'
      return [f'deap_processed_graph.{file_name}_{self.feature}_{self.target}.dataset']

  def process(self):
        # Number of nodes per graph
        n_nodes = len(self.electrodes.positions_3d)
        

        if self.undirected_graphs:
            source_nodes, target_nodes = np.repeat(np.arange(0,n_nodes),n_nodes), np.tile(np.arange(0,n_nodes),n_nodes)
        else:
            source_nodes, target_nodes = np.tril_indices(n_nodes,n_nodes)
        
        edge_attr = self.electrodes.adjacency_matrix[source_nodes,target_nodes]
        
        # Remove zero weight links
        mask = np.ma.masked_not_equal(edge_attr, 0).mask
        edge_attr,source_nodes,target_nodes = edge_attr[mask], source_nodes[mask], target_nodes[mask]

        edge_attr, edge_index = torch.FloatTensor(edge_attr), torch.tensor([source_nodes,target_nodes], dtype=torch.long)
        
        # Expand edge_index and edge_attr to match windows
        e_edge_index = edge_index.clone()
        e_edge_attr = edge_attr.clone()
        
        number_of_graphs = 4
        for i in range(number_of_graphs-1):
            a = edge_index + e_edge_index.max() + 1
            e_edge_index = torch.cat([e_edge_index,a],dim=1)
            e_edge_attr = torch.cat([e_edge_attr,edge_attr],dim=0)

        print(f'Number of graphs per video: {number_of_graphs}')
        # List of graphs that will be written to file
        data_list = []
        pbar = tqdm(range(self.participant_from,self.participant_to+1))
        for participant_id in pbar:
            raw_name = [e for e in self.raw_file_names if str(participant_id).zfill(2) in e][0]
            pbar.set_description(raw_name)
            # Load raw file as np array
            participant_data = scipy.io.loadmat(f'{self.raw_dir}/{raw_name}')
            signal_data = torch.FloatTensor(remove_baseline_mean(participant_data['data'][:,:32,:]))
#             signal_data = torch.FloatTensor()
            processed = []
            for i, video in enumerate(signal_data[:self.n_videos,:,:]):
                if self.feature == 'wav':
                    node_features = process_video_wavelet(video)
                else:
                    node_features = process_video(video, feature=self.feature)
                
                if self.target == 'emotion_labels':
                    target = [participant_data['labels'][i]]
                if self.target == 'participant_id':
                    target = participant_id-1
                elif self.target == 'video_id':
                    pass
                else:
                    raise 'Invalid target'
                data = Data(x=torch.FloatTensor(node_features),edge_attr=e_edge_attr,edge_index=e_edge_index, y=torch.LongTensor([target])) if self.include_edge_attr else Data(x=torch.FloatTensor(node_features), edge_index=e_edge_index, y=torch.LongTensor([target]))
                data_list.append(data) 
               
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [3]:
def calculate_de(window):
    return ee.entropy(window.reshape(-1,1), k=2)
# Input: Video with shape (32,7680)
# Output: Graph node features with shape (5*32, 59) -> 5 graphs with 32 nodes each with 59 features each
def process_video(video, feature='psd'):
    # Transform to frequency domain
    fft_vals = np.fft.rfft(video, axis=-1)
     # Get frequencies for amplitudes in Hz
    samplingFrequency = 128
    fft_freq = np.fft.rfftfreq(video.shape[-1], 1.0/samplingFrequency)
    # Delta, Theta, Alpha, Beta, Gamma
    bands = [(0,4),(4,8),(8,12),(12,30),(30,45)]
    
    band_mask = np.array([np.logical_or(fft_freq < f, fft_freq > t) for f,t in bands])
    band_mask = repeat(band_mask,'a b -> a c b', c=32)
    band_data = np.array(fft_vals)
    band_data = repeat(band_data,'a b -> c a b', c=5)
     
    band_data[band_mask] = 0
    
    band_data = np.fft.irfft(band_data)

    windows = skimage.util.view_as_windows(band_data, (5,32,128), step=128).squeeze()
    # (5, 32, 60, 128)
    windows = rearrange(windows, 'a b c d -> b c a d')
    
    if feature == 'psd':
        features = scipy.signal.periodogram(windows)[1]
        features = np.mean(features, axis=-1)
    elif feature == 'de':
        features = np.apply_along_axis(calculate_de, -1, windows)

    
    features = rearrange(features, 'a b c -> (a b) c')
    features = torch.FloatTensor(features)

    return features

In [4]:
def remove_baseline_mean(signal_data):
    # Take first three senconds of data
    signal_baseline = np.array(signal_data[:,:,:128*3]).reshape(40,32,128,-1)
    # Mean of three senconds of baseline will be deducted from all windows
    signal_noise = np.mean(signal_baseline,axis=-1)
    # Expand mask
    signal_noise = repeat(signal_noise,'a b c -> a b (d c)',d=60)
    return signal_data[:,:,128*3:] - signal_noise

In [5]:
def process_video_wavelet(video, feature='energy', time_domain=False):
    band_widths = [32,16,8,4]
    features = []
    for i in range(5):
        if i == 0:
            # Highest frequencies (64-128Hz) are not used
            cA, cD = pywt.dwt(video.numpy(), 'db4')
        else:
            cA, cD = pywt.dwt(cA, 'db4')
            
            cA_windows = skimage.util.view_as_windows(cA, (32,band_widths[i-1]*2), step=band_widths[i-1]).squeeze()
            cA_windows = np.transpose(cA_windows[:59,:,:],(1,0,2))
            if feature == 'energy':
                cA_windows = np.square(cA_windows)
                cA_windows = np.sum(cA_windows, axis=-1)
                features.append(cA_windows)
                
    if time_domain:
        features = np.transpose(features,(2,1,0))
    features = rearrange(features, 'a b c -> (a b) c')
    features = torch.FloatTensor(features)
    
    # Normalization
    m = features.mean(0, keepdim=True)
    s = features.std(0, unbiased=False, keepdim=True)
    features -= m
    features /= s
    return features

In [6]:
# Constants used to define data paths
ROOT_DIR = './'
RAW_DIR = 'data/matlabPREPROCESSED'
PROCESSED_DIR = 'data/graphProcessedData'

dataset = DEAPDatasetEEGFeatures(root= ROOT_DIR, raw_dir= RAW_DIR, processed_dir= PROCESSED_DIR, feature='wav',participant_to=32)
# dataset = dataset.shuffle()

Using global connections


In [7]:
# 880 used for training, 220 validation and 180 testing
# test_participant = 1
# 
splt_idx = 1100

dataset = dataset.shuffle()

# 85% used for train/val
train_dataset = dataset[:splt_idx]
test_dataset = dataset[splt_idx:]

len(train_dataset),len(test_dataset)

(1100, 180)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
print(device)

cuda


In [14]:
from torch_geometric.nn import GCN2Conv, GCNConv, global_max_pool as gmp
class Model(torch.nn.Module):
    def __init__(self, in_channels,n_graphs, hidden_channels=128, n_classes = 32):
        super(Model, self).__init__()
        
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        
        self.gconv1 = GCNConv(in_channels,hidden_channels*2)
        self.gconv2 = GCNConv(hidden_channels*2,hidden_channels)
        
#         self.gconv3 = GCNConv(in_channels,hidden_channels)
        
        # self.rnn = torch.nn.GRU(hidden_channels, rnn_hidden_dim, 2,dropout=0.2, batch_first=True)
        self.cnn1 = torch.nn.Conv1d(n_graphs, 1, kernel_size=1, stride=1)
        
        self.lin1 = torch.nn.Linear(32*hidden_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, n_classes)

        self.softmax = nn.Softmax(dim=-1)

        
    def forward(self, batch):
        bs = len(torch.unique(batch.batch))
        x, edge_index, edge_attr = batch.x, batch.edge_index, batch.edge_attr
#         print(x.shape)
        x = self.gconv1(x, edge_index, edge_attr)
        x = self.gconv2(x, edge_index, edge_attr)
        x = F.dropout(x, p=0.4, training=self.training)
        x = x.relu()
#         print(x.shape)
        x = rearrange(x, '(bs g e) f -> (bs e) g f', bs=bs, e=32)
#         print(x.shape)
        x = self.cnn1(x).squeeze()
        x = x.tanh()
        x = rearrange(x, '(bs e) f -> bs (e f)', bs=bs)
#         print(x.shape)
#         x = torch.sum(x, dim=1)
        
        
        x = F.dropout(x, p=0.4, training=self.training)
#         print(x.shape)
        x = self.lin1(x)
        x = x.relu()
        x = self.lin2(x)
#         print(x.shape)
#         x = x.view(-1)
       
        x = self.softmax(x)
        return x

        

In [15]:
# %%timeit

model = Model(train_dataset[0].x.shape[1],train_dataset[0].x.shape[0]//32).to(device)  
pytorch_total_params = sum(p.numel() for p in model.parameters())
print(f'Model parameter count: {pytorch_total_params}')

# model = model.to(devic)
# optimizer = torch.optim.Adadelta(model.parameters(), lr=.1, rho=0.9, eps=1e-06, weight_decay=1e-5)
# optimizer = torch.optim.SGD(model.parameters(),lr=1e-2, weight_decay=0)
# optimizer = torch.optim.Adam(model.parameters(),lr=1e-4, weight_decay=1e-2)
# optimizer = torch.optim.Adam(model.parameters())
# optimizer = torch.optim.Adagrad(model.parameters(), lr=1e-5)
# optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001, weight_decay=5e-4)
optimizer = torch.optim.Adagrad(model.parameters(), lr=5e-3, lr_decay=1e-4, weight_decay=0)

# Instantiate optimizer
# scheduler = StepLR(optimizer, step_size=20, gamma=0.7)

# criterion = nn.MSELoss()
criterion = nn.CrossEntropyLoss()



def train(loader, target = 0):
    model.train()
    losses = []
    right = 0
    tot = 0
    for batch in loader:
        optimizer.zero_grad()
        batch = batch.to(device)
        out = model(batch)
        loss = criterion(out,batch.y)
        loss.backward()
        losses.append(loss.item())
        optimizer.step()
        pred = torch.argmax(out,-1)
        right += torch.sum((pred == batch.y).detach().cpu())
        tot += batch.y.shape[0]
        
    return np.array(losses).mean(), right/tot

def test(loader,verbose=False, target = 0):
    model.eval()
    losses = []
    right = 0
    tot = 0
    for batch in loader:
        batch = batch.to(device)
        # y = batch.y[:,target]
        out = model(batch)
        pred = torch.argmax(out,-1)
        if verbose:
            print(pred,batch.y)
        loss = criterion(out,batch.y)
        losses.append(loss.item())
        
        right += torch.sum((pred == batch.y).detach().cpu())
        tot += batch.y.shape[0]
    return np.array(losses).mean(), right/tot

best_val_loss = np.inf
esp = 0
MAX_ESP = 50

BS = 64

target = 0 # Valence-Arousal-Dominance-Liking

splt_idx = 1000
train_data, val_data = torch.utils.data.random_split(train_dataset, [splt_idx, len(train_dataset)-splt_idx])

train_loader = DataLoader(train_data, batch_size=BS, shuffle=True)
val_loader = DataLoader(val_data, batch_size=BS)
writer = SummaryWriter()
for epoch in range(1, 10000):    

    
    
    # Training and validation
    train_loss, train_acc = train(train_loader, target = target)
    val_loss, val_acc = test(val_loader , target = target)
    print(f'Epoch {epoch};t loss: {train_loss:.5f} ;t acc: {train_acc:.2f} ;v loss: {val_loss:.5f} ;v acc: {val_acc:.2f}')

    writer.add_scalar('Loss/train', train_loss, epoch)
    writer.add_scalar('Loss/test', val_loss, epoch)
    writer.add_scalar('Accuracy/train', train_acc, epoch)
    writer.add_scalar('Accuracy/test', val_acc, epoch)
    # Early stopping and checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        esp = 0
        torch.save(model.state_dict(),'./best_params') 
    else:
        esp += 1
        if esp >= MAX_ESP:
            break
            
    if epoch % 20 == 0:
        test_loader = DataLoader(test_dataset, batch_size=2)
        loss, acc = test(test_loader, True)
        print(f'Test loss: {loss} ; Test acc: {acc}')
        
    


print('Finished training')

Model parameter count: 576805
Epoch 1;t loss: 3.45937 ;t acc: 0.05 ;v loss: 3.45612 ;v acc: 0.04
Epoch 2;t loss: 3.38985 ;t acc: 0.13 ;v loss: 3.38658 ;v acc: 0.13
Epoch 3;t loss: 3.29876 ;t acc: 0.24 ;v loss: 3.26137 ;v acc: 0.28
Epoch 4;t loss: 3.20543 ;t acc: 0.35 ;v loss: 3.14796 ;v acc: 0.38
Epoch 5;t loss: 3.12795 ;t acc: 0.43 ;v loss: 3.09179 ;v acc: 0.43
Epoch 6;t loss: 3.08142 ;t acc: 0.48 ;v loss: 3.07192 ;v acc: 0.44
Epoch 7;t loss: 3.05732 ;t acc: 0.49 ;v loss: 3.06981 ;v acc: 0.44
Epoch 8;t loss: 3.04624 ;t acc: 0.49 ;v loss: 3.06427 ;v acc: 0.45
Epoch 9;t loss: 3.04231 ;t acc: 0.50 ;v loss: 3.05344 ;v acc: 0.45
Epoch 10;t loss: 3.03507 ;t acc: 0.50 ;v loss: 3.06039 ;v acc: 0.45
Epoch 11;t loss: 3.03563 ;t acc: 0.50 ;v loss: 3.05980 ;v acc: 0.45
Epoch 12;t loss: 3.02783 ;t acc: 0.50 ;v loss: 3.05181 ;v acc: 0.46
Epoch 13;t loss: 3.02553 ;t acc: 0.50 ;v loss: 3.04355 ;v acc: 0.46
Epoch 14;t loss: 3.01835 ;t acc: 0.50 ;v loss: 3.04624 ;v acc: 0.46
Epoch 15;t loss: 3.02380 ;t

Epoch 32;t loss: 2.89741 ;t acc: 0.63 ;v loss: 2.89441 ;v acc: 0.61
Epoch 33;t loss: 2.89608 ;t acc: 0.63 ;v loss: 2.89771 ;v acc: 0.61
Epoch 34;t loss: 2.88564 ;t acc: 0.65 ;v loss: 2.88243 ;v acc: 0.63
Epoch 35;t loss: 2.86489 ;t acc: 0.67 ;v loss: 2.85548 ;v acc: 0.65
Epoch 36;t loss: 2.86213 ;t acc: 0.67 ;v loss: 2.83758 ;v acc: 0.69
Epoch 37;t loss: 2.83988 ;t acc: 0.70 ;v loss: 2.83504 ;v acc: 0.69
Epoch 38;t loss: 2.83726 ;t acc: 0.69 ;v loss: 2.82558 ;v acc: 0.69
Epoch 39;t loss: 2.83605 ;t acc: 0.70 ;v loss: 2.82800 ;v acc: 0.69
Epoch 40;t loss: 2.83897 ;t acc: 0.70 ;v loss: 2.82469 ;v acc: 0.69
tensor([4, 8], device='cuda:0') tensor([10,  8], device='cuda:0')
tensor([14,  8], device='cuda:0') tensor([14, 12], device='cuda:0')
tensor([11,  4], device='cuda:0') tensor([15, 10], device='cuda:0')
tensor([20, 19], device='cuda:0') tensor([20, 29], device='cuda:0')
tensor([17,  2], device='cuda:0') tensor([17,  2], device='cuda:0')
tensor([1, 3], device='cuda:0') tensor([29, 18], d

tensor([3, 2], device='cuda:0') tensor([18,  2], device='cuda:0')
tensor([25,  3], device='cuda:0') tensor([25, 16], device='cuda:0')
tensor([27,  6], device='cuda:0') tensor([27,  6], device='cuda:0')
tensor([ 0, 31], device='cuda:0') tensor([ 0, 10], device='cuda:0')
tensor([ 4, 31], device='cuda:0') tensor([10, 31], device='cuda:0')
tensor([4, 6], device='cuda:0') tensor([10, 26], device='cuda:0')
tensor([28,  1], device='cuda:0') tensor([28,  1], device='cuda:0')
tensor([ 4, 11], device='cuda:0') tensor([ 4, 15], device='cuda:0')
tensor([3, 1], device='cuda:0') tensor([16, 29], device='cuda:0')
tensor([ 3, 11], device='cuda:0') tensor([ 3, 11], device='cuda:0')
tensor([ 6, 21], device='cuda:0') tensor([ 6, 26], device='cuda:0')
tensor([24, 30], device='cuda:0') tensor([24, 30], device='cuda:0')
tensor([17,  3], device='cuda:0') tensor([17,  3], device='cuda:0')
tensor([14, 13], device='cuda:0') tensor([14, 13], device='cuda:0')
tensor([27,  4], device='cuda:0') tensor([27, 10], dev

Epoch 81;t loss: 2.81728 ;t acc: 0.70 ;v loss: 2.81858 ;v acc: 0.69
Epoch 82;t loss: 2.81522 ;t acc: 0.70 ;v loss: 2.81906 ;v acc: 0.69
Epoch 83;t loss: 2.81952 ;t acc: 0.70 ;v loss: 2.81884 ;v acc: 0.69
Epoch 84;t loss: 2.81528 ;t acc: 0.70 ;v loss: 2.81865 ;v acc: 0.69
Epoch 85;t loss: 2.81637 ;t acc: 0.70 ;v loss: 2.81884 ;v acc: 0.69
Epoch 86;t loss: 2.81581 ;t acc: 0.70 ;v loss: 2.81885 ;v acc: 0.69
Epoch 87;t loss: 2.81428 ;t acc: 0.70 ;v loss: 2.81885 ;v acc: 0.69
Epoch 88;t loss: 2.81579 ;t acc: 0.70 ;v loss: 2.81865 ;v acc: 0.69
Epoch 89;t loss: 2.81465 ;t acc: 0.70 ;v loss: 2.81866 ;v acc: 0.69
Epoch 90;t loss: 2.82019 ;t acc: 0.70 ;v loss: 2.81858 ;v acc: 0.69
Epoch 91;t loss: 2.81913 ;t acc: 0.70 ;v loss: 2.81858 ;v acc: 0.69
Epoch 92;t loss: 2.81672 ;t acc: 0.70 ;v loss: 2.81867 ;v acc: 0.69
Epoch 93;t loss: 2.81712 ;t acc: 0.70 ;v loss: 2.81874 ;v acc: 0.69
Epoch 94;t loss: 2.81465 ;t acc: 0.70 ;v loss: 2.81869 ;v acc: 0.69
Epoch 95;t loss: 2.81626 ;t acc: 0.70 ;v loss: 2

Epoch 112;t loss: 2.78880 ;t acc: 0.74 ;v loss: 2.79606 ;v acc: 0.72
Epoch 113;t loss: 2.79157 ;t acc: 0.73 ;v loss: 2.79555 ;v acc: 0.72
Epoch 114;t loss: 2.78906 ;t acc: 0.74 ;v loss: 2.79563 ;v acc: 0.72
Epoch 115;t loss: 2.78508 ;t acc: 0.74 ;v loss: 2.79013 ;v acc: 0.72
Epoch 116;t loss: 2.78189 ;t acc: 0.75 ;v loss: 2.76102 ;v acc: 0.76
Epoch 117;t loss: 2.77325 ;t acc: 0.76 ;v loss: 2.75991 ;v acc: 0.76
Epoch 118;t loss: 2.77138 ;t acc: 0.76 ;v loss: 2.76288 ;v acc: 0.75
Epoch 119;t loss: 2.77039 ;t acc: 0.76 ;v loss: 2.75986 ;v acc: 0.76
Epoch 120;t loss: 2.77157 ;t acc: 0.76 ;v loss: 2.75761 ;v acc: 0.76
tensor([4, 8], device='cuda:0') tensor([10,  8], device='cuda:0')
tensor([14, 20], device='cuda:0') tensor([14, 12], device='cuda:0')
tensor([11,  4], device='cuda:0') tensor([15, 10], device='cuda:0')
tensor([20, 19], device='cuda:0') tensor([20, 29], device='cuda:0')
tensor([17,  2], device='cuda:0') tensor([17,  2], device='cuda:0')
tensor([1, 3], device='cuda:0') tensor([2

tensor([ 4, 31], device='cuda:0') tensor([10, 31], device='cuda:0')
tensor([ 4, 26], device='cuda:0') tensor([10, 26], device='cuda:0')
tensor([28,  1], device='cuda:0') tensor([28,  1], device='cuda:0')
tensor([ 4, 28], device='cuda:0') tensor([ 4, 15], device='cuda:0')
tensor([22,  1], device='cuda:0') tensor([16, 29], device='cuda:0')
tensor([ 3, 11], device='cuda:0') tensor([ 3, 11], device='cuda:0')
tensor([ 6, 26], device='cuda:0') tensor([ 6, 26], device='cuda:0')
tensor([24, 30], device='cuda:0') tensor([24, 30], device='cuda:0')
tensor([17,  3], device='cuda:0') tensor([17,  3], device='cuda:0')
tensor([14, 13], device='cuda:0') tensor([14, 13], device='cuda:0')
tensor([27,  4], device='cuda:0') tensor([27, 10], device='cuda:0')
tensor([ 1, 22], device='cuda:0') tensor([ 1, 16], device='cuda:0')
tensor([22, 23], device='cuda:0') tensor([ 5, 23], device='cuda:0')
tensor([ 9, 26], device='cuda:0') tensor([ 9, 26], device='cuda:0')
tensor([31, 11], device='cuda:0') tensor([31, 11

Epoch 161;t loss: 2.72377 ;t acc: 0.80 ;v loss: 2.72254 ;v acc: 0.80
Epoch 162;t loss: 2.72629 ;t acc: 0.80 ;v loss: 2.72217 ;v acc: 0.80
Epoch 163;t loss: 2.72372 ;t acc: 0.80 ;v loss: 2.72278 ;v acc: 0.80
Epoch 164;t loss: 2.72508 ;t acc: 0.80 ;v loss: 2.72438 ;v acc: 0.80
Epoch 165;t loss: 2.72200 ;t acc: 0.80 ;v loss: 2.72307 ;v acc: 0.80
Epoch 166;t loss: 2.72179 ;t acc: 0.80 ;v loss: 2.72289 ;v acc: 0.80
Epoch 167;t loss: 2.72368 ;t acc: 0.80 ;v loss: 2.72285 ;v acc: 0.80
Epoch 168;t loss: 2.72374 ;t acc: 0.80 ;v loss: 2.72255 ;v acc: 0.80
Epoch 169;t loss: 2.71898 ;t acc: 0.80 ;v loss: 2.72296 ;v acc: 0.80
Epoch 170;t loss: 2.72153 ;t acc: 0.80 ;v loss: 2.72320 ;v acc: 0.80
Epoch 171;t loss: 2.72289 ;t acc: 0.80 ;v loss: 2.72298 ;v acc: 0.80
Epoch 172;t loss: 2.72108 ;t acc: 0.80 ;v loss: 2.72183 ;v acc: 0.80
Epoch 173;t loss: 2.71909 ;t acc: 0.80 ;v loss: 2.72173 ;v acc: 0.80
Epoch 174;t loss: 2.72056 ;t acc: 0.80 ;v loss: 2.72199 ;v acc: 0.80
Epoch 175;t loss: 2.71881 ;t acc: 

Epoch 191;t loss: 2.71950 ;t acc: 0.80 ;v loss: 2.72204 ;v acc: 0.80
Epoch 192;t loss: 2.71914 ;t acc: 0.80 ;v loss: 2.72199 ;v acc: 0.80
Epoch 193;t loss: 2.71989 ;t acc: 0.80 ;v loss: 2.72208 ;v acc: 0.80
Epoch 194;t loss: 2.72164 ;t acc: 0.80 ;v loss: 2.72077 ;v acc: 0.80
Epoch 195;t loss: 2.71583 ;t acc: 0.80 ;v loss: 2.71676 ;v acc: 0.81
Epoch 196;t loss: 2.71227 ;t acc: 0.81 ;v loss: 2.70945 ;v acc: 0.82
Epoch 197;t loss: 2.70628 ;t acc: 0.82 ;v loss: 2.70726 ;v acc: 0.83
Epoch 198;t loss: 2.70343 ;t acc: 0.82 ;v loss: 2.70258 ;v acc: 0.83
Epoch 199;t loss: 2.70665 ;t acc: 0.82 ;v loss: 2.69975 ;v acc: 0.83
Epoch 200;t loss: 2.70186 ;t acc: 0.82 ;v loss: 2.70193 ;v acc: 0.83
tensor([24,  8], device='cuda:0') tensor([10,  8], device='cuda:0')
tensor([14,  8], device='cuda:0') tensor([14, 12], device='cuda:0')
tensor([15,  4], device='cuda:0') tensor([15, 10], device='cuda:0')
tensor([20, 26], device='cuda:0') tensor([20, 29], device='cuda:0')
tensor([17,  2], device='cuda:0') tens

tensor([3, 2], device='cuda:0') tensor([18,  2], device='cuda:0')
tensor([25,  3], device='cuda:0') tensor([25, 16], device='cuda:0')
tensor([27,  6], device='cuda:0') tensor([27,  6], device='cuda:0')
tensor([ 0, 31], device='cuda:0') tensor([ 0, 10], device='cuda:0')
tensor([ 4, 31], device='cuda:0') tensor([10, 31], device='cuda:0')
tensor([ 4, 26], device='cuda:0') tensor([10, 26], device='cuda:0')
tensor([28,  1], device='cuda:0') tensor([28,  1], device='cuda:0')
tensor([ 4, 15], device='cuda:0') tensor([ 4, 15], device='cuda:0')
tensor([22,  1], device='cuda:0') tensor([16, 29], device='cuda:0')
tensor([ 3, 11], device='cuda:0') tensor([ 3, 11], device='cuda:0')
tensor([ 6, 26], device='cuda:0') tensor([ 6, 26], device='cuda:0')
tensor([24, 30], device='cuda:0') tensor([24, 30], device='cuda:0')
tensor([17,  3], device='cuda:0') tensor([17,  3], device='cuda:0')
tensor([14, 13], device='cuda:0') tensor([14, 13], device='cuda:0')
tensor([27,  4], device='cuda:0') tensor([27, 10],

Epoch 241;t loss: 2.69413 ;t acc: 0.83 ;v loss: 2.69831 ;v acc: 0.83
Epoch 242;t loss: 2.69141 ;t acc: 0.83 ;v loss: 2.69843 ;v acc: 0.83
Epoch 243;t loss: 2.69329 ;t acc: 0.83 ;v loss: 2.69818 ;v acc: 0.83
Epoch 244;t loss: 2.69172 ;t acc: 0.83 ;v loss: 2.69826 ;v acc: 0.83
Epoch 245;t loss: 2.69202 ;t acc: 0.83 ;v loss: 2.69826 ;v acc: 0.83
Epoch 246;t loss: 2.69268 ;t acc: 0.83 ;v loss: 2.69818 ;v acc: 0.83
Epoch 247;t loss: 2.69358 ;t acc: 0.83 ;v loss: 2.69806 ;v acc: 0.83
Epoch 248;t loss: 2.68992 ;t acc: 0.83 ;v loss: 2.69803 ;v acc: 0.83
Epoch 249;t loss: 2.69055 ;t acc: 0.83 ;v loss: 2.69817 ;v acc: 0.83
Epoch 250;t loss: 2.69248 ;t acc: 0.83 ;v loss: 2.69809 ;v acc: 0.83
Epoch 251;t loss: 2.69120 ;t acc: 0.83 ;v loss: 2.69809 ;v acc: 0.83
Epoch 252;t loss: 2.69039 ;t acc: 0.83 ;v loss: 2.69850 ;v acc: 0.83
Epoch 253;t loss: 2.69199 ;t acc: 0.83 ;v loss: 2.69837 ;v acc: 0.83
Epoch 254;t loss: 2.69055 ;t acc: 0.83 ;v loss: 2.69817 ;v acc: 0.83
Epoch 255;t loss: 2.69090 ;t acc: 

Epoch 271;t loss: 2.69013 ;t acc: 0.83 ;v loss: 2.69847 ;v acc: 0.83
Epoch 272;t loss: 2.69094 ;t acc: 0.83 ;v loss: 2.69812 ;v acc: 0.83
Epoch 273;t loss: 2.69044 ;t acc: 0.83 ;v loss: 2.69819 ;v acc: 0.83
Epoch 274;t loss: 2.68874 ;t acc: 0.83 ;v loss: 2.69801 ;v acc: 0.83
Epoch 275;t loss: 2.69048 ;t acc: 0.83 ;v loss: 2.69806 ;v acc: 0.83
Epoch 276;t loss: 2.68954 ;t acc: 0.83 ;v loss: 2.69817 ;v acc: 0.83
Epoch 277;t loss: 2.69163 ;t acc: 0.83 ;v loss: 2.69827 ;v acc: 0.83
Epoch 278;t loss: 2.69192 ;t acc: 0.83 ;v loss: 2.69820 ;v acc: 0.83
Epoch 279;t loss: 2.69368 ;t acc: 0.83 ;v loss: 2.69783 ;v acc: 0.83
Epoch 280;t loss: 2.69247 ;t acc: 0.83 ;v loss: 2.69794 ;v acc: 0.83
tensor([24,  8], device='cuda:0') tensor([10,  8], device='cuda:0')
tensor([14,  8], device='cuda:0') tensor([14, 12], device='cuda:0')
tensor([15,  4], device='cuda:0') tensor([15, 10], device='cuda:0')
tensor([20, 26], device='cuda:0') tensor([20, 29], device='cuda:0')
tensor([17,  2], device='cuda:0') tens

tensor([ 4, 31], device='cuda:0') tensor([10, 31], device='cuda:0')
tensor([ 4, 26], device='cuda:0') tensor([10, 26], device='cuda:0')
tensor([28,  1], device='cuda:0') tensor([28,  1], device='cuda:0')
tensor([ 4, 15], device='cuda:0') tensor([ 4, 15], device='cuda:0')
tensor([3, 1], device='cuda:0') tensor([16, 29], device='cuda:0')
tensor([ 3, 11], device='cuda:0') tensor([ 3, 11], device='cuda:0')
tensor([ 6, 26], device='cuda:0') tensor([ 6, 26], device='cuda:0')
tensor([24, 30], device='cuda:0') tensor([24, 30], device='cuda:0')
tensor([17,  3], device='cuda:0') tensor([17,  3], device='cuda:0')
tensor([14, 13], device='cuda:0') tensor([14, 13], device='cuda:0')
tensor([27,  4], device='cuda:0') tensor([27, 10], device='cuda:0')
tensor([1, 3], device='cuda:0') tensor([ 1, 16], device='cuda:0')
tensor([22, 23], device='cuda:0') tensor([ 5, 23], device='cuda:0')
tensor([ 9, 26], device='cuda:0') tensor([ 9, 26], device='cuda:0')
tensor([31, 11], device='cuda:0') tensor([31, 11], d

Epoch 321;t loss: 2.69054 ;t acc: 0.83 ;v loss: 2.69790 ;v acc: 0.83
Epoch 322;t loss: 2.69061 ;t acc: 0.83 ;v loss: 2.69794 ;v acc: 0.83
Epoch 323;t loss: 2.68955 ;t acc: 0.83 ;v loss: 2.69801 ;v acc: 0.83
Epoch 324;t loss: 2.68846 ;t acc: 0.83 ;v loss: 2.69803 ;v acc: 0.83
Epoch 325;t loss: 2.68869 ;t acc: 0.83 ;v loss: 2.69802 ;v acc: 0.83
Epoch 326;t loss: 2.68922 ;t acc: 0.83 ;v loss: 2.69794 ;v acc: 0.83
Epoch 327;t loss: 2.68861 ;t acc: 0.83 ;v loss: 2.69792 ;v acc: 0.83
Epoch 328;t loss: 2.68803 ;t acc: 0.83 ;v loss: 2.69802 ;v acc: 0.83
Epoch 329;t loss: 2.69308 ;t acc: 0.83 ;v loss: 2.69820 ;v acc: 0.83
Epoch 330;t loss: 2.69064 ;t acc: 0.83 ;v loss: 2.69805 ;v acc: 0.83
Epoch 331;t loss: 2.68956 ;t acc: 0.83 ;v loss: 2.69811 ;v acc: 0.83
Epoch 332;t loss: 2.69127 ;t acc: 0.83 ;v loss: 2.69809 ;v acc: 0.83
Epoch 333;t loss: 2.68984 ;t acc: 0.83 ;v loss: 2.69806 ;v acc: 0.83
Epoch 334;t loss: 2.68957 ;t acc: 0.83 ;v loss: 2.69797 ;v acc: 0.83
Epoch 335;t loss: 2.69085 ;t acc: 

Epoch 351;t loss: 2.68952 ;t acc: 0.83 ;v loss: 2.69790 ;v acc: 0.83
Epoch 352;t loss: 2.68755 ;t acc: 0.83 ;v loss: 2.69789 ;v acc: 0.83
Epoch 353;t loss: 2.68782 ;t acc: 0.83 ;v loss: 2.69792 ;v acc: 0.83
Epoch 354;t loss: 2.69129 ;t acc: 0.83 ;v loss: 2.69790 ;v acc: 0.83
Epoch 355;t loss: 2.68984 ;t acc: 0.83 ;v loss: 2.69775 ;v acc: 0.83
Epoch 356;t loss: 2.68948 ;t acc: 0.83 ;v loss: 2.69770 ;v acc: 0.83
Epoch 357;t loss: 2.68918 ;t acc: 0.83 ;v loss: 2.69769 ;v acc: 0.83
Epoch 358;t loss: 2.69071 ;t acc: 0.83 ;v loss: 2.69762 ;v acc: 0.83
Epoch 359;t loss: 2.68820 ;t acc: 0.83 ;v loss: 2.69763 ;v acc: 0.83
Epoch 360;t loss: 2.68981 ;t acc: 0.83 ;v loss: 2.69764 ;v acc: 0.83
tensor([24,  8], device='cuda:0') tensor([10,  8], device='cuda:0')
tensor([14,  8], device='cuda:0') tensor([14, 12], device='cuda:0')
tensor([15,  4], device='cuda:0') tensor([15, 10], device='cuda:0')
tensor([20, 26], device='cuda:0') tensor([20, 29], device='cuda:0')
tensor([17,  2], device='cuda:0') tens

tensor([ 0, 31], device='cuda:0') tensor([ 0, 10], device='cuda:0')
tensor([ 4, 31], device='cuda:0') tensor([10, 31], device='cuda:0')
tensor([ 4, 26], device='cuda:0') tensor([10, 26], device='cuda:0')
tensor([28,  1], device='cuda:0') tensor([28,  1], device='cuda:0')
tensor([ 4, 15], device='cuda:0') tensor([ 4, 15], device='cuda:0')
tensor([3, 1], device='cuda:0') tensor([16, 29], device='cuda:0')
tensor([ 3, 11], device='cuda:0') tensor([ 3, 11], device='cuda:0')
tensor([ 6, 26], device='cuda:0') tensor([ 6, 26], device='cuda:0')
tensor([24, 30], device='cuda:0') tensor([24, 30], device='cuda:0')
tensor([17,  3], device='cuda:0') tensor([17,  3], device='cuda:0')
tensor([14, 13], device='cuda:0') tensor([14, 13], device='cuda:0')
tensor([27,  4], device='cuda:0') tensor([27, 10], device='cuda:0')
tensor([1, 3], device='cuda:0') tensor([ 1, 16], device='cuda:0')
tensor([22, 23], device='cuda:0') tensor([ 5, 23], device='cuda:0')
tensor([ 9, 26], device='cuda:0') tensor([ 9, 26], d

Epoch 401;t loss: 2.69063 ;t acc: 0.83 ;v loss: 2.69775 ;v acc: 0.83
Epoch 402;t loss: 2.68819 ;t acc: 0.83 ;v loss: 2.69770 ;v acc: 0.83
Epoch 403;t loss: 2.68907 ;t acc: 0.83 ;v loss: 2.69763 ;v acc: 0.83
Epoch 404;t loss: 2.69000 ;t acc: 0.83 ;v loss: 2.69764 ;v acc: 0.83
Epoch 405;t loss: 2.68848 ;t acc: 0.83 ;v loss: 2.69763 ;v acc: 0.83
Epoch 406;t loss: 2.68839 ;t acc: 0.83 ;v loss: 2.69760 ;v acc: 0.83
Epoch 407;t loss: 2.69058 ;t acc: 0.83 ;v loss: 2.69764 ;v acc: 0.83
Epoch 408;t loss: 2.68843 ;t acc: 0.83 ;v loss: 2.69765 ;v acc: 0.83
Epoch 409;t loss: 2.69028 ;t acc: 0.83 ;v loss: 2.69766 ;v acc: 0.83
Epoch 410;t loss: 2.68998 ;t acc: 0.83 ;v loss: 2.69766 ;v acc: 0.83
Epoch 411;t loss: 2.69096 ;t acc: 0.83 ;v loss: 2.69768 ;v acc: 0.83
Epoch 412;t loss: 2.68932 ;t acc: 0.83 ;v loss: 2.69765 ;v acc: 0.83
Epoch 413;t loss: 2.68855 ;t acc: 0.83 ;v loss: 2.69765 ;v acc: 0.83
Epoch 414;t loss: 2.68909 ;t acc: 0.83 ;v loss: 2.69764 ;v acc: 0.83
Epoch 415;t loss: 2.68687 ;t acc: 

Epoch 431;t loss: 2.68974 ;t acc: 0.83 ;v loss: 2.69770 ;v acc: 0.83
Epoch 432;t loss: 2.68931 ;t acc: 0.83 ;v loss: 2.69772 ;v acc: 0.83
Epoch 433;t loss: 2.68807 ;t acc: 0.83 ;v loss: 2.69773 ;v acc: 0.83
Epoch 434;t loss: 2.68786 ;t acc: 0.83 ;v loss: 2.69769 ;v acc: 0.83
Epoch 435;t loss: 2.68669 ;t acc: 0.83 ;v loss: 2.69769 ;v acc: 0.83
Epoch 436;t loss: 2.68976 ;t acc: 0.83 ;v loss: 2.69768 ;v acc: 0.83
Epoch 437;t loss: 2.69131 ;t acc: 0.83 ;v loss: 2.69769 ;v acc: 0.83
Epoch 438;t loss: 2.69016 ;t acc: 0.83 ;v loss: 2.69775 ;v acc: 0.83
Epoch 439;t loss: 2.68665 ;t acc: 0.83 ;v loss: 2.69776 ;v acc: 0.83
Epoch 440;t loss: 2.68951 ;t acc: 0.83 ;v loss: 2.69771 ;v acc: 0.83
tensor([24,  8], device='cuda:0') tensor([10,  8], device='cuda:0')
tensor([14,  8], device='cuda:0') tensor([14, 12], device='cuda:0')
tensor([15,  4], device='cuda:0') tensor([15, 10], device='cuda:0')
tensor([20, 26], device='cuda:0') tensor([20, 29], device='cuda:0')
tensor([17,  2], device='cuda:0') tens

In [16]:
model.load_state_dict(torch.load('./best_params'))
test_loader = DataLoader(test_dataset, batch_size=1)
loss, acc = test(train_loader, False,target=target)
print(f'Train loss: {loss} ; Train acc: {acc}')
loss, acc = test(val_loader, False,target=target)

print(f'Val loss: {loss} ; Val acc: {acc}')
loss, acc = test(test_loader, False,target=target)
print(f'Test loss: {loss} ; Test acc: {acc}')

Train loss: 2.6868740022182465 ; Train acc: 0.828000009059906
Val loss: 2.697600245475769 ; Val acc: 0.8299999833106995
Test loss: 2.804517172442542 ; Test acc: 0.7111111283302307
