In [1]:
import jax
import jax.numpy as jnp
from jax import vmap
from mlp import MLP
import tensorflow.keras.datasets as datasets

In [2]:
def get_preprocessed_mnist():
    def to_one_hot(x, classes):
        return jax.ops.index_update(jnp.zeros(classes), x, 1)
    (x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
    x_train = x_train.reshape(-1, 28*28) / 255
    x_test = x_test.reshape(-1, 28*28) / 255
    y_train = vmap(lambda x: to_one_hot(x, 10))(y_train)
    y_test = vmap(lambda x: to_one_hot(x, 10))(y_test)
    return x_train, y_train, x_test, y_test

In [3]:
x_train, y_train, x_test, y_test = get_preprocessed_mnist()



In [4]:
mlp = MLP([28*28, 512, 512, 10], 'classification')

In [5]:
%%time
mlp.train(x_train, y_train)

CPU times: user 1min 32s, sys: 5.75 s, total: 1min 37s
Wall time: 32.7 s


In [6]:
print(f'accuracy: {mlp.accuracy(x_test, y_test)*100:.2f}%')

accuracy: 92.31%
