<a href="https://colab.research.google.com/github/peterchang0414/lecun1989-flax/blob/main/lecun_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
!pip install -q flax

In [47]:
import jax
import jax.numpy as jnp
import numpy as np

from flax import linen as nn
from torchvision import datasets

data = datasets.MNIST('./data', train=True, download=True)

# Adapted from https://github.com/karpathy/lecun1989-repro/blob/master/prepro.py
def get_datasets():
  train_test = {}
  for split in {'train', 'test'}:
    data = datasets.MNIST('./data', train=split=='train', download=True)
    
    n = 1000 if split == 'train' else 500
    key = jax.random.PRNGKey(0)
    rp = jax.random.permutation(key, len(data))[:n]

    X = jnp.full((n, 16, 16, 1), 0.0, dtype=jnp.float32)
    Y = jnp.full((n, 10), -1.0, dtype=jnp.float32)
    for i, ix in enumerate(rp):
      I, yint = data[int(ix)]
      xi = jnp.array(I, dtype=np.float32) / 127.5 - 1.0
      xi = jax.image.resize(xi, (16, 16), 'bilinear')
      X = X.at[i].set(np.expand_dims(xi, axis=2))
      Y = Y.at[i, yint].set(1.0)
    train_test[split] = (X, Y)
  return train_test

In [48]:
from flax import linen as nn
from flax.training import train_state
from flax.linen.activation import tanh
import optax
from typing import Callable

class Net(nn.Module):
  bias_init: Callable = nn.initializers.zeros
  kernel_init: Callable = nn.initializers.uniform()

  @nn.compact
  def __call__(self, x):
    # For weight initialization, Karpathy used numerator of 2.4 
    # which is very close to sqrt(6) = 2.449... used by he_uniform()
    # By default, weight-sharing forces bias-sharing and therefore
    # we add the bias separately.
    bias1 = self.param('bias1', self.bias_init, (8, 8, 12))
    bias2 = self.param('bias2', self.bias_init, (4, 4, 12))
    bias3 = self.param('bias3', self.bias_init, (30,))
    bias4 = self.param('bias4', nn.initializers.constant(-1.0), (10,))
    x = jnp.pad(x, [(0,0),(2,2),(2,2),(0,0)], constant_values=-1.0)
    x = nn.Conv(features=12, kernel_size=(5,5), strides=2, padding='VALID',
                use_bias=False, kernel_init=self.kernel_init)(x)
    x = tanh(x + bias1)
    x = jnp.pad(x, [(0,0),(2,2),(2,2),(0,0)], constant_values=-1.0)
    x1, x2, x3 = x[..., 0:8], x[..., 4:12], jnp.concatenate((x[..., 0:4], x[..., 8:12]), axis=-1)
    slice1 = nn.Conv(features=4, kernel_size=(5,5), strides=2, padding='VALID', 
                     use_bias=False, kernel_init=self.kernel_init)(x1)
    slice2 = nn.Conv(features=4, kernel_size=(5,5), strides=2, padding='VALID',
                     use_bias=False, kernel_init=self.kernel_init)(x2)
    slice3 = nn.Conv(features=4, kernel_size=(5,5), strides=2, padding='VALID',
                     use_bias=False, kernel_init=self.kernel_init)(x3)
    # x = jnp.pad(x, [(0,0),(2,2),(2,2),(0,0)], constant_values=-1.0)
    # x = nn.Conv(features=12, kernel_size=(5,5), strides=2, padding='VALID',
    #             use_bias=False, kernel_init=nn.initializers.he_uniform())(x)
    x = jnp.concatenate((slice1, slice2, slice3), axis=-1)
    x = tanh(x + bias2)
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=30, use_bias=False)(x)
    x = tanh(x + bias3)
    x = nn.Dense(features=10, use_bias=False)(x)
    x = tanh(x + bias4)

    return x

In [49]:
@jax.jit
def eval_step(params, X, Y):
  Yhat = Net().apply({'params': params}, X)
  loss = jnp.mean((Yhat - Y)**2)
  err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float)
  return loss, err

