Importing modules

In [5]:
import numpy as np
import os
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from Layer import *
from network import Network
from activation_func import tanh, tanh_prime, sigmoid, sigmoid_prime, softmax, softmax_prime, relu, relu_prime
from loss_func import mse, mse_prime, cross_entropy, cross_entropy_prime
from data_func import vectorize_labels, k_fold, import_data
from performance_func import plot_confusion_matrix, plot_error

Importing data

In [6]:
# import data and initialize seed
np.random.seed(10)
training_size = 6000
normalize = True
training, labels, test, original_test_labels, test_labels = import_data(size=training_size, normalize=normalize)

# specify input and output parameters
features = 784
output_classes = 10

Setting configuration

In [9]:
# hyper parameters
learning_rate = 5e-3
hidden_layers = [300, 200, 100]
max_epochs = 10
batch_size = 32
weight_decay = 0.01
momentum = False



Train the network

In [10]:
# set up the network with specified layers, loss, and activation
net = Network()
net.setup_net(hidden_layers, features, output_classes,
              activation=relu, activation_prime=relu_prime,
              loss_activation=softmax, loss_activation_prime=softmax_prime,
              loss=cross_entropy, loss_prime=cross_entropy_prime)
# prepare data for training
fold_train_data, fold_train_labels, fold_val_data, fold_val_labels = k_fold(training, labels, k=5, n=5)

# train the model on training data and labels using specific hyper-parameters
errors, val_errors = net.fit(fold_train_data, fold_train_labels, fold_val_data, fold_val_labels,
                             max_epochs, learning_rate, batch_size, momentum, weight_decay)

# print the accuracy
print("The test accuracy of the network is: {}".format(
      net.accuracy(x=test, y_true=original_test_labels, errors=errors, val_errors=val_errors)))

epoch 1/10   training error=1.301110  validation error=0.641473 validation accuracy=0.785000
epoch 2/10   training error=0.495587  validation error=0.504692 validation accuracy=0.840000
epoch 3/10   training error=0.368760  validation error=0.441009 validation accuracy=0.864167
epoch 4/10   training error=0.305300  validation error=0.393392 validation accuracy=0.885000
epoch 5/10   training error=0.260007  validation error=0.378865 validation accuracy=0.880833
epoch 6/10   training error=0.227129  validation error=0.344258 validation accuracy=0.893333
epoch 7/10   training error=0.203162  validation error=0.345213 validation accuracy=0.890000
epoch 8/10   training error=0.184184  validation error=0.316737 validation accuracy=0.905833
epoch 9/10   training error=0.168443  validation error=0.309046 validation accuracy=0.905000
epoch 10/10   training error=0.154835  validation error=0.306213 validation accuracy=0.902500
The test accuracy of the network is: 0.9035


Plot results

In [None]:
plot_error(errors, val_errors)