In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from Project import Project

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import pandas as pd

from Scripts.data.DataLoading import get_image_dataloaders, get_image_dataset
from Scripts.data.SSVEPDataset import SSVEPDataset
from Scripts.data.SSVEPDataloader import SSVEPDataloader
from Scripts.neuralnets.NNPreinstalledModelSelection import *
from Scripts.neuralnets.NNTrainingUtils import train, kfold_train

In [5]:
ngpu = 1; device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

***

Training with simple cross validation.

In [None]:
preprocessing = 'cca'
signal_length = '512'
fig_type = 'rp'
dataloaders, dataset_sizes = get_image_dataloaders(Project, preprocessing, signal_length, fig_type, batch_size = 16)

In [None]:
model_type = 'resnet'; model_size = '50'; num_classes = 5; model_name = model_type + model_size
model = model_selection(model_name, num_classes, pretrained = False).to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr = 0.001, momentum = 0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 10, gamma = 0.1)

In [None]:
model, stats = train(Project, model, dataloaders, dataset_sizes, criterion, optimizer, scheduler, num_epochs = 15)

***

Training with K-Fold cross validation.

In [6]:
preprocessing = 'cca'
signal_length = '1280'
fig_type = 'rp'
dataset = get_image_dataset(Project, preprocessing, signal_length, fig_type)

In [11]:
model_type = 'resnet'; model_size = '50'; num_classes = 5; model_name = model_type + model_size
model = model_selection(model_name, num_classes).to(device)

In [12]:
optimizer = optim.SGD(model.parameters(), lr = 0.001, momentum = 0.4)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 5, gamma = 0.5)

In [13]:
model, stats = kfold_train(Project, model, dataset, criterion, optimizer, num_epochs = 20, num_folds = 7)

[23/Sep/2020 14:45:42] INFO - Fold 0
[23/Sep/2020 14:45:42] INFO - Epoch    Stage       Loss    Accuracy

[23/Sep/2020 14:46:40] INFO - 1/20     Training    1.65    0.21    
[23/Sep/2020 14:46:45] INFO -          Validation  1.69    0.19    
[23/Sep/2020 14:47:41] INFO - 2/20     Training    1.58    0.27    
[23/Sep/2020 14:47:45] INFO -          Validation  1.61    0.24    
[23/Sep/2020 14:48:38] INFO - 3/20     Training    1.54    0.30    
[23/Sep/2020 14:48:42] INFO -          Validation  1.62    0.24    
[23/Sep/2020 14:49:40] INFO - 4/20     Training    1.44    0.39    
[23/Sep/2020 14:49:44] INFO -          Validation  1.73    0.25    
[23/Sep/2020 14:50:39] INFO - 5/20     Training    1.33    0.46    
[23/Sep/2020 14:50:43] INFO -          Validation  1.76    0.24    
[23/Sep/2020 14:51:37] INFO - 6/20     Training    1.12    0.58    
[23/Sep/2020 14:51:41] INFO -          Validation  1.87    0.24    
[23/Sep/2020 14:52:40] INFO - 7/20     Training    0.89    0.68    
[23/Sep/20

KeyboardInterrupt: 

***

Training with K-Fold cross validation and data preprocessed uniquely with CCA.

In [None]:
preprocessing = 'cca'
signal_length = '512'
data = pd.read_csv((Project.output_dir / ('SSVEPDataset_'+signal_length+'.csv')).__str__(), dtype = 'str')
dataset = SSVEPDataset(Project, preprocessing, data, signal_length)

In [None]:
batch_size = 64
dataloaders, dataset_sizes = SSVEPDataloader(dataset, batch_size)
print(dataset_sizes)

In [None]:
model_type = 'resnet'; model_size = '50'; num_classes = 5; model_name = model_type + model_size
model = model_selection(model_name, num_classes)
model.conv1 = nn.Conv2d(1, 64, kernel_size = (7, 7), stride = (2, 2), padding = (3, 3), bias = False)
model = model.to(device)

In [None]:
optimizer = optim.SGD(model.parameters(), lr = 0.001, momentum = 0.9)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 7, gamma = 0.1)

In [None]:
model, stats = kfold_train(Project, model, dataset, criterion, optimizer, num_epochs = 20, num_folds = 5)