In [None]:
# Fix randomness and hide warnings
seed = 42

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['MPLCONFIGDIR'] = os.getcwd()+'/configs/'



import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=Warning)

import numpy as np
np.random.seed(seed)

import logging

import random
random.seed(seed)

In [None]:
# Import tensorflow
import torch
import torch.version
from torch import nn
from torch.utils.data import DataLoader
import torch.optim as optim


torch.manual_seed(seed)

torch.version.__version__
print('PyTorch Version:',torch.version.__version__)
print('Cuda Version:',torch.version.cuda,'\n')

print('Available devices:')
for i in range(torch.cuda.device_count()):
   print('\t',torch.cuda.get_device_properties(i).name)
   print('\t\tMultiprocessor Count:',torch.cuda.get_device_properties(i).multi_processor_count)
   print('\t\tTotal Memory:',torch.cuda.get_device_properties(i).total_memory/1024/1024, 'MB')
   

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('\n',device)

In [None]:
# Import other libraries
import matplotlib.pyplot as plt
import cv2
from skimage import transform
import pickle
from data_tools import *
from tqdm import tqdm
from copy import deepcopy

In [None]:
DATASETS_DIR = '../datasets/VTNet/'

with open(f'{DATASETS_DIR}trainset_vtnet_{2}.pkl', 'rb') as file:
    train_set = pickle.load(file)

with open(f'{DATASETS_DIR}valset_vtnet_{2}.pkl', 'rb') as file:
    val_set = pickle.load(file)

with open(f'{DATASETS_DIR}testset_vtnet_{2}.pkl', 'rb') as file:
    test_set = pickle.load(file)

