# Image Captioning

**Author:** [A_K_Nain](https://twitter.com/A_K_Nain)<br>
**Date created:** 2021/05/29<br>
**Last modified:** 2021/10/31<br>
**Description:** Implement an image captioning model using a CNN and a Transformer.

## Setup

In [1]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import re
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import keras
from keras import layers
from keras.applications import efficientnet
from keras.layers import TextVectorization

keras.utils.set_random_seed(111)

## Download the dataset

We will be using the Flickr8K dataset for this tutorial. This dataset comprises over
8,000 images, that are each paired with five different captions.

In [2]:
!wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
!wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
!unzip -qq Flickr8k_Dataset.zip
!unzip -qq Flickr8k_text.zip
!rm Flickr8k_Dataset.zip Flickr8k_text.zip

In [3]:
# Path to the images
IMAGES_PATH = "Flicker8k_Dataset"

# Desired image dimensions
IMAGE_SIZE = (299, 299)

# Vocabulary size
VOCAB_SIZE = 10000

# Fixed length allowed for any sequence
SEQ_LENGTH = 25

# Dimension for the image embeddings and token embeddings
EMBED_DIM = 512

# Per-layer units in the feed-forward network
FF_DIM = 512

# Other training parameters
BATCH_SIZE = 64
EPOCHS = 100
AUTOTUNE = tf.data.AUTOTUNE

## Preparing the dataset

In [4]:
def load_captions_data(filename):
    """Loads captions (text) data and maps them to corresponding images.

    Args:
        filename: Path to the text file containing caption data.

    Returns:
        caption_mapping: Dictionary mapping image names and the corresponding captions
        text_data: List containing all the available captions
    """

    with open(filename) as caption_file:
        caption_data = caption_file.readlines()
        caption_mapping = {}
        text_data = []
        images_to_skip = set()

        for line in caption_data:
            line = line.rstrip("\n")
            # Image name and captions are separated using a tab
            img_name, caption = line.split("\t")

            # Each image is repeated five times for the five different captions.
            # Each image name has a suffix `#(caption_number)`
            img_name = img_name.split("#")[0]
            img_name = os.path.join(IMAGES_PATH, img_name.strip())

            # We will remove caption that are either too short to too long
            tokens = caption.strip().split()

            if len(tokens) < 5 or len(tokens) > SEQ_LENGTH:
                images_to_skip.add(img_name)
                continue

            if img_name.endswith("jpg") and img_name not in images_to_skip:
                # We will add a start and an end token to each caption
                caption = "<start> " + caption.strip() + " <end>"
                text_data.append(caption)

                if img_name in caption_mapping:
                    caption_mapping[img_name].append(caption)
                else:
                    caption_mapping[img_name] = [caption]

        for img_name in images_to_skip:
            if img_name in caption_mapping:
                del caption_mapping[img_name]

        return caption_mapping, text_data


def train_val_split(caption_data, train_size=0.8, shuffle=True):
    """Split the captioning dataset into train and validation sets.

    Args:
        caption_data (dict): Dictionary containing the mapped caption data
        train_size (float): Fraction of all the full dataset to use as training data
        shuffle (bool): Whether to shuffle the dataset before splitting

    Returns:
        Traning and validation datasets as two separated dicts
    """

    # 1. Get the list of all image names
    all_images = list(caption_data.keys())

    # 2. Shuffle if necessary
    if shuffle:
        np.random.shuffle(all_images)

    # 3. Split into training and validation sets
    train_size = int(len(caption_data) * train_size)

    training_data = {
        img_name: caption_data[img_name] for img_name in all_images[:train_size]
    }
    validation_data = {
        img_name: caption_data[img_name] for img_name in all_images[train_size:]
    }

    # 4. Return the splits
    return training_data, validation_data


# Load the dataset
captions_mapping, text_data = load_captions_data("Flickr8k.token.txt")

# Split the dataset into training and validation sets
train_data, valid_data = train_val_split(captions_mapping)
print("Number of training samples: ", len(train_data))
print("Number of validation samples: ", len(valid_data))

Number of training samples:  6114
Number of validation samples:  1529


## Vectorizing the text data

We'll use the `TextVectorization` layer to vectorize the text data,
that is to say, to turn the
original strings into integer sequences where each integer represents the index of
a word in a vocabulary. We will use a custom string standardization scheme
(strip punctuation characters except `<` and `>`) and the default
splitting scheme (split on whitespace).

In [5]:
def custom_standardization(input_string):
    lowercase = tf.strings.lower(input_string)
    return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")


strip_chars = "!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
strip_chars = strip_chars.replace("<", "")
strip_chars = strip_chars.replace(">", "")

vectorization = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode="int",
    output_sequence_length=SEQ_LENGTH,
    standardize=custom_standardization,
)
vectorization.adapt(text_data)

