In [None]:
%load_ext autoreload
%autoreload 2

In [2]:
from Project import Project

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import pandas as pd

from Scripts.data.DataLoading import get_image_dataloaders, get_image_dataset
from Scripts.data.SSVEPDataset import SSVEPDataset
from Scripts.data.SSVEPDataloader import SSVEPDataloader
from Scripts.neuralnets.NNPreinstalledModelSelection import *
from Scripts.neuralnets.NNTrainingUtils import train, kfold_train

In [4]:
ngpu = 1; device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

***

Training with simple cross validation.

In [None]:
preprocessing = 'cca'
signal_length = '512'
fig_type = 'rp'
dataloaders, dataset_sizes = get_image_dataloaders(Project, preprocessing, signal_length, fig_type, batch_size = 16)

In [None]:
model_type = 'resnet'; model_size = '50'; num_classes = 5; model_name = model_type + model_size
model = model_selection(model_name, num_classes, pretrained = False).to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr = 0.001, momentum = 0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 10, gamma = 0.1)

In [None]:
model, stats = train(Project, model, dataloaders, dataset_sizes, criterion, optimizer, scheduler, num_epochs = 15)

***

Training with K-Fold cross validation.

In [13]:
preprocessing = 'cca'
signal_length = '1280'
fig_type = 'rp'
dataset = get_image_dataset(Project, preprocessing, signal_length, fig_type)

In [14]:
model_type = 'resnet'; model_size = '50'; num_classes = 5; model_name = model_type + model_size
model = model_selection(model_name, num_classes).to(device)

In [15]:
optimizer = optim.SGD(model.parameters(), lr = 0.001, momentum = 0.9)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 10, gamma = 0.5)

In [16]:
model, stats = kfold_train(Project, model, dataset, criterion, optimizer, num_epochs = 30, num_folds = 5, batch_size = 16)

[24/Sep/2020 08:34:54] INFO - Fold 0
[24/Sep/2020 08:34:54] INFO - Epoch    Stage       Loss    Accuracy

[24/Sep/2020 08:35:40] INFO - 1/30     Training    1.71    0.21    
[24/Sep/2020 08:35:48] INFO -          Validation  1.67    0.24    
[24/Sep/2020 08:36:30] INFO - 2/30     Training    1.46    0.38    
[24/Sep/2020 08:36:35] INFO -          Validation  1.65    0.26    
[24/Sep/2020 08:37:15] INFO - 3/30     Training    1.11    0.60    
[24/Sep/2020 08:37:21] INFO -          Validation  1.85    0.28    
[24/Sep/2020 08:38:02] INFO - 4/30     Training    0.56    0.86    
[24/Sep/2020 08:38:08] INFO -          Validation  2.15    0.22    
[24/Sep/2020 08:38:47] INFO - 5/30     Training    0.21    0.96    
[24/Sep/2020 08:38:53] INFO -          Validation  2.28    0.25    
[24/Sep/2020 08:39:33] INFO - 6/30     Training    0.08    0.99    
[24/Sep/2020 08:39:38] INFO -          Validation  2.33    0.25    
[24/Sep/2020 08:40:16] INFO - 7/30     Training    0.04    1.00    
[24/Sep/20

KeyboardInterrupt: 

***

Training with K-Fold cross validation and data preprocessed uniquely with CCA.

In [None]:
preprocessing = 'cca'
signal_length = '512'
data = pd.read_csv((Project.output_dir / ('SSVEPDataset_'+signal_length+'.csv')).__str__(), dtype = 'str')
dataset = SSVEPDataset(Project, preprocessing, data, signal_length)

In [None]:
batch_size = 64
dataloaders, dataset_sizes = SSVEPDataloader(dataset, batch_size)
print(dataset_sizes)

In [None]:
model_type = 'resnet'; model_size = '50'; num_classes = 5; model_name = model_type + model_size
model = model_selection(model_name, num_classes)
model.conv1 = nn.Conv2d(1, 64, kernel_size = (7, 7), stride = (2, 2), padding = (3, 3), bias = False)
model = model.to(device)

In [None]:
optimizer = optim.SGD(model.parameters(), lr = 0.001, momentum = 0.9)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 7, gamma = 0.1)

In [None]:
model, stats = kfold_train(Project, model, dataset, criterion, optimizer, num_epochs = 20, num_folds = 5)