In [None]:
import os
import time
import spacy

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset
from torch.nn.utils.rnn import pad_sequence
from keras.preprocessing.sequence import pad_sequences

import torchvision.transforms as T
import torchvision.models as models
import matplotlib.pyplot as plt
from collections import Counter
from PIL import Image
print("IMPORTED SUCCESSFULLY 😎")

In [None]:
def show_image(img, title=None):
    img[0] = img[0] * 0.229
    img[1] = img[1] * 0.224 
    img[2] = img[2] * 0.225 
    img[0] += 0.485 
    img[1] += 0.456 
    img[2] += 0.406    
    img = img.numpy().transpose((1, 2, 0))
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)

In [None]:
class Vocabulary:
    def __init__(self,freq_threshold):
        #setting the pre-reserved tokens int to string tokens
        self.itos = {0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}
        
        #string to int tokens
        #its reverse dict self.itos
        self.stoi = {v:k for k,v in self.itos.items()}        
        self.freq_threshold = freq_threshold
        self.frequencies = Counter()
        
    def __len__(self):
        return len(self.itos)
    
    @staticmethod
    def tokenize(text):
        return [token.text.lower() for token in spacy_eng.tokenizer(text)]
    
    def build_vocab(self, sentence_list):
        idx = 4        
        for sentence in sentence_list:
            for word in self.tokenize(sentence):
                self.frequencies[word] += 1
                
                #add the word to the vocab if it reaches minum frequecy threshold
                if self.frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
    
    def numericalize(self,text):
        tokenized_text = self.tokenize(text)
        return [ self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text ]
    
    def word(self,ind):
        return [self.itos[idx] if idx in self.itos else self.itos["<UNK"] for idx in ind]

In [None]:
class FlickrDataset(Dataset):
    def __init__(self,root_dir,captions_file,transform=None,freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(caption_file)
        self.transform = transform
        
        #Get image and caption colum from the dataframe
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]
        
        #Initialize vocabulary and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab(self.captions.tolist())
        
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        caption = self.captions[idx]
        img_name = self.imgs[idx]
        img_location = os.path.join(self.root_dir,img_name)
        img = Image.open(img_location).convert("RGB")
        
        #apply the transfromation to the image
        if self.transform is not None:
            img = self.transform(img)
        
        #numericalize the caption text
        caption_vec = []
        caption_vec += [self.vocab.stoi["<SOS>"]]
        caption_vec += self.vocab.numericalize(caption)
        caption_vec += [self.vocab.stoi["<EOS>"]]
        
        return img, torch.tensor(caption_vec)

In [None]:
class CapsCollate:
    def __init__(self,pad_idx,batch_first=False):
        self.pad_idx = pad_idx
        self.batch_first = batch_first
    
    def __call__(self,batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs,dim=0)
        
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=self.batch_first, padding_value=self.pad_idx)
        return imgs,targets

In [None]:
#Initiate the Dataset and Dataloader
spacy_eng = spacy.load("en")

caption_file = '../input/flickr8k/captions.txt'
data_location =  "../input/flickr8k"

#setting the constants
BATCH_SIZE = 500
# BATCH_SIZE = 6
NUM_WORKER = 0

transforms = T.Compose([
    T.Resize(226),                     
    T.RandomCrop(224),                 
    T.ToTensor(),                               
    T.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
])


#testing the dataset class
dataset =  FlickrDataset(
    root_dir = data_location+"/Images",
    captions_file = data_location+"/captions.txt",
    transform=transforms
)

#writing the dataloader
data_loader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKER,
    shuffle=True,
    # batch_first=False
)

#vocab_size
vocab_size = len(dataset.vocab)
pad_idx = dataset.vocab.stoi["<PAD>"]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:

BATCH_SIZE = 500

NUM_WORKER = 0


transforms = T.Compose([T.Resize(256),T.RandomCrop(224),T.ToTensor(),T.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))])
dataset =  FlickrDataset(root_dir = '../input/flickr8k/Images',captions_file = '../input/flickr8k/captions.txt',transform=transforms)
data_loader = DataLoader(dataset=dataset,batch_size=BATCH_SIZE,num_workers=NUM_WORKER,shuffle=True,collate_fn=CapsCollate(pad_idx=pad_idx,batch_first=True))
vocab_size = len(dataset.vocab)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad_(False)

        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)        

    def forward(self, images):
        features = self.resnet(images)                                    
        features = features.permute(0, 2, 3, 1)                           
        features = features.view(features.size(0), -1, features.size(-1)) 
        return features

In [None]:
class Attention(nn.Module):
    def __init__(self, encoder_dim,decoder_dim,attention_dim):
        super(Attention, self).__init__()        
        self.attention_dim = attention_dim        
        self.W = nn.Linear(decoder_dim,attention_dim)
        self.U = nn.Linear(encoder_dim,attention_dim)        
        self.A = nn.Linear(attention_dim,1)
        
    def forward(self, features, hidden_state):
        u_hs = self.U(features)     
        w_ah = self.W(hidden_state) 
        
        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1)) 
        
        attention_scores = self.A(combined_states)         
        attention_scores = attention_scores.squeeze(2)     
        
        
        alpha = F.softmax(attention_scores,dim=1)          
        
        attention_weights = features * alpha.unsqueeze(2)  
        attention_weights = attention_weights.sum(dim=1)   
        
        return alpha,attention_weights      