# Data augmentation for image data
image_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.2),
        layers.RandomContrast(0.3),
    ]
)

## Building a `tf.data.Dataset` pipeline for training

We will generate pairs of images and corresponding captions using a `tf.data.Dataset` object.
The pipeline consists of two steps:

1. Read the image from the disk
2. Tokenize all the five captions corresponding to the image

In [6]:
def decode_and_resize(img_path):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMAGE_SIZE)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return img


def process_input(img_path, captions):
    return decode_and_resize(img_path), vectorization(captions)


def make_dataset(images, captions):
    dataset = tf.data.Dataset.from_tensor_slices((images, captions))
    dataset = dataset.shuffle(BATCH_SIZE * 8)
    dataset = dataset.map(process_input, num_parallel_calls=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)

    return dataset


# Pass the list of images and the list of corresponding captions
train_dataset = make_dataset(list(train_data.keys()), list(train_data.values()))

valid_dataset = make_dataset(list(valid_data.keys()), list(valid_data.values()))


## Building the model

Our image captioning architecture consists of three models:

1. A CNN: used to extract the image features
2. A TransformerEncoder: The extracted image features are then passed to a Transformer
                    based encoder that generates a new representation of the inputs
3. A TransformerDecoder: This model takes the encoder output and the text data
                    (sequences) as inputs and tries to learn to generate the caption.

## Building the model (advanced emsemble)

In [7]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import (
    VGG16,
    ResNet50,
    MobileNetV2,
    InceptionV3,
    EfficientNetB0,
)