In [50]:
def eval_split(data, split, params):
  X, Y = data[split]
  loss, err = eval_step(params, X, Y)
  print(f"eval: split {split:5s}. loss {loss:e}. error {err*100:.2f}%. misses: {int(err*Y.shape[0])}")

In [51]:
from jax import value_and_grad
import optax
from flax.training import train_state

def create_train_state(key, lr, X):
  model = Net()
  params = model.init(key, X)['params']
  sgd_opt = optax.sgd(lr)
  return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=sgd_opt)

@jax.jit
def train_step(state, X, Y):
  def loss_fn(params):
    Yhat = Net().apply({'params': params}, X)
    loss = jnp.mean((Yhat - Y)**2)
    err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float)
    return loss, err
  (_, Yhats), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
  state = state.apply_gradients(grads=grads)
  return state

def train_one_epoch(state, X, Y):
  for step_num in range(X.shape[0]):
    x, y = np.expand_dims(X[step_num], 0), np.expand_dims(Y[step_num], 0)
    state = train_step(state, x, y)
  return state

def train(key, data, epochs, lr):
  Xtr, Ytr = data['train']
  Xte, Yte = data['test']
  train_state = create_train_state(key, lr, Xtr)
  for epoch in range(epochs):
    print(f"epoch {epoch+1}")
    train_state = train_one_epoch(train_state, Xtr, Ytr)
    for split in ['train', 'test']:
      eval_split(data, split, train_state.params)


In [52]:
key = jax.random.PRNGKey(0)

train(key, get_datasets(), 10, 0.03)

epoch 1
eval: split train. loss 3.603377e-01. error 90.10%. misses: 901
eval: split test . loss 3.603952e-01. error 89.60%. misses: 448
epoch 2
eval: split train. loss 2.758465e-01. error 46.30%. misses: 463
eval: split test . loss 2.725264e-01. error 43.40%. misses: 217
epoch 3
eval: split train. loss 1.376759e-01. error 15.60%. misses: 156
eval: split test . loss 1.359383e-01. error 17.20%. misses: 86
epoch 4
eval: split train. loss 9.459769e-02. error 11.40%. misses: 114
eval: split test . loss 9.940626e-02. error 13.80%. misses: 69
epoch 5
eval: split train. loss 7.227705e-02. error 8.90%. misses: 89
eval: split test . loss 8.325174e-02. error 12.40%. misses: 62
epoch 6
eval: split train. loss 5.911439e-02. error 7.10%. misses: 71
eval: split test . loss 7.545932e-02. error 10.60%. misses: 53
epoch 7
eval: split train. loss 4.991591e-02. error 6.00%. misses: 60
eval: split test . loss 7.037602e-02. error 10.20%. misses: 51
epoch 8
eval: split train. loss 4.331648e-02. error 5.30%. 

In [42]:
train_test = get_datasets()

In [43]:
train = train_test['train'][0]

In [44]:
train.shape

(100, 16, 16, 1)

In [45]:
key = jax.random.PRNGKey(0)
cnn = Net()
param = cnn.init(key, train)['params']

In [46]:
print(jax.tree_map(lambda x: x.shape, param))

FrozenDict({
    Conv_0: {
        kernel: (5, 5, 1, 12),
    },
    Conv_1: {
        kernel: (5, 5, 8, 4),
    },
    Conv_2: {
        kernel: (5, 5, 8, 4),
    },
    Conv_3: {
        kernel: (5, 5, 8, 4),
    },
    Dense_0: {
        kernel: (192, 30),
    },
    Dense_1: {
        kernel: (30, 10),
    },
    bias1: (8, 8, 12),
    bias2: (4, 4, 12),
    bias3: (30,),
    bias4: (10,),
})


In [None]:
eval_split(train_test, 'train', param)

1:  (1000, 16, 16, 1)
2:  (1000, 20, 20, 1)
3:  (1000, 8, 8, 12)
4:  (1000, 8, 8, 12)
5:  (1000, 12, 12, 12)
6:  (1000, 4, 4, 12)
7:  (1000, 4, 4, 12)
8:  (1000, 30)
9:  (1000, 10)
eval: split train. loss 4.543942e-01. error 87.60%. misses: 876
