In [None]:
pip install transformers

In [None]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import image as mpimg
import torch
from torch import nn
import torchvision
from torch.utils.data import Dataset as Dataset
from torch.utils.data import DataLoader
from torchvision import transforms as tt
from PIL import Image
import torchtext
from torchtext.data import get_tokenizer
from transformers import BertTokenizer
from torchinfo import summary
import torch.optim as optim
import torch.nn.functional as F
from tqdm.notebook import tqdm

In [None]:
with open('/kaggle/input/flickr8k/captions.txt') as f:
    contents = f.readlines()
contents = contents[1:]

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
token_ids = tokenizer.encode("You can now install TorchText using pip!", padding="max_length", max_length=20)
back_to_words = tokenizer.convert_ids_to_tokens( token_ids ) 

In [None]:
image_name_list = []
caption_list = []

for line in contents:
    image_name = line.split(',')[0]
    caption = tokenizer.encode("".join(line[:-1].split(',')[1:]), padding='max_length', max_length = 30, truncation=True)
    image_name_list.append(image_name)
    caption_list.append(caption)
    
caption_df = pd.DataFrame({'Image_name': image_name_list, 'Caption': caption_list})

In [None]:
caption_df.drop_duplicates(subset='Image_name', inplace=True)
caption_df

In [None]:
image = mpimg.imread("/kaggle/input/flickr8k/Images/2544246151_727427ee07.jpg")
plt.imshow(image)
plt.show()

In [None]:
def show_image(img, title=None):
    """Imshow for Tensor."""
    
    #unnormalize 
    img[0] = img[0] * 0.229
    img[1] = img[1] * 0.224 
    img[2] = img[2] * 0.225 
    img[0] += 0.485 
    img[1] += 0.456 
    img[2] += 0.406
    
    img = img.numpy().transpose((1, 2, 0))
    
    
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)

# Data making

In [None]:
train_transform_trivial_augment = tt.Compose([
    torchvision.models.ResNet101_Weights.IMAGENET1K_V2.transforms()
])

test_transform = tt.Compose([
    torchvision.models.ResNet101_Weights.IMAGENET1K_V2.transforms()
])

In [None]:
class ImageToCaption(Dataset):
    def __init__(self, df, transform):
        self.data_dict = df
        self.transform = transform
        
        labels = self.data_dict['Caption']
        
    def __len__(self):
        return len(self.data_dict)
    
    def __getitem__(self, idx):
        
        img = self.data_dict['Image_name'].iloc[idx]
        label = self.data_dict['Caption'].iloc[idx]
        img = Image.open("/kaggle/input/flickr8k/Images/" + img)
            
        return self.transform(img), torch.tensor(label)

In [None]:
train_data = caption_df.iloc[:7000]
test_data = caption_df.iloc[7000:]

In [None]:
train_data_custom = ImageToCaption(train_data, transform = train_transform_trivial_augment)
test_data_custom = ImageToCaption(test_data, transform = test_transform)

In [None]:
BATCH_SIZE = 32

train_dataloader_custom = DataLoader(dataset=train_data_custom, 
                                     batch_size=BATCH_SIZE, 
                                     # num_workers=1,
                                     shuffle=True) 

test_dataloader_custom = DataLoader(dataset=test_data_custom, 
                                    batch_size=BATCH_SIZE, 
                                    # num_workers=1, 
                                    shuffle=False) 

train_dataloader_custom, test_dataloader_custom

In [None]:
img_data, caption = next(iter(train_dataloader_custom))

print(f"Image shape: {img_data.shape} -> [batch_size, color_channels, height, width]\n")
print(f"Label shape: {caption.shape}")

# Model