############################################################
# Updated CNN model function using multiple base backbones #
############################################################
def get_cnn_model() -> keras.Model:
    """
    Creates a CNN feature-extraction model using five different base backbones
    (VGG16, ResNet50, MobileNetV2, InceptionV3, and EfficientNetB0).
    The outputs from these models are concatenated to form a unified image
    embedding.

    Returns:
        keras.Model: A model that takes an image as input and outputs a combined
                     feature representation.
    """
    # Create a common input layer for all base models
    input_tensor = layers.Input(shape=(*IMAGE_SIZE, 3), name="image_input")

    ###################################################################
    # 1. Create each base model (with pretrained weights, no top) and #
    #    freeze them so they don't get updated during training        #
    ###################################################################
    # VGG16
    vgg16_base = VGG16(
        weights="imagenet",
        include_top=False,
        input_tensor=input_tensor,
    )
    for layer in vgg16_base.layers:
        layer.trainable = False

    # ResNet50
    resnet_base = ResNet50(
        weights="imagenet",
        include_top=False,
        input_tensor=input_tensor,
    )
    for layer in resnet_base.layers:
        layer.trainable = False

    # MobileNetV2
    mobilenet_base = MobileNetV2(
        weights="imagenet",
        include_top=False,
        input_tensor=input_tensor,
    )
    for layer in mobilenet_base.layers:
        layer.trainable = False

    # InceptionV3
    inception_base = InceptionV3(
        weights="imagenet",
        include_top=False,
        input_tensor=input_tensor,
    )
    for layer in inception_base.layers:
        layer.trainable = False

    # EfficientNetB0
    efficientnet_base = EfficientNetB0(
        weights="imagenet",
        include_top=False,
        input_tensor=input_tensor,
    )
    for layer in efficientnet_base.layers:
        layer.trainable = False

    ######################################################
    # 2. Concatenate all features along the last dimension
    ######################################################
    # We will globally average pool each model's output to get a (batch, features) shape,
    # then expand dims to (batch, 1, features) and concatenate them.

    vgg16_pooled = layers.GlobalAveragePooling2D()(vgg16_base.output)
    vgg16_pooled = layers.Reshape((1, -1))(vgg16_pooled)

    resnet_pooled = layers.GlobalAveragePooling2D()(resnet_base.output)
    resnet_pooled = layers.Reshape((1, -1))(resnet_pooled)

    mobilenet_pooled = layers.GlobalAveragePooling2D()(mobilenet_base.output)
    mobilenet_pooled = layers.Reshape((1, -1))(mobilenet_pooled)

    inception_pooled = layers.GlobalAveragePooling2D()(inception_base.output)
    inception_pooled = layers.Reshape((1, -1))(inception_pooled)

    efficientnet_pooled = layers.GlobalAveragePooling2D()(efficientnet_base.output)
    efficientnet_pooled = layers.Reshape((1, -1))(efficientnet_pooled)

    concatenated_out = layers.Concatenate(axis=-1)(
        [
            vgg16_pooled,
            resnet_pooled,
            mobilenet_pooled,
            inception_pooled,
            efficientnet_pooled,
        ]
    )
    # Shape is (batch, 1, sum_of_features)

    #########################################################################
    # 3. Create a Model that outputs the concatenated embeddings            #
    #########################################################################
    cnn_model = keras.Model(inputs=input_tensor, outputs=concatenated_out)
    return cnn_model


##########################################
#          Transformer components        #
##########################################
class TransformerEncoderBlock(layers.Layer):
    def __init__(self, embed_dim: int, dense_dim: int, num_heads: int, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.0
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.dense_1 = layers.Dense(embed_dim, activation="relu")

    def call(self, inputs: tf.Tensor, training: bool, mask: tf.Tensor = None) -> tf.Tensor:
        """
        Forward pass of the Transformer encoder block.

        Args:
            inputs (tf.Tensor): The input tensor with shape (batch, sequence, features).
            training (bool): Whether the model is in training mode.
            mask (tf.Tensor, optional): A boolean mask for padding. Defaults to None.

        Returns:
            tf.Tensor: Output tensor of the same shape as inputs.
        """
        # Normalize and feed to a dense layer
        inputs = self.layernorm_1(inputs)
        inputs = self.dense_1(inputs)

        # Self-attention
        attention_output_1 = self.attention_1(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=None,
            training=training,
        )
        # Add & norm
        out_1 = self.layernorm_2(inputs + attention_output_1)
        return out_1


class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length: int, vocab_size: int, embed_dim: int, **kwargs):
        super().__init__(**kwargs)
        self.token_embeddings = layers.Embedding(
            input_dim=vocab_size, output_dim=embed_dim
        )
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=embed_dim
        )
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.embed_scale = tf.math.sqrt(tf.cast(embed_dim, tf.float32))

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        """
        Forward pass for positional embeddings.

        Args:
            inputs (tf.Tensor): Token IDs with shape (batch, sequence_length).

        Returns:
            tf.Tensor: Token embeddings + positional embeddings
        """
        length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_tokens = embedded_tokens * self.embed_scale
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs: tf.Tensor, mask=None) -> tf.Tensor:
        """
        Compute a mask for padding tokens (assumed to be '0').

        Args:
            inputs (tf.Tensor): The input token IDs.
            mask: Unused.

        Returns:
            tf.Tensor: Boolean mask where False means padding (token=0).
        """
        return tf.math.not_equal(inputs, 0)


