In [2]:
import jax.numpy as jnp
from jax import random, jit, vmap, grad
from tensorflow.keras.datasets import mnist


In [3]:
# Create a program to classify MNIST data, using JAX
# First load data

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train, x_test, y_train, y_test = jnp.array(x_train), jnp.array(x_test), jnp.array(y_train), jnp.array(y_test)


In [17]:
# Define functions to create network

def initialise_layer_params(inp, out, key=None, scale=1e-2):
    if key == None:
        raise RuntimeError("Key must be provided for generating weights")
    
    key1, key2 = random.split(key)

    return scale*random.normal(key1, (out, inp)), scale*random.normal(key2)

def initialise_nn(layers, key):
    keys = random.split(key, num=len(layers))
    return [initialise_layer_params(i, o, k) for i, o, k in zip(layers[:-1], layers[1:], keys)]

layer_sizes = [784,256,256,256,10]

nn = initialise_nn(layer_sizes, random.key(1878))

In [72]:
# Define one-hot encoding for y data, prediction function and loss function

def one_hot(y, num_classes, dtype=jnp.float32):
    return jnp.array(y[:,None] == jnp.arange(num_classes), dtype=dtype)

def relu(x):
    return jnp.max(0,x)

def predict(params, image):
    res = image
    for weights, bias in params[:-1]:
        res = jnp.dot(weights, res) + bias
        res = relu(res)
    res = jnp.dot(params[-1][0], res) + params[-1][1]
    return jnp.exp(-res) / jnp.sum(jnp.exp(-res))

batch_pred = vmap(predict, in_axes=(None, 0))

def loss(params, images, targets):
    preds = batch_pred(params, images)
    return -jnp.mean(jnp.log(preds)*targets)

def accuracy(params, images, targets):
    preds = batch_pred(params, images)
    return jnp.mean(jnp.argmax(preds, axis=1)==jnp.argmax(targets, axis=1))

@jit
def update(params, x, y, lr):
    grads = grad(loss)(params, x, y)
    return [(w - lr * dw, b - lr * db)
            for (w, b), (dw, db) in zip(params, grads)]


In [73]:
# Define some parameters
epochs = 100
batch_size = 128
learning_rate = 0.01
num_digits = 10

In [74]:
# One-hot encode data
y_train, y_test = one_hot(y_train, num_digits), one_hot(y_test, num_digits)
x_train, x_test = jnp.reshape(x_train, (-1, 28*28)), jnp.reshape(x_test, (-1, 28*28))


In [75]:
def batch_data(images, key, bsize=128):
    order = random.permutation(key, len(images))
    print(order[0])
    for i in range(jnp.floor(len(images)/bsize)):
        yield images[order[bsize*i:bsize*(i+1)]]
    if images%bsize != 0:
        yield images[order[bsize*(i+1):]]

In [77]:
for epoch in range(10):
    nn = update(nn, x_train, y_train, learning_rate)
    
    train_acc = accuracy(nn, x_train, y_train)
    test_acc = accuracy(nn, x_test, y_test)
    print("Training set accuracy {}".format(train_acc))

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float32[256]
The axis argument must be known statically.
This BatchTracer with object id 1426817064992 was created on line:
  C:\Users\thoma\AppData\Local\Temp\ipykernel_13036\740884550.py:12:14 (predict)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError