<a href="https://colab.research.google.com/github/peterchang0414/lecun1989-flax/blob/main/lecun1989-flax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Backpropagation Applied to MNIST
Based on Lecun 1989: http://yann.lecun.com/exdb/publis/pdf/lecun-89e.pdf

Adapted to JAX from https://github.com/karpathy/lecun1989-repro/blob/master/prepro.py

Author: Peter G. Chang ([@peterchang0414](https://github.com/peterchang0414))

# 1989 Reproduction

In [None]:
!pip install -q flax

In [None]:
import jax
import jax.numpy as jnp

try:
    from flax import linen as nn
except ModuleNotFoundError:
    %pip install -qq flax
    from flax import linen as nn
try:
    from torchvision import datasets
except ModuleNotFoundError:
    %pip install -qq torchvision
    from torchvision import datasets


def get_datasets(n_tr, n_te):
    train_test = {}
    for split in {"train", "test"}:
        data = datasets.MNIST("./data", train=split == "train", download=True)
        n = n_tr if split == "train" else n_te
        key = jax.random.PRNGKey(42)
        rp = jax.random.permutation(key, len(data))[:n]
        X = jnp.full((n, 16, 16, 1), 0.0, dtype=jnp.float32)
        Y = jnp.full((n, 10), -1.0, dtype=jnp.float32)
        for i, ix in enumerate(rp):
            I, yint = data[int(ix)]
            xi = jnp.array(I, dtype=jnp.float32) / 127.5 - 1.0
            xi = jax.image.resize(xi, (16, 16), "bilinear")
            X = X.at[i].set(jnp.expand_dims(xi, axis=2))
            Y = Y.at[i, yint].set(1.0)
        train_test[split] = (X, Y)
    return train_test

In [None]:
try:
    from flax import linen as nn
except ModuleNotFoundError:
    %pip install -qq flax
    from flax import linen as nn
from flax.training import train_state
from flax.linen.activation import tanh

try:
    import optax
except ModuleNotFoundError:
    %pip install -qq optax
    import optax
from typing import Callable


class Net(nn.Module):
    bias_init: Callable = nn.initializers.zeros
    # sqrt(6) = 2.449... used by he_uniform() approximates Karpathy's 2.4
    kernel_init: Callable = nn.initializers.he_uniform()

    @nn.compact
    def __call__(self, x):
        x = jnp.pad(x, [(0, 0), (2, 2), (2, 2), (0, 0)], constant_values=-1.0)
        x = nn.Conv(
            features=12, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init
        )(x)
        bias1 = self.param("bias1", self.bias_init, (8, 8, 12))
        x = tanh(x + bias1)
        x = jnp.pad(x, [(0, 0), (2, 2), (2, 2), (0, 0)], constant_values=-1.0)
        x1, x2, x3 = (x[..., 0:8], x[..., 4:12], jnp.concatenate((x[..., 0:4], x[..., 8:12]), axis=-1))
        slice1 = nn.Conv(
            features=4, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init
        )(x1)
        slice2 = nn.Conv(
            features=4, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init
        )(x2)
        slice3 = nn.Conv(
            features=4, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init
        )(x3)
        x = jnp.concatenate((slice1, slice2, slice3), axis=-1)
        bias2 = self.param("bias2", self.bias_init, (4, 4, 12))
        x = tanh(x + bias2)
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=30, use_bias=False)(x)
        bias3 = self.param("bias3", self.bias_init, (30,))
        x = tanh(x + bias3)
        x = nn.Dense(features=10, use_bias=False)(x)
        bias4 = self.param("bias4", nn.initializers.constant(-1.0), (10,))
        x = tanh(x + bias4)
        return x

In [None]:
@jax.jit
def eval_step(params, X, Y):
    Yhat = Net().apply({"params": params}, X)
    loss = jnp.mean((Yhat - Y) ** 2)
    err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float)
    return loss, err

