In [1]:
# %%

from inout import load_mnist
from utils import preprocess
from network import Network

In [2]:
# %%

dataset_name = 'mnist'
num_epochs = 1
learning_rate = 0.01
validate = 1
regularization = 0
verbose = 1
plot_weights = 1
plot_correct = 1
plot_missclassified = 1
plot_feature_maps = 1

In [3]:
# %%

print('\n--- Loading ' + dataset_name + ' dataset ---')                 # load dataset
dataset = load_mnist() # if dataset_name is 'mnist' else load_cifar()


--- Loading mnist dataset ---


In [4]:
# %%

print('\n--- Processing the dataset ---')                               # pre process dataset
dataset = preprocess(dataset)


--- Processing the dataset ---


In [5]:
# %%

print('\n--- Building the model ---')                                   # build model
model = Network()
model.build_model(dataset_name)


--- Building the model ---


In [6]:
# %%

print('\n--- Training the model ---')                                   # train model
model.train(
    dataset,
    num_epochs,
    learning_rate,
    validate,
    regularization,
    plot_weights,
    verbose
)


--- Training the model ---

--- Epoch 1 ---
[Step 00100]: Loss 2.212 | Accuracy: 18.000 | Time: 55.16 seconds | Validation Loss 2.146 | Validation Accuracy: 22.660
[Step 00200]: Loss 2.039 | Accuracy: 26.500 | Time: 52.34 seconds | Validation Loss 1.397 | Validation Accuracy: 56.120
[Step 00300]: Loss 1.766 | Accuracy: 38.333 | Time: 60.99 seconds | Validation Loss 1.012 | Validation Accuracy: 68.600
[Step 00400]: Loss 1.584 | Accuracy: 45.000 | Time: 54.36 seconds | Validation Loss 0.871 | Validation Accuracy: 71.620
[Step 00500]: Loss 1.421 | Accuracy: 50.400 | Time: 64.88 seconds | Validation Loss 0.718 | Validation Accuracy: 77.920
[Step 00600]: Loss 1.311 | Accuracy: 54.667 | Time: 54.78 seconds | Validation Loss 0.689 | Validation Accuracy: 77.140
[Step 00700]: Loss 1.207 | Accuracy: 58.571 | Time: 56.72 seconds | Validation Loss 0.637 | Validation Accuracy: 81.360
[Step 00800]: Loss 1.169 | Accuracy: 61.000 | Time: 53.87 seconds | Validation Loss 0.693 | Validation Accuracy: 79

KeyboardInterrupt: 

In [None]:
# %%

from utils import regularized_cross_entropy, plot_learning_curve, plot_accuracy_curve, plot_histogram, plot_sample, lr_update
import numpy as np

def forward(image, plot_feature_maps):                # forward propagate
    for layer in model.layers:
        if plot_feature_maps:
            plot_sample((image * 255)[0, :, :], None, None)
        image = layer.forward(image)

    return image
  
def evaluate(X, y, regularization, plot_correct, plot_missclassified, plot_feature_maps, verbose):
    loss, num_correct = 0, 0
    for i in range(len(X)):
        tmp_output = forward(X[i], plot_feature_maps)              # forward propagation

        # compute cross-entropy update loss
        loss += regularized_cross_entropy(model.layers, regularization, tmp_output[y[i]])

        prediction = np.argmax(tmp_output)                              # update accuracy
        if prediction == y[i]:
            num_correct += 1
            if plot_correct:                                            # plot correctly classified digit
                image = (X[i] * 255)[0, :, :]
                plot_sample(image, y[i], prediction)
                plot_correct = 1
        else:
            if plot_missclassified:                                     # plot missclassified digit
                image = (X[i] * 255)[0, :, :]
                plot_sample(image, y[i], prediction)
                plot_missclassified = 1

    test_size = len(X)
    accuracy = (num_correct / test_size) * 100
    loss = loss / test_size
    if verbose:
        print('Test Loss: %02.3f' % loss)
        print('Test Accuracy: %02.3f' % accuracy)
    return loss, accuracy

evaluate(
    dataset['test_images'],
    dataset['test_labels'],
    regularization,
    plot_correct,
    plot_missclassified,
    plot_feature_maps,
    verbose   
)