In [1]:
# Fix randomness and hide warnings
seed = 42

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['MPLCONFIGDIR'] = os.getcwd()+'/configs/'



import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)

import numpy as np
np.random.seed(seed)

import logging

import random
random.seed(seed)

import gc

In [2]:
# Import tensorflow
import torch
import torch.version
from torch import nn
from torch.utils.data import DataLoader
import torch.optim as optim


torch.manual_seed(seed)

torch.version.__version__
print('PyTorch Version:',torch.version.__version__)
print('Cuda Version:',torch.version.cuda,'\n')

print('Available devices:')
for i in range(torch.cuda.device_count()):
   print('\t',torch.cuda.get_device_properties(i).name)
   print('\t\tMultiprocessor Count:',torch.cuda.get_device_properties(i).multi_processor_count)
   print('\t\tTotal Memory:',torch.cuda.get_device_properties(i).total_memory/1024/1024, 'MB')
   

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('\n',device)

PyTorch Version: 2.2.2+cu121
Cuda Version: 12.1 

Available devices:
	 NVIDIA GeForce RTX 3050 Ti Laptop GPU
		Multiprocessor Count: 20
		Total Memory: 3693.875 MB

 cuda:0


In [3]:
# Import other libraries
import matplotlib.pyplot as plt
import cv2
from skimage import transform
import pickle
from data_tools import *
from tqdm import tqdm
from copy import deepcopy

In [4]:
DATASETS_DIR = '../datasets/VTNet/'

with open(f'{DATASETS_DIR}trainset_vtnet.pkl', 'rb') as file:
    train_set = pickle.load(file)

with open(f'{DATASETS_DIR}valset_vtnet.pkl', 'rb') as file:
    val_set = pickle.load(file)

with open(f'{DATASETS_DIR}testset_vtnet.pkl', 'rb') as file:
    test_set = pickle.load(file)