class TransformerDecoderBlock(layers.Layer):
    def __init__(self, embed_dim: int, ff_dim: int, num_heads: int, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.1
        )
        self.attention_2 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.1
        )
        self.ffn_layer_1 = layers.Dense(ff_dim, activation="relu")
        self.ffn_layer_2 = layers.Dense(embed_dim)

        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()

        self.embedding = PositionalEmbedding(
            embed_dim=EMBED_DIM,
            sequence_length=SEQ_LENGTH,
            vocab_size=VOCAB_SIZE,
        )
        self.out = layers.Dense(VOCAB_SIZE, activation="softmax")

        self.dropout_1 = layers.Dropout(0.3)
        self.dropout_2 = layers.Dropout(0.5)
        self.supports_masking = True

    def call(
        self, inputs: tf.Tensor, encoder_outputs: tf.Tensor, training: bool, mask: tf.Tensor = None
    ) -> tf.Tensor:
        """
        Forward pass of the Transformer decoder block.

        Args:
            inputs (tf.Tensor): The token IDs for the decoder, shape (batch, seq_length).
            encoder_outputs (tf.Tensor): Outputs from the encoder, shape (batch, 1, features).
            training (bool): Whether the model is in training mode.
            mask (tf.Tensor, optional): Boolean mask. Defaults to None.

        Returns:
            tf.Tensor: Logits over VOCAB_SIZE, shape (batch, seq_length, vocab_size).
        """
        # Token + positional embeddings
        inputs = self.embedding(inputs)
        causal_mask = self.get_causal_attention_mask(inputs)

        if mask is not None:
            padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
            combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
            combined_mask = tf.minimum(combined_mask, causal_mask)
        else:
            combined_mask = causal_mask
            padding_mask = None

        # Decoder self-attention
        attention_output_1 = self.attention_1(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=combined_mask,
            training=training,
        )
        out_1 = self.layernorm_1(inputs + attention_output_1)

        # Cross-attention with encoder outputs
        attention_output_2 = self.attention_2(
            query=out_1,
            value=encoder_outputs,
            key=encoder_outputs,
            attention_mask=padding_mask,
            training=training,
        )
        out_2 = self.layernorm_2(out_1 + attention_output_2)

        # Feed-forward network
        ffn_out = self.ffn_layer_1(out_2)
        ffn_out = self.dropout_1(ffn_out, training=training)
        ffn_out = self.ffn_layer_2(ffn_out)

        ffn_out = self.layernorm_3(ffn_out + out_2, training=training)
        ffn_out = self.dropout_2(ffn_out, training=training)
        preds = self.out(ffn_out)
        return preds

    def get_causal_attention_mask(self, inputs: tf.Tensor) -> tf.Tensor:
        """
        Builds a causal (future-masking) attention mask to use in the
        self-attention mechanism.

        Args:
            inputs (tf.Tensor): The input tensor, shape (batch, seq_length, features).

        Returns:
            tf.Tensor: A causal mask of shape (batch, seq_length, seq_length).
        """
        input_shape = tf.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        i = tf.range(sequence_length)[:, tf.newaxis]
        j = tf.range(sequence_length)
        mask = tf.cast(i >= j, dtype="int32")
        mask = tf.reshape(mask, (1, sequence_length, sequence_length))
        mult = tf.concat(
            [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], axis=0
        )
        return tf.tile(mask, mult)


