# Name: Exploring Temporal Latent Bottlenecks for Image Classification

**Author**: Aritra Roy Gosthipaty, Suvaditya Mukherjee  
**Date Created:** 06/03/2023  
**Last Modified:** 07/03/2023  
**Description:** Performing Image Classification with State-of-the-art Temporal Latent Bottleneck Mechanism.  
**Accelerator:** GPU

## Introduction

The following example explores how we can make use of the new Temporal Latent Bottleneck mechanism to perform image classification on the CIFAR-10 dataset. We implement this model by making a custom `RNNCell` implementation in order to make a performant and vectorized design, as proposed by [Didolkar et. al](https://arxiv.org/abs/2205.14794)  
A simple Recurrent Neural Network displays strong [inductive bias](https://en.wikipedia.org/wiki/Inductive_bias), i.e. the ability to generalize well within a specific domain. But it faces the significant problem of Vanishing/Exploding Gradients, along with the inability to store hidden-state information for long sequences.  
On the other end of the spectrum, the concept of the [Attention-based Transformer mechanism as introduced by Vaswani et. al](https://arxiv.org/abs/1706.03762) has shown considerable improvements in those departments, wherein it has achieved State-of-the-art results in Natural Language Processing tasks while also being adapted and used considerably in the Vision domain. While the Transformer has the ability to \"attend\" to different sections of the input sequence, it suffers from lacking inductive bias. This makes the mechanism prone to not generalizing well to domain-specific tasks.  
"This paper combines the concepts from both ends of the spectrum in order to make a new mechanism which has the ability to tackle the problem of inductive biases, vanishing/exploding gradient and loss of information with higher sequence lengths. While this method has the novelty of introducing different processing streams in order to preserve and process latent states, it has parallels drawn in other works like the [Perceiver Mechanism (by Jaegle et. al.)](https://arxiv.org/abs/2103.03206) and [Grounded Language Learning Fast and Slow (by Hill et. al.)](https://arxiv.org/pdf/2009.01719.pdf).  

This example is structured as follows:
- Perform necessary imports
- Set-up required configurations and settings
- Load the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html)
- Visualize random samples from the dataset
- Define Base layer for Attention and `PatchEmbed` layer for performing Patching and Embedding operations
- Define the `AttentionWithFFN` layer
- Compose the Perceptual Module and Temporal Latent Bottleneck Module as a stack of `SelfAttentionWithFFN` and `CrossAttentionWithFFN` layers
- Create custom `RNNCell` implementation which makes use of the above-mentioned modules (vectorized) and load into a Recurrent Neural Network
- Define hyperparameters and `model.fit()` pipeline
- Perform inference and testing

This example makes use of `TensorFlow 2.11.0`, which must be installed into our system

## Setup imports

In [None]:
!pip install -q tensorflow==2.11.0

In [None]:
import numpy as np
import wandb
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.optimizers.experimental import AdamW
from tensorflow.keras import mixed_precision
from typing import Tuple
from matplotlib import pyplot as plt

# Set seed for reproducibility.
tf.keras.utils.set_random_seed(42)

AUTO = tf.data.AUTOTUNE

## Setting required configuration

We set a few configuration parameters that are needed within the pipeline we have designed. The current parameters are for use with the [CIFAR10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html).  
The model also supports mixed-precision settings, which would quantize the model to use 16-bit float numbers where it can, while keeping some parameters in 32-bit as needed for numerical stability. This brings performance benefits as the footprint of the model decreases significantly while bring speed boosts at inference-time.

In [None]:
config = {
    "mixed_precision": True,
    "dataset": "cifar10",
    "train_slice": 40_000,
    "batch_size": 1024,
    "buffer_size": 1024 * 2,
    "input_shape": [32, 32, 3],
    "image_size": 48,
    "num_classes": 10,
    "learning_rate": 1e-4,
    "weight_decay": 1e-4,
    "epochs": 100,
    "patch_size": 4,
    "embed_dim": 128,
    "chunk_size": 8,
    "r": 2,
    "num_layers": 6,
    "ffn_drop": 0.2,
    "attn_drop": 0.2,
    "num_heads": 1,
}


policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_global_policy(policy)

## Loading the CIFAR-10 dataset

As mentioned previously in the Introduction, we are going to use the CIFAR10 dataset for running our experiments. This dataset contains a training set of 50,000 images for 10 classes with the standard image size of (32, 32, 3). It also has a separate set of 10,000 images with similar characteristics. More information about the dataset may be found at the official site for the dataset as well as [`keras.datasets.cifar10`](https://keras.io/api/datasets/cifar10/) API reference

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
    (x_train[: config["train_slice"]], y_train[: config["train_slice"]]),
    (x_train[config["train_slice"] :], y_train[config["train_slice"] :]),
)

## Define Augmentation pipelines for Train and Validation/Test pipelines

We define separate pipelines for performing image augmentation on our data. This step is important in pre-processing the data, making the model more robust to changes, and help it to generalize better. The steps we perform are as follows:

- `Rescaling` (Training, Test): This step is performed to normalize all image pixel values from the [0,255] range to [0,1). This helps in maintaining numerical stability later ahead during training.

- `Resizing` (Training, Test): We resize the image from it's original size of (32, 32) to (52, 52). This is done to account for the Random Crop, as well as comply with the specifications of the data given in the paper.

- `RandomCrop` (Training): This layer will randomly select a crop/sub-region of the image with size (48, 48).

- `RandomFlip` (Training): This layer will randomly flip all the images horizontally, keeping sizes same.

In [None]:
# Build the `train` augmentation pipeline.
train_aug = keras.Sequential(
    [
        layers.Rescaling(1 / 255.0, dtype="float32"),
        layers.Resizing(
            config["input_shape"][0] + 20,
            config["input_shape"][0] + 20,
            dtype="float32",
        ),
        layers.RandomCrop(config["image_size"], config["image_size"], dtype="float32"),
        layers.RandomFlip("horizontal", dtype="float32"),
    ],
    name="train_data_augmentation",
)

# Build the `val` and `test` data pipeline.
test_aug = keras.Sequential(
    [
        layers.Rescaling(1 / 255.0, dtype="float32"),
        layers.Resizing(config["image_size"], config["image_size"], dtype="float32"),
    ],
    name="test_data_augmentation",
)

In [None]:
## We define functions in place of simple lambda functions to run through the
## `keras.Sequential`in order to solve this warning:
## (https://github.com/tensorflow/tensorflow/issues/56089)


def train_map_fn(image, label):
    return train_aug(image), label


def test_map_fn(image, label):
    return test_aug(image), label

## Load dataset into `tf.data.Dataset` and perform optimizations

- We take the `np.ndarray` instance of the datasets and move them into a `tf.data.Dataset` instance
- Apply augmentations using `.map()`
- Shuffle the dataset using `.shuffle()`
- Batch the dataset using `.batch()`
- Enable pre-fetching of batches using `.prefetch()`

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = (
    train_ds.map(train_map_fn, num_parallel_calls=AUTO)
    .shuffle(config["buffer_size"])
    .batch(config["batch_size"], num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = (
    val_ds.map(test_map_fn, num_parallel_calls=AUTO)
    .batch(config["batch_size"], num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = (
    test_ds.map(test_map_fn, num_parallel_calls=AUTO)
    .batch(config["batch_size"], num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

In [None]:
# Calculate number of batches
total_batches_train = tf.data.experimental.cardinality(train_ds).numpy()
print(f"Total batches to train on: {total_batches_train}")

In [None]:
# Check if the `tf.data.Dataset` instance returns iterative outputs

image, label = next(iter(train_ds))
print(image.shape)
print(label.shape)

## Define `PatchEmbed` layer

This custom `keras.layers.Layer` instance is useful for generating patches from the image and transform them into a higher-dimensional embedding space using `keras.layers.Embedding`.  
The patching operation is done using a `keras.layers.Conv2D` instance instead of a traditional `tf.image.extract_patches` to allow for vectorization to take place.  
Once the patching of images is complete, we reshape the image patches in order to get a flattened representation where the number of dimensions is the Embedding dimension. At this stage, we also add a Positional Embedding factor into the input.  
We then pass the images into the Embedding layer, following which we go through a `keras.layers.LayerNormalization`, finally performing the 'chunking' operation.  
The Chunking operation involves taking fixed-size sequences from the embedding output to create 'chunks', which will then be used as the final input to the model.

In [None]:
class PatchEmbed(layers.Layer):
    """Image patch embedding layer.

    Args:
        image_size (Tuple[int]): Input image resolution.
        patch_size (Tuple[int]): Patch spatial resolution.
        embed_dim (int): Embedding dimension.
        add_pos_info (bool): Whether to add positional information to tokens.
    """

    def __init__(
        self,
        image_size: Tuple[int] = (48, 48),
        patch_size: Tuple[int] = (4, 4),
        embed_dim: int = 32,
        chunk_size: int = 8,
        **kwargs,
    ):
        super().__init__(**kwargs)

        # Get patch resolution
        patch_resolution = [
            image_size[0] // patch_size[0],
            image_size[1] // patch_size[1],
        ]
        self.image_size = image_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.patch_resolution = patch_resolution

        # Calculate number of patches per image
        self.num_patches = patch_resolution[0] * patch_resolution[1]
        self.proj = layers.Conv2D(
            filters=embed_dim, kernel_size=patch_size, strides=patch_size
        )
        self.flatten = layers.Reshape(target_shape=(-1, embed_dim))
        self.position_embedding = layers.Embedding(
            input_dim=self.num_patches, output_dim=embed_dim
        )

        # Calculate number of positions for the patches.
        self.positions = tf.range(start=0, limit=self.num_patches, delta=1)
        self.norm = keras.layers.LayerNormalization(epsilon=1e-5)

        # Perform Chunking
        self.chunking_layer = layers.Reshape(
            target_shape=(self.num_patches // chunk_size, chunk_size, embed_dim)
        )

    def call(self, x: tf.Tensor) -> Tuple[tf.Tensor, int, int, int]:
        """Patchifies the image, converts into tokens and adds pos information.

        Args:
            x: Tensor of shape (B, H, W, C)

        Returns:
            A tuple of the processed tensor, height of the projected
            feature map, width of the projected feature map, number
            of channels of the projected feature map.
        """
        # Project the inputs.
        x = self.proj(x)
        x = self.flatten(x)
        x = x + self.position_embedding(self.positions)

        # B, H, W, C -> B, H*W, C
        x = self.norm(x)

        # Chunk the tokens in K
        x = self.chunking_layer(x)

        return x

## Define `FeedForwardNetwork`

This custom `keras.layers.Layer` instance allows us to define a generic FFN along with a dropout.

In [None]:
class FeedForwardNetwork(layers.Layer):
    """Feed Forward Network.

    Args:
        dims (`int`): Dimension of the FFN.
        dropout (`float`): Dropout probability of FFN.
    """

    def __init__(self, dims: int, dropout: float = 0.1, **kwargs):
        super().__init__(**kwargs)
        self.ffn = keras.Sequential(
            [
                layers.Dense(units=4 * dims, activation=tf.nn.gelu),
                layers.Dense(units=dims),
                layers.Dropout(rate=dropout),
            ]
        )
        self.add = layers.Add()
        self.layernorm = layers.LayerNormalization(epsilon=1e-5)

    def call(self, inputs: tf.Tensor):
        x = self.layernorm(inputs)
        x = self.add([inputs, self.ffn(x)])
        return x

## Define `BaseAttention` as base class for Attention modules

This custom `keras.layers.Layer` instance is a `super`/`base` class that wraps a `keras.layers.MultiHeadAttention` layer along with some other components. This gives us basic common denominator functionality for all the Attention layers/modules in our model.

In [None]:
class BaseAttention(layers.Layer):
    """The base attention module.

    Args:
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for query and key.
        dropout (`float`): Dropout probability for Attention Module.

    """

    def __init__(self, num_heads: int, key_dim: int, dropout: float = 0.1, **kwargs):
        super().__init__(**kwargs)
        self.mha = layers.MultiHeadAttention(num_heads, key_dim, dropout=dropout)
        self.q_layernorm = layers.LayerNormalization(epsilon=1e-5)
        self.k_layernorm = layers.LayerNormalization(epsilon=1e-5)
        self.v_layernorm = layers.LayerNormalization(epsilon=1e-5)
        self.add = layers.Add()

    def call(self, input_query: tf.Tensor, key: tf.Tensor, value: tf.Tensor):
        query = self.q_layernorm(input_query)
        key = self.k_layernorm(key)
        value = self.v_layernorm(value)
        (attn_outs, attn_scores) = self.mha(
            query=query,
            key=key,
            value=value,
            return_attention_scores=True,
        )
        self.attn_scores = attn_scores
        x = self.add([input_query, attn_outs])
        return x

## Define `Attention` module with `FeedForwardNetwork` at head

This custom `keras.layers.Layer` implementation combines the `BaseAttention` and `FeedForwardNetwork` components to develop one block which will be used repeatedly within the model. This module is highly customizable and flexible, allowing for changes within the internal layers.

In [None]:
class AttentionWithFFN(layers.Layer):
    """Self-attention module with FFN

    Args:
        ffn_dims (`int`): Number of units in FFN.
        ffn_dropout (`float`): Dropout probability for FFN.
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for query and key.
        attn_dropout (`float`): Dropout probability for attention module.
    """

    def __init__(
        self,
        ffn_dims: int = 128,
        ffn_dropout: float = 0.1,
        num_heads: int = 4,
        key_dim: int = 256,
        attn_dropout: float = 0.1,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.attention = BaseAttention(
            num_heads, key_dim, attn_dropout, name="BaseAttention"
        )
        self.ffn = FeedForwardNetwork(ffn_dims, ffn_dropout, name="FeedForward")

    def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor):
        x = self.attention(query, key, value)
        self.attn_scores = self.attention.attn_scores
        x = self.ffn(x)
        return x

## Define Custom Recurrent Neural Network cell layer for **Temporal Latent Bottleneck** and **Perceptual Module**

This custom cell, implemented as a `keras.layers.Layer`, is the integral part of the logic for the model.
The cell's functionality can be divided into 2 parts:
- **Slow Stream (Temporal Latent Bottleneck):** 
    - This module consists of a single `AttentionWithFFN` layer that parses the output of the previous Slow Stream, an intermediate hidden representation (which is the *latent* in the Temporal Latent Bottleneck) as the Query, and the output of the latest Fast Stream as Key and Value. This layer can also be construed as a *CrossAttention* layer.  

- **Fast Stream (Perceptual Module):** 
     - This module consists of intertwined `AttentionWithFFN` layers.This stream consists of *n* layers of `SelfAttention` and `CrossAttention` in a sequential manner. 
     - Here, some layers take the chunked input as the Query, Key and Value (Also referred to as the *SelfAttention* layer). 
     - The other layers take the intermediate state outputs from within the Temporal Latent Bottleneck module as the Query while using the output of the previous Self-Attention layers before it as the Key and Value. 

In [None]:
class CustomCell(layers.Layer):
    """Custom logic inside each recurrence.

    Args:
        chunk_size (`int`): Chunk size of the inputs.
        r (`int`): One Cross Attention per **r** Self Attention.
        num_layers (`int`): Number of layers in the Perceptual Model.
        ffn_dims (`int`): Number of units in FFN.
        ffn_dropout (`float`): Dropout probability for FFN.
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for query and key.
        attn_dropout (`float`): Dropout probability for attention module.
    """

    def __init__(
        self,
        chunk_size,
        r=2,
        num_layers: int = 5,
        ffn_dims: int = 128,
        ffn_dropout: float = 0.1,
        num_heads: int = 4,
        key_dim: int = 256,
        attn_dropout: float = 0.1,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.chunk_size = chunk_size
        self.r = r
        self.num_layers = num_layers
        self.ffn_dims = ffn_dims
        self.ffn_droput = ffn_dropout
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.attn_dropout = attn_dropout

        # Required for any custom `RNNCell` implementations
        self.state_size = tf.TensorShape([chunk_size, ffn_dims])
        self.output_size = tf.TensorShape([chunk_size, ffn_dims])

        self.get_attn_scores = False
        self.attn_scores = []

        ########################################################################
        # Perceptual Module
        ########################################################################
        perceptual_module = list()
        for layer_idx in range(num_layers):
            perceptual_module.append(
                AttentionWithFFN(
                    ffn_dims=ffn_dims,
                    ffn_dropout=ffn_dropout,
                    num_heads=num_heads,
                    key_dim=key_dim,
                    attn_dropout=attn_dropout,
                    name=f"PM_SelfAttentionFFN{layer_idx}",
                )
            )
            if layer_idx % r == 0:
                perceptual_module.append(
                    AttentionWithFFN(
                        ffn_dims=ffn_dims,
                        ffn_dropout=ffn_dropout,
                        num_heads=num_heads,
                        key_dim=key_dim,
                        attn_dropout=attn_dropout,
                        name=f"PM_CrossAttentionFFN{layer_idx}",
                    )
                )
        self.perceptual_module = perceptual_module

        ########################################################################
        # Temporal Latent Bottleneck Module
        ########################################################################
        self.tlb_module = AttentionWithFFN(
            ffn_dims=ffn_dims,
            ffn_dropout=ffn_dropout,
            num_heads=num_heads,
            key_dim=key_dim,
            attn_dropout=attn_dropout,
            name=f"TLBM_CrossAttentionFFN",
        )

    def call(self, inputs, states, training=None):
        # inputs => (batch, chunk_size, dims)
        # states => [(batch, chunk_size, units)]

        slow_stream = states[0]
        fast_stream = inputs

        for layer_idx, layer in enumerate(self.perceptual_module):
            fast_stream = layer(query=fast_stream, key=fast_stream, value=fast_stream)

            if layer_idx % self.r == 0:
                fast_stream = layer(
                    query=fast_stream, key=slow_stream, value=slow_stream
                )

        slow_stream = self.tlb_module(
            query=slow_stream, key=fast_stream, value=fast_stream
        )

        if self.get_attn_scores:
            self.attn_scores.append(self.tlb_module.attn_scores)

        return fast_stream, [slow_stream]

## Define `ModelTrainer` to encapsulate full model

Here, we just wrap the full model as to expose it for training.

In [None]:
class ModelTrainer(keras.Model):
    def __init__(self, patch_layer, custom_cell, **kwargs):
        super().__init__(**kwargs)
        self.patch_layer = patch_layer
        self.rnn = layers.RNN(custom_cell)
        self.gap = layers.GlobalAveragePooling1D()
        self.head = layers.Dense(10, activation="softmax", dtype="float32")

    def call(self, inputs):
        x = self.patch_layer(inputs)
        x = self.rnn(x)
        x = self.gap(x)
        outputs = self.head(x)
        return outputs

## Define Model Components for creating Training Pipeline

To begin training, we now define the components individually and pass them as arguments to our wrapper class, which will prepare the final model for training. We define a `PatchEmbed` layer, and the `CustomCell`-based RNN.

In [None]:
# We call this to clear all previous session states. This frees up GPU memory
# consumption as well.

keras.backend.clear_session()

# Patch & Embedding
patch_layer = PatchEmbed(
    image_size=(config["image_size"], config["image_size"]),
    patch_size=(config["patch_size"], config["patch_size"]),
    embed_dim=config["embed_dim"],
    chunk_size=config["chunk_size"],
)

# Recurrent Cell
cell = CustomCell(
    chunk_size=config["chunk_size"],
    r=config["r"],
    num_layers=config["num_layers"],
    ffn_dims=config["embed_dim"],
    ffn_dropout=config["ffn_drop"],
    num_heads=config["num_heads"],
    key_dim=config["embed_dim"],
    attn_dropout=config["attn_drop"],
)

# Model
model = ModelTrainer(patch_layer, cell)

## Compile Model and set-up metrics + callbacks tracking

We use the `AdamW` optimizer from `tf.keras.optimizers.experimental` (previously part of `tensorflow-addons`) since it has been shown to perform very well on several benchmark tasks from an optimization perspective. It is a version of the `tf.keras.optimizers.Adam` optimizer, along with Weight Decay in place.  
For a loss function, we make use of the `keras.losses.SparseCategoricalCrossentropy` function that makes use of simple Cross-entropy between prediction and actual logits. We also calculate accuracy on our dataset.  
We define a callback that will be triggered after each epoch ends, to help track all metrics and state variables with the help of Tensorboard. [Know more about how Tensorboard is useful over here](https://www.tensorflow.org/tensorboard)

In [None]:
optimizer = AdamW(
    learning_rate=config["learning_rate"], weight_decay=config["weight_decay"]
)
model.compile(
    optimizer=optimizer,
    loss=keras.optimizers.SparseCategoricalCrossentropy,
    metrics=keras.metrics.SparseCategoricalAccuracy,
)

callbacks = [
    keras.callbacks.TensorBoard(
        log_dir="tfboard-logs",
        histogram_freq=0,
        write_graph=True,
        write_images=True,
        update_freq="epoch",
        embeddings_freq=10,
    )
]

## Train the model with `model.fit()`

We pass the training dataset and run training for 100 epochs.

In [None]:
history = model.fit(
    train_ds,
    epochs=config["epochs"],
    validation_data=val_ds,
    callbacks=callbacks,
)

## Visualize training metrics

The `model.fit()` will return a History object, which stores the values of the metrics generated during the training run (but it is ephemeral and needs to be stored manually).  
We now display the Loss and Accuracy curves for the training and validation sets.

In [None]:
plt.plot(history.history["loss"], label="loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.legend()
plt.show()

plt.plot(history.history["accuracy"], label="accuracy")
plt.plot(history.history["val_accuracy"], label="val_accuracy")
plt.legend()
plt.show()

## Visualize attention maps from the Temporal Latent Bottleneck

We visualize the attention scores returned from the Temporal Latent Bottleneck. This is done by extracting the attention scores from the TLB Cross-attention layer at each chunk's intersection and storing it within the RNN's state. This is followed by 'ballooning' it up and returning these values. Finally, we process the shape of the tensors and return the scores as a heatmap overlaid on the original image.

In [None]:
images, labels = next(iter(test_ds))

# Set the flag for attn score
model.rnn.cell.get_attn_scores = True
outputs = model(images)

# Grab the list of chunk scores
list_chunk_scores = model.rnn.cell.attn_scores

In [None]:
def score_to_viz(chunk_score):
    chunk_viz = tf.math.reduce_max(chunk_score, axis=-2)  # get the most attended token
    chunk_viz = tf.math.reduce_mean(chunk_viz, axis=1)  # get the mean across heads
    return chunk_viz

In [None]:
list_chunk_viz = [score_to_viz(x) for x in list_chunk_scores]

In [None]:
chunk_viz = tf.concat(list_chunk_viz[1:], axis=-1)
chunk_viz = tf.reshape(
    chunk_viz,
    (
        config["batch_size"],
        config["image_size"] // config["patch_size"],
        config["image_size"] // config["patch_size"],
        1,
    ),
)

In [None]:
upsampled_heat_map = layers.UpSampling2D(
    size=(4, 4), interpolation="bilinear", dtype="float32"
)(chunk_viz)

In [None]:
index = 60
orig_image = images[index]
overlay_image = upsampled_heat_map[index, ..., 0]

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))

ax[0].imshow(orig_image)
ax[0].set_title("Original:")
ax[0].axis("off")

image = ax[1].imshow(orig_image)
ax[1].imshow(
    overlay_image,
    cmap="inferno",
    alpha=0.6,
    extent=image.get_extent(),
)
ax[1].set_title("TLB Attention:")

plt.show()

## Conclusion

This example has hereby demonstrated an implementation of the Temporal Latent Bottleneck mechanism which involves using a mix of Attention mechanisms to bring about a solution to the inductive bias problem of Transformers. The example highlights the use of compression and storage of historical states in the form of a Temporal Latent Bottleneck with regular updates from a Perceptual Module as an effective method to do so. In the original paper, the authors have conducted highly extensive tests around different modalities ranging from Supervised Image Classification to applications in Reinforcement Learning.  
<TODO WRITE A PARA ON FINAL TEST/TRAIN METRICS>  
While we have only displayed a method to apply this mechanism to Image Classification, it can be extended to other modalities too with minimal changes.