In [21]:
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
from dataset import load_svhn, random_split_train_val
from gradient_check import check_layer_gradient, check_layer_param_gradient, check_model_gradient
from layers import FullyConnectedLayer, ReLULayer
from model import TwoLayerNet
from trainer import Trainer, Dataset
from optim import SGD, MomentumSGD
from metrics import multiclass_accuracy

In [52]:
def prepare_for_neural_network(train_X, test_X):
    train_flat = train_X.reshape(train_X.shape[0], -1).astype(float) / 255.0
    test_flat = test_X.reshape(test_X.shape[0], -1).astype(float) / 255.0
    
    # Subtract mean
    mean_image = np.mean(train_flat, axis = 0)
    train_flat -= mean_image
    test_flat -= mean_image
    
    return train_flat, test_flat
    
train_X, train_y, test_X, test_y = load_svhn("data", max_train=20000, max_test=2000)    
train_X, test_X = prepare_for_neural_network(train_X, test_X)
# Split train into train and val
train_X, train_y, val_X, val_y = random_split_train_val(train_X, train_y, num_val = 2000)

In [53]:
train_X.shape

(18000, 3072)

In [62]:
model = TwoLayerNet(n_input = train_X.shape[1], n_output = 10, hidden_layer_size = 1000, reg = 2e-5)
dataset = Dataset(train_X, train_y, val_X, val_y)
trainer = Trainer(model, dataset, MomentumSGD(), learning_rate=1e-3, learning_rate_decay=0.921)

loss_history, train_history, val_history = trainer.fit()

Loss: 294.333652, Train accuracy: 0.493111, val accuracy: 0.490000
Loss: 267.577196, Train accuracy: 0.643889, val accuracy: 0.617500
Loss: 231.017532, Train accuracy: 0.676056, val accuracy: 0.635000
Loss: 160.965240, Train accuracy: 0.707056, val accuracy: 0.647500
Loss: 177.102631, Train accuracy: 0.772722, val accuracy: 0.715000
Loss: 129.580805, Train accuracy: 0.758778, val accuracy: 0.684500
Loss: 136.531447, Train accuracy: 0.791611, val accuracy: 0.721500
Loss: 135.941436, Train accuracy: 0.828889, val accuracy: 0.744500
Loss: 143.450791, Train accuracy: 0.852778, val accuracy: 0.751000
Loss: 128.028089, Train accuracy: 0.858444, val accuracy: 0.754000
Loss: 142.161461, Train accuracy: 0.861556, val accuracy: 0.753000
Loss: 90.843162, Train accuracy: 0.884222, val accuracy: 0.778000
Loss: 89.406321, Train accuracy: 0.890333, val accuracy: 0.775000
Loss: 77.517742, Train accuracy: 0.901500, val accuracy: 0.775000
Loss: 56.028226, Train accuracy: 0.909722, val accuracy: 0.768500

In [63]:
accuracy_history.append(multiclass_accuracy(model.predict(test_X), test_y))
accuracy_history

[0.707,
 0.73,
 0.77,
 0.74,
 0.77,
 0.774,
 0.6435,
 0.767,
 0.7755,
 0.7755,
 0.799,
 0.799,
 0.808,
 0.808,
 0.6595,
 0.79]

In [64]:
loss_history, train_history, val_history = trainer.fit()

Loss: 46.573772, Train accuracy: 0.954000, val accuracy: 0.792000
Loss: 46.855149, Train accuracy: 0.960000, val accuracy: 0.793000
Loss: 36.516839, Train accuracy: 0.958222, val accuracy: 0.787000
Loss: 29.808346, Train accuracy: 0.964056, val accuracy: 0.791500
Loss: 33.210782, Train accuracy: 0.961278, val accuracy: 0.791500
Loss: 27.329076, Train accuracy: 0.971167, val accuracy: 0.798000
Loss: 25.175704, Train accuracy: 0.973111, val accuracy: 0.803000
Loss: 30.154618, Train accuracy: 0.976778, val accuracy: 0.800500
Loss: 20.481632, Train accuracy: 0.975889, val accuracy: 0.803500
Loss: 28.742936, Train accuracy: 0.977389, val accuracy: 0.798500
Loss: 22.515776, Train accuracy: 0.980222, val accuracy: 0.799000
Loss: 26.105226, Train accuracy: 0.980667, val accuracy: 0.803500
Loss: 21.295256, Train accuracy: 0.982056, val accuracy: 0.803000
Loss: 26.987489, Train accuracy: 0.982444, val accuracy: 0.800500
Loss: 20.990532, Train accuracy: 0.984278, val accuracy: 0.800500
Loss: 24.0

In [65]:
accuracy_history.append(multiclass_accuracy(model.predict(test_X), test_y))
accuracy_history

[0.707,
 0.73,
 0.77,
 0.74,
 0.77,
 0.774,
 0.6435,
 0.767,
 0.7755,
 0.7755,
 0.799,
 0.799,
 0.808,
 0.808,
 0.6595,
 0.79,
 0.8045]