#####################################################
#       The main Image Captioning model class       #
#####################################################
class ImageCaptioningModel(keras.Model):
    def __init__(
        self,
        cnn_model: keras.Model,
        encoder: TransformerEncoderBlock,
        decoder: TransformerDecoderBlock,
        num_captions_per_image: int = 5,
        image_aug: tf.keras.Sequential = None,
    ):
        """
        A custom model to tie together the CNN-based feature extractor,
        Transformer encoder, and Transformer decoder.

        Args:
            cnn_model (keras.Model): Pretrained CNN feature extractor.
            encoder (TransformerEncoderBlock): Transformer encoder.
            decoder (TransformerDecoderBlock): Transformer decoder.
            num_captions_per_image (int, optional): Number of captions per image. Defaults to 5.
            image_aug (tf.keras.Sequential, optional): Optional image augmentation pipeline. Defaults to None.
        """
        super().__init__()
        self.cnn_model = cnn_model
        self.encoder = encoder
        self.decoder = decoder
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.acc_tracker = keras.metrics.Mean(name="accuracy")
        self.num_captions_per_image = num_captions_per_image
        self.image_aug = image_aug

    def apply_augmentation(self, images: tf.Tensor) -> tf.Tensor:
        """
        Applies random data augmentation transformations (e.g., flips, rotations, zoom)
        to the input images. The captions remain unchanged.

        Args:
            images (tf.Tensor): A batch of images of shape (batch, height, width, channels).

        Returns:
            tf.Tensor: A batch of augmented images, same shape as input.
        """
        if self.image_aug:
            images = self.image_aug(images)
        return images

    def calculate_loss(
        self, y_true: tf.Tensor, y_pred: tf.Tensor, mask: tf.Tensor
    ) -> tf.Tensor:
        """
        Applies the model's loss function, taking the mask into account.

        Args:
            y_true (tf.Tensor): Ground-truth token IDs, shape (batch, seq_length).
            y_pred (tf.Tensor): Predictions (logits), shape (batch, seq_length, vocab_size).
            mask (tf.Tensor): Boolean mask for non-padding tokens, shape (batch, seq_length).

        Returns:
            tf.Tensor: Scalar loss value.
        """
        loss = self.loss(y_true, y_pred)
        mask = tf.cast(mask, dtype=loss.dtype)
        loss *= mask
        return tf.reduce_sum(loss) / tf.reduce_sum(mask)

    def calculate_accuracy(
        self, y_true: tf.Tensor, y_pred: tf.Tensor, mask: tf.Tensor
    ) -> tf.Tensor:
        """
        Computes token-level accuracy under a given mask.

        Args:
            y_true (tf.Tensor): Ground-truth token IDs, shape (batch, seq_length).
            y_pred (tf.Tensor): Predictions (logits), shape (batch, seq_length, vocab_size).
            mask (tf.Tensor): Boolean mask for non-padding tokens, shape (batch, seq_length).

        Returns:
            tf.Tensor: Scalar accuracy value.
        """
        accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
        accuracy = tf.math.logical_and(mask, accuracy)
        accuracy = tf.cast(accuracy, dtype=tf.float32)
        mask = tf.cast(mask, dtype=tf.float32)
        return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)

    def _compute_caption_loss_and_acc(
        self, img_embed: tf.Tensor, batch_seq: tf.Tensor, training: bool = True
    ) -> tuple[tf.Tensor, tf.Tensor]:
        """
        Given a batch of image embeddings and corresponding caption sequences,
        computes the loss and accuracy for one forward pass.

        Args:
            img_embed (tf.Tensor): Image embeddings from the CNN, shape (batch, 1, features).
            batch_seq (tf.Tensor): Caption token IDs, shape (batch, seq_length).
            training (bool): Whether we are in training mode.

        Returns:
            tuple[tf.Tensor, tf.Tensor]: Loss and accuracy for the given batch.
        """
        encoder_out = self.encoder(img_embed, training=training)
        batch_seq_inp = batch_seq[:, :-1]  # all but last token
        batch_seq_true = batch_seq[:, 1:]  # all but first token
        mask = tf.math.not_equal(batch_seq_true, 0)

        batch_seq_pred = self.decoder(
            batch_seq_inp, encoder_out, training=training, mask=mask
        )
        loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask)
        acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask)
        return loss, acc

    def train_step(self, batch_data: tuple[tf.Tensor, tf.Tensor]) -> dict:
        """
        Defines the forward + backward pass under the Keras .fit() loop.

        Args:
            batch_data (tuple[tf.Tensor, tf.Tensor]): A tuple of (images, token_sequences),
                where 'images' shape = (batch_size, height, width, channels)
                and 'token_sequences' shape = (batch_size, num_captions_per_image, seq_length).

        Returns:
            dict: A dictionary of {'loss': ..., 'acc': ...} metrics.
        """
        batch_img, batch_seq = batch_data
        batch_loss = 0.0
        batch_acc = 0.0

        # 1. Augment images if available (the captions remain the same)
        batch_img = self.apply_augmentation(batch_img)

        # 2. Get image embeddings
        img_embed = self.cnn_model(batch_img)

        # 3. Pass each of the 'num_captions_per_image' captions
        for i in range(self.num_captions_per_image):
            with tf.GradientTape() as tape:
                loss, acc = self._compute_caption_loss_and_acc(
                    img_embed, batch_seq[:, i, :], training=True
                )
            batch_loss += loss
            batch_acc += acc

            # Update only encoder & decoder trainable variables
            train_vars = (
                self.encoder.trainable_variables + self.decoder.trainable_variables
            )
            grads = tape.gradient(loss, train_vars)
            self.optimizer.apply_gradients(zip(grads, train_vars))

        # Average the accuracy over the number of captions
        batch_acc /= float(self.num_captions_per_image)

        self.loss_tracker.update_state(batch_loss)
        self.acc_tracker.update_state(batch_acc)

        return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}

    def test_step(self, batch_data: tuple[tf.Tensor, tf.Tensor]) -> dict:
        """
        Defines the forward pass for validation/testing under the Keras .fit() loop.

        Args:
            batch_data (tuple[tf.Tensor, tf.Tensor]): A tuple of (images, token_sequences).

        Returns:
            dict: A dictionary of {'loss': ..., 'acc': ...} metrics.
        """
        batch_img, batch_seq = batch_data
        batch_loss = 0.0
        batch_acc = 0.0

        # 1. Get image embeddings
        img_embed = self.cnn_model(batch_img)

        # 2. Pass each of the 'num_captions_per_image' captions
        for i in range(self.num_captions_per_image):
            loss, acc = self._compute_caption_loss_and_acc(
                img_embed, batch_seq[:, i, :], training=False
            )
            batch_loss += loss
            batch_acc += acc

        # Average the accuracy
        batch_acc /= float(self.num_captions_per_image)

        self.loss_tracker.update_state(batch_loss)
        self.acc_tracker.update_state(batch_acc)

        return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}

    @property
    def metrics(self) -> list:
        """
        Lists the metrics so that reset_states() can be called automatically
        at the start of each epoch in model.fit().

        Returns:
            list: A list containing the loss and accuracy trackers.
        """
        return [self.loss_tracker, self.acc_tracker]


