<a href="https://colab.research.google.com/github/peterchang0414/lecun1989-flax/blob/main/lecun_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q flax

[?25l[K     |█▉                              | 10 kB 17.1 MB/s eta 0:00:01[K     |███▋                            | 20 kB 22.1 MB/s eta 0:00:01[K     |█████▍                          | 30 kB 25.7 MB/s eta 0:00:01[K     |███████▏                        | 40 kB 8.5 MB/s eta 0:00:01[K     |█████████                       | 51 kB 5.8 MB/s eta 0:00:01[K     |██████████▊                     | 61 kB 6.8 MB/s eta 0:00:01[K     |████████████▌                   | 71 kB 7.7 MB/s eta 0:00:01[K     |██████████████▎                 | 81 kB 7.0 MB/s eta 0:00:01[K     |████████████████                | 92 kB 7.7 MB/s eta 0:00:01[K     |█████████████████▉              | 102 kB 8.4 MB/s eta 0:00:01[K     |███████████████████▋            | 112 kB 8.4 MB/s eta 0:00:01[K     |█████████████████████▍          | 122 kB 8.4 MB/s eta 0:00:01[K     |███████████████████████▏        | 133 kB 8.4 MB/s eta 0:00:01[K     |█████████████████████████       | 143 kB 8.4 MB/s eta 0:00:01[K  

In [102]:
import jax
import jax.numpy as jnp
import numpy as np

from flax import linen as nn
from torchvision import datasets

data = datasets.MNIST('./data', train=True, download=True)

# Adapted from https://github.com/karpathy/lecun1989-repro/blob/master/prepro.py
def get_datasets():
  train_test = {}
  for split in {'train', 'test'}:
    data = datasets.MNIST('./data', train=split=='train', download=True)
    
    n = 1000 if split == 'train' else 1
    key = jax.random.PRNGKey(0)
    rp = jax.random.permutation(key, len(data))[:n]

    X = jnp.full((n, 16, 16, 1), 0.0, dtype=jnp.float32)
    Y = jnp.full((n, 10), -1.0, dtype=jnp.float32)
    for i, ix in enumerate(rp):
      I, yint = data[int(ix)]
      xi = jnp.array(I, dtype=np.float32) / 127.5 - 1.0
      xi = jax.image.resize(xi, (16, 16), 'bilinear')
      X = X.at[i].set(np.expand_dims(xi, axis=2))
      Y = Y.at[i, yint].set(1.0)
    train_test[split] = (X, Y)
  return train_test

In [85]:
from flax import linen as nn
from flax.training import train_state
from flax.linen.activation import tanh
import optax
from typing import Callable

class Net(nn.Module):
  bias_init: Callable = nn.initializers.zeros

  @nn.compact
  def __call__(self, x):
    # H1 layer
    # For weight initialization, Karpathy used numerator of 2.4 
    # which is very close to sqrt(6) used by he_uniform
    # By default, weight-sharing forces bias-sharing and therefore
    # we add the bias separately.
    bias1 = self.param('bias1', self.bias_init, (8, 8, 12))
    bias2 = self.param('bias2', self.bias_init, (4, 4, 12))
    bias3 = self.param('bias3', self.bias_init, (30,))
    bias4 = self.param('bias4', nn.initializers.constant(-1.0), (10,))
    print("1: ", x.shape)
    x = jnp.pad(x, [(0,0),(2,2),(2,2),(0,0)], constant_values=-1.0)
    print("2: ", x.shape)
    x = nn.Conv(features=12, kernel_size=(5,5), strides=2, padding='VALID',
                use_bias=False, kernel_init=nn.initializers.he_uniform())(x)
    print("3: ", x.shape)
    x = tanh(x + bias1)
    print("4: ", x.shape)
    # slice1 = nn.Conv(features=8, kernel_size=(5,5), strides=2, padding=-1.0,
    #                 use_bias=False, kernel_init=nn.initializers.he_uniform())(x[..., 0:8])
    # slice1 = nn.Conv(features=8, kernel_size=(5,5), strides=2, padding=-1.0,
    #                 use_bias=False, kernel_init=nn.initializers.he_uniform())(x[..., 0:8])
    x = jnp.pad(x, [(0,0),(2,2),(2,2),(0,0)], constant_values=-1.0)
    print("5: ", x.shape)
    x = nn.Conv(features=12, kernel_size=(5,5), strides=2, padding='VALID',
                use_bias=False, kernel_init=nn.initializers.he_uniform())(x)
    print("6: ", x.shape)
    x = tanh(x + bias2)
    print("7: ", x.shape)
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=30, use_bias=False)(x)
    x = tanh(x + bias3)
    print("8: ", x.shape)
    x = nn.Dense(features=10, use_bias=False)(x)
    x = tanh(x + bias4)
    print("9: ", x.shape)

    return x


In [100]:
def eval_step(params, X, Y):
  Yhat = Net().apply({'params': params}, X)
  print(type(Y))
  print("yhat.argmax() = ", jnp.argmax(Yhat, -1))
  loss = jnp.mean((Yhat - Y)**2)
  err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float)
  return loss, err

In [45]:
def eval_split(data, split, params):
  X, Y = data[split]
  loss, err = eval_step(params, X, Y)
  print(f"eval: split {split:5s}. loss {loss:e}. error {err*100:.2f}%. misses: {int(err*Y.shape[0])}")

In [None]:
@jax.jit
def train_step(state, imgs, labels):
  def loss_fn(params):
    result = Net().apply({'params': params}, imgs)
    loss = jnp.mean((result - labels)**2)
    return loss, result
  
  (_, result), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics()

In [103]:
train_test = get_datasets()

In [104]:
train = train_test['train'][0]

In [105]:
train.shape

(1000, 16, 16, 1)

In [106]:
key = jax.random.PRNGKey(0)
cnn = Net()
param = cnn.init(key, train)['params']

1:  (1000, 16, 16, 1)
2:  (1000, 20, 20, 1)
3:  (1000, 8, 8, 12)
4:  (1000, 8, 8, 12)
5:  (1000, 12, 12, 12)
6:  (1000, 4, 4, 12)
7:  (1000, 4, 4, 12)
8:  (1000, 30)
9:  (1000, 10)


In [107]:
print(jax.tree_map(lambda x: x.shape, param))

FrozenDict({
    Conv_0: {
        kernel: (5, 5, 1, 12),
    },
    Conv_1: {
        kernel: (5, 5, 12, 12),
    },
    Dense_0: {
        kernel: (192, 30),
    },
    Dense_1: {
        kernel: (30, 10),
    },
    bias1: (8, 8, 12),
    bias2: (4, 4, 12),
    bias3: (30,),
    bias4: (10,),
})


In [108]:
eval_split(train_test, 'train', param)

1:  (1000, 16, 16, 1)
2:  (1000, 20, 20, 1)
3:  (1000, 8, 8, 12)
4:  (1000, 8, 8, 12)
5:  (1000, 12, 12, 12)
6:  (1000, 4, 4, 12)
7:  (1000, 4, 4, 12)
8:  (1000, 30)
9:  (1000, 10)
<class 'jaxlib.xla_extension.DeviceArray'>
yhat.argmax() =  [6 7 3 3 6 7 3 3 3 7 7 3 6 3 3 3 3 6 3 6 6 6 3 6 3 3 3 6 3 3 3 6 6 6 3 3 3
 3 3 3 6 6 3 3 3 6 6 3 3 3 3 3 3 3 6 3 6 6 3 6 3 3 6 3 3 3 6 3 7 3 3 3 6 3
 3 3 3 3 3 3 3 7 6 3 3 6 6 3 3 3 6 3 6 6 3 6 3 6 3 3 3 3 6 3 6 3 6 3 3 3 6
 3 6 6 6 3 6 6 3 3 4 3 3 6 3 6 3 6 3 6 3 3 3 3 6 6 6 3 6 6 3 6 3 3 7 3 6 3
 3 3 3 6 3 3 3 6 6 6 3 6 6 6 6 3 6 3 6 6 3 6 3 3 6 6 7 6 3 3 3 6 6 6 4 6 6
 4 3 3 3 3 6 3 6 3 3 6 3 4 6 6 3 6 3 6 6 3 4 6 3 6 6 6 3 6 3 3 3 3 6 3 6 6
 3 6 7 6 3 6 3 3 3 6 7 6 3 6 3 6 6 3 6 6 3 6 3 6 6 3 6 6 6 3 6 3 6 3 3 6 7
 3 3 3 3 6 3 3 6 6 3 3 3 3 3 3 3 6 3 6 6 6 6 3 3 3 6 3 7 6 7 6 7 6 3 3 7 3
 6 6 3 3 3 3 6 6 3 3 3 6 3 6 3 3 3 3 6 3 6 3 3 6 6 6 6 6 4 3 3 3 3 6 3 3 3
 3 3 6 7 3 6 6 6 6 3 6 6 6 6 6 7 3 6 6 3 6 6 3 6 3 3 3 6 6 6 6 6 6 3 6 7 3
 3 6 6 3 