# Libraries

In [1]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/ImageCaptioning

Mounted at /content/drive
/content/drive/MyDrive/ImageCaptioning


In [2]:
import re
import os
import cv2
import glob
import spacy
import random
import numpy as np
import pandas as pd
from time import time
from PIL import Image
from tqdm import tqdm
import tensorflow as tf
from collections import Counter
import matplotlib.pyplot as plt
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction


import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import resnet50, ResNet50_Weights

device = 'cuda' if torch.cuda.is_available() else 'cpu'
spacy_eng = spacy.load("en_core_web_sm")

# Dataset

In [3]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.index2word = {0:"<PAD>", 1:"<SOS>", 2:"<EOS>", 3:"<UNK>"}
        self.word2index = {v: k for k, v in self.index2word.items()}

        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.index2word)

    @staticmethod
    def tokenize(text):
        return [token.text.lower() for token in spacy_eng.tokenizer(text)]

    def build_vocab(self, sentence_list):
        frequencies = Counter()
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenize(sentence):
                frequencies[word] += 1

                #add the word to the vocab if it reaches minum frequecy threshold
                if frequencies[word] == self.freq_threshold:
                    self.word2index[word] = idx
                    self.index2word[idx] = word
                    idx += 1

    def numericalize(self, text):
        """ For each word in the text corresponding index token for that word form the vocab built as list """
        tokenized_text = self.tokenize(text)
        return [self.word2index[token] if token in self.word2index else self.word2index["<UNK>"] for token in tokenized_text ]

In [4]:
class ImageCaptioningDataset(Dataset):
    """Image Captioning dataset"""

    def __init__(self, csv_file, transform, freq_threshold=5):
        self.dataframe = pd.read_csv(csv_file)
        self.transform = transform

        self.images = sorted(os.listdir("dataset/Images"))
        self.captions = self.dataframe['caption']

        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab(self.captions.tolist())


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        captions = self.captions[5 * idx: 5 * idx + 5].tolist()
        image_path = self.images[idx]

        image = cv2.imread(f'dataset/Images/{image_path}')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            image = self.transform(image)

        caption_vec = []
        caption_vec.append(torch.full((50,), 0))
        for cap in captions:
            temp = self.vocab.numericalize(cap)
            caption_vec.append(torch.tensor(temp))

        targets = pad_sequence(caption_vec, batch_first=True, padding_value=0)

        return image, targets

# Model

## Image

In [5]:
class ImageFeatureExtractor(torch.nn.Module):
    def __init__(self):
        super(ImageFeatureExtractor, self).__init__()

        # Load pretrained model and remove last fc layer
        pretrained_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        self.model = torch.nn.Sequential(*list(pretrained_model.children())[:-2]).to(device)

        # Freeze layer
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, images):
        # Preprocess images
        images = images.to(device)

        features = self.model(images)                                       # (batch_size, 2048, 7, 7)
        features = features.permute(0, 2, 3, 1)                             # (batch_size, 7, 7, 2048)
        features = features.view(features.size(0), -1, features.size(-1))   # (batch_size, 49, 2048)
        return features

## Attention


In [6]:
class Attention(torch.nn.Module):
    def __init__(self, attention_dim, encoder_dim, decoder_dim):
        super(Attention, self).__init__()

        self.attention_dim = attention_dim
        self.W_layer = torch.nn.Linear(decoder_dim, attention_dim).to(device)
        self.U_layer = torch.nn.Linear(encoder_dim, attention_dim).to(device)
        self.V_layer = torch.nn.Linear(attention_dim, 1).to(device)

    def forward(self, keys, query):
        U = self.U_layer(keys)     # (batch_size, num_layers, attention_dim)
        W = self.W_layer(query) # (batch_size, attention_dim)

        combined = torch.tanh(U + W.unsqueeze(1)) # (batch_size, num_layers, attention_dim)
        score = self.V_layer(combined)  # (batch_size, num_layers, 1)
        score = score.squeeze(2) # (batch_size, num_layers)

        weights = F.softmax(score, dim=1)    # (batch_size, num_layers)

        context = keys * weights.unsqueeze(2) # (batch_size, num_layers, feature_dim)
        context = context.sum(dim=1)   # (batch_size, feature_dim)
        return context, weights


