<a href="https://colab.research.google.com/github/ritwiks9635/Image_Captioning/blob/main/Image_Captioning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**ðŸŽ‘ImageðŸŒ‰ðŸ’¬CaptioningðŸ”¡**

[DATASET](https://www.kaggle.com/datasets/adityajn105/flickr8k?select=Images)

In [None]:
!unzip /content/https:/www.kaggle.com/datasets/adityajn105/flickr8k/flickr8k.zip

In [3]:
import re
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import TextVectorization
from keras.applications import efficientnet

tf.random.set_seed(111)

In [4]:
image_dir = "/content/Images"
caption_file = "/content/captions.txt"

batch_size = 64
image_size = (299, 299)
vocab_size = 10000
sequence_length = 25
embedding_dim = 512
ff_dim = 512

In [7]:
def load_caption_file(filename):
    with open(filename) as files:
        caption_data = files.readlines()
        caption_mapping = {}
        text_data = []
        image_to_skip = set()

        for line in caption_data:
            line = line.rstrip("\n")
            lines = line.split(",")
            img_path = lines[0]
            caption = lines[1]

            img_path = os.path.join(image_dir, img_path.strip())

            tokens = caption.strip().split()
            if len(tokens) < 5 or len(tokens) > sequence_length:
                image_to_skip.add(img_path)
                continue

            if img_path.endswith(".jpg") and img_path not in image_to_skip:
                caption = "<start> " + caption.strip() + " <end>"
                text_data.append(caption)

                if img_path in caption_mapping:
                    caption_mapping[img_path].append(caption)
                else:
                    caption_mapping[img_path] = [caption]

        for img_path in image_to_skip:
            if img_path in caption_mapping:
                del caption_mapping[img_path]

        return caption_mapping, text_data

In [8]:
caption_mapping, text_data = load_caption_file(caption_file)

In [9]:
def train_val_split(caption_data, train_size = 0.8, shuffle = True):
    all_images = list(caption_data.keys())

    if shuffle:
        np.random.shuffle(all_images)

    train_split = int(len(all_images) * train_size)

    train_data = {img_path : caption_data[img_path] for img_path in all_images[: train_split]}

    valid_data = {img_path : caption_data[img_path] for img_path in all_images[train_split :]}

    return train_data, valid_data


train_data, valid_data = train_val_split(caption_mapping)
print("Total train data is :: ", len(train_data))
print("Total valid data is :: ", len(valid_data))

Total train data is ::  5616
Total valid data is ::  1405


In [10]:
def custom_standardization(input_string):
    lowercase = tf.strings.lower(input_string)
    return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")

strip_chars = "!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
strip_chars = strip_chars.replace("<", "")
strip_chars = strip_chars.replace(">", "")

vectorize_layer = TextVectorization(
    standardize = custom_standardization,
    max_tokens = vocab_size,
    output_sequence_length = sequence_length,
    output_mode = "int")

vectorize_layer.adapt(text_data)

In [11]:
def load_image(img_path):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_jpeg(img, channels = 3)
    img = tf.image.resize(img, image_size)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return img


def preprocess_data(img_path, caption):
    img = load_image(img_path)
    caption = vectorize_layer(caption)
    return img, caption


def build_dataset(images, caption):
    dataset = tf.data.Dataset.from_tensor_slices((images, caption))
    dataset = dataset.shuffle(batch_size * 8)
    dataset = dataset.map(preprocess_data, num_parallel_calls = tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size = tf.data.AUTOTUNE)
    return dataset

train_dataset = build_dataset(list(train_data.keys()), list(train_data.values()))
valid_dataset = build_dataset(list(valid_data.keys()), list(valid_data.values()))

In [12]:
for i, j in train_dataset.take(1):
    print(i.shape)
    print(j.shape)
    break

(64, 299, 299, 3)
(64, 5, 25)


In [13]:
data_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor = 0.2),
        layers.RandomContrast(0.3)
    ])

## Building the model

Our image captioning architecture consists of three models:

1. A CNN: used to extract the image features
2. A TransformerEncoder: The extracted image features are then passed to a Transformer based encoder that generates a new representation of the inputs
3. A TransformerDecoder: This model takes the encoder output and the text data (sequences) as inputs and tries to learn to generate the caption.

In [14]:
def feature_extractor():
    base_model = efficientnet.EfficientNetB0(
        input_shape = (*image_size, 3),
        weights = "imagenet",
        include_top = False)

    base_model.trainable = False

    base_model_output = base_model.output

    base_model_output = layers.Reshape((-1, base_model_output.shape[-1]))(base_model_output)

    feature_extractor = keras.Model(base_model.input, base_model_output)
    return feature_extractor

In [15]:
class TransformerEncoderBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        #self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention_layer = layers.MultiHeadAttention(num_heads = num_heads, key_dim = embed_dim, dropout = 0.0)
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.dense_layer = layers.Dense(self.embed_dim, activation = "relu")

    def call(self, inputs, training, mask = None):
        inputs = self.layernorm_1(inputs)
        inputs = self.dense_layer(inputs)
        attention_output = self.attention_layer(
            query = inputs,
            value = inputs,
            key = inputs,
            attention_mask = mask,
            training = training)

        output = self.layernorm_2(inputs + attention_output)
        return output