In [None]:
class VETNet(nn.Module):
    def __init__(self, timeseries_size ,scanpath_size, cnn_shape = (16,6)):
        super(VETNet, self).__init__()
        self.scanpath_layer = nn.Sequential(
            nn.Conv2d(scanpath_size[0],cnn_shape[0],kernel_size=5,padding='same'),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(cnn_shape[0],cnn_shape[1],kernel_size=5,padding='same'),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(scanpath_size[1]*scanpath_size[2]*cnn_shape[1]//16,50)
        )
        
        self.timeseries_layer_attention = nn.MultiheadAttention(timeseries_size[-1]-1,1, batch_first=True)
        self.timeseries_layer_gru = nn.GRU(input_size=timeseries_size[-1]-1, hidden_size=256, batch_first=True)
        
        self.classifier = nn.Sequential(
            nn.Linear(306,20),
            #nn.LeakyReLU(),
            nn.Linear(20,2),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x_timeseries, x_scanpath):
        x_scanpath = self.scanpath_layer(x_scanpath)
        
        x_timeseries,_ = self.timeseries_layer_attention(x_timeseries,x_timeseries,x_timeseries)
        _ ,x_timeseries = self.timeseries_layer_gru(x_timeseries)
        
        x = torch.cat((torch.squeeze(x_timeseries,0),x_scanpath),dim=1)
        
        return self.classifier(x)

class EarlyStopping():
    def __init__(self, patience: int, mode: str, minimum_delta = 0.0):
        assert mode in {'min', 'max'}, "mode has to be 'min' or 'max'"
        self.minimum_delta = minimum_delta
        self.patience = patience
        self.counter = 0
        self.tracking = torch.inf if mode == 'min' else -torch.inf
        self.mode = 1 if mode=='max' else -1
        self.best_model = None
        
    def check(self, value, model: nn.Module):
        if self.mode*(value-self.tracking) <= self.minimum_delta:
            self.counter+=1
        else:
            self.counter = 0
            self.tracking = value
            self.best_model = deepcopy(model.state_dict())
            
        if self.counter == self.patience:
            model.load_state_dict(self.best_model)
            return True
        
        return False


class ReduceLROnPlateau():
    def __init__(self, patience: int, rate: float, mode: str, minimum_lr = 0.0, minimum_delta = 0.0):
        assert rate<=1 and rate>0, "rate as to be a number between 0 and 1"
        assert mode in {'min', 'max'}, "mode has to be 'min' or 'max'"

        self.minimum_delta = minimum_delta
        self.patience = patience
        self.counter = 0
        self.rate = rate
        self.minimum_lr = minimum_lr
        self.tracking = torch.inf if mode == 'min' else -torch.inf
        self.mode = 1 if mode=='max' else -1
        self.best_model = None
        
    def check(self, value, optimizer, model):
        if self.mode*(value-self.tracking) <= self.minimum_delta:
            self.counter+=1
        else:
            self.counter = 0
            self.tracking = value
            self.best_model = deepcopy(model.state_dict())
            
        if self.counter == self.patience:
            for i in range(len(optimizer.param_groups)):
                optimizer.param_groups[i]['lr'] = max(self.rate*optimizer.param_groups[i]['lr'], self.minimum_lr)
            model.load_state_dict(self.best_model)
            self.counter = 0

In [None]:
model = VETNet(timeseries_size=train_set[0][0].shape, scanpath_size=train_set[0][1].shape).to(device)
model

In [None]:
batchsize = 32
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

val_criterion = nn.CrossEntropyLoss()
lr_tracker = ReduceLROnPlateau(5, 0.5, mode='min', minimum_lr=1e-6)
earlystop_tracker = EarlyStopping(10, mode='min')

trainloader = DataLoader(train_set, batch_size=batchsize, shuffle=True)
valloader = DataLoader(val_set, batch_size=batchsize, shuffle=False)

In [None]:
running_loss = []
val_running_loss = []
for epoch in range(1,101):  # loop over the dataset multiple times

    running_loss += [0.0]
    val_running_loss += [0.0]
    
    with tqdm(trainloader, unit="batch") as tepoch:
        for input_rawdata, input_scanpath, labels in tepoch:
            # get the inputs; data is a list of [inputs, labels]
            tepoch.set_description(f"Epoch {epoch}")
            input_rawdata = input_rawdata[:,:,1:].to(device)
            input_scanpath = (input_scanpath/128-1).to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(input_rawdata, input_scanpath)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss[-1] += loss.item()
            
            tepoch.set_postfix(loss=running_loss[-1])
    
        for val_rawdata, val_scanpath, val_labels in valloader:
            val_rawdata = val_rawdata.to(device)
            val_scanpath = (val_scanpath/128-1).to(device)
            val_labels = val_labels.to(device)
            
            val_outputs = model(val_rawdata[:,:,1:], val_scanpath)
            val_loss = val_criterion(val_outputs, val_labels)
            val_running_loss[-1] += val_loss.item()
        
        print(f"\t Training Loss (final): {running_loss[-1]/len(train_set): .4f}, Validation Loss: {val_running_loss[-1]/len(val_set): .4f}, Learning Rate: {optimizer.param_groups[-1]['lr']: .2E}")
        
        lr_tracker.check(value=val_running_loss[-1], optimizer=optimizer, model=model)
        
        if earlystop_tracker.check(value=val_running_loss[-1], model=model):
            break
        

In [None]:
test_loader = DataLoader(test_set, batch_size=batchsize, shuffle=False)
classes = ['CONTROL', 'PATIENT']

with torch.no_grad():
    TP = 0
    FN = 0
    FP = 0
    TN = 0
    
    for input_rawdata, input_scanpath, labels in test_loader:
        
        input_rawdata = input_rawdata[:,:,1:].to(device)
        input_scanpath = (input_scanpath/128-1).to(device)
        labels = labels.to(device)
        outputs = model(input_rawdata, input_scanpath)
        # max returns (value ,index)
        _, predicted = torch.max(outputs, 1)
        TP += torch.sum((predicted==0)[labels==0])
        FN += torch.sum((predicted==1)[labels==0])
        FP += torch.sum((predicted==0)[labels==1])
        TN += torch.sum((predicted==1)[labels==1])

    sensitivity = TP/(TP+FN)
    specificity = TN/(TN+FP)
    print(f'Sensitivity: {sensitivity*100} %')
    print(f'Specificity: {specificity*100} %')
    print()
    print('         | {classes[0]} | {classes[1]} ')
    print('---------|---------|----------')
    print(f'negative |   {int(TP)}   |   {int(FP)}   ')
    print(f'positive |   {int(FN)}   |   {int(TN)}   ')



In [None]:
subjects = test_set.subject*(test_set.groups*2-1)
groups = torch.unique(subjects)
groups[groups<0]=0
groups[groups>0]=1

test_loader = DataLoader(test_set, batch_size=1, shuffle=False)
classes = ['CONTROL', 'PATIENT']

with torch.no_grad():
    TP = 0
    FN = 0
    FP = 0
    TN = 0
    
    prev_subject = None
    
    predicted = []
    
    for i,(input_rawdata, input_scanpath, labels) in enumerate(test_loader):
        
        input_rawdata = input_rawdata[:,:,1:].to(device)
        input_scanpath = (input_scanpath/128-1).to(device)
        labels = labels.to(device)
        
        if prev_subject == subjects [i]:
            outputs += model(input_rawdata, input_scanpath)
        else:
            predicted += [torch.max(outputs, 1)[1]]
            
            outputs = model(input_rawdata, input_scanpath)
        
        prev_subject = subjects[i]

    predicted += [torch.max(outputs, 1)[1]]
    predicted = torch.tensor(predicted[1:], dtype=torch.int64)
        
    TP += torch.sum((predicted==0)[groups == 0])
    FN += torch.sum((predicted==1)[groups == 0])
    FP += torch.sum((predicted==0)[groups == 1])
    TN += torch.sum((predicted==1)[groups == 1])

    sensitivity = TP/(TP+FN)
    specificity = TN/(TN+FP)
    print(f'Sensitivity: {sensitivity*100} %')
    print(f'Specificity: {specificity*100} %')
    print()
    print('         | Healthy |   Sick   ')
    print('---------|---------|----------')
    print(f'negative |   {int(TP)}   |   {int(FP)}   ')
    print(f'positive |   {int(FN)}   |   {int(TN)}   ')

