In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
#for dirname, _, filenames in os.walk('/kaggle/input'):
 #   for filename in filenames:
  #      print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from collections import Counter 
import torchvision
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

### Text cleaning

In [None]:
import nltk
import string
nltk.download('stopwords')
nltk.download('punkt')
nltk.download('wordnet')
from nltk.stem import PorterStemmer
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords

In [None]:
#lower word
def lower_text(text):
    return str(text).lower()

In [None]:
#Puntuations
remove=string.punctuation
def remove_text(text):
    text=' '.join(w for w in text.split() if w not in remove)
    return text    

In [None]:
nltk.download('stopwords')
stopwords_eng = stopwords.words('english') 
#remove stopwords
def stop_words(text):
    text=' '.join(w for w in text.split() if w not in stopwords_eng)
    return text

In [None]:
#Stemming
stemmer = PorterStemmer()
def stem_words(text):
    text=' '.join(stemmer.stem(word) for word in text.split())
    return text

In [None]:
#Lemmatization
nltk.download('wordnet')
from nltk.stem import WordNetLemmatizer
lemmatizer = WordNetLemmatizer()
def lemmatize_words(text):
    text=' '.join(lemmatizer.lemmatize(word) for word in text.split())
    return text

In [None]:
def text_cleaning(text):
    text=lower_text(text)
    text=remove_text(text)
    #text=stop_words(text)
    #text=stem_words(text)
    #text=lemmatize_words(text)
    return text
def list_cleaning(list_text):
    for i in range(len(list_text)):
        list_text[i]=text_cleaning(list_text[i])
    return list_text

### Text files

In [None]:
path='/kaggle/input/flickr8k/'
os.listdir(path)

In [None]:
#paths 
img_path='/kaggle/input/flickr8k/Flickr8k_Dataset/'
text_path='/kaggle/input/flickr8k/Flickr8k_text/'

In [None]:
#List of files
img_files=os.listdir(img_path)
text_files=os.listdir(text_path)

In [None]:
#Open text files
def open_file(path):
    file=open(path,'r')
    text=file.read()
    text=text.split('\n')
    return text

In [None]:
cap_lem=open_file(text_path+text_files[3])
len(cap_lem)

Training text

In [None]:
train_files=open_file(text_path+text_files[6])
train_files=train_files[:-1]

In [None]:
dev_files=open_file(text_path+text_files[2])
dev_files=dev_files[:-1]
train_files=train_files+dev_files
len(train_files)

Validation text

In [None]:
val_files=open_file(text_path+text_files[4])
val_files=val_files[:-1]
len(val_files)

Description of captions

In [None]:
desc={}
for i in range(len(cap_lem)-1):
    img, caption=cap_lem[i].split('\t')
    if img[:-2] not in desc:
        desc[img[:-2]]=[caption[:-2]]
    else:
        desc[img[:-2]].append(caption[:-2])

In [None]:
desc[img_files[1]]

In [None]:
for key in desc.keys():
    desc[key]=list_cleaning(desc[key])

In [None]:
desc[img_files[1]]

Tokenization

In [None]:
from nltk.tokenize import word_tokenize

In [None]:
vocab=set()
for key in desc.keys():
    for i in range(len(desc[key])):
        vocab.update(word_tokenize(desc[key][i]))  

In [None]:
idx_to_word={idx+3: word for idx, word in enumerate(vocab)}
word_to_idx={word: idx+3 for idx, word in enumerate(vocab)}    

In [None]:
idx_to_word[0]='UNK'
word_to_idx['UNK']=0

In [None]:
sos_token=1
eos_token=2
vocab.update({'SOS', 'EOS'})
word_to_idx['SOS']=sos_token
word_to_idx['EOS']=eos_token
idx_to_word[sos_token]='SOS'
idx_to_word[eos_token]='EOS'
vocab_size=len(vocab)+1
vocab_size

Padding

In [None]:
#from torch.nn.utils.rnn import pad_sequence
#padding
def padded(x,n):
    pad=np.zeros(n)
    pad[0]=1
    pad[len(x)+1]=2
    pad[1:len(x)+1]=x
    return pad

In [None]:
desc_short={}
for key in desc.keys():
    desc_short[key]=[word_to_idx[w] for w in word_tokenize(desc[key][0])]  
    desc_short[key]=padded(desc_short[key],35)

In [None]:
max_length=max(len(desc_short[w]) for w in desc_short.keys())
max_length

Dataset

In [None]:
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, TensorDataset

In [None]:
transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
class img_dataset(Dataset):
    def __init__(self, files, path, transform):
        self.path=path
        self.transform=transform
        self.files=files
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        img_id=self.files[idx]
        cap=desc_short[img_id]
        img=Image.open(self.path+img_id)
        img=self.transform(img)
        return img, cap, img_id    

In [None]:
data_train=img_dataset(train_files, img_path, transform)
data_loader=DataLoader(data_train, batch_size=32)