In [None]:
def eval_split(data, split, params):
    X, Y = data[split]
    loss, err = eval_step(params, X, Y)
    print(f"eval: split {split:5s}. loss {loss:e}. error {err*100:.2f}%. misses: {int(err*Y.shape[0])}")

In [None]:
from jax import value_and_grad

try:
    import optax
except ModuleNotFoundError:
    %pip install -qq optax
    import optax
try:
    from flax.training import train_state
except ModuleNotFoundError:
    %pip install -qq flax
    from flax.training import train_state


def create_train_state(key, lr, X):
    model = Net()
    params = model.init(key, X)["params"]
    sgd_opt = optax.sgd(lr)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=sgd_opt)


@jax.jit
def train_step(state, X, Y):
    def loss_fn(params):
        Yhat = Net().apply({"params": params}, X)
        loss = jnp.mean((Yhat - Y) ** 2)
        err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float)
        return loss, err

    (_, Yhats), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)
    return state


def train_one_epoch(state, X, Y):
    for step_num in range(X.shape[0]):
        x, y = jnp.expand_dims(X[step_num], 0), jnp.expand_dims(Y[step_num], 0)
        state = train_step(state, x, y)
    return state


def train(key, data, epochs, lr):
    Xtr, Ytr = data["train"]
    Xte, Yte = data["test"]
    train_state = create_train_state(key, lr, Xtr)
    for epoch in range(epochs):
        print(f"epoch {epoch+1}")
        train_state = train_one_epoch(train_state, Xtr, Ytr)
        for split in ["train", "test"]:
            eval_split(data, split, train_state.params)

In [None]:
data = get_datasets(7291, 2007)

In [None]:
key, _ = jax.random.split(jax.random.PRNGKey(42))

train(key, data, 23, 0.03)

epoch 1
eval: split train. loss 5.576071e-02. error 8.11%. misses: 591
eval: split test . loss 5.287848e-02. error 7.37%. misses: 148
epoch 2
eval: split train. loss 4.097378e-02. error 5.80%. misses: 423
eval: split test . loss 4.257497e-02. error 6.08%. misses: 122
epoch 3
eval: split train. loss 3.390130e-02. error 4.92%. misses: 359
eval: split test . loss 3.796291e-02. error 5.48%. misses: 110
epoch 4
eval: split train. loss 2.989994e-02. error 4.38%. misses: 319
eval: split test . loss 3.480190e-02. error 5.23%. misses: 105
epoch 5
eval: split train. loss 2.566473e-02. error 3.77%. misses: 275
eval: split test . loss 3.232093e-02. error 4.73%. misses: 95
epoch 6
eval: split train. loss 2.348944e-02. error 3.33%. misses: 242
eval: split test . loss 3.208887e-02. error 4.58%. misses: 92
epoch 7
eval: split train. loss 2.151174e-02. error 3.09%. misses: 225
eval: split test . loss 3.206819e-02. error 4.93%. misses: 99
epoch 8
eval: split train. loss 1.941714e-02. error 2.77%. misses

Results:

```
epoch 23
eval: split train. loss 5.265484e-03. error 0.82%. misses: 60
eval: split test . loss 2.467080e-02. error 3.69%. misses: 74
```



# "Modern" Adjustments

In [None]:
!pip install -q flax

In [None]:
import jax
import jax.numpy as jnp

try:
    from flax import linen as nn
except ModuleNotFoundError:
    %pip install -qq flax
    from flax import linen as nn
try:
    from torchvision import datasets
except ModuleNotFoundError:
    %pip install -qq torchvision
    from torchvision import datasets


def get_datasets(n_tr, n_te):
    train_test = {}
    for split in {"train", "test"}:
        data = datasets.MNIST("./data", train=split == "train", download=True)
        n = n_tr if split == "train" else n_te
        key = jax.random.PRNGKey(42)
        rp = jax.random.permutation(key, len(data))[:n]
        X = jnp.full((n, 16, 16, 1), 0.0, dtype=jnp.float32)
        Y = jnp.full((n, 10), 0, dtype=jnp.float32)
        for i, ix in enumerate(rp):
            I, yint = data[int(ix)]
            xi = jnp.array(I, dtype=jnp.float32) / 127.5 - 1.0
            xi = jax.image.resize(xi, (16, 16), "bilinear")
            X = X.at[i].set(jnp.expand_dims(xi, axis=2))
            Y = Y.at[i, yint].set(1.0)
        train_test[split] = (X, Y)
    return train_test

