# Libraries

In [None]:
import pandas as pd
import spacy
from tqdm import tqdm
import os
import math
import random
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image

In [None]:
import torchvision
from torchvision import models, transforms
import timm
# from transformers import ViTModel, ViTFeatureExtractor

In [None]:
nlp = spacy.load("en_core_web_sm", disable=["ner", "parser", "lemmatizer"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
ENVIRON = "LOCAL"
CONFIG = {
    "LOCAL" : {
        "DF_PATH": "data/results.csv",
        "IMAGES_DIR_ROOT": "data",
        "FEATURE_MAPS_PATH": "feature_maps",
    },
    "KAGGLE" : {
        "DF_PATH": "/kaggle/input/flickr-image-dataset/flickr30k_images/results.csv",
        "IMAGES_DIR_ROOT": "/kaggle/input/flickr-image-dataset/flickr30k_images",
        "FEATURE_MAPS_PATH": "/kaggle/input/feature-maps"
    }
}

# Hyper Parameters

In [None]:
EPOCHS = 10
BATCH_SIZE = 32
LEARNING_RATE = 1e-5

# Model
N_HEADS = 8
N_LAYERS = 6
EMBED_DIM = 128

# Data Processing

In [None]:
DF_PATH = CONFIG[ENVIRON]["DF_PATH"]
df = pd.read_csv(DF_PATH, delimiter="|")

df.dropna(inplace=True)

df.head()

In [None]:
df["comment"] = df[" comment"].apply(lambda x: str(x).strip()) \
                             .apply(lambda x: x.lower())
df.head()

# Tokenization

In [None]:
captions_list = df["comment"].to_list()
tokens_list = []
docs = list(nlp.pipe(captions_list, n_process=-1))
# for caption in tqdm(captions_list, desc="Vocab Building"):
#     tokens = nlp(caption)
#     tokens = list(map(lambda x: x.text, tokens))
#     vocab.update(tokens)
for doc in tqdm(docs, desc="Document Processing"):
    tokens = [token.text for token in doc]
    tokens_list.append(tokens)
df["tokens"] = tokens_list

In [None]:
caption_length = df["tokens"].apply(lambda x: len(x))
max_len = 15
bool_map = (caption_length <= max_len)
print("No. of rows -", df[bool_map].shape[0])
df = df[bool_map].reset_index(drop=True)

# Vocabulary

In [None]:
vocab = set()
tokens_list = df["tokens"].to_list()
for tokens in tqdm(tokens_list, desc="Vocab Building"):
    vocab.update(tokens)

In [None]:
START_TOKEN = "</start>"
END_TOKEN = "</end>"
PAD_TOKEN = "</pad>"

vocab.add(START_TOKEN)
vocab.add(END_TOKEN)
vocab.add(PAD_TOKEN)

vocab = sorted(list(vocab)) # Just incase
vocab_size = len(vocab)

print("Vocab Size -", vocab_size)

In [None]:
token_to_idx = {token: i for i, token in enumerate(vocab)}
idx_to_token = {v: k for k, v in token_to_idx.items()}

In [None]:
# max_len += 2 # # +2 due to START and END tokens
max_len = 17 # +2 due to START and END tokens

In [None]:
# df.drop(columns=["comment_number", "comment"], inplace=True)
df["tokens"] = df["tokens"].apply(lambda x: [START_TOKEN, ] + x + [END_TOKEN, ])
df["tokens"] = df["tokens"].apply(lambda x: x + [PAD_TOKEN, ] * (max_len - len(x)))
df.head()

# Dataset & DataLoader

In [None]:
# df = df.sample(frac=.05).reset_index(drop=True)

In [None]:
def get_train_test_split(df, test_size):
    train_df, val_df = train_test_split(df, test_size=test_size, random_state=42)
    
    train_df = train_df.reset_index(drop=True)
    val_df = val_df.reset_index(drop=True)
    return train_df, val_df
train_df, val_df = get_train_test_split(df, test_size=.2)

#### Old Implementation - Ignore (Don't Remove)

In [None]:
# class CaptionsDataset(Dataset):
#     def __init__(self, df, token_to_idx, image_transforms):
#         self.df = df
#         self.token_to_idx = token_to_idx
#         self.image_transforms = image_transforms
#         resnet = models.resnet18(pretrained=True)
#         self.resnet = nn.Sequential(*list(resnet.children())[:-2])
#     def __len__(self):
#         return len(self.df)
#     def _encode_tokens(self, tokens):
#         return [self.token_to_idx[token] for token in tokens]
#     def _get_image_features(self, x):
#         with torch.no_grad():
#             image_features = self.resnet(x.unsqueeze(0))
#         return image_features.squeeze()
#     def __getitem__(self, i):
#         image_name, tokens = self.df.loc[i, "image_name"], self.df.loc[i, "tokens"]
#         target_tokens = tokens[1:] + [PAD_TOKEN, ]
#         image_path = os.path.join("data", "flickr30k_images", image_name)
#         image = Image.open(image_path)
#         image = self.image_transforms(image)
#         tokens = self._encode_tokens(tokens)
#         target_tokens = self._encode_tokens(target_tokens)
#         tokens = torch.tensor(tokens, dtype=torch.long)
#         target_tokens = torch.tensor(target_tokens, dtype=torch.long)
#         image_features = self._get_image_features(image)
#         return image_features, tokens, target_tokens

In [None]:
# train_dataset = CaptionsDataset(train_df, token_to_idx, transform)
# val_dataset = CaptionsDataset(val_df, token_to_idx, transform)

# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
# a, b, c = next(iter(train_loader))
# a.size(), b.size(), c.size()

#### New Implementation

In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
transform = transforms.Compose([
    
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

In [None]:
image_encoder_model = timm.create_model('inception_v3', pretrained=True)
image_encoder_model = nn.Sequential(*list(image_encoder_model.children())[:-5]).to(device)

def get_feature_map(x):
    if x.ndimension() == 3:
        x = x.unsqueeze(0)
    with torch.no_grad():
        feature_map = image_encoder_model(x.to(device))
    if feature_map.ndimension() == 4:
        feature_map = feature_map.squeeze()
    feature_map = feature_map.permute(1, 2, 0)
    # channels_size = feature_map.size()[-1]
    # feature_map = feature_map.view(1, -1, channels_size)
    return feature_map

    
# vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
# vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
    
# def get_feature_map(x):
#     if x.ndimension() == 3:
#         x = x.unsqueeze(0)
#     with torch.no_grad():
#         inputs = vit_feature_extractor(images=x, return_tensors="pt")
#         outputs = vit_model(**inputs)
#     features = outputs.last_hidden_state
#     return features

In [None]:
# import pickle
image_names = set(df["image_name"].to_list())
feature_maps = {}
for image_name in tqdm(image_names, desc="Inception-v3 Feature Maps"):
    image_path = os.path.join(CONFIG[ENVIRON]["IMAGES_DIR_ROOT"], "flickr30k_images", image_name)
    image = Image.open(image_path)
    image = transform(image)
    feature_map = get_feature_map(image)
    feature_maps[image_name] = feature_map
# feature_map_path = os.path.join("feature_maps", "resnet50")
# with open(feature_map_path, "wb") as f:
#     pickle.dump(feature_maps, f)

# import pickle
# image_names = set(df["image_name"].to_list())
# feature_maps = {}
# for image_name in tqdm(image_names, desc="ViT Feature Extractor"):
#     image_path = os.path.join("data", "flickr30k_images", image_name)
#     image = Image.open(image_path)
#     image = transform(image)
#     feature_map = get_feature_map(image)
#     feature_maps[image_name] = feature_map

In [None]:
class CaptionsDataset(Dataset):
    def __init__(self, df, token_to_idx, feature_maps=None):
        
        self.df = df
        self.token_to_idx = token_to_idx
        self.feature_maps = feature_maps
        
    def __len__(self):
        return len(self.df)
    def _encode_tokens(self, tokens):
        return [self.token_to_idx[token] for token in tokens]
    def _process_image(self, image_name):
        image_path = os.path.join(CONFIG[ENVIRON]["IMAGES_DIR_ROOT"], "flickr30k_images", image_name)
        image = Image.open(image_path)
        image = transform(image)
        image_features = get_feature_map(image)
        return image_features
    def __getitem__(self, i):
        image_name, tokens = self.df.loc[i, "image_name"], self.df.loc[i, "tokens"]
        target_tokens = tokens[1:] + [PAD_TOKEN, ]
        tokens = self._encode_tokens(tokens)
        target_tokens = self._encode_tokens(target_tokens)
        tokens = torch.tensor(tokens, dtype=torch.long)
        target_tokens = torch.tensor(target_tokens, dtype=torch.long)
        
        if self.feature_maps is not None:
            image_features = self.feature_maps[image_name]
        else:
            image_features = self._process_image(image_name)
        image_features = image_features.squeeze()
        
        return image_features, tokens, target_tokens

In [None]:
# FEATURE_MAPS_FILENAME = os.path.join(CONFIG[ENVIRON]["FEATURE_MAPS_PATH"], "resnet18")
# feature_maps = pd.read_pickle(FEATURE_MAPS_FILENAME)

In [None]:
def get_datasets(df, token_to_idx, feature_maps):
    train_df, val_df = get_train_test_split(df, test_size=.2)
    train_dataset = CaptionsDataset(train_df, token_to_idx, feature_maps)
    val_dataset = CaptionsDataset(val_df, token_to_idx, feature_maps)
    return train_dataset, val_dataset

def get_dataloaders(df, token_to_idx, feature_maps, batch_size):
    train_dataset, val_dataset = get_datasets(df, token_to_idx, feature_maps)
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    return train_loader, val_loader

# Model Architecture

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, embed_dim, p=0.1, max_length=max_len):
        super(PositionalEmbedding, self).__init__()
        self.dropout_layer = nn.Dropout(p)
        encoding = torch.zeros(max_length, embed_dim)
        positions = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
        scale_factor = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        encoding[:, 0::2] = torch.sin(positions * scale_factor)
        encoding[:, 1::2] = torch.cos(positions * scale_factor)
        encoding = encoding.unsqueeze(0)
        self.register_buffer('encoding', encoding)
    def forward(self, x):
        if self.encoding.size(0) < x.size(0):
            self.encoding = self.encoding.repeat(x.size(0), 1, 1).to(device)

        self.encoding = self.encoding[:x.size(0), :, :]

        x = x + self.encoding

        return self.dropout_layer(x)

In [None]:
class ImageCaptioningModel(nn.Module):
    def __init__(self, n_heads, n_layers, vocab_size, embed_dim):
        super(ImageCaptioningModel, self).__init__()

        self.position_encoder = PositionalEmbedding(embed_dim, 0.1)

        self.decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=n_heads)
        self.decoder = nn.TransformerDecoder(decoder_layer=self.decoder_layer, num_layers=n_layers)
        
        self.embed_dim = embed_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.output_layer = nn.Linear(embed_dim, vocab_size)

        self.feature_map_reduce = nn.Linear(1280, EMBED_DIM)
        
        self._initialize_weights()

    def _initialize_weights(self, param_range=0.1):
        self.embedding.weight.data.uniform_(-param_range, param_range)
        self.output_layer.bias.data.zero_()
        self.output_layer.weight.data.uniform_(-param_range, param_range)

    def _create_masks(self, size, decoder_input):
        
        causal_mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        causal_mask = causal_mask.float().masked_fill(causal_mask == 0, float('-inf')).masked_fill(causal_mask == 1, float(0.0))

        pad_mask = decoder_input.float().masked_fill(decoder_input == 0, float(0.0)).masked_fill(decoder_input > 0, float(1.0))
        pad_mask_bool = decoder_input == 0

        return causal_mask, pad_mask, pad_mask_bool

    def forward(self, image_features, decoder_input):

        batch_size, *_, channels_size = image_features.size()
        image_features = image_features.view(batch_size, 1, -1, channels_size)
        image_features = self.feature_map_reduce(image_features)
        image_features = image_features.permute(1, 0, 2) # (num_patches, batch_size, embed_dim) or (feature_map_size, batch_size, num_channels)
        
        decoder_input_embed = self.embedding(decoder_input) * math.sqrt(self.embed_dim)
        decoder_input_embed = self.position_encoder(decoder_input_embed)
        
        decoder_input_embed = decoder_input_embed.permute(1, 0, 2)
        causal_mask, pad_mask, pad_mask_bool = self._create_masks(decoder_input.size(1), decoder_input)

        causal_mask = causal_mask.to(device)
        pad_mask = pad_mask.to(device)
        pad_mask_bool = pad_mask_bool.to(device)

        decoder_output = self.decoder(tgt=decoder_input_embed, memory=image_features, tgt_mask=causal_mask, tgt_key_padding_mask=pad_mask_bool)
        
        output = self.output_layer(decoder_output)
        return output, pad_mask

# Training & Validation Loop

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0

    progress_bar = tqdm(train_loader, desc="Training")
    for image_features, tokens, target_tokens in progress_bar:

        optimizer.zero_grad()
        
        image_features = image_features.to(device)
        tokens = tokens.to(device)
        target_tokens = target_tokens.to(device)
        
        logits, padding_mask = model(image_features, tokens)
        logits = logits.permute(1, 2, 0)

        loss = criterion(logits, target_tokens)
        loss_masked = torch.mul(loss, padding_mask)

        batch_loss = torch.sum(loss_masked) / torch.sum(padding_mask)
        batch_loss.backward()
        optimizer.step()

        total_loss += batch_loss.item()

        progress_bar.set_postfix(batch_loss=batch_loss.item(), refresh=True)

    avg_loss = total_loss / len(train_loader)
    return avg_loss

In [None]:
def val_epoch(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    
    progress_bar = tqdm(val_loader, desc="Validation")
    with torch.inference_mode():
        for image_features, tokens, target_tokens in progress_bar:
            
            image_features = image_features.to(device)
            tokens = tokens.to(device)
            target_tokens = target_tokens.to(device)

            logits, padding_mask = model(image_features, tokens)
            logits = logits.permute(1, 2, 0)

            loss = criterion(logits, target_tokens)
            loss_masked = torch.mul(loss, padding_mask)

            batch_loss = torch.sum(loss_masked) / torch.sum(padding_mask)

            total_loss += batch_loss.item()

            progress_bar.set_postfix(batch_loss=batch_loss.item(), refresh=True)

    avg_loss = total_loss / len(val_loader)
    return avg_loss

# Model Training

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):
    
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        print(f"Epoch : [{epoch + 1}/{num_epochs}]")

        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss = val_epoch(model, val_loader, criterion, device)

        print(f"Training Loss: {train_loss:.4f}")
        print(f"Validation Loss: {val_loss:.4f}")

        scheduler.step(val_loss)

        for param_group in optimizer.param_groups:
            print(f"Current Learning Rate: {param_group['lr']}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model, "best_model")
            print("Model saved!")

    print("Training Complete.")

In [None]:
criterion = nn.CrossEntropyLoss()

model = ImageCaptioningModel(N_HEADS, N_LAYERS, vocab_size, EMBED_DIM).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=.3, patience=2)

In [None]:
train_loader, val_loader = get_dataloaders(df, token_to_idx, feature_maps, BATCH_SIZE)

In [None]:
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, EPOCHS, device)

In [None]:
def generate_caption(K, image_name):
    
    image_path = os.path.join(CONFIG[ENVIRON]["IMAGES_DIR_ROOT"], "flickr30k_images", image_name)
    image = Image.open(image_path).convert("RGB")
    plt.imshow(image)

    model.eval()
    feature_map = feature_maps[image_name].to(device)


    input_tokens = [token_to_idx[PAD_TOKEN]] * max_len
    input_tokens[0] = token_to_idx[START_TOKEN]

    input_tokens = torch.tensor(input_tokens).unsqueeze(0).to(device)
    predicted_sentence = []
    
    with torch.no_grad():
        for eval_iter in range(0, max_len-1):

            logits, padding_mask = model.forward(feature_map, input_tokens)

            logits = logits[eval_iter, 0, :]

            values = torch.topk(logits, K).values.tolist()
            indices = torch.topk(logits, K).indices.tolist()

            next_word_index = random.choices(indices, values, k = 1)[0]

            next_word = idx_to_token[next_word_index]

            input_tokens[:, eval_iter+1] = next_word_index


            if next_word == '</end>' :
                break

            predicted_sentence.append(next_word)
    print("\n")
    print("Predicted caption : ")
    print(" ".join(predicted_sentence))

In [None]:
model = torch.load("best_model").to(device)

In [None]:
i = 0

In [None]:
i += 1
generate_caption(K=1, image_name=val_df.loc[i, "image_name"])

In [None]:
i += 1
generate_caption(K=1, image_name=val_df.loc[i, "image_name"])

In [None]:
i += 1
generate_caption(K=1, image_name=val_df.loc[i, "image_name"])

In [None]:
i += 1
generate_caption(K=1, image_name=val_df.loc[i, "image_name"])

In [None]:
i += 1
generate_caption(K=1, image_name=val_df.loc[i, "image_name"])

In [None]:
i += 1
generate_caption(K=1, image_name=val_df.loc[i, "image_name"])

In [None]:
i += 1
generate_caption(K=1, image_name=val_df.loc[i, "image_name"])