In [1]:
# Required imports
import numpy as np 
import pandas as pd 
import os
from fastai.vision.all import *
from fastai.text.all import *
from pathlib import Path
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision
import torchtext
from torchtext.data import get_tokenizer   # for tokenization
from collections import Counter     # for tokenizer
import torchvision.transforms as T
import torchvision.models as models
import matplotlib.pyplot as plt
import PIL
from PIL import Image
from nltk.translate import bleu
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import corpus_bleu

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
captions_path = '/kaggle/input/flickr8k/captions.txt'
images_path = "../input/flickr8k/Images/"

tokenizer = get_tokenizer("basic_english")  
token_counter = Counter()


    

In [3]:
class my_dictionary(dict): 
    def __init__(self): 
        self = dict() 
          
    def add(self, key, value): 
        if key not in self.keys():
            self[key] = [value]
        else:
            self[key].append(value)
        
descriptors = my_dictionary() 

for i in range(len(df)):
    img_id = df.iloc[i, 0]
    sentence = ("<start> " + df.iloc[i, 1] + " <end>").split()
    descriptors.add(img_id, sentence)

In [None]:
def show_image(img, title=None):
    
    # unnormalize
    img[0] *= 0.229
    img[1] *= 0.224
    img[2] *= 0.225
    img[0] += 0.485
    img[1] += 0.456
    img[2] += 0.406
    
    img = img.numpy().transpose((1, 2, 0))
    plt.imshow(img)
    if title is not None:
        plt.title(title)
        
    plt.pause(0.001) 

In [4]:
class textVocab:
    def __init__(self):
        self.itos = {0:"<PAD>", 1:"<start>", 2:"<end>", 3:"<UNK>"}
        self.stoi = {b:a for a, b in self.itos.items()}   
        self.min_freq = 1
        self.tokenizer = get_tokenizer("basic_english") 
        self.token_counter = Counter()
        
    def __len__(self):
        return len(self.itos)
    
    def tokenize(self, text):
        return self.tokenizer(text)
    
    def numericalize(self, text):
        tokens_list = self.tokenize(text)      
        ans = []
        for token in tokens_list:
            if token in self.stoi.keys():
                ans.append(self.stoi[token]) 
            else:
                ans.append(self.stoi["<UNK>"])
        return ans   
    
    def build_vocab(self, sentence_list):
        word_count = 4
        for sentence in sentence_list:            
            tokens = self.tokenizer(sentence)
            token_counter.update(tokens)
            for token in tokens:
                if token_counter[token] >= self.min_freq and token not in self.stoi.keys():
                    self.stoi[token] = word_count
                    self.itos[word_count] = token
                    word_count += 1
                    
inception = models.inception_v3(pretrained=True)

Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

In [9]:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

# we have to input the test and train indices to create the samplers.

In [10]:
dls = torch.utils.data.DataLoader(dataset, 
                                           batch_size=batch_size, shuffle=False,
                                           collate_fn = Collate_fn(pad_value=pad_value, batch_first = True),
                                           sampler=train_sampler)

validation_loader = torch.utils.data.DataLoader(dataset, shuffle=False,
                                                batch_size=batch_size,
                                                collate_fn = Collate_fn(pad_value=pad_value, batch_first = True),
                                                sampler=valid_sampler)

In [None]:
dlsItr = iter(dls)
batch = next(dlsItr)
imgs, captions, img_ids = batch
for i in range(batch_size):
    img, caption = imgs[i], captions[i]
    sentence = [dataset.vocab.itos[token] for token in caption.tolist()]
    end_indx = sentence.index('<end>')
    sentence = sentence[1:end_indx]
    sentence = ' '.join(sentence)
    break

In [12]:
VGG16 = models.VGG16(pretrained=True)

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.my_inception = VGG16Extractor(inception)
        
    def forward(self, images):        
        features = self.my_inception(images) 
        features = features.permute(0, 2, 3, 1)
        features = features.view(features.size(0), -1, features.size(-1))        
        return features

In [13]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.attention_dim = attention_dim
        self.decoder_att = nn.Linear(decoder_dim, attention_dim) 
        self.encoder_att = nn.Linear(encoder_dim, attention_dim) 
        self.full_att = nn.Linear(attention_dim, 1) 
    
    def forward(self, features, hidden_states):
        att1 = self.encoder_att(features)   
        att2 = self.decoder_att(hidden_states)
        combined_states = torch.tanh(att1 + att2.unsqueeze(1))
        attention_scores = self.full_att(combined_states)
        attention_scores = attention_scores.squeeze(2)
        alpha = F.softmax(attention_scores, dim=1)
        weighted_encoding = features * alpha.unsqueeze(2)   
        weighted_encoding = weighted_encoding.sum(dim=1) 
        return alpha, weighted_encoding

In [14]:
class Decoder(nn.Module):
    def __init__(self, embed_sz, vocab_sz, att_dim, enc_dim, dec_dim, drop_prob=0.3):
        super().__init__()
        self.vocab_sz = vocab_sz
        self.att_dim = att_dim
        self.dec_dim = dec_dim
        self.embedding = nn.Embedding(vocab_sz, embed_sz)
        self.attention = Attention(enc_dim, dec_dim, att_dim)
        self.init_h = nn.Linear(enc_dim, dec_dim)
        self.init_c = nn.Linear(enc_dim, dec_dim)
        self.lstm_cell = nn.LSTMCell(embed_sz + enc_dim, dec_dim, bias=True)
        self.f_beta = nn.Linear(dec_dim, enc_dim)
        self.fcn = nn.Linear(dec_dim, vocab_sz)
        self.drop = nn.Dropout(drop_prob)
    
    def forward(self, features, captions):        
        embeds = self.embedding(captions)
        h, c = self.init_hidden_state(features)
        cap_len = len(captions[0]) - 1        
        batch_sz = captions.size(0)
        num_features = features.size(1)
        preds = torch.zeros(batch_sz, cap_len, self.vocab_sz)
        alphas = torch.zeros(batch_sz, cap_len, num_features)
        for i in range(cap_len):
            alpha, att_weights = self.attention(features, h)
            lstm_input = torch.cat((embeds[:,i], att_weights), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            preds[:, i] = output
            alphas[:, i] = alpha
        return preds, alphas
    
    def generate_caption(self, features, max_len=20, vocab=None):
        batch_sz = features.size(0)
        h, c = self.init_hidden_state(features)
        alphas = []
        captions = [vocab.stoi['<start>']]
        word = torch.tensor(vocab.stoi['<start>']).view(1, -1)
        embeds = self.embedding(word)
        for i in range(max_len):
            alpha, weighted_encoding = self.attention(features, h)
            alphas.append(alpha.cpu().detach().numpy())
            lstm_input = torch.cat((embeds[:, 0], weighted_encoding), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            output = output.view(batch_sz, -1)
            pred_word_idx = output.argmax(dim=1)
            captions.append(pred_word_idx.item())
            if vocab.itos[pred_word_idx.item()] == '<end>':
                break
            embeds = self.embedding(pred_word_idx.unsqueeze(0))
        return [vocab.itos[idx] for idx in captions], alphas  
    
    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        return h, c

In [15]:
class EncoderDecoder(nn.Module):
    def __init__(self, embed_sz, vocab_sz, att_dim, enc_dim, dec_dim, drop_prob=0.3):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder(
            embed_sz = embed_sz,
            vocab_sz = vocab_sz,
            att_dim = att_dim,
            enc_dim = enc_dim,
            dec_dim = dec_dim
        )
    
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)        
        return outputs