In [None]:
try:
    from flax import linen as nn
except ModuleNotFoundError:
    %pip install -qq flax
    from flax import linen as nn
from flax.training import train_state
from flax.linen.activation import tanh

try:
    import optax
except ModuleNotFoundError:
    %pip install -qq optax
    import optax
from typing import Callable


class Net(nn.Module):
    training: bool
    bias_init: Callable = nn.initializers.zeros
    # sqrt(6) = 2.449... used by he_uniform() approximates Karpathy's 2.4
    kernel_init: Callable = nn.initializers.he_uniform()

    @nn.compact
    def __call__(self, x):
        if self.training:
            augment_rng = self.make_rng("aug")
            shift_x, shift_y = jax.random.randint(augment_rng, (2,), -1, 2)
            x = jnp.roll(x, (shift_x, shift_y), (1, 2))
        x = jnp.pad(x, [(0, 0), (2, 2), (2, 2), (0, 0)], constant_values=-1.0)
        x = nn.Conv(
            features=12, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init
        )(x)
        bias1 = self.param("bias1", self.bias_init, (8, 8, 12))
        x = nn.relu(x + bias1)
        x = jnp.pad(x, [(0, 0), (2, 2), (2, 2), (0, 0)], constant_values=-1.0)
        x1, x2, x3 = (x[..., 0:8], x[..., 4:12], jnp.concatenate((x[..., 0:4], x[..., 8:12]), axis=-1))
        slice1 = nn.Conv(
            features=4, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init
        )(x1)
        slice2 = nn.Conv(
            features=4, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init
        )(x2)
        slice3 = nn.Conv(
            features=4, kernel_size=(5, 5), strides=2, padding="VALID", use_bias=False, kernel_init=self.kernel_init
        )(x3)
        x = jnp.concatenate((slice1, slice2, slice3), axis=-1)
        bias2 = self.param("bias2", self.bias_init, (4, 4, 12))
        x = nn.relu(x + bias2)
        x = nn.Dropout(0.25, deterministic=not self.training)(x)
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=30, use_bias=False)(x)
        bias3 = self.param("bias3", self.bias_init, (30,))
        x = nn.relu(x + bias3)
        x = nn.Dense(features=10, use_bias=False)(x)
        bias4 = self.param("bias4", self.bias_init, (10,))
        x = x + bias4
        return x

In [None]:
from jax import value_and_grad

try:
    import optax
except ModuleNotFoundError:
    %pip install -qq optax
    import optax
try:
    from flax.training import train_state
except ModuleNotFoundError:
    %pip install -qq flax
    from flax.training import train_state


def learning_rate_fn(initial_rate, epochs, steps_per_epoch):
    return optax.linear_schedule(
        init_value=initial_rate, end_value=initial_rate / 3, transition_steps=epochs * steps_per_epoch
    )


def create_train_state(key, X, lr_fn):
    model = Net(training=True)
    key1, key2, key3 = jax.random.split(key, 3)
    params = model.init({"params": key1, "aug": key2, "dropout": key3}, X)["params"]
    opt = optax.adamw(lr_fn)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=opt)


@jax.jit
def train_step(state, X, Y, rng=jax.random.PRNGKey(0)):
    aug_rng, dropout_rng = jax.random.split(jax.random.fold_in(rng, state.step))

    def loss_fn(params):
        Yhat = Net(training=True).apply({"params": params}, X, rngs={"aug": aug_rng, "dropout": dropout_rng})
        loss = jnp.mean(optax.softmax_cross_entropy(logits=Yhat, labels=Y))
        err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float)
        return loss, err

    (_, Yhats), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)
    return state


def train_one_epoch(state, X, Y):
    for step_num in range(X.shape[0]):
        x, y = jnp.expand_dims(X[step_num], 0), jnp.expand_dims(Y[step_num], 0)
        state = train_step(state, x, y)
    return state


def train(key, data, epochs, lr):
    Xtr, Ytr = data["train"]
    Xte, Yte = data["test"]
    lr_fn = learning_rate_fn(lr, epochs, Xtr.shape[0])
    train_state = create_train_state(key, Xtr, lr_fn)
    for epoch in range(epochs):
        print(f"epoch {epoch+1} with learning rate {lr_fn(train_state.step):.6f}")
        train_state = train_one_epoch(train_state, Xtr, Ytr)
        for split in ["train", "test"]:
            eval_split(data, split, train_state.params)



In [None]:
@jax.jit
def eval_step(params, X, Y):
    Yhat = Net(training=False).apply({"params": params}, X)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=Yhat, labels=Y))
    err = jnp.mean(jnp.argmax(Y, -1) != jnp.argmax(Yhat, -1)).astype(float)
    return loss, err


def eval_split(data, split, params):
    X, Y = data[split]
    loss, err = eval_step(params, X, Y)
    print(f"eval: split {split:5s}. loss {loss:e}. error {err*100:.2f}%. misses: {int(err*Y.shape[0])}")

In [None]:
data = get_datasets(7291, 2007)

In [None]:
key, _ = jax.random.split(jax.random.PRNGKey(42))

train(key, data, 80, 3e-4)

epoch 1 with learning rate 0.000300
eval: split train. loss 4.722151e-01. error 12.73%. misses: 928
eval: split test . loss 4.376389e-01. error 11.81%. misses: 237
epoch 2 with learning rate 0.000297
eval: split train. loss 3.456218e-01. error 9.77%. misses: 712
eval: split test . loss 3.105372e-01. error 8.87%. misses: 178
epoch 3 with learning rate 0.000295
eval: split train. loss 2.216365e-01. error 6.45%. misses: 469
eval: split test . loss 1.981873e-01. error 5.53%. misses: 111
epoch 4 with learning rate 0.000292
eval: split train. loss 2.072843e-01. error 5.99%. misses: 437
eval: split test . loss 1.910520e-01. error 5.48%. misses: 110
epoch 5 with learning rate 0.000290
eval: split train. loss 1.750381e-01. error 5.49%. misses: 399
eval: split test . loss 1.611853e-01. error 4.93%. misses: 99
epoch 6 with learning rate 0.000288
eval: split train. loss 1.538368e-01. error 4.42%. misses: 321
eval: split test . loss 1.411121e-01. error 4.19%. misses: 84
epoch 7 with learning rate 0

Change 1: replace tanh on last layer with FC and use softmax. Lower learning rate to 0.01

```
epoch 23
eval: split train. loss 7.162272e-03. error 0.05%. misses: 4
eval: split test . loss 1.687743e-01. error 4.14%. misses: 83

```

Change 2: change from SGD to AdamW with LR 3e-4, double epochs to 46, decay LR to 1e-4 over the course of training.

```
epoch 46
eval: split train. loss 1.890260e-03. error 0.04%. misses: 2
eval: split test . loss 1.953933e-01. error 4.04%. misses: 81
```

Change 3: Introduce data augmentation, e.g. a shift by at most 1 pixel in both x/y directions, and bump up training time to 60 epochs.

```
epoch 60
eval: split train. loss 5.098452e-02. error 1.65%. misses: 120
eval: split test . loss 9.166716e-02. error 2.59%. misses: 52
```

Change 4: add dropout at layer H3, shift activation function to relu, and bring up iterations to 80.

```
epoch 80
eval: split train. loss 3.316079e-02. error 1.06%. misses: 77
eval: split test . loss 4.969697e-02. error 1.74%. misses: 35

```