In [None]:
for _, (x,y,z) in enumerate(data_train):
    print(y)
    print(desc[z][0])
    plt.imshow(np.transpose(x.numpy()))
    break

In [None]:
for _, (x,y,z) in enumerate(data_loader):
    print(y.shape)
    print(len(z))
    print(x.shape)
    break

Test dataset

In [None]:
data_val=img_dataset(val_files, img_path, transform)
val_loader=DataLoader(data_val, batch_size=32)

In [None]:
for i, (x,y,z) in enumerate(val_loader):
    print(i)
    print(y.shape)
    print(len(z))
    print(x.shape)
    break

### Encoder model

In [None]:
from torchvision.models import resnet50

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        self.resnet = resnet50(pretrained=True)
        for param in self.resnet.parameters():
            param.requires_grad_(False)
        self.relu = nn.ReLU()
    def forward(self, images):
        features = self.resnet(images)
        return self.relu(features)

In [None]:
ecd=EncoderCNN().to(device)
x_out=ecd(x[:1,:,:,:].to(device))
x_out.shape

### Decoder

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self, vocab_size=vocab_size, emb_size=100, hidden_size=256, feature_size=1000):
        super(DecoderRNN, self).__init__()
        self.linear1=nn.Linear(feature_size, hidden_size)
        self.embed = nn.Embedding(vocab_size, emb_size)
        self.lstm = nn.LSTM(emb_size+hidden_size, hidden_size, batch_first=True)
        self.linear2 = nn.Linear(hidden_size, vocab_size)
        self.relu=nn.ReLU()
        #self.softmax=nn.Softmax()
    def forward(self, encoder_outputs, target=None):
        encoder_outputs=self.linear1(encoder_outputs)
        batch_size=encoder_outputs.size(0)
        decoder_input=torch.ones(batch_size, 1, dtype=torch.long)
        decoder_hidden=(encoder_outputs.unsqueeze(0), encoder_outputs.unsqueeze(0))
        decoder_outputs=[]
        for i in range(max_length):
            decoder_emb=self.embed(decoder_input.to(device))
            decoder_emb=self.relu(decoder_emb)
            lstm_input=torch.cat((decoder_emb, encoder_outputs.unsqueeze(1)), dim=2)
            decoder_output, decoder_hidden= self.lstm(lstm_input, decoder_hidden)
            decoder_output=self.relu(decoder_output)
            decoder_output=self.linear2(decoder_output)
            decoder_outputs.append(decoder_output)#.squeeze(1))
            if target is not None:
                decoder_input=target[:,i].unsqueeze(1).to(torch.long)
            else:
                _, dcd_input=decoder_output.topk(1)
                decoder_input=dcd_input.squeeze(-1).detach()
        decoder_outputs=torch.cat(decoder_outputs, dim=1)
        decoder_output=nn.functional.log_softmax(decoder_output, dim=-1)
        return decoder_outputs   

In [None]:
dcd=DecoderRNN().to(device)
x_out_dcd=dcd(x_out)
x_out_dcd.shape

### Training

In [None]:
encoder=EncoderCNN().to(device)
decoder=DecoderRNN().to(device)
criterion=nn.CrossEntropyLoss()
opt=torch.optim.Adam(decoder.parameters(), lr=0.005)

In [None]:
def training(encoder, decoder, data, n_epochs):
    Loss=[]
    for epoch in range(n_epochs):
        loss_train=0
        for i,(img,cap,img_id) in enumerate(data):
            encoder_output=encoder(img.to(device))
            cap=cap.to(device)
            #zero gradient
            opt.zero_grad()
            #forward
            decoder_out=decoder(encoder_output, cap)
            loss=criterion(decoder_out.view(-1, decoder_out.size(-1)),
                       cap.view(-1).to(torch.int64))#.to(torch.float)
            #backward 
            loss.backward()
            #update
            opt.step()
            loss_train+=loss.item()
        Loss.append(loss_train/len(data))
        if epoch%10==0:
            print(f'Epoch: {epoch}, Loss: {loss_train/len(data)}')   
    return Loss

In [None]:
Loss=training(encoder, decoder, data_loader, 101)

In [None]:
import seaborn as sns
sns.set()
plt.plot(Loss)

In [None]:
def val_file(ecd, dcd, img_file):
    ecd.eval()
    dcd.eval()
    with torch.no_grad():
        img=Image.open(img_path+img_file)
        img=transform(img)
        img=img.unsqueeze(dim=0)
        encoder_output=ecd(img.to(device))
        decoder_outputs=dcd(encoder_output)
        _, word_idx=decoder_outputs.topk(1)
        decoder_idx=word_idx.squeeze()
        cap=[]
        for idx  in decoder_idx:
            if idx.item()==eos_token:
                cap.append('EOS')
                break
            else:
                cap.append(idx_to_word[idx.item()])
        caption=' '.join(cap[1:-1])
    return caption

In [None]:
n=torch.randint(0,1000,(1,))
caption_val=val_file(encoder, decoder, val_files[n])
plt.imshow(plt.imread(img_path+val_files[n]))
plt.title(caption_val)
print(desc[val_files[n]][0])