## Text

In [7]:
class TextFeatureExtractor(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim, attention_dim, encoder_dim, decoder_dim, drop_prob=0.3):
        super(TextFeatureExtractor, self).__init__()
        self.vocab_size = vocab_size

        # Embedding layer
        self.embedding = torch.nn.Embedding(vocab_size, embed_dim).to(device)

        # LSTM layer
        self.lstm = torch.nn.LSTMCell(input_size=embed_dim + encoder_dim,
                                      hidden_size=decoder_dim, bias=True).to(device)

        # Linear layer
        self.fcn = torch.nn.Linear(decoder_dim, self.vocab_size).to(device)
        self.drop = torch.nn.Dropout(drop_prob)

        # Attention layer
        self.init_h = torch.nn.Linear(encoder_dim, decoder_dim).to(device)
        self.init_c = torch.nn.Linear(encoder_dim, decoder_dim).to(device)
        self.attention = Attention(attention_dim, encoder_dim, decoder_dim)

    def init_hidden_state(self, features):
        mean_features = features.mean(dim=1)
        h = self.init_h(mean_features)
        c = self.init_c(mean_features)
        return h, c

    def forward_step(self, embed_word, features, hidden_state, cell_state):
        # Computation between features and hidden state to create a context vector
        context, attn_weight = self.attention(features, hidden_state)

        # Compute feature vector of input text
        lstm_input = torch.cat((embed_word, context), dim=1)

        hidden_state, cell_state = self.lstm(lstm_input, (hidden_state, cell_state))

        # Predicted vector
        output = self.fcn(self.drop(hidden_state))

        return output, hidden_state, attn_weight

    def forward(self, features, sequences):
        # Sequence
        sequence_length = len(sequences[0]) - 1
        sequences = sequences.to(device)

        # Prediction store
        preds = torch.zeros(sequences.shape[0], sequence_length, self.vocab_size).to(device)

        # Embedding sequence
        embeds = self.embedding(sequences)
        embeds = embeds.to(torch.float32)

        # Init hidden state
        hidden_state, cell_state = self.init_hidden_state(features)

        # Forward pass
        for idx in range(sequence_length):
            embed_word = embeds[:, idx]

            # Predicted vector
            output, hidden_state, _ = self.forward_step(embed_word, features, hidden_state, cell_state)

            # Store output
            preds[:, idx] = output

        return preds

    def predict(self, feature, max_length, vocab=None):
        # Starting input
        word = torch.tensor(vocab.word2index['<SOS>']).view(1, -1).to(device)
        feature = feature.to(device)

        # Embedding sequence
        embeds = self.embedding(word)

        captions = []
        attention = []
        hidden_state, cell_state = self.init_hidden_state(feature)

        for idx in range(max_length):
            embed_word = embeds[:, 0]
            output, hidden_state, attn_weight = self.forward_step(embed_word, feature, hidden_state, cell_state)
            attention.append(attn_weight.cpu().detach().numpy())

            # Predict word index
            predicted_word_idx = output.argmax(dim=1)

            # End if <EOS> appears
            if vocab.index2word[predicted_word_idx.item()] == "<EOS>":
                break

            captions.append(predicted_word_idx.item())

            # Send generated word as the next caption
            embeds = self.embedding(predicted_word_idx.unsqueeze(0))

        # Convert the vocab idx to words and return sentence
        return ' '.join([vocab.index2word[idx] for idx in captions]), attention


    def predict_batch(self, feature, max_length, vocab=None):
        # Starting input
        word = torch.full((feature.shape[0], 1), vocab.word2index['<SOS>']).to(device)
        feature = feature.to(device)

        # Embedding sequence
        embeds = self.embedding(word)
        predicted_captions = torch.zeros(20, feature.shape[0])
        hidden_state, cell_state = self.init_hidden_state(feature)

        for idx in range(max_length):
            embed_word = embeds[:, 0]
            output, hidden_state, attn_weight = self.forward_step(embed_word, feature, hidden_state, cell_state)
            # Predict word index
            predicted_word_idx = output.argmax(dim=1)
            predicted_captions[idx, :] = predicted_word_idx.unsqueeze(0)[:, :]

            # Send generated word as the next caption
            embeds = self.embedding(predicted_word_idx.unsqueeze(1))
        predicted_captions = predicted_captions.permute(1, 0)
        return predicted_captions

