In [None]:
import sys
sys.path.append('..')
%load_ext autoreload
%autoreload 2
%matplotlib inline


In [None]:
import os
from pathlib import Path

import torch
import numpy as np
import seaborn as sns
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

from dataloader.builder import build_dataset

from model.cnn import SimpleConv, MediumConv, StrongConv
from model.trainer import Trainer, EnsembleTrainer


# torch.cuda.set_device(1)
#

In [None]:
plt.rcParams['figure.facecolor'] = 'white'
config = {
    'use_cuda': True,
    'seed': 1,
    
    'nn_runs': 150,
    'patience': 5,
    'dropout_train': 0.5,
    
    'dataset': 'cifar_10',
   
    'model_class': StrongConv,
    'train_samples': 20_000,
    'epochs': 10,
    'batch_size': 256,
    'log_interval': 150,
    'lr': 1e-2,
    'num_classes': 10
}
restore_model = False 


In [None]:
dataset = build_dataset(config['dataset'], val_size=10_000)
x_train, y_train = dataset.dataset('train')
x_val, y_val = dataset.dataset('val')



In [None]:
def scale(images):
    return (images - 128) / 128
x_train = scale(x_train)
x_val = scale(x_val)

In [None]:
input_shape = (-1, 3, 32, 32)
x_train = x_train.reshape(input_shape)[:config['train_samples']]
x_val = x_val.reshape(input_shape)

y_train = y_train.astype('long').reshape(-1)[:config['train_samples']]
y_val = y_val.astype('long').reshape(-1)

train_set = (x_train, y_train)
val_set = (x_val, y_val)


In [None]:
model_class = config['model_class'] 
model = model_class(config['num_classes'], activation=torch.nn.functional.celu)

model_dir = Path('data')
model_path = model_dir / f"model_{config['dataset']}_{config['model_class'].__name__}.pt"


In [None]:
trainer = Trainer(model, dropout_train=config['dropout_train'])

if restore_model and os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
else:
    trainer.fit(
        train_set, val_set , epochs=config['epochs'], verbose=True,
        patience=config['patience'])
    torch.save(model.state_dict(), model_path)



In [None]:
accuracy_score(trainer.predict(train_set[0]), train_set[1])


In [None]:
accuracy_score(trainer.predict(val_set[0]), val_set[1])


In [None]:
plt.figure(figsize=(16, 9))
plt.plot(trainer.train_loss_history)
plt.plot(trainer.val_loss_history)