### With Badauran Attention model

Encoder

In [None]:
from torchvision.models import resnet50

In [None]:
class EncoderAttention(nn.Module):
    def __init__(self):
        super(EncoderAttention, self).__init__()
        resnet = resnet50(pretrained=True)
        self.resnet=nn.Sequential(*list(resnet.children())[:-2])
        self.relu = nn.ReLU()
    def forward(self, images):
        features = self.resnet(images)
        features=self.relu(features)
        features=features.view(features.size(0), 2048,-1)
        features=features.permute(0,2,1)
        return features

In [None]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size=256, attention_dim=64):
        super(BahdanauAttention, self).__init__()
        self.Ua=nn.Linear(hidden_size, attention_dim)
        self.Wa=nn.Linear(hidden_size, attention_dim)
        self.Va=nn.Linear(attention_dim,1)
    def forward(self, decoder_hidden, encoder_hidden):
        sum1=self.Ua(decoder_hidden)
        sum2=self.Wa(encoder_hidden)
        scores=self.Va(torch.tanh(sum1+sum2))
        scores=scores.squeeze(2).unsqueeze(1)
        weights=nn.functional.softmax(scores,dim=-1)
        context=torch.bmm(weights, encoder_hidden)
        return context       

In [None]:
class DecoderAttention(nn.Module):
    def __init__(self, vocab_size=vocab_size, feature_size=2048, emb_size=100, hidden_size=256):
        super(DecoderAttention, self).__init__()
        self.linear1 = nn.Linear(feature_size, hidden_size)
        self.embed = nn.Embedding(vocab_size, emb_size)
        self.attention=BahdanauAttention()
        self.lstm = nn.LSTM(emb_size+hidden_size,hidden_size, batch_first=True)
        self.linear2 = nn.Linear(hidden_size, vocab_size)
        self.relu=nn.ReLU()
        #self.softmax=nn.Softmax()
    def forward(self, encoder_outputs, target=None):
        encoder_outputs=self.linear1(encoder_outputs)#batch_size, pixed, hidden_size 
        batch_size=encoder_outputs.size(0)
        decoder_input=torch.ones(batch_size, 1, dtype=torch.long)
        decoder_hidden=encoder_outputs.mean(dim=1).unsqueeze(dim=0)# 1, batchsize, hidden_size
        decoder_cell=encoder_outputs.mean(dim=1).unsqueeze(dim=0)# 1, batch_size, hidden_size
        decoder_outputs=[]
        for i in range(max_length):
            decoder_emb=self.embed(decoder_input.to(device))
            decoder_emb=self.relu(decoder_emb) # batch_size,1, emb_size
            hidden_att=decoder_hidden.permute(1,0,2)## batch_size,1, emb_size
            context=self.attention(hidden_att, encoder_outputs) # batch_size, 1, hidden_size
            input_lstm=torch.cat((decoder_emb, context), dim=2)
            decoder_output, (decoder_hidden, decoder_cell)= self.lstm(input_lstm, (decoder_hidden, decoder_cell))
            decoder_output=self.relu(decoder_output)
            decoder_output=self.linear2(decoder_output)
            decoder_outputs.append(decoder_output)#.squeeze(1))
            if target is not None:
                decoder_input=target[:,i].unsqueeze(1).to(torch.long)
            else:
                _, dcd_input=decoder_output.topk(1)
                decoder_input=dcd_input.squeeze(-1).detach()
        decoder_outputs=torch.cat(decoder_outputs, dim=1)
        decoder_output=nn.functional.log_softmax(decoder_output, dim=-1)
        return decoder_outputs   

### Training attention model

In [None]:
encoder_att=EncoderAttention().to(device)
decoder_att=DecoderAttention().to(device)
criterion=nn.CrossEntropyLoss()
opt=torch.optim.Adam(decoder_att.parameters())

In [None]:
LossAtt=training(encoder_att, decoder_att, data_loader, 101)

In [None]:
import seaborn as sns
sns.set()
plt.plot(LossAtt)

In [None]:
def val_file(ecd, dcd, img_file):
    ecd.eval()
    dcd.eval()
    with torch.no_grad():
        img=Image.open(img_path+img_file)
        img=transform(img)
        img=img.unsqueeze(dim=0)
        encoder_output=ecd(img.to(device))
        decoder_outputs=dcd(encoder_output)
        _, word_idx=decoder_outputs.topk(1)
        decoder_idx=word_idx.squeeze()
        cap=[]
        for idx  in decoder_idx:
            if idx.item()==eos_token:
                cap.append('EOS')
                break
            else:
                cap.append(idx_to_word[idx.item()])
        caption=' '.join(cap[1:-1])
    return caption

In [None]:
n=torch.randint(0,1000,(1,))
caption_val=val_file(encoder_att, decoder_att, val_files[n])
plt.imshow(plt.imread(img_path+val_files[n]))
plt.title(caption_val)