In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models, datasets
import torchvision.transforms as T
from torch.utils.data import DataLoader,Dataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch import optim
from torch.nn.utils.rnn import pad_sequence

In [2]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad_(False)
        
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        

    def forward(self, images):
        features = self.resnet(images)                                    
        features = features.permute(0, 2, 3, 1)                           
        features = features.view(features.size(0), -1, features.size(-1)) 
        return features

In [3]:
class Attention(nn.Module):
    def __init__(self, encoder_dim,decoder_dim,attention_dim):
        super(Attention, self).__init__()
        self.attention_dim = attention_dim
        self.W = nn.Linear(decoder_dim,attention_dim)
        self.U = nn.Linear(encoder_dim,attention_dim)
        self.A = nn.Linear(attention_dim,1)
      
    def forward(self, features, hidden_state):
        u_hs = self.U(features)     
        w_ah = self.W(hidden_state) 
        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1)) 
        attention_scores = self.A(combined_states)         
        attention_scores = attention_scores.squeeze(2)    
        alpha = torch.softmax(attention_scores,dim=1)          
        attention_weights = features * alpha.unsqueeze(2)  
        attention_weights = attention_weights.sum(dim=1)   
        
        return alpha, attention_weights
        

In [4]:
class DecoderRNN(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        
        #save the model param
        self.vocab_size = vocab_size
        self.attention_dim = attention_dim
        self.decoder_dim = decoder_dim
        
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.attention = Attention(encoder_dim,decoder_dim,attention_dim)
        
        
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  
        self.lstm_cell = nn.LSTMCell(embed_size+encoder_dim,decoder_dim,bias=True)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        
        
        self.fcn = nn.Linear(decoder_dim,vocab_size)
        self.drop = nn.Dropout(drop_prob)
        
        
    
    def forward(self, features, captions):
        
        #vectorize the caption
        embeds = self.embedding(captions)
        
        # Initialize LSTM state
        h, c = self.init_hidden_state(features)  # (batch_size, decoder_dim)
        
        #get the seq length to iterate
        seq_length = len(captions[0])-1 #Exclude the last one
        batch_size = captions.size(0)
        num_features = features.size(1)
        
        preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device)
        alphas = torch.zeros(batch_size, seq_length,num_features).to(device)
                
        for s in range(seq_length):
            alpha,context = self.attention(features, h)
            lstm_input = torch.cat((embeds[:, s], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
                    
            output = self.fcn(self.drop(h))
            
            preds[:,s] = output
            alphas[:,s] = alpha  
        
        
        return preds, alphas
    
    def generate_caption(self,features,max_len=20,vocab=None):
        batch_size = features.size(0)
        h, c = self.init_hidden_state(features) 
        alphas = []

        word = torch.tensor(vocab.wtoi['<SOS>']).view(1,-1).to(device)
        embeds = self.embedding(word)
        captions = []
        for i in range(max_len):
            alpha,context = self.attention(features, h)
            alphas.append(alpha.cpu().detach().numpy())
            lstm_input = torch.cat((embeds[:, 0], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            output = output.view(batch_size,-1)
            predicted_word_idx = output.argmax(dim=1)
            captions.append(predicted_word_idx.item())
            if vocab.itow[predicted_word_idx.item()] == "<EOS>":
                break
            embeds = self.embedding(predicted_word_idx.unsqueeze(0))
        return [vocab.itow[idx] for idx in captions],alphas
    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out) 
        c = self.init_c(mean_encoder_out)
        return h, c


In [5]:
class EncoderDecoder(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        self.encoder = EncoderCNN()
        self.decoder = DecoderRNN(
            embed_size=embed_size,
            vocab_size = len(dataset.vocab),
            attention_dim=attention_dim,
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim
        )
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs, _ = self.decoder(features, captions)
        return outputs

In [6]:
#Kaggle API token and shit kaggle datasets download -d adityajn105/flickr8k
from google.colab import files
files.upload()
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 /root/.kaggle/kaggle.json
!kaggle datasets download -d adityajn105/flickr8k
!unzip /content/flickr8k.zip

flickr8k.zip: Skipping, found more recently modified local copy (use --force to force download)
Archive:  /content/flickr8k.zip
replace Images/1000268201_693b08cb0e.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: N


In [7]:
annotations_path = "/content/captions.txt"
images_path = "/content/Images"

In [8]:
import nltk
nltk.download('punkt')
from collections import Counter
import pandas as pd
import os
from PIL import Image

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [9]:
class Vocab:
  def __init__(self, freq_thres):
    self.itow = {0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}
    self.freq_thres = freq_thres
    self.wtoi = {b:a for a,b in self.itow.items()}

  def __len__(self):
    return len(self.itow)

  @staticmethod
  def tokenize(text):
    return [token.lower() for token in nltk.tokenize.word_tokenize(text)]

  def vocab(self, doc):
    freq = Counter()
    idx = 4
    for sentence in doc:
      for word in self.tokenize(sentence):
        freq[word] += 1
        self.itow[idx] = word
        self.wtoi[word] = idx
        idx += 1

  def text_to_index(self, text):
    return[self.wtoi[word] if word in self.wtoi else self.wtoi["<UNK>"] for word in self.tokenize(text)]
  
class Flickr_dataset(Dataset):
  def __init__(self, images_dir, captions_dir, transform = None, freq_thres = 5):
    self.images_dir = images_dir
    self.captions_dir = captions_dir
    self.transform = transform
    self.df = pd.read_csv(captions_dir)
    self.imgs = self.df["image"]
    self.captions = self.df["caption"]

    self.vocab = Vocab(freq_thres)
    self.vocab.vocab(self.captions.tolist())
  
  def __len__(self):
    return len(self.df)
  
  def __getitem__(self, idx):
    captions = self.captions[idx]
    img_name = self.imgs[idx]
    img_path = os.path.join(self.images_dir, img_name)
    img = Image.open(img_path).convert("RGB")
    if self.transform:
      img = self.transform(img)
    
    caption = [self.vocab.wtoi["<SOS>"]]
    caption += self.vocab.text_to_index(captions)
    caption += [self.vocab.wtoi["<EOS>"]]

    return img, torch.tensor(caption)


In [10]:
class PAD_length():
  def __init__(self, pad_idx):
    self.pad_idx = pad_idx

  def __call__(self, batch):
    imgs = [item[0].unsqueeze(0) for item in batch]
    imgs = torch.cat(imgs,dim=0)
    targets = [item[1] for item in batch]
    targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)
    return imgs,targets

In [13]:
BATCH_SIZE = 8
NUM_WORKER = 2

transforms = T.Compose([
    T.Resize(226),                     
    T.RandomCrop(224),                 
    T.ToTensor(),                               
    T.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
])

dataset = Flickr_dataset(images_dir = images_path,captions_dir = annotations_path ,transform = transforms)
pad_idx = dataset.vocab.wtoi["<PAD>"]
dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, num_workers = NUM_WORKER, collate_fn = PAD_length(pad_idx))

vocab_size = len(dataset.vocab)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [14]:
embed_size=300
vocab_size = len(dataset.vocab)
attention_dim=256
encoder_dim=2048
decoder_dim=512
learning_rate = 3e-4

In [15]:
model = EncoderDecoder(
    embed_size=300,
    vocab_size = len(dataset.vocab),
    attention_dim=256,
    encoder_dim=2048,
    decoder_dim=512
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.wtoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
epochs = 100

for i in range(epochs):
  for (image,captions) in tqdm(iter(dataloader)):
    image,captions = image.to(device),captions.to(device)
    optimizer.zero_grad()
    outputs= model(image, captions)
    targets = captions[:,1:]
    loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))
    loss.backward()
    optimizer.step()
  print(loss)

  0%|          | 15/5057 [00:28<2:28:56,  1.77s/it]