## MNIST with Linear Layers

In [1]:
from jax import random, jit, grad, jacrev, value_and_grad
import jax.numpy as jnp
import tensorflow as tf
import numpy as np
from jax.scipy.special import logsumexp

In [2]:
def relu(x):
    return jnp.maximum(0,x)

In [3]:
class Network:
    def __init__(self, layers=[784, 300, 100, 10]):
        
        self.params = []
        key = random.PRNGKey(0)

        for i in range(len(layers)-1):
            # FIXME: Move me outside, random initialization
            key_w, key_b = random.split(key, 2)
            w = 1e-2 * random.normal(key_w, (layers[i+1], layers[i]))
            b = 1e-2 * random.normal(key_b, (1, layers[i+1]))
            
            self.params.append([w, b])

        
    def forward(self, x):
        
        for w,b in self.params[:-1]:
            x = jnp.dot(x, w.T) + b
            x = relu(x)
        
        w,b = self.params[-1]
        x = jnp.dot(x, w.T) + b
        
        return x
    
    def get_params(self):
        return self.params
    
    @staticmethod
    def loss(params, x, y):
        
        for w,b in params[:-1]:
            x = jnp.dot(x, w.T) + b
            x = relu(x)
        
        w,b = params[-1]
        y_pred = jnp.dot(x, w.T) + b
        
        
        # Apply log-softmax
        sum_y_pred = logsumexp(y_pred, 1, keepdims=True)
        y_pred = y_pred - sum_y_pred
        
        # Calculate loss
        l = -jnp.sum(y*y_pred, 1)
        
        return jnp.mean(l)
    

    def update(self, grads, lr = 0.01):
        for g,p in zip(grad, self.params):
            p[0] = p[0] - lr * g[0]
            p[1] = p[1] - lr * g[1]
            

## Load Data

In [4]:
mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()
X, Y = mnist_train
X_t, Y_t = mnist_test

X = X.reshape(X.shape[0], -1)
X_t = X_t.reshape(X_t.shape[0], -1)

X = X / 255
X_t = X_t / 255

Y = tf.keras.utils.to_categorical(Y)
Y_t = tf.keras.utils.to_categorical(Y_t)

print(X.shape)
print(Y.shape)

print(X_t.shape)
print(Y_t.shape)

(60000, 784)
(60000, 10)
(10000, 784)
(10000, 10)


## Training

In [5]:
epochs = 20
batch = 128
lr = 0.01

model = Network([784, 300, 100, 10])
params = model.get_params()

for i in range(epochs):
    train_dataset = tf.data.Dataset.from_tensor_slices((X, Y)).shuffle(X.shape[0])
    train_iter = iter(train_dataset.batch(batch))
    
    running_loss = []
    for batch_x, batch_y in train_iter:
        batch_x = batch_x.numpy()
        batch_y = batch_y.numpy()
        
        loss, grad = value_and_grad(jit(Network.loss))(params, batch_x, batch_y)
        
        running_loss.append(loss)
        
        model.update(grad, lr)
    
    print(f"Epoch {i} : {np.mean(running_loss)}")
    



Epoch 0 : 2.301670789718628
Epoch 1 : 2.299506902694702
Epoch 2 : 2.295412063598633
Epoch 3 : 2.2795825004577637
Epoch 4 : 2.1601624488830566
Epoch 5 : 1.5485061407089233
Epoch 6 : 0.9187452793121338
Epoch 7 : 0.7305006384849548
Epoch 8 : 0.6579647064208984
Epoch 9 : 0.609839677810669
Epoch 10 : 0.5654870867729187
Epoch 11 : 0.5185030698776245
Epoch 12 : 0.4745645821094513
Epoch 13 : 0.44207873940467834
Epoch 14 : 0.41690927743911743
Epoch 15 : 0.39584100246429443
Epoch 16 : 0.3778747022151947
Epoch 17 : 0.3621681034564972
Epoch 18 : 0.3482908606529236
Epoch 19 : 0.33564209938049316


### Accuracy

In [6]:
test_dataset = tf.data.Dataset.from_tensor_slices((X_t, Y_t)).shuffle(X_t.shape[0])
test_iter = iter(test_dataset.batch(batch))

acc = []
for batch_x, batch_y in test_iter:
        batch_x = batch_x.numpy()
        batch_y = batch_y.numpy()
        
        y_pred = model.forward(batch_x)
        y_pred = jnp.argmax(y_pred,1)
        acc.append(jnp.mean(y_pred == jnp.argmax(batch_y,1)))
        
print(np.mean(acc))

0.9075356
