<a href="https://colab.research.google.com/github/peterchang0414/lecun1989-flax/blob/main/lecun_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q flax

[?25l[K     |█▉                              | 10 kB 25.7 MB/s eta 0:00:01[K     |███▋                            | 20 kB 29.1 MB/s eta 0:00:01[K     |█████▍                          | 30 kB 34.4 MB/s eta 0:00:01[K     |███████▏                        | 40 kB 16.8 MB/s eta 0:00:01[K     |█████████                       | 51 kB 14.6 MB/s eta 0:00:01[K     |██████████▊                     | 61 kB 16.8 MB/s eta 0:00:01[K     |████████████▌                   | 71 kB 13.2 MB/s eta 0:00:01[K     |██████████████▎                 | 81 kB 14.2 MB/s eta 0:00:01[K     |████████████████                | 92 kB 15.7 MB/s eta 0:00:01[K     |█████████████████▉              | 102 kB 14.8 MB/s eta 0:00:01[K     |███████████████████▋            | 112 kB 14.8 MB/s eta 0:00:01[K     |█████████████████████▍          | 122 kB 14.8 MB/s eta 0:00:01[K     |███████████████████████▏        | 133 kB 14.8 MB/s eta 0:00:01[K     |█████████████████████████       | 143 kB 14.8 MB/s eta 0:

In [50]:
import jax
import jax.numpy as jnp
import numpy as np

from flax import linen as nn
from torchvision import datasets

data = datasets.MNIST('./data', train=True, download=True)

# Adapted from https://github.com/karpathy/lecun1989-repro/blob/master/prepro.py
def get_datasets():
  train_test = {}
  for split in {'train', 'test'}:
    data = datasets.MNIST('./data', train=split=='train', download=True)
    
    n = 1 if split == 'train' else 1
    key = jax.random.PRNGKey(0)
    rp = jax.random.permutation(key, len(data))[:n]

    X = jnp.full((n, 16, 16, 1), 0.0, dtype=jnp.float32)
    Y = jnp.full((n, 10), -1.0, dtype=jnp.float32)
    for i, ix in enumerate(rp):
      I, yint = data[int(ix)]
      xi = jnp.array(I, dtype=np.float32) / 127.5 - 1.0
      xi = jax.image.resize(xi, (16, 16), 'bilinear')
      X = X.at[i].set(np.expand_dims(xi, axis=2))
      Y = Y.at[i, yint].set(1.0)
    train_test[split] = (X, Y)
  return train_test

In [141]:
from flax import linen as nn
from flax.training import train_state
from flax.linen.activation import tanh
import optax
from typing import Callable

class Net(nn.Module):
  bias_init: Callable = nn.initializers.zeros

  @nn.compact
  def __call__(self, x):
    # H1 layer
    # For weight initialization, Karpathy used numerator of 2.4 
    # which is very close to sqrt(6) used by he_uniform
    # By default, weight-sharing forces bias-sharing and therefore
    # we add the bias separately.
    bias1 = self.param('bias1', self.bias_init, (8, 8, 12))
    bias2 = self.param('bias2', self.bias_init, (4, 4, 12))
    bias3 = self.param('bias3', self.bias_init, (30,))
    bias4 = self.param('bias4', nn.initializers.constant(-1.0), (10,))
    print(x.shape)
    x = jnp.pad(x, [(0,0),(2,2),(2,2),(0,0)], constant_values=-1.0)
    print(x.shape)
    x = nn.Conv(features=12, kernel_size=(5,5), strides=2, padding='VALID',
                use_bias=False, kernel_init=nn.initializers.he_uniform())(x)
    print(x.shape)
    x = tanh(x + bias1)
    print(x.shape)
    # slice1 = nn.Conv(features=8, kernel_size=(5,5), strides=2, padding=-1.0,
    #                 use_bias=False, kernel_init=nn.initializers.he_uniform())(x[..., 0:8])
    # slice1 = nn.Conv(features=8, kernel_size=(5,5), strides=2, padding=-1.0,
    #                 use_bias=False, kernel_init=nn.initializers.he_uniform())(x[..., 0:8])
    x = jnp.pad(x, 2, constant_values=-1.0)
    x = nn.Conv(features=12, kernel_size=(5,5,12), strides=2, padding='VALID',
                use_bias=False, kernel_init=nn.initializers.he_uniform())(x)
    x = tanh(x + bias2)
    print(x.shape)
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=30, use_bias=False)(x)
    x = tanh(x + bias3)
    print(x.shape)
    x = nn.Dense(features=10, use_bias=False)(x)
    x = tanh(x + bias4)
    print(x.shape)

    return x


In [111]:
def eval_step(params, X, Y):
  Yhat = Net().apply({'params': params}, X)
  print(Y.shape)
  print(Yhat.shape)
  loss = jnp.mean((Yhat - Y)**2)
  err = jnp.mean((Y.argmax() != Yhat.argmax()).float())
  return loss, err

In [106]:
def eval_split(data, split, params):
  X, Y = data[split]
  loss, err = eval_step(params, X, Y)
  print(f"eval: split {split:5s}. loss {loss:e}. error {err*100:.2f}%. misses: {int(err*Y.size(0))}")

In [92]:
@jax.jit
def train_step(state, imgs, labels):
  def loss_fn(params):
    result = Net().apply({'params': params}, imgs)
    loss = jnp.mean((result - labels)**2)
    return loss, result
  
  (_, result), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics()

In [113]:
train_test = get_datasets()

In [114]:
train = train_test['train'][0]

In [115]:
train.shape

(1, 16, 16, 1)

In [120]:
key = jax.random.PRNGKey(0)
cnn = Net()
param = cnn.init(key, train)['params']

In [121]:
print(jax.tree_map(lambda x: x.shape, param))

FrozenDict({
    Conv_0: {
        kernel: (5, 5, 5, 12),
    },
    Conv_1: {
        kernel: (5, 5, 12, 16, 12),
    },
    Dense_0: {
        kernel: (192, 30),
    },
    Dense_1: {
        kernel: (30, 10),
    },
    bias1: (8, 8, 12),
    bias2: (4, 4, 12),
    bias3: (30,),
    bias4: (10,),
})


In [142]:
eval_split(train_test, 'train', param)

(1, 16, 16, 1)
(1, 20, 20, 1)


ScopeParamShapeError: ignored