In [1]:
import jax
import jax.numpy as jnp
import jax.example_libraries.optimizers as jopt
import mnist
import matplotlib.pyplot as plt
import numpy.random as npr
import numpy as np
import math
import datetime

jax.config.update("jax_debug_nans", False)

In [2]:
assert jax.devices()[0].device_kind == 'NVIDIA GeForce RTX 3060'

In [3]:
def layer_norm(vals):
    mean = jnp.average(vals, axis=1, keepdims=True)
    std = jnp.std(vals, axis=1, keepdims=True)
    return (vals - mean) / (0.0001 + std)

In [4]:
images = mnist.train_images()
labels = mnist.train_labels()
labels_vec = jax.nn.one_hot(labels, 10)
images_vec = layer_norm(images.reshape((-1, 784)))

order = np.array(range(images_vec.shape[1]))
npr.RandomState(0).shuffle(order)
images_vec = images_vec[:, order]

test_img = layer_norm(mnist.test_images().reshape((-1, 784)))[:, order]
test_lbl = mnist.test_labels()

In [5]:
def init_params_old():
    scale = 1/255.0
    rng = npr.RandomState(0)
    bias = scale * 0.001
    out = [scale*rng.randn(784, width), bias*rng.randn(width)]
    for i in range(height):
        out += [scale*rng.randn(width, width), bias*rng.randn(width)]
    out += [scale*rng.randn(width, 10), bias*rng.randn(10)]
    return out

def init_params():
    def he_init(rng, dim1, dim2):
        # https://mmuratarat.github.io/2019-02-25/xavier-glorot-he-weight-init
        # why do i need 0.1??? something is wrong
        return 0.1*rng.normal(0.0, math.sqrt(4/(dim1+dim2)), (dim1, dim2))
    bias = 0.01
    rng = npr.RandomState(0)
    out = [he_init(rng, 784, width), bias*rng.randn(width)]
    for i in range(height):
        out += [he_init(rng, width, width), bias*rng.randn(width)]
    out += [he_init(rng, width, 10), bias*rng.randn(10)]
    return out

In [6]:
nonlinearlity = jax.nn.relu  # gelu, relu6, selu bad -- they don't support as many layers
'''
issue is that we seem to have better depth with relu, up to 7 layers
  --> fix init
  --> change learning rate

take regular snapshots of paramters, sometimes a high learning rate breaks need to change
'''
def predict(params, images):
    vals = images
    vals = nonlinearlity(jnp.dot(vals, params[0]) + params[1])
    for i in range(height):
        vals = layer_norm(vals)
        vals = nonlinearlity(jnp.dot(vals, params[2+i*2]) + params[3+i*2]) + vals
    vals = layer_norm(vals)
    vals = jnp.dot(vals, params[2+height*2]) + params[3+height*2]
    return jax.nn.softmax(vals)

def loss(params, images, labels):
    guess = predict(params, images)
    return -jnp.mean(jnp.log(jnp.sum(guess * labels, axis=1)))
# regularization bad why: + 0.00001 * sum(jnp.sum(jnp.power(p, 2)) for p in params)

In [13]:
width = 250
height = 5
params = init_params()
print(f'param count = {sum(p.size for p in params)} {[p.size for p in params]}')

batch_size = (50,)
step_size = 0.0001
steps = 10000

opt_init, opt_update, get_params = jopt.adam(step_size)
opt_state = opt_init(params)

@jax.jit
def update(i, opt_state, images_vec_batch, labels_vec_batch):
    params = get_params(opt_state)
    grads = jax.grad(loss)(params, images_vec_batch, labels_vec_batch)
    # jax.debug.print('{grads}', grads=grads)
    return opt_update(i, grads, opt_state)

t = datetime.datetime.now()
rnd = jax.random.split(jax.random.PRNGKey(0), steps)
for i in range(0, steps+1):
    images_vec_batch = jax.random.choice(rnd[i], images_vec, batch_size)
    labels_vec_batch = jax.random.choice(rnd[i], labels_vec, batch_size)
    opt_state = update(i, opt_state, images_vec_batch, labels_vec_batch)
    if i < 10 or (datetime.datetime.now() - t).total_seconds() > 5.0 or i == steps:
        params = get_params(opt_state)
        test = jnp.mean(jnp.argmax(predict(params, test_img), axis=1) == test_lbl)
        accuracy = jnp.mean(jnp.argmax(predict(params, images_vec), axis=1) == labels)
        print(f'{i:03d} {accuracy:0.3f} / {test:0.3f}')
        t = datetime.datetime.now()

param count = 512510 [196000, 250, 62500, 250, 62500, 250, 62500, 250, 62500, 250, 62500, 250, 2500, 10]
000 0.263 / 0.270
001 0.358 / 0.368
002 0.417 / 0.426
003 0.470 / 0.476
004 0.517 / 0.526
005 0.556 / 0.568
006 0.595 / 0.605
007 0.624 / 0.635
008 0.643 / 0.651
009 0.663 / 0.671
6584 0.991 / 0.978
10000 0.994 / 0.981