In [None]:
class Encoder(nn.Module):

    def __init__(self, encoded_image_size=14):
        super(Encoder, self).__init__()
        self.enc_image_size = encoded_image_size

        resnet = torchvision.models.resnet101(weights=torchvision.models.ResNet101_Weights.DEFAULT)

        # Remove linear and pool layers (since we're not doing classification)
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)

        # Resize image to fixed size to allow input images of variable size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

        self.fine_tune()

    def forward(self, images):
        """
        Forward propagation.

        :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
        :return: encoded images
        """
        out = self.resnet(images)  # (batch_size, 2048, image_size/32, image_size/32)
        out = self.adaptive_pool(out)  # (batch_size, 2048, encoded_image_size, encoded_image_size)
        out = out.permute(0, 2, 3, 1)  # (batch_size, encoded_image_size, encoded_image_size, 2048)
        out = out.view(out.size(0), -1, out.size(-1))
        return out

    def fine_tune(self, fine_tune=False):
        """
        Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.

        :param fine_tune: Allow?
        """
        for p in self.resnet.parameters():
            p.requires_grad = False
        # If fine-tuning, only fine-tune convolutional blocks 2 through 4
        for c in list(self.resnet.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune
                
class Attention(nn.Module):
    def __init__(self, encoder_dim,decoder_dim,attention_dim):
        super(Attention, self).__init__()
        
        self.attention_dim = attention_dim
        
        self.W = nn.Linear(decoder_dim,attention_dim)
        self.U = nn.Linear(encoder_dim,attention_dim)
        
        self.A = nn.Linear(attention_dim,1)
        
        self.relu = nn.ReLU()
        
        
        
        
    def forward(self, features, hidden_state):
        u_hs = self.U(features)     #(batch_size,num_layers,attention_dim)
        w_ah = self.W(hidden_state) #(batch_size,attention_dim)
        
        combined_states = self.relu(u_hs + w_ah.unsqueeze(1)) #(batch_size,num_layers,attemtion_dim)
        
        attention_scores = self.A(combined_states)         #(batch_size,num_layers,1)
        attention_scores = attention_scores.squeeze(2)     #(batch_size,num_layers)
        
        
        alpha = F.softmax(attention_scores,dim=1)          #(batch_size,num_layers)
        
        attention_weights = features * alpha.unsqueeze(2)  #(batch_size,num_layers,features_dim)
        attention_weights = attention_weights.sum(dim=1)   #(batch_size,num_layers)
        
        return alpha,attention_weights
        
class DecoderRNN(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        
        #save the model param
        self.vocab_size = vocab_size
        self.attention_dim = attention_dim
        self.decoder_dim = decoder_dim
        
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.attention = Attention(encoder_dim,decoder_dim,attention_dim)
        
        
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  
        self.lstm_cell = nn.LSTMCell(embed_size+encoder_dim,decoder_dim,bias=True)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        
        
        self.fcn = nn.Linear(decoder_dim,vocab_size)
        self.drop = nn.Dropout(drop_prob)
        
        
    
    def forward(self, features, captions):
        
        #vectorize the caption
        embeds = self.embedding(captions)
        
        # Initialize LSTM state
        h, c = self.init_hidden_state(features)  # (batch_size, decoder_dim)
        
        #get the seq length to iterate
        seq_length = len(captions[0])-1 #Exclude the last one
        batch_size = captions.size(0)
        num_features = features.size(1)
        
        preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device)
        alphas = torch.zeros(batch_size, seq_length,num_features).to(device)
                
        for s in range(seq_length):
            alpha,context = self.attention(features, h)
            lstm_input = torch.cat((embeds[:, s], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
                    
            output = self.fcn(self.drop(h))
            
            preds[:,s] = output
            alphas[:,s] = alpha  
        
        
        return preds, alphas
    
    def generate_caption(self,features,max_len=20,tokenizer = tokenizer):
        # Inference part
        # Given the image features generate the captions
        
        batch_size = features.size(0)
        h, c = self.init_hidden_state(features)  # (batch_size, decoder_dim)
        
        alphas = []
        
        #starting input
        word = torch.tensor(tokenizer.vocab['[CLS]']).view(1,-1).to(device)
        embeds = self.embedding(word)

        
        captions = []
        
        for i in range(max_len):
            alpha,context = self.attention(features, h)
            
            
            #store the apla score
            alphas.append(alpha.cpu().detach().numpy())
            
            lstm_input = torch.cat((embeds[:, 0], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            output = output.view(batch_size,-1)
        
            
            #select the word with most val
            predicted_word_idx = output.argmax(dim=1)
            
            #save the generated word
            captions.append(predicted_word_idx.item())
            
            #end if <EOS detected>
            if tokenizer.convert_ids_to_tokens( predicted_word_idx.item() ) == "[SEP]":
                break
            
            #send generated word as the next caption
            embeds = self.embedding(predicted_word_idx.unsqueeze(0))
        
        #covert the vocab idx to words and return sentence
        return [tokenizer.convert_ids_to_tokens( idx ) for idx in captions],alphas
    
    
    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c
    
class EncoderDecoder(nn.Module):
    def __init__(self,embed_size=300, vocab_size=30_000, attention_dim = 256,encoder_dim=2048,decoder_dim=512,drop_prob=0.3):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = DecoderRNN(
            embed_size=embed_size,
            vocab_size = vocab_size,
            attention_dim=attention_dim,
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim
        )
        
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

# Training

In [None]:
device = 'cuda'
lr = 0.0004
embed_size = 512
vocab_size = 30_000
attention_dim = 512
encoder_dim = 2048 
decoder_dim = 512
drop_prob = 0.5
epochs = 30
output_interval = 50
alpha_c = 1.
grad_clip = 5.

model = EncoderDecoder(embed_size, vocab_size, attention_dim, encoder_dim, decoder_dim, drop_prob).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.vocab['[PAD]']).to(device)
optimizer = optim.RMSprop(model.parameters(), lr=lr)
model.to(device)
summary(model)

In [None]:
dataiter = iter(test_dataloader_custom)

for epoch in tqdm(range(1,epochs+1)):
    for train_idx, (image, caption) in enumerate(iter(train_dataloader_custom)):
        image, caption = image.to(device), caption.to(device)
        optimizer.zero_grad()
        outputs,attentions = model(image, caption)
        
        targets = caption[:, 1:]
        loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))
        loss += alpha_c * ((1. - attentions.sum(dim=1)) ** 2).mean() # mentioned in the paper
        loss.backward() 
        
        for group in optimizer.param_groups: # Gradient Clipping 
            for param in group['params']:
                if param.grad is not None:
                    param.grad.data.clamp_(-grad_clip, grad_clip)
        
        optimizer.step()
        
        if (train_idx + 1) % output_interval == 0:
            
            #generate the caption
            model.eval()
            with torch.no_grad():
                val_loss = 0
                count = 0
                for idx, (image,caption) in enumerate(iter(test_dataloader_custom)):
                    image, caption = image.to(device), caption.to(device)
                    outputs, attentions = model(image, caption)
                    targets = caption[:, 1:]
                    val_loss += (criterion(outputs.view(-1, vocab_size), targets.reshape(-1)).item() + alpha_c * ((1. - attentions.sum(dim=1)) ** 2).mean())
                    count += 1
                val_loss /= count
            
            print(f"Epoch: {epoch} | Index: {train_idx+1} | Training Loss : {loss.item()} | Validation Loss : {val_loss}")
            model.train()
                
    img,_ = next(dataiter)
    features = model.encoder(img[0:1].to(device))
    caps,alphas = model.decoder.generate_caption(features, max_len = 30, tokenizer = tokenizer)
    caption = ' '.join(caps)
    show_image(img[0],title=caption)               
            