In [None]:
#Attention Decoder
class DecoderRNN(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        
        #save the model param
        self.vocab_size = vocab_size
        self.attention_dim = attention_dim
        self.decoder_dim = decoder_dim
        
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.attention = Attention(encoder_dim,decoder_dim,attention_dim)
        
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.lstm_cell = nn.LSTMCell(embed_size+encoder_dim,decoder_dim,bias=True)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        
        self.fcn = nn.Linear(decoder_dim,vocab_size)
        self.drop = nn.Dropout(drop_prob)
    
    def forward(self, features, captions):
        
        
        embeds = self.embedding(captions)
        
        
        h, c = self.init_hidden_state(features)  
        
        seq_length = len(captions[0])-1 
        batch_size = captions.size(0)
        num_features = features.size(1)
        
        preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device)
        alphas = torch.zeros(batch_size, seq_length,num_features).to(device)
                
        for s in range(seq_length):
            alpha,context = self.attention(features, h)
            lstm_input = torch.cat((embeds[:, s], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            preds[:,s] = output
            alphas[:,s] = alpha
        return preds, alphas

    def generate_caption(self,features,max_len=20,vocab=None):
        
        batch_size = features.size(0)
        h, c = self.init_hidden_state(features)  

        alphas = []
        
        word = torch.tensor(vocab.stoi['<SOS>']).view(1,-1).to(device)
        embeds = self.embedding(word)       
        captions = []
        
        for i in range(max_len):
            alpha,context = self.attention(features, h)
            
            
            alphas.append(alpha.cpu().detach().numpy())
            
            lstm_input = torch.cat((embeds[:, 0], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            output = output.view(batch_size,-1)
        
            
            
            predicted_word_idx = output.argmax(dim=1)
            
            
            captions.append(predicted_word_idx.item())
            
            if vocab.itos[predicted_word_idx.item()] == "<EOS>":
                break
            
            embeds = self.embedding(predicted_word_idx.unsqueeze(0))
        
        return [vocab.itos[idx] for idx in captions],alphas
    
    
    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  
        c = self.init_c(mean_encoder_out)
        return h, c

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        self.encoder = EncoderCNN()
        self.decoder = DecoderRNN(embed_size=embed_size,vocab_size = len(dataset.vocab),attention_dim=attention_dim,encoder_dim=encoder_dim,decoder_dim=decoder_dim)
        
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:

embed_size=300
vocab_size = len(dataset.vocab)
attention_dim=256
encoder_dim=1024
decoder_dim=512
learning_rate = 3e-4

In [None]:

model = EncoderDecoder(embed_size=300,vocab_size = len(dataset.vocab),attention_dim=256,encoder_dim=2048,decoder_dim=512).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
num_epochs = 1
print_every = 10
import time
start_time = time.time()

for epoch in range(1,num_epochs+1):   
    for idx, (image, captions) in enumerate(iter(data_loader)):
        print(idx)
        image,captions = image.to(device),captions.to(device)

        optimizer.zero_grad()
        outputs,attentions = model(image, captions)

        targets = captions[:,1:]
        loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))

        loss.backward()

        optimizer.step()

        if (idx+1)%print_every == 0:
            print("Epoch: {} loss: {:.5f}".format(epoch,loss.item()))            
            model.eval()
            with torch.no_grad():
                dataiter = iter(data_loader)
                img,_ = next(dataiter)
                features = model.encoder(img[0:1].to(device))
                caps,alphas = model.decoder.generate_caption(features,vocab=dataset.vocab)
                caption = ' '.join(caps)
                show_image(img[0],title=caption)                
            model.train()        
print("--- %s seconds ---" % (time.time() - start_time))

In [None]:

def get_caps_from(features_tensors):
    model.eval()
    with torch.no_grad():
        features = model.encoder(features_tensors.to(device))
        caps,alphas = model.decoder.generate_caption(features,vocab=dataset.vocab)
        caption = ' '.join(caps)
        show_image(features_tensors[0],title=caption)
    
    return caps,alphas


In [None]:
def search(photo):
    in_text='<start>'
    sequence=[dataset.vocab.numericalize(s) for s in in_text.split(" ") if dataset.vocab.frequencies[s]>=dataset.vocab.freq_threshold]
    sequence = pad_sequences([sequence], maxlen=20, padding='post')
    word,alpha=get_caps_from(photo)
    return word

In [None]:
from torchvision import transforms
from nltk.translate.bleu_score import sentence_bleu
tot_score=0
iterations=20
for idx in range(iterations):
    print(idx)
    caption = dataset.captions[idx]
    img_name = dataset.imgs[idx]
    img_location = os.path.join(dataset.root_dir,img_name)
    img = Image.open(img_location).convert("RGB")
    img1=transforms.ToTensor()(img).unsqueeze_(0)
    candidate=search(img1)
    caption=[token.text.lower() for token in spacy_eng.tokenizer(caption)]
    reference_list=[caption]
    print(reference_list)
    print(candidate)
    score = sentence_bleu(reference_list,candidate)
    tot_score+=score


avg=tot_score/iterations*1.0
print(avg)

In [None]:
img=Image.open('../input/vr-proj-images/sample5.jpg').convert("RGB")

In [None]:
from torchvision import transforms
img1=transforms.ToTensor()(img).unsqueeze_(0)

In [None]:
a,b=get_caps_from(img1)

In [None]:
show_image(img1[0],title=a)