In [5]:
class VETNet(nn.Module):
    def __init__(self, timeseries_size ,scanpath_size):
        super(VETNet, self).__init__()
        self.scanpath_layer_conv1 = nn.Conv2d(scanpath_size[0],16,kernel_size=5,padding='same')
        self.scanpath_layer_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.scanpath_layer_conv2 = nn.Conv2d(16,6,kernel_size=5,padding='same')
        self.scanpath_layer_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.scanpath_layer_flatten = nn.Flatten()
        self.scanpath_layer_linear = nn.Linear(scanpath_size[1]*scanpath_size[2]*6//16,50)
        # self.scanpath_layer_relu = nn.LeakyReLU()
        
        self.timeseries_layer_attention = nn.MultiheadAttention(timeseries_size[-1]-1,1, batch_first=True)
        self.timeseries_layer_gru = nn.GRU(input_size=timeseries_size[-1]-1, hidden_size=256, batch_first=True)
        
        self.classifier = nn.Sequential(
            nn.Linear(306,20),
            #nn.LeakyReLU(),
            nn.Linear(20,2)
        )
        
    def forward(self, x_timeseries, x_scanpath):
        x_scanpath = self.scanpath_layer_conv1(x_scanpath)
        x_scanpath = self.scanpath_layer_pool1(x_scanpath)
        x_scanpath = self.scanpath_layer_conv2(x_scanpath)
        x_scanpath = self.scanpath_layer_pool2(x_scanpath)
        x_scanpath = self.scanpath_layer_flatten(x_scanpath)
        x_scanpath = self.scanpath_layer_linear(x_scanpath)
        # x_scanpath = self.scanpath_layer_relu(x_scanpath)
        
        x_timeseries,_ = self.timeseries_layer_attention(x_timeseries,x_timeseries,x_timeseries)
        _ ,x_timeseries = self.timeseries_layer_gru(x_timeseries)
        
        x = torch.cat((torch.squeeze(x_timeseries,0),x_scanpath),dim=1)
        
        return self.classifier(x)

class EarlyStopping():
    def __init__(self, patience: int, mode: str, minimum_delta = 0.0):
        assert mode in {'min', 'max'}, "mode has to be 'min' or 'max'"
        self.minimum_delta = minimum_delta
        self.patience = patience
        self.counter = 0
        self.tracking = torch.inf if mode == 'min' else -torch.inf
        self.mode = 1 if mode=='max' else -1
        self.best_model = None
        
    def check(self, value, model: nn.Module):
        if self.mode*(value-self.tracking) <= self.minimum_delta:
            self.counter+=1
        else:
            self.counter = 0
            self.tracking = value
            self.best_model = deepcopy(model.state_dict())
            
        if self.counter == self.patience:
            model.load_state_dict(self.best_model)
            return True
        
        return False


class ReduceLROnPlateau():
    def __init__(self, patience: int, rate: float, mode: str, minimum_lr = 0.0, minimum_delta = 0.0):
        assert rate<=1 and rate>0, "rate as to be a number between 0 and 1"
        assert mode in {'min', 'max'}, "mode has to be 'min' or 'max'"

        self.minimum_delta = minimum_delta
        self.patience = patience
        self.counter = 0
        self.rate = rate
        self.minimum_lr = minimum_lr
        self.tracking = torch.inf if mode == 'min' else -torch.inf
        self.mode = 1 if mode=='max' else -1
        self.best_model = None
        
    def check(self, value, optimizer, model):
        if self.mode*(value-self.tracking) <= self.minimum_delta:
            self.counter+=1
        else:
            self.counter = 0
            self.tracking = value
            self.best_model = deepcopy(model.state_dict())
            
        if self.counter == self.patience:
            for i in range(len(optimizer.param_groups)):
                optimizer.param_groups[i]['lr'] = max(self.rate*optimizer.param_groups[i]['lr'], self.minimum_lr)
            model.load_state_dict(self.best_model)
            self.counter = 0

In [6]:
model = VETNet(timeseries_size=train_set[0][0].shape, scanpath_size=train_set[0][1].shape).to(device)
model

VETNet(
  (scanpath_layer_conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=same)
  (scanpath_layer_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (scanpath_layer_conv2): Conv2d(16, 6, kernel_size=(5, 5), stride=(1, 1), padding=same)
  (scanpath_layer_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (scanpath_layer_flatten): Flatten(start_dim=1, end_dim=-1)
  (scanpath_layer_linear): Linear(in_features=98304, out_features=50, bias=True)
  (timeseries_layer_attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=6, out_features=6, bias=True)
  )
  (timeseries_layer_gru): GRU(6, 256, batch_first=True)
  (classifier): Sequential(
    (0): Linear(in_features=306, out_features=20, bias=True)
    (1): Linear(in_features=20, out_features=2, bias=True)
  )
)

In [7]:
batchsize = 32
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

val_criterion = nn.CrossEntropyLoss()
lr_tracker = ReduceLROnPlateau(5, 0.5, mode='min', minimum_lr=1e-5)
earlystop_tracker = EarlyStopping(10, mode='min')

trainloader = DataLoader(train_set, batch_size=batchsize, shuffle=True)
valloader = DataLoader(val_set, batch_size=batchsize, shuffle=False)

In [8]:
running_loss = []
val_running_loss = []
for epoch in range(1,101):  # loop over the dataset multiple times

    running_loss += [0.0]
    val_running_loss += [0.0]
    
    with tqdm(trainloader, unit="batch") as tepoch:
        for input_rawdata, input_scanpath, labels in tepoch:
            # get the inputs; data is a list of [inputs, labels]
            tepoch.set_description(f"Epoch {epoch}")
            input_rawdata = input_rawdata[:,:,1:].to(device)
            input_scanpath = (input_scanpath/128-1).to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(input_rawdata, input_scanpath)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss[-1] += loss.item()
            
            tepoch.set_postfix(loss=running_loss[-1])
    
        for val_rawdata, val_scanpath, val_labels in valloader:
            val_rawdata = val_rawdata.to(device)
            val_scanpath = (val_scanpath/128-1).to(device)
            val_labels = val_labels.to(device)
            
            val_outputs = model(val_rawdata[:,:,1:], val_scanpath)
            val_loss = val_criterion(val_outputs, val_labels)
            val_running_loss[-1] += val_loss.item()
        
        print(f"\t Training Loss (final): {running_loss[-1]/len(train_set): .4f}, Validation Loss: {val_running_loss[-1]/len(val_set): .4f}, Learning Rate: {optimizer.param_groups[-1]['lr']: .5f}")
        
        lr_tracker.check(value=val_running_loss[-1], optimizer=optimizer, model=model)
        
        if earlystop_tracker.check(value=val_running_loss[-1], model=model):
            break
        

Epoch 1: 100%|██████████| 144/144 [00:17<00:00,  8.08batch/s, loss=99.4]


	 Training Loss (final):  0.0216, Validation Loss:  0.0291, Learning Rate:  0.00010


Epoch 2: 100%|██████████| 144/144 [00:18<00:00,  7.87batch/s, loss=88.8]


	 Training Loss (final):  0.0193, Validation Loss:  0.0287, Learning Rate:  0.00010


Epoch 3: 100%|██████████| 144/144 [00:17<00:00,  8.12batch/s, loss=81]  


	 Training Loss (final):  0.0176, Validation Loss:  0.0309, Learning Rate:  0.00010


Epoch 4: 100%|██████████| 144/144 [00:17<00:00,  8.36batch/s, loss=68.7]


	 Training Loss (final):  0.0150, Validation Loss:  0.0372, Learning Rate:  0.00010


Epoch 5: 100%|██████████| 144/144 [00:17<00:00,  8.31batch/s, loss=53.1]


	 Training Loss (final):  0.0116, Validation Loss:  0.0457, Learning Rate:  0.00010


Epoch 6: 100%|██████████| 144/144 [00:17<00:00,  8.37batch/s, loss=35.8]


	 Training Loss (final):  0.0078, Validation Loss:  0.0524, Learning Rate:  0.00010


Epoch 7: 100%|██████████| 144/144 [00:17<00:00,  8.27batch/s, loss=23.3]


	 Training Loss (final):  0.0051, Validation Loss:  0.0635, Learning Rate:  0.00010


Epoch 8: 100%|██████████| 144/144 [00:17<00:00,  8.34batch/s, loss=79.8]


	 Training Loss (final):  0.0174, Validation Loss:  0.0307, Learning Rate:  0.00005


Epoch 9: 100%|██████████| 144/144 [00:17<00:00,  8.34batch/s, loss=71.6]


	 Training Loss (final):  0.0156, Validation Loss:  0.0338, Learning Rate:  0.00005


Epoch 10: 100%|██████████| 144/144 [00:17<00:00,  8.31batch/s, loss=62.3]


	 Training Loss (final):  0.0136, Validation Loss:  0.0339, Learning Rate:  0.00005


Epoch 11: 100%|██████████| 144/144 [00:17<00:00,  8.32batch/s, loss=52.3]


	 Training Loss (final):  0.0114, Validation Loss:  0.0394, Learning Rate:  0.00005


Epoch 12: 100%|██████████| 144/144 [00:17<00:00,  8.17batch/s, loss=40.8]


	 Training Loss (final):  0.0089, Validation Loss:  0.0462, Learning Rate:  0.00005


In [9]:
test_loader = DataLoader(test_set, batch_size=batchsize, shuffle=True)
classes = ['CONTROL', 'PATIENT']

with torch.no_grad():
    TP = 0
    FN = 0
    FP = 0
    TN = 0
    
    for input_rawdata, input_scanpath, labels in test_loader:
        
        input_rawdata = input_rawdata[:,:,1:].to(device)
        input_scanpath = (input_scanpath/128-1).to(device)
        labels = labels.to(device)
        outputs = model(input_rawdata, input_scanpath)
        # max returns (value ,index)
        _, predicted = torch.max(outputs, 1)
        TP += torch.sum((predicted==0)[labels==0])
        FN += torch.sum((predicted==1)[labels==0])
        FP += torch.sum((predicted==0)[labels==1])
        TN += torch.sum((predicted==1)[labels==1])

    sensitivity = TP/(TP+FN)
    specificity = TN/(TN+FP)
    print(f'Sensitivity: {sensitivity*100} %')
    print(f'Specificity: {specificity*100} %')
    print()
    print('         | Healthy |   Sick   ')
    print('---------|---------|----------')
    print(f'negative |   {int(TP)}   |   {int(FP)}   ')
    print(f'positive |   {int(FN)}   |   {int(TN)}   ')



Sensitivity: 53.125 %
Specificity: 61.224491119384766 %

         | Healthy |   Sick   
---------|---------|----------
negative |   136   |   133   
positive |   120   |   210   