class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.token_embed = layers.Embedding(input_dim = vocab_size, output_dim  = embed_dim)
        self.pos_embed = layers.Embedding(input_dim = sequence_length, output_dim = embed_dim)
        self.embed_scale = tf.math.sqrt(tf.cast(embed_dim, tf.float32))


    def call(self, inputs):
        length = tf.shape(inputs)[-1]
        positions = tf.range(start = 0, limit = length, delta = 1)
        embedded_tokens = self.token_embed(inputs)
        embedded_tokens = embedded_tokens + self.embed_scale
        embedded_positions = self.pos_embed(positions)
        return embedded_positions + embedded_tokens

    def compute_mask(self, inputs, mask = None):
        return tf.math.not_equal(inputs, 0)

In [16]:
class TransformerDecoderBlock(layers.Layer):
    def __init__(self, embed_dim, ff_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.1
        )
        self.attention_2 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.1
        )
        self.ffn_layer_1 = layers.Dense(ff_dim, activation="relu")
        self.ffn_layer_2 = layers.Dense(embed_dim)

        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()

        self.embedding = PositionalEmbedding(
            embed_dim = embedding_dim,
            sequence_length = sequence_length,
            vocab_size = vocab_size,
        )
        self.out = layers.Dense(vocab_size, activation="softmax")

        self.dropout_1 = layers.Dropout(0.3)
        self.dropout_2 = layers.Dropout(0.5)
        self.supports_masking = True

    def call(self, inputs, encoder_outputs, training, mask=None):
        inputs = self.embedding(inputs)
        causal_mask = self.get_causal_attention_mask(inputs)

        if mask is not None:
            padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
            combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
            combined_mask = tf.minimum(combined_mask, causal_mask)

        attention_output_1 = self.attention_1(
            query=inputs,
            value=inputs,
            key=inputs,
            attention_mask=combined_mask,
            training=training,
        )
        out_1 = self.layernorm_1(inputs + attention_output_1)

        attention_output_2 = self.attention_2(
            query=out_1,
            value=encoder_outputs,
            key=encoder_outputs,
            attention_mask=padding_mask,
            training=training,
        )
        out_2 = self.layernorm_2(out_1 + attention_output_2)

        ffn_out = self.ffn_layer_1(out_2)
        ffn_out = self.dropout_1(ffn_out, training=training)
        ffn_out = self.ffn_layer_2(ffn_out)

        ffn_out = self.layernorm_3(ffn_out + out_2, training=training)
        ffn_out = self.dropout_2(ffn_out, training=training)
        preds = self.out(ffn_out)
        return preds

    def get_causal_attention_mask(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        i = tf.range(sequence_length)[:, tf.newaxis]
        j = tf.range(sequence_length)
        mask = tf.cast(i >= j, dtype="int32")
        mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
        mult = tf.concat(
            [
                tf.expand_dims(batch_size, -1),
                tf.constant([1, 1], dtype=tf.int32),
            ],
            axis=0,
        )
        return tf.tile(mask, mult)

In [17]:
class ImageCaptioningModel(keras.Model):
    def __init__(self, base_model, encoder, decoder, num_captions_per_image=5, image_aug=None):
        super().__init__()
        self.base_model = base_model
        self.encoder = encoder
        self.decoder = decoder
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.acc_tracker = keras.metrics.Mean(name="accuracy")
        self.num_captions_per_image = num_captions_per_image
        self.image_aug = image_aug

    def calculate_loss(self, y_true, y_pred, mask):
        loss = self.loss(y_true, y_pred)
        mask = tf.cast(mask, dtype=loss.dtype)
        loss *= mask
        return tf.reduce_sum(loss) / tf.reduce_sum(mask)

    def calculate_accuracy(self, y_true, y_pred, mask):
        accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
        accuracy = tf.math.logical_and(mask, accuracy)
        accuracy = tf.cast(accuracy, dtype = tf.float32)
        mask = tf.cast(mask, dtype=tf.float32)
        return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)

    def _compute_caption_loss_and_acc(self, img_embed, batch_seq, training=True):
        encoder_out = self.encoder(img_embed, training=training)
        batch_seq_inp = batch_seq[:, :-1]
        batch_seq_true = batch_seq[:, 1:]
        mask = tf.math.not_equal(batch_seq_true, 0)
        batch_seq_pred = self.decoder(
            batch_seq_inp, encoder_out, training=training, mask=mask
        )
        loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask)
        acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask)
        return loss, acc

    def train_step(self, batch_data):
        batch_img, batch_seq = batch_data
        batch_loss = 0
        batch_acc = 0

        if self.image_aug:
            batch_img = self.image_aug(batch_img)

        # 1. Get image embeddings
        img_embed = self.base_model(batch_img)

        # 2. Pass each of the five captions one by one to the decoder
        # along with the encoder outputs and compute the loss as well as accuracy
        # for each caption.
        for i in range(self.num_captions_per_image):
            with tf.GradientTape() as tape:
                loss, acc = self._compute_caption_loss_and_acc(
                    img_embed, batch_seq[:, i, :], training=True
                )

                # 3. Update loss and accuracy
                batch_loss += loss
                batch_acc += acc

            # 4. Get the list of all the trainable weights
            train_vars = (
                self.encoder.trainable_variables + self.decoder.trainable_variables
            )

            # 5. Get the gradients
            grads = tape.gradient(loss, train_vars)

            # 6. Update the trainable weights
            self.optimizer.apply_gradients(zip(grads, train_vars))

        # 7. Update the trackers
        batch_acc /= float(self.num_captions_per_image)
        self.loss_tracker.update_state(batch_loss)
        self.acc_tracker.update_state(batch_acc)

        # 8. Return the loss and accuracy values
        return {
            "loss": self.loss_tracker.result(),
            "acc": self.acc_tracker.result(),
        }

    def test_step(self, batch_data):
        batch_img, batch_seq = batch_data
        batch_loss = 0
        batch_acc = 0

        # 1. Get image embeddings
        img_embed = self.base_model(batch_img)

        # 2. Pass each of the five captions one by one to the decoder
        # along with the encoder outputs and compute the loss as well as accuracy
        # for each caption.
        for i in range(self.num_captions_per_image):
            loss, acc = self._compute_caption_loss_and_acc(
                img_embed, batch_seq[:, i, :], training=False
            )

            # 3. Update batch loss and batch accuracy
            batch_loss += loss
            batch_acc += acc

        batch_acc /= float(self.num_captions_per_image)

        # 4. Update the trackers
        self.loss_tracker.update_state(batch_loss)
        self.acc_tracker.update_state(batch_acc)

        # 5. Return the loss and accuracy values
        return {
            "loss": self.loss_tracker.result(),
            "acc": self.acc_tracker.result(),
        }

    @property
    def metrics(self):
        # We need to list our metrics here so the `reset_states()` can be
        # called automatically.
        return [self.loss_tracker, self.acc_tracker]

In [18]:
base_model = feature_extractor()
encoder = TransformerEncoderBlock(embed_dim = embedding_dim, num_heads = 1)
decoder = TransformerDecoderBlock(embed_dim = embedding_dim, ff_dim = ff_dim, num_heads = 2)
captioning_model = ImageCaptioningModel(
    base_model = base_model,
    encoder = encoder,
    decoder = decoder,
    image_aug = data_augmentation)

Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5


In [19]:
class LRSchedule(keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, post_warmup_learning_rate, warmup_steps):
        super().__init__()
        self.post_warmup_learning_rate = post_warmup_learning_rate
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        global_step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)
        warmup_progress = global_step / warmup_steps
        warmup_learning_rate = self.post_warmup_learning_rate * warmup_progress
        return tf.cond(
            global_step < warmup_steps,
            lambda: warmup_learning_rate,
            lambda: self.post_warmup_learning_rate,
        )

num_train_step = len(train_dataset) * 30
num_warmup_steps = num_train_step // 15
lr_schedule = LRSchedule(post_warmup_learning_rate=1e-4, warmup_steps=num_warmup_steps)

captioning_model.compile(
    optimizer = keras.optimizers.Adam(learning_rate = lr_schedule),
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits = False, reduction = "none")
    )

In [None]:
early_stopping_cb = keras.callbacks.EarlyStopping(monitor = "val_loss", patience = 3, restore_best_weights = True)

captioning_model.fit(
    train_dataset,
    validation_data = valid_dataset,
    epochs = 50,
    callbacks = [early_stopping_cb])

In [None]:
vocab = vectorize_layer.get_vocabulary()
index_lookup = dict(zip(range(len(vocab)), vocab))
max_decoded_sentence_length = sequence_length - 1
valid_images = list(valid_data.keys())


def generate_caption():
    # Select a random image from the validation dataset
    sample_img = np.random.choice(valid_images)

    # Read the image from the disk
    sample_img = load_image(sample_img)
    img = sample_img.numpy().clip(0, 255).astype(np.uint8)
    plt.imshow(img)
    plt.show()

    # Pass the image to the CNN
    img = tf.expand_dims(sample_img, 0)
    img = captioning_model.base_model(img)

    # Pass the image features to the Transformer encoder
    encoded_img = captioning_model.encoder(img, training=False)

    # Generate the caption using the Transformer decoder
    decoded_caption = "<start> "
    for i in range(max_decoded_sentence_length):
        tokenized_caption = vectorize_layer([decoded_caption])[:, :-1]
        mask = tf.math.not_equal(tokenized_caption, 0)
        predictions = captioning_model.decoder(
            tokenized_caption, encoded_img, training=False, mask=mask
        )
        sampled_token_index = np.argmax(predictions[0, i, :])
        sampled_token = index_lookup[sampled_token_index]
        if sampled_token == "<end>":
            break
        decoded_caption += " " + sampled_token

    decoded_caption = decoded_caption.replace("<start> ", "")
    decoded_caption = decoded_caption.replace(" <end>", "").strip()
    print("Predicted Caption: ", decoded_caption)


# Check predictions for a few samples
generate_caption()
generate_caption()
generate_caption()