###############################################
# Instantiate your updated image caption model
###############################################
# 1. Create the multi-backbone CNN model
cnn_model = get_cnn_model()

# 2. Create transformer encoder & decoder
encoder = TransformerEncoderBlock(embed_dim=EMBED_DIM, dense_dim=FF_DIM, num_heads=1)
decoder = TransformerDecoderBlock(embed_dim=EMBED_DIM, ff_dim=FF_DIM, num_heads=2)

# 3. (Optional) image augmentation, if you want to do random flips, rotations, etc.
image_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.1),
    ],
    name="image_augmentation",
)

# 4. Build the final image captioning model with data augmentation
caption_model = ImageCaptioningModel(
    cnn_model=cnn_model,
    encoder=encoder,
    decoder=decoder,
    image_aug=image_augmentation,
)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m58889256/58889256[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 0us/step


  mobilenet_base = MobileNetV2(


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
[1m9406464/9406464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m87910968/87910968[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 0us/step
Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5
[1m16705208/16705208[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


## Model training

In [None]:
# Define the loss function
cross_entropy = keras.losses.SparseCategoricalCrossentropy(
    from_logits=False,
    reduction=None,
)

# EarlyStopping criteria
early_stopping = keras.callbacks.EarlyStopping(patience=50, restore_best_weights=True)


# Learning Rate Scheduler for the optimizer
class LRSchedule(keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, post_warmup_learning_rate, warmup_steps):
        super().__init__()
        self.post_warmup_learning_rate = post_warmup_learning_rate
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        global_step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)
        warmup_progress = global_step / warmup_steps
        warmup_learning_rate = self.post_warmup_learning_rate * warmup_progress
        return tf.cond(
            global_step < warmup_steps,
            lambda: warmup_learning_rate,
            lambda: self.post_warmup_learning_rate,
        )


# Create a learning rate schedule
num_train_steps = len(train_dataset) * EPOCHS
num_warmup_steps = num_train_steps // 15
lr_schedule = LRSchedule(post_warmup_learning_rate=1e-3, warmup_steps=num_warmup_steps)

# Compile the model
caption_model.compile(optimizer=keras.optimizers.Adam(lr_schedule), loss=cross_entropy)

# Fit the model
caption_model.fit(
    train_dataset,
    epochs=EPOCHS,
    validation_data=valid_dataset,
    callbacks=[early_stopping],
)

Epoch 1/100




[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m353s[0m 3s/step - acc: 0.1797 - loss: 31.2147 - val_acc: 0.3352 - val_loss: 18.1499
Epoch 2/100
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m244s[0m 3s/step - acc: 0.3394 - loss: 17.9645 - val_acc: 0.3491 - val_loss: 17.0664
Epoch 3/100
[1m96/96[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - acc: 0.3605 - loss: 16.4004

## Check sample predictions

In [None]:
vocab = vectorization.get_vocabulary()
index_lookup = dict(zip(range(len(vocab)), vocab))
max_decoded_sentence_length = SEQ_LENGTH - 1
valid_images = list(valid_data.keys())


def generate_caption():
    # Select a random image from the validation dataset
    sample_img = np.random.choice(valid_images)

    # Read the image from the disk
    sample_img = decode_and_resize(sample_img)
    img = sample_img.numpy().clip(0, 255).astype(np.uint8)
    plt.imshow(img)
    plt.show()

    # Pass the image to the CNN
    img = tf.expand_dims(sample_img, 0)
    img = caption_model.cnn_model(img)

    # Pass the image features to the Transformer encoder
    encoded_img = caption_model.encoder(img, training=False)

    # Generate the caption using the Transformer decoder
    decoded_caption = "<start> "
    for i in range(max_decoded_sentence_length):
        tokenized_caption = vectorization([decoded_caption])[:, :-1]
        mask = tf.math.not_equal(tokenized_caption, 0)
        predictions = caption_model.decoder(
            tokenized_caption, encoded_img, training=False, mask=mask
        )
        sampled_token_index = np.argmax(predictions[0, i, :])
        sampled_token = index_lookup[sampled_token_index]
        if sampled_token == "<end>":
            break
        decoded_caption += " " + sampled_token

    decoded_caption = decoded_caption.replace("<start> ", "")
    decoded_caption = decoded_caption.replace(" <end>", "").strip()
    print("Predicted Caption: ", decoded_caption)


# Check predictions for a few samples
generate_caption()
generate_caption()
generate_caption()

## End Notes

We saw that the model starts to generate reasonable captions after a few epochs. To keep
this example easily runnable, we have trained it with a few constraints, like a minimal
number of attention heads. To improve the predictions, you can try changing these training
settings and find a good model for your use case.