## Captioner

In [8]:
class Captioner(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim, attention_dim, encoder_dim, decoder_dim, vocab):
        super(Captioner, self).__init__()
        self.image_encoder =  ImageFeatureExtractor()
        self.text_decoder = TextFeatureExtractor(vocab_size, embed_dim, attention_dim,
                                                 encoder_dim, decoder_dim)
        self.vocab = vocab

    def forward(self, images, captions):

        features = self.image_encoder(images)
        output = self.text_decoder(features, captions)

        return output

    def generate_caption(self, image, max_length=20):
        image = image.to(device)
        feature = self.image_encoder(image)
        predicted_caption, attn_weights = self.text_decoder.predict(feature, max_length, self.vocab)

        return predicted_caption, attn_weights

    def generate_caption_batch(self, images, max_length=20):
        images = images.to(device)
        features = self.image_encoder(images)
        predicted_captions = self.text_decoder.predict_batch(features, max_length, self.vocab)

        return predicted_captions


# Test

In [9]:
def load_model(path):
    checkpoint = torch.load(path, map_location=torch.device('cpu'))

    model = Captioner(
        vocab_size=checkpoint['vocab_size'],
        embed_dim=checkpoint['embed_dim'],
        attention_dim=checkpoint['attention_dim'],
        encoder_dim=checkpoint['encoder_dim'],
        decoder_dim=checkpoint['decoder_dim'],
        vocab=checkpoint['vocab']
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    return model


In [10]:
model = load_model("models/bahdanau_attn/model_best.pth")
model.eval()
dataset = ImageCaptioningDataset(
                    csv_file=f"dataset/captions.txt",
                    transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize(232, antialias=True),
                    transforms.CenterCrop(224),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])]))



loader = DataLoader(
                dataset=dataset,
                batch_size=32,
                num_workers=2)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 110MB/s]


In [11]:
def map_target(in_caption):
    out_caption = list()
    for caption5s in in_caption:
        temp5 = list()
        for cap in caption5s:
            out_cap = list()
            for idx in cap:
                if idx == 0:
                    break
                else:
                    out_cap.append(dataset.vocab.index2word[idx])
            temp5.append(out_cap)
        out_caption.append(temp5)
    return out_caption


def map_predict(in_caption):
    out_caption = list()
    for idx in in_caption:
        if idx == 2:
            break
        else:
            out_caption.append(dataset.vocab.index2word[idx])
    return out_caption

In [None]:
with torch.no_grad():
    list_of_references = []
    hypotheses = []
    bleu_score = []
    for idx, (image, target) in tqdm(enumerate(iter(loader))):
        image, target = image.to(device), target[:, 1:, :].tolist()


        mapped_target = map_target(target)
        list_of_references.extend(mapped_target)

        predicted_captions = model.generate_caption_batch(image).tolist()
        predicted_captions= list(map(map_predict, predicted_captions))

        hypotheses.extend(predicted_captions)
        score = corpus_bleu(list_of_references, hypotheses)
        bleu_score.append(score)

2it [00:12,  5.16s/it]

In [None]:
sum(bleu_score) / len(bleu_score)