# Neural Networks

## The Data

In [None]:
from nnet.data import get_mnist_data
from nnet.data import get_batch
from nnet.data import Data

from nnet.plotting import plot_examples
from nnet.plotting import plot_performance

In [None]:
data = get_mnist_data()

In [None]:
plot_examples(data.train)

## The Network

In [None]:
from nnet.network import get_layer
from nnet.network import get_forward
from nnet.network import get_accuracy
from nnet.network import get_update_step
from nnet.network import get_optimizer_triplet
from nnet.network import get_start_params
from nnet.network import get_loss
from nnet.network import get_fitting_function

from tqdm.auto import tqdm

import jax
key = jax.random.PRNGKey(1)

## Building the Network

In [None]:
layer = get_layer(activation_func="sigmoid")

In [None]:
forward_pass = get_forward(layer)

In [None]:
loss_func = get_loss(forward_pass, loss_type="mean_square")

In [None]:
accuracy = get_accuracy(forward_pass)

In [None]:
image_dim = 28  # number of pixels
layer_sizes = [image_dim ** 2, 28, 28, 10]

start_params = get_start_params(layer_sizes, key)

In [None]:
opt_state, opt_update, get_params = get_optimizer_triplet("sgd", start_params=start_params, step_size=0.01)

In [None]:
update = get_update_step(get_params, opt_update, loss_func)

In [None]:
fit_network = get_fitting_function(data, update, accuracy, get_params, tqdm=tqdm)

In [None]:
result = fit_network(n_epochs=50, batch_size=100, opt_state=opt_state, verbose=False)

In [None]:
predictions = forward_pass(result["params"], data.test.images)
predictions = predictions.argmax(axis=1)

In [None]:
wrong = jax.numpy.where(predictions != data.test.labels)[0]

In [None]:
_data = Data(images=data.test.images[wrong], labels=data.test.labels[wrong], predictions=predictions[wrong])

In [None]:
len(wrong) / len(data.test.labels)

In [None]:
plot_examples(_data)

In [None]:
plot_performance(result["log"])