# Image Captioning

## Setup

### Import modules

In [None]:
import collections
import json
import os
import pathlib
import random
import re
import time

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from PIL import Image
from tqdm import tqdm

In [None]:
devices = tf.config.experimental.list_physical_devices("GPU")
for device in devices:
    tf.config.experimental.set_memory_growth(device, enable=True)

### Define constants

In [None]:
MAX_CAPTION_LENGTH = 30

## Load dataset from Flickr8k

In [None]:
annotation_folder = pathlib.Path("data/flickr8k/Flickr8k_text")
image_folder = pathlib.Path("data/flickr8k/Flickr8k_Dataset")

In [None]:
# Create a dictionary from all captions
captions = (annotation_folder/"Flickr8k.token.txt").read_text().splitlines()
captions = (line.split('\t') for line in captions)
captions = ((image.split('#')[0], f"<start> {caption} <end>") for (image, caption) in captions)

cap_dict = collections.defaultdict(list)
for image, cap in captions:
    cap_dict[image].append(cap)

In [None]:
# Create vocabulary
word_counter = collections.Counter([word for caption in captions for word in caption.split()])
vocabulary = [word for word, count in word_counter.items() if count >= 5]  # Minimum word count threshold
word_to_index = {word: i for i, word in enumerate(vocabulary)}
index_to_word = {i: word for i, word in enumerate(vocabulary)}

In [None]:
def load_images_and_captions(captions_file):
    image_files = (annotation_folder/captions_file).read_text().splitlines()
    loaded_images, loaded_captions = [], []
    for image_file in image_files:
        loaded_images.extend([str(image_folder/image_file)] * len(cap_dict[image_file]))
        loaded_captions.extend(cap_dict[image_file])
    return loaded_images, loaded_captions

In [None]:
# Load the image paths and captions
train_images, train_captions = load_images_and_captions("Flickr_8k.trainImages.txt")

dev_images, dev_captions = load_images_and_captions("Flickr_8k.devImages.txt")

test_images, test_captions = load_images_and_captions("Flickr_8k.testImages.txt")

In [None]:
print(len(cap_dict), len(train_images), len(train_captions), len(dev_images), len(dev_captions), len(test_images), len(test_captions))
train_images[:2]

In [None]:
def load_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (299, 299))
    image = tf.keras.applications.inception_v3.preprocess_input(image)
    return image, image_path

In [None]:
# Create dataset
encoded_train = sorted(set(train_images))
train_dataset = tf.data.Dataset.from_tensor_slices(encoded_train)
train_dataset = train_dataset.map(load_image).batch(32)

encoded_dev = sorted(set(dev_images))
dev_dataset = tf.data.Dataset.from_tensor_slices(encoded_dev)
dev_dataset = dev_dataset.map(load_image).batch(32)

encoded_test = sorted(set(test_images))
test_dataset = tf.data.Dataset.from_tensor_slices(encoded_test)
test_dataset = test_dataset.map(load_image).batch(32)

## ~~Commented~~

In [None]:
# Create a model to extract the image features
inception_v3 = tf.keras.applications.InceptionV3(include_top=False, weights='imagenet')
new_input = inception_v3.input
hidden_layer = inception_v3.layers[-1].output

model = tf.keras.Model(new_input, hidden_layer)

In [None]:
for image, path in tqdm(train_dataset, total=len(train_dataset)):
    batch_features = model(image)
    batch_features = tf.reshape(batch_features, (batch_features.shape[0], -1, batch_features.shape[3]))
    for bf, p in zip(batch_features, path):
        path_of_features = p.numpy().decode("utf-8")
        print(path_of_features)
        break
        # np.save(p, bf.numpy())
    break

## Load the CNN Encoder (Inception V3)

In [None]:
# Load the pre-trained InceptionV3 model without the top layers (for feature extraction)
encoder = tf.keras.applications.InceptionV3(include_top=False, weights="imagenet")

# Freeze the pre-trained model weights (optional, to prevent retraining them)
for layer in encoder.layers:
  layer.trainable = False

## Define the RNN Decoder (LSTM)

In [None]:
# Define the embedding dimension for word representation
embedding_dim = 256

# Define the vocabulary size (number of unique words) based on your preprocessed captions
vocabulary_size = len(vocabulary)

# Create the embedding layer to map words to vectors
embedding = tf.keras.layers.Embedding(vocabulary_size, embedding_dim, mask_zero=True)

# Define the LSTM layers
lstm1 = tf.keras.layers.LSTM(256, return_sequences=True)
lstm2 = tf.keras.layers.LSTM(256)

# Define the dense layer for predicting the next word
decoder_dense = tf.keras.layers.Dense(vocabulary_size, activation='softmax')

## Create model by combining the Encoder and the Decoder

In [None]:
# Input for the image
image_input = tf.keras.models.Input(shape=(299, 299, 3))

# Extract features from the image using the encoder
encoded_features = encoder(image_input)

# Define a repeat vector to feed the same features at each step of the decoder
decoder_hidden = tf.keras.layers.RepeatVector(MAX_CAPTION_LENGTH)(encoded_features)

# Pass the encoded features and hidden state through the LSTM layers
decoder_output = lstm1(decoder_hidden)
decoder_output = lstm2(decoder_output)

# Predict the next word probability distribution
decoder_logits = decoder_dense(decoder_output)

# Define the model by connecting inputs and outputs
model = tf.keras.models.Model(inputs=image_input, outputs=decoder_logits)

# Compile the model
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

## Train the model

In [None]:
# Train the model on the prepared training data for multiple epochs
model.fit(train_images, train_captions, epochs=10, validation_data=(dev_images, dev_captions))

## Save the model

In [None]:
model.save('image_captioning_model.h5')

## Load the saved model

In [None]:
model = tf.keras.models.load_model('image_captioning_model.h5')

## Generate captions

In [None]:
def generate_caption(image_path):
  # Load and preprocess the image
  image, _ = load_image(image_path)
  
  # Get encoded features from the image
  encoded_image = encoder.predict(np.expand_dims(image, axis=0))
  
  # Initialize variables for caption generation
  max_len = MAX_CAPTION_LENGTH  # Define the maximum caption length
  sequence = [word_to_index['<start>']]
  
  # Generate caption word by word
  for _ in range(max_len):
    # One-hot encode the current sequence
    current_sequence = np.array([sequence])
    predicted_probs = model.predict(current_sequence)
    predicted_index = np.argmax(predicted_probs)
    
    # Check for end of caption or maximum length reached
    if predicted_index == word_to_index['<end>'] or len(sequence) >= max_len:
      break
    
    sequence.append(predicted_index)
  
  # Convert the predicted word indices to actual words
  caption = [index_to_word[idx] for idx in sequence[1:-1]]  # Exclude start and end tokens
  return ' '.join(caption)


## Test with Real Image