# Imports, load data

In [10]:
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
from matplotlib import cm
from copy import deepcopy

import torch
from torch import nn
from torch.nn import functional
import torch.optim as optim

from IPython.display import display, clear_output
from ipywidgets import IntProgress, Text, Output

from src.utils import progress_bar
from src.network import train_CNN, LeNet

SEED = 3489
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f11188890b0>

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
f = open('data/X_train.p', 'rb')
X_train = pkl.load(f)
f.close()
y_train = np.loadtxt('data/y_train.txt')

idx_valid = list(np.random.choice(X_train.shape[0], replace=False, size=7500))
idx_train = list(set(list(range(X_train.shape[0]))).difference(set(idx_valid)))

X_valid, y_valid = X_train[idx_valid], y_train[idx_valid]
X_train, y_train = X_train[idx_train], y_train[idx_train]
num_train = X_train.shape[0]

f = open('data/X_test1.p', 'rb')
X_test1 = pkl.load(f)
f.close()
y_test1 = np.loadtxt('data/y_test1.txt')

# Train

In [4]:
batch_size = 20
num_epochs = 100

# Set up training monitors
num_batches = int(np.ceil(num_train / batch_size))
batch_progress = IntProgress(value=0, max=num_batches)
epoch_progress = IntProgress(value=0, max=num_epochs)
valid_stats = Output()

info_box = Text(value='')

def batch_hook(model, stats, epoch, batch_num):
    batch_progress.value = batch_num
    batch_progress.description = str(batch_num)
    
def epoch_hook(model, stats, epoch):
    epoch_progress.value = epoch
    batch_progress.value = 0
    batch_progress.description = '0'
    
    if epoch > 1:
        t_loss = stats['train_loss'][-1]
        t_diff = stats['train_loss'][-2] - t_loss
        v_acc = stats['valid_acc'][-1]
        v_diff = stats['valid_acc'][-2] - v_acc

        msg = 'Tr loss: %.4f (%.4f); ' % (t_loss, t_diff)
        msg = msg + 'Vd acc: %.4f (%.4f)' % (v_acc, v_diff)
        info_box.value = msg
        
        with valid_stats:
            clear_output()
            f, ax = plt.subplots(figsize=(10, 5))
            ax.set_xlabel('Epoch')
            ax.set_ylabel('Validation accuracy')
            ax.plot(range(len(stats['valid_acc'])), stats['valid_acc'])
            ax.scatter(range(len(stats['valid_acc'])), stats['valid_acc'])
            plt.show()
    
display(batch_progress)
display(epoch_progress)
display(info_box)
display(valid_stats)

# Set up training
model = LeNet(p=0.4)
train_data = (X_train, y_train)
learning_rate = 0.001
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss = nn.CrossEntropyLoss()
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5)
scheduler = None
valid_stop_threshold = 0.0001
valid_data = (X_valid, y_valid)
epoch_hooks = [epoch_hook]
batch_hooks = [batch_hook]

model, stats = train_CNN(
    model, train_data, optimizer, loss, num_epochs, batch_size, scheduler, valid_data,
    valid_stop_threshold, epoch_hooks=epoch_hooks, batch_hooks=batch_hooks
)

batch_progress.close()
epoch_progress.close()
info_box.close()


IntProgress(value=0, max=2625)

IntProgress(value=0)

Text(value='')

Output()

In [16]:
# stats_cp = deepcopy(stats)
# bm, acc = stats['best_model']
# stats_cp['best_model_state_dict'] = (bm.state_dict(), acc)
# del stats_cp['best_model']

# to_save = dict()
# to_save['model_state_dict'] = model.state_dict()
# to_save['stats'] = stats_cp
# to_save['optimizer'] = optimizer
# to_save['scheduler'] = scheduler
# to_save['batch_size'] = batch_size
# to_save['SEED'] = SEED

# f = open('saved_models/LeNet_attempt5.p', 'wb')
# pkl.dump(to_save, f)
# f.close()

# stats['best_model']

torch.save(model.state_dict(), 'saved_models/LeNet_no_adversarial.pth')

# Get test accuracy

In [None]:
y_pred = model(
    torch.Tensor(X_test1.reshape(-1, 1, 28, 28))
).detach().numpy().argmax(axis=1)
acc = len(np.argwhere(y_pred == y_test1)) / len(y_test1)

print('Test acccuracy: %.5f' % acc)