# Convolutional Neural Networks (LeNet)
:label:`sec_lenet`

We now have all the ingredients required to assemble
a fully-functional CNN.
In our earlier encounter with image data, we applied
a linear model with softmax regression (:numref:`sec_softmax_scratch`)
and an MLP (:numref:`sec_mlp-implementation`)
to pictures of clothing in the Fashion-MNIST dataset.
To make such data amenable we first flattened each image from a $28\times28$ matrix
into a fixed-length $784$-dimensional vector,
and thereafter processed them in fully connected layers.
Now that we have a handle on convolutional layers,
we can retain the spatial structure in our images.
As an additional benefit of replacing fully connected layers with convolutional layers,
we will enjoy more parsimonious models that require far fewer parameters.

In this section, we will introduce *LeNet*,
among the first published CNNs
to capture wide attention for its performance on computer vision tasks.
The model was introduced by (and named for) Yann LeCun,
then a researcher at AT&T Bell Labs,
for the purpose of recognizing handwritten digits in images :cite:`LeCun.Bottou.Bengio.ea.1998`.
This work represented the culmination
of a decade of research developing the technology;
LeCun's team published the first study to successfully
train CNNs via backpropagation :cite:`LeCun.Boser.Denker.ea.1989`.

At the time LeNet achieved outstanding results
matching the performance of support vector machines,
then a dominant approach in supervised learning, achieving an error rate of less than 1% per digit.
LeNet was eventually adapted to recognize digits
for processing deposits in ATM machines.
To this day, some ATMs still run the code
that Yann LeCun and his colleague Leon Bottou wrote in the 1990s!


In [18]:
# from types import FunctionType

# from absl import logging
# from flax import linen as nn
# from clu import metric_writers
# from flax.training import train_state
# import jax
# import jax.numpy as jnp
# import ml_collections
# import numpy as np
# import optax
# import tensorflow_datasets as tfds

from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import jax
from jax import numpy as jnp
import optax
from flax import nnx

## LeNet

At a high level, (**LeNet (LeNet-5) consists of two parts:
(i) a convolutional encoder consisting of two convolutional layers; and
(ii) a dense block consisting of three fully connected layers**).
The architecture is summarized in :numref:`img_lenet`.

![Data flow in LeNet. The input is a handwritten digit, the output is a probability over 10 possible outcomes.](../img/lenet.svg)
:label:`img_lenet`

The basic units in each convolutional block
are a convolutional layer, a sigmoid activation function,
and a subsequent average pooling operation.
Note that while ReLUs and max-pooling work better,
they had not yet been discovered.
Each convolutional layer uses a $5\times 5$ kernel
and a sigmoid activation function.
These layers map spatially arranged inputs
to a number of two-dimensional feature maps, typically
increasing the number of channels.
The first convolutional layer has 6 output channels,
while the second has 16.
Each $2\times2$ pooling operation (stride 2)
reduces dimensionality by a factor of $4$ via spatial downsampling.
The convolutional block emits an output with shape given by
(batch size, number of channel, height, width).

In order to pass output from the convolutional block
to the dense block,
we must flatten each example in the minibatch.
In other words, we take this four-dimensional input and transform it
into the two-dimensional input expected by fully connected layers:
as a reminder, the two-dimensional representation that we desire uses the first dimension to index examples in the minibatch
and the second to give the flat vector representation of each example.
LeNet's dense block has three fully connected layers,
with 120, 84, and 10 outputs, respectively.
Because we are still performing classification,
the 10-dimensional output layer corresponds
to the number of possible output classes.

While getting to the point where you truly understand
what is going on inside LeNet may have taken a bit of work,
we hope that the following code snippet will convince you
that implementing such models with modern deep learning frameworks
is remarkably simple.
We need only to instantiate a `Sequential` block
and chain together the appropriate layers,
using Xavier initialization as
introduced in :numref:`subsec_xavier`.


In [11]:
class LeNet(nnx.Module):  #@save
  """The LeNet-5 model."""
  
  def __init__(self, in_channels: int, num_classes: int, *, rngs: nnx.Rngs):
    self.kernel_init = nnx.initializers.xavier_uniform
    self.conv1 = nnx.Conv(in_features=in_channels, out_features=6, 
                          kernel_size=(5, 5), padding='SAME', 
                          kernel_init=self.kernel_init(), rngs=rngs)
    self.conv2 = nnx.Conv(in_features=6, out_features=16, 
                          kernel_size=(5, 5), padding='VALID',
                          kernel_init=self.kernel_init(), rngs=rngs)
    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    self.linear1 = nnx.Linear(400, 120, kernel_init=self.kernel_init(), rngs=rngs)
    self.linear2 = nnx.Linear(120, 84, kernel_init=self.kernel_init(), rngs=rngs)
    self.linear3 = nnx.Linear(84, num_classes, kernel_init=self.kernel_init(), rngs=rngs)
    
  def __call__(self, x):
    x = self.avg_pool(nnx.sigmoid(self.conv1(x)))
    x = self.avg_pool(nnx.sigmoid(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.sigmoid(self.linear1(x))
    x = nnx.sigmoid(self.linear2(x))
    return self.linear3(x)

We have taken some liberty in the reproduction of LeNet insofar as we have replaced the Gaussian activation layer by
a softmax layer. This greatly simplifies the implementation, not least due to the
fact that the Gaussian decoder is rarely used nowadays. Other than that, this network matches
the original LeNet-5 architecture.


Let's see what happens inside the network. By passing a
single-channel (black and white)
$28 \times 28$ image through the network
and printing the output shape at each layer,
we can [**inspect the model**] to ensure
that its operations line up with
what we expect from :numref:`img_lenet_vert`.
Flax provides `nn.tabulate`, a nifty method to summarise the layers and
parameters in our network. Here we use the `bind` method to create a bounded model.
The variables are now bound to the `d2l.Module` class, i.e., this bounded model
becomes a stateful object which can then be used to access the `Sequential`
object attribute `net` and the `layers` within. Note that the `bind` method should
only be used for interactive experimentation, and is not a direct
replacement for the `apply` method.


![Compressed notation for LeNet-5.](../img/lenet-vert.svg)
:label:`img_lenet_vert`


In [12]:
model = LeNet(in_channels=1, num_classes=10, rngs=nnx.Rngs(0))

In [14]:
model

LeNet(
  kernel_init=<function glorot_uniform at 0x12f502fc0>,
  conv1=Conv(
    kernel_shape=(5, 5, 1, 6),
    kernel=Param(
      value=Array(shape=(5, 5, 1, 6), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(6,), dtype=float32)
    ),
    in_features=1,
    out_features=6,
    kernel_size=(5, 5),
    strides=1,
    padding='SAME',
    input_dilation=1,
    kernel_dilation=1,
    feature_group_count=1,
    use_bias=True,
    mask=None,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    precision=None,
    kernel_init=<function variance_scaling.<locals>.init at 0x152dc1e40>,
    bias_init=<function zeros at 0x12f47d300>,
    conv_general_dilated=<function conv_general_dilated at 0x11f714860>
  ),
  conv2=Conv(
    kernel_shape=(5, 5, 6, 16),
    kernel=Param(
      value=Array(shape=(5, 5, 6, 16), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(16,), dtype=float32)
    ),
    in_features=6,
    out_features=16,
    kernel_size=(5, 5),
    s

In [22]:
x = jnp.ones((1, 28, 28, 1))
model(x)

Array([[-0.28011444,  0.2215457 ,  0.51564384, -0.6677027 ,  0.5182838 ,
        -0.511331  , -0.6132752 ,  0.76808715,  0.6332212 ,  0.18320563]],      dtype=float32)

## Training

Now that we have implemented the model,
let's [**run an experiment to see how the LeNet-5 model fares on Fashion-MNIST**].

While CNNs have fewer parameters,
they can still be more expensive to compute
than similarly deep MLPs
because each parameter participates in many more
multiplications.
If you have access to a GPU, this might be a good time
to put it into action to speed up training.

In [19]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

In [23]:
x_train.shape

(60000, 28, 28)

In [24]:
x_train = x_train / 255.
# added channel axis
x_train = x_train[..., np.newaxis]
y_train = y_train.astype(np.int32)

In [25]:
x_train.shape

(60000, 28, 28, 1)

In [26]:
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))

In [27]:
n_epoch = 10
batch_size = 256
train_data = train_data.repeat(n_epoch).shuffle(buffer_size=batch_size)
train_data = train_data.batch(batch_size, drop_remainder=True).prefetch(1)

In [28]:
train_data.cardinality().numpy() # (60000 * n_epoch) / batch_size

2343

In [29]:
x_test = x_test / 255.
x_test = x_test[..., np.newaxis]
y_test = y_test.astype(np.int32)
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))

In [30]:
test_data = test_data.shuffle(batch_size) 
test_data = test_data.batch(batch_size, drop_remainder=True).prefetch(1)

In [31]:
def loss_fn(model, batch):
  x, y = batch
  logits = model(x)
  loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean()
  return loss, logits

In [32]:
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=0.1))

In [33]:
metrics = nnx.MultiMetric(
  accuracy=nnx.metrics.Accuracy(),
  loss=nnx.metrics.Average('loss'),
)

In [34]:
@nnx.jit
def train_step(model, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """Train for a single step."""
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch[1])
  optimizer.update(grads)

In [35]:
@nnx.jit
def eval_step(model, metrics: nnx.MultiMetric, batch):
  loss, logits = loss_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch[1])

In [36]:
eval_every = 250

metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'test_loss': [],
  'test_accuracy': [],
}

for step, batch in enumerate(train_data.as_numpy_iterator()):
  train_step(model, optimizer, metrics, batch)
  
  if step > 0 and (step % eval_every == 0): 
    for metric, value in metrics.compute().items():
      metrics_history[f'train_{metric}'].append(value)
    metrics.reset()

    for test_batch in test_data.as_numpy_iterator():
      eval_step(model, metrics, test_batch)

    for metric, value in metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)
    metrics.reset()

2024-10-24 13:57:52.361340: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-24 13:57:59.027785: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-24 13:58:05.671240: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-24 13:58:12.320985: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-24 13:58:18.999153: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-24 13:58:25.822800: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-24 13:58:32.644561: W tensorflow/core/framework/local_rendezvous.cc:404] L

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.set_title('Loss')
ax2.set_title('Accuracy')
for dataset in ('train', 'test'):
  ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
  ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
ax1.legend()
ax2.legend()
plt.show()