<a href="https://colab.research.google.com/github/peterchang0414/lecun1989-flax/blob/main/lecun_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1989 Reproduction

In [1]:
!pip install -q flax

[?25l[K     |█▉                              | 10 kB 26.8 MB/s eta 0:00:01[K     |███▋                            | 20 kB 30.8 MB/s eta 0:00:01[K     |█████▍                          | 30 kB 20.8 MB/s eta 0:00:01[K     |███████▏                        | 40 kB 13.1 MB/s eta 0:00:01[K     |█████████                       | 51 kB 6.9 MB/s eta 0:00:01[K     |██████████▊                     | 61 kB 8.1 MB/s eta 0:00:01[K     |████████████▌                   | 71 kB 7.7 MB/s eta 0:00:01[K     |██████████████▎                 | 81 kB 5.9 MB/s eta 0:00:01[K     |████████████████                | 92 kB 6.6 MB/s eta 0:00:01[K     |█████████████████▉              | 102 kB 7.2 MB/s eta 0:00:01[K     |███████████████████▋            | 112 kB 7.2 MB/s eta 0:00:01[K     |█████████████████████▍          | 122 kB 7.2 MB/s eta 0:00:01[K     |███████████████████████▏        | 133 kB 7.2 MB/s eta 0:00:01[K     |█████████████████████████       | 143 kB 7.2 MB/s eta 0:00:01[K 

In [2]:
import jax
import jax.numpy as jnp
import numpy as np

from flax import linen as nn
from torchvision import datasets

# Adapted from https://github.com/karpathy/lecun1989-repro/blob/master/prepro.py
def get_datasets(n_tr, n_te):
  train_test = {}
  for split in {'train', 'test'}:
    data = datasets.MNIST('./data', train=split=='train', download=True)
    
    n = n_tr if split == 'train' else n_te
    key = jax.random.PRNGKey(42)
    rp = jax.random.permutation(key, len(data))[:n]

    X = jnp.full((n, 16, 16, 1), 0.0, dtype=jnp.float32)
    Y = jnp.full((n, 10), -1.0, dtype=jnp.float32)
    for i, ix in enumerate(rp):
      I, yint = data[int(ix)]
      xi = jnp.array(I, dtype=np.float32) / 127.5 - 1.0
      xi = jax.image.resize(xi, (16, 16), 'bilinear')
      X = X.at[i].set(np.expand_dims(xi, axis=2))
      Y = Y.at[i, yint].set(1.0)
    train_test[split] = (X, Y)
  return train_test

In [3]:
from flax import linen as nn
from flax.training import train_state
from flax.linen.activation import tanh
import optax
from typing import Callable

class Net(nn.Module):
  bias_init: Callable = nn.initializers.zeros
  kernel_init: Callable = nn.initializers.uniform()

  @nn.compact
  def __call__(self, x):
    # For weight initialization, Karpathy used numerator of 2.4 
    # which is very close to sqrt(6) = 2.449... used by he_uniform()
    # By default, weight-sharing forces bias-sharing and therefore
    # we add the bias separately.
    bias1 = self.param('bias1', self.bias_init, (8, 8, 12))
    bias2 = self.param('bias2', self.bias_init, (4, 4, 12))
    bias3 = self.param('bias3', self.bias_init, (30,))
    bias4 = self.param('bias4', nn.initializers.constant(-1.0), (10,))
    x = jnp.pad(x, [(0,0),(2,2),(2,2),(0,0)], constant_values=-1.0)
    x = nn.Conv(features=12, kernel_size=(5,5), strides=2, padding='VALID',
                use_bias=False, kernel_init=self.kernel_init)(x)
    x = tanh(x + bias1)
    x = jnp.pad(x, [(0,0),(2,2),(2,2),(0,0)], constant_values=-1.0)
    x1, x2, x3 = (x[..., 0:8], x[..., 4:12], 
                  jnp.concatenate((x[..., 0:4], x[..., 8:12]), axis=-1))
    slice1 = nn.Conv(features=4, kernel_size=(5,5), strides=2, padding='VALID', 
                     use_bias=False, kernel_init=self.kernel_init)(x1)
    slice2 = nn.Conv(features=4, kernel_size=(5,5), strides=2, padding='VALID',
                     use_bias=False, kernel_init=self.kernel_init)(x2)
    slice3 = nn.Conv(features=4, kernel_size=(5,5), strides=2, padding='VALID',
                     use_bias=False, kernel_init=self.kernel_init)(x3)
    x = jnp.concatenate((slice1, slice2, slice3), axis=-1)
    x = tanh(x + bias2)
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=30, use_bias=False)(x)
    x = tanh(x + bias3)
    x = nn.Dense(features=10, use_bias=False)(x)
    x = tanh(x + bias4)

    return x

In [4]:
@jax.jit
def eval_step(params, X, Y):
  Yhat = Net().apply({'params': params}, X)
  loss = jnp.mean((Yhat - Y)**2)
  err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float)
  return loss, err

In [5]:
def eval_split(data, split, params):
  X, Y = data[split]
  loss, err = eval_step(params, X, Y)
  print(f"eval: split {split:5s}. loss {loss:e}. error {err*100:.2f}%. misses: {int(err*Y.shape[0])}")

In [6]:
from jax import value_and_grad
import optax
from flax.training import train_state

def create_train_state(key, lr, X):
  model = Net()
  params = model.init(key, X)['params']
  sgd_opt = optax.sgd(lr)
  return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=sgd_opt)

@jax.jit
def train_step(state, X, Y):
  def loss_fn(params):
    Yhat = Net().apply({'params': params}, X)
    loss = jnp.mean((Yhat - Y)**2)
    err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float)
    return loss, err
  (_, Yhats), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
  state = state.apply_gradients(grads=grads)
  return state

def train_one_epoch(state, X, Y):
  for step_num in range(X.shape[0]):
    x, y = np.expand_dims(X[step_num], 0), np.expand_dims(Y[step_num], 0)
    state = train_step(state, x, y)
  return state

def train(key, data, epochs, lr):
  Xtr, Ytr = data['train']
  Xte, Yte = data['test']
  train_state = create_train_state(key, lr, Xtr)
  for epoch in range(epochs):
    print(f"epoch {epoch+1}")
    train_state = train_one_epoch(train_state, Xtr, Ytr)
    for split in ['train', 'test']:
      eval_split(data, split, train_state.params)


In [8]:
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)

train(key, get_datasets(7291, 2007), 23, 0.03)

KeyboardInterrupt: ignored

# "Modern" Adjustments

In [50]:
!pip install -q flax

In [51]:
import jax
import jax.numpy as jnp
import numpy as np

from flax import linen as nn
from torchvision import datasets

# Adapted from https://github.com/karpathy/lecun1989-repro/blob/master/prepro.py
def get_datasets(n_tr, n_te):
  train_test = {}
  for split in {'train', 'test'}:
    data = datasets.MNIST('./data', train=split=='train', download=True)
    
    n = n_tr if split == 'train' else n_te
    key = jax.random.PRNGKey(42)
    rp = jax.random.permutation(key, len(data))[:n]

    X = jnp.full((n, 16, 16, 1), 0.0, dtype=jnp.float32)
    Y = jnp.full((n, 10), 0, dtype=jnp.float32)
    for i, ix in enumerate(rp):
      I, yint = data[int(ix)]
      xi = jnp.array(I, dtype=np.float32) / 127.5 - 1.0
      xi = jax.image.resize(xi, (16, 16), 'bilinear')
      X = X.at[i].set(np.expand_dims(xi, axis=2))
      Y = Y.at[i, yint].set(1.0)
    train_test[split] = (X, Y)
  return train_test

In [52]:
from flax import linen as nn
from flax.training import train_state
from flax.linen.activation import tanh
import optax
from typing import Callable

class Net(nn.Module):
  bias_init: Callable = nn.initializers.zeros
  kernel_init: Callable = nn.initializers.uniform()

  @nn.compact
  def __call__(self, x):
    # For weight initialization, Karpathy used numerator of 2.4 
    # which is very close to sqrt(6) = 2.449... used by he_uniform()
    # By default, weight-sharing forces bias-sharing and therefore
    # we add the bias separately.
    bias1 = self.param('bias1', self.bias_init, (8, 8, 12))
    bias2 = self.param('bias2', self.bias_init, (4, 4, 12))
    bias3 = self.param('bias3', self.bias_init, (30,))
    bias4 = self.param('bias4', nn.initializers.constant(-1.0), (10,))
    x = jnp.pad(x, [(0,0),(2,2),(2,2),(0,0)], constant_values=-1.0)
    x = nn.Conv(features=12, kernel_size=(5,5), strides=2, padding='VALID',
                use_bias=False, kernel_init=self.kernel_init)(x)
    x = tanh(x + bias1)
    x = jnp.pad(x, [(0,0),(2,2),(2,2),(0,0)], constant_values=-1.0)
    x1, x2, x3 = (x[..., 0:8], x[..., 4:12], 
                  jnp.concatenate((x[..., 0:4], x[..., 8:12]), axis=-1))
    slice1 = nn.Conv(features=4, kernel_size=(5,5), strides=2, padding='VALID', 
                     use_bias=False, kernel_init=self.kernel_init)(x1)
    slice2 = nn.Conv(features=4, kernel_size=(5,5), strides=2, padding='VALID',
                     use_bias=False, kernel_init=self.kernel_init)(x2)
    slice3 = nn.Conv(features=4, kernel_size=(5,5), strides=2, padding='VALID',
                     use_bias=False, kernel_init=self.kernel_init)(x3)
    x = jnp.concatenate((slice1, slice2, slice3), axis=-1)
    x = tanh(x + bias2)
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=30, use_bias=False)(x)
    x = tanh(x + bias3)
    x = nn.Dense(features=10, use_bias=False)(x)
    x = x + bias4

    return x

In [58]:
from jax import value_and_grad
import optax
from flax.training import train_state

def create_train_state(key, lr, X):
  model = Net()
  params = model.init(key, X)['params']
  sgd_opt = optax.sgd(lr)
  return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=sgd_opt)

@jax.jit
def train_step(state, X, Y):
  def loss_fn(params):
    Yhat = Net().apply({'params': params}, X)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=Yhat, labels=Y))
    err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float)
    return loss, err
  (_, Yhats), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
  state = state.apply_gradients(grads=grads)
  return state

def train_one_epoch(state, X, Y):
  for step_num in range(X.shape[0]):
    x, y = np.expand_dims(X[step_num], 0), np.expand_dims(Y[step_num], 0)
    state = train_step(state, x, y)
  return state

def train(key, data, epochs, lr):
  Xtr, Ytr = data['train']
  Xte, Yte = data['test']
  train_state = create_train_state(key, lr, Xtr)
  for epoch in range(epochs):
    print(f"epoch {epoch+1}")
    train_state = train_one_epoch(train_state, Xtr, Ytr)
    for split in ['train', 'test']:
      eval_split(data, split, train_state.params)


In [54]:
@jax.jit
def eval_step(params, X, Y):
  Yhat = Net().apply({'params': params}, X)
  loss = jnp.mean((Yhat - Y)**2)
  err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float)
  return loss, err

def eval_split(data, split, params):
  X, Y = data[split]
  loss, err = eval_step(params, X, Y)
  print(f"eval: split {split:5s}. loss {loss:e}. error {err*100:.2f}%. misses: {int(err*Y.shape[0])}")

In [59]:
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)

train(key, get_datasets(7291, 2007), 23, 0.01)

epoch 1
eval: split train. loss 8.296457e+00. error 8.48%. misses: 618
eval: split test . loss 8.328159e+00. error 8.27%. misses: 166
epoch 2
eval: split train. loss 9.976032e+00. error 5.88%. misses: 429
eval: split test . loss 9.958797e+00. error 6.48%. misses: 130
epoch 3


KeyboardInterrupt: ignored

Change 1: replace tanh on last layer with FC and use softmax. Lower learning rate to 0.01

```
# epoch 23
eval: split train. loss 2.984636e+01. error 0.05%. misses: 4
eval: split test . loss 2.979341e+01. error 4.14%. misses: 83

```

Change 2: change from SGD to AdamW with LR 3e-4, double epochs to 46, decay LR to 1e-4 over the course of training.

```
# 코드로 형식 지정됨
```

Change 3: Introduce data augmentation, e.g. a shift by at most 1 pixel in both x/y directions, and bump up training time to 60 epochs.

```
# 코드로 형식 지정됨
```

Change 4: add dropout at layer H3, shift activation function to relu, and bring up iterations to 80.

```
# 코드로 형식 지정됨
```




