In [4]:
from models import *
import torchvision.models as models
from torch import nn
import torch
import pandas as pd
import random
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from torchvision.transforms import transforms
from nltk.tokenize import word_tokenize
from string import punctuation
from torchtext.vocab import build_vocab_from_iterator
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from models import Encoder, Decoder
import torchtext; torchtext.disable_torchtext_deprecation_warning()
torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'



In [5]:
data = pd.read_csv("./flickr8k/captions.txt")

In [6]:
def clean_text(text, lowercase=False, remove_punc=False, remove_num=False, sos_token='<sos>', eos_token='<eos>'):
    if lowercase:
        text = text.lower()
    if remove_punc:
        text = ''.join([ch for ch in text if ch not in punctuation])
    if remove_num:
        text = ''.join([ch for ch in text if ch not in '1234567890'])
    text = [sos_token] + word_tokenize(text) + [eos_token]
    return text

In [7]:
clean_text("A cat is sitting on the table.", lowercase=True, remove_punc=True, remove_num=True)

['<sos>', 'a', 'cat', 'is', 'sitting', 'on', 'the', 'table', '<eos>']

In [8]:
unk_token = '<unk>'
pad_token = '<pad>'
sos_token = '<sos>'
eos_token = '<eos>'

In [9]:
clean_cap = data['caption'].apply(lambda x: clean_text(x, lowercase=True, remove_punc=True, remove_num=True))

In [10]:
data['clean_caption'] = clean_cap

In [11]:
vocab = build_vocab_from_iterator(clean_cap, specials=[unk_token, pad_token, sos_token, eos_token])

In [12]:
pad_token_idx = vocab[pad_token]
unk_token_idx = vocab[unk_token]

In [13]:
vocab.set_default_index(unk_token_idx)

In [14]:
# to number
def text_to_number(text, vocab):
    return [vocab[token] for token in text]

In [15]:
to_int = clean_cap.apply(lambda x: text_to_number(x, vocab))

In [16]:
data['embed_caption'] = to_int

In [17]:
train, test = train_test_split(data, test_size=0.2, random_state=42)
train = train.reset_index(drop=True)
test = test.reset_index(drop=True)

In [18]:
test.head()

Unnamed: 0,image,caption,clean_caption,embed_caption
0,2973269132_252bfd0160.jpg,A large wild cat is pursuing a horse across a ...,"[<sos>, a, large, wild, cat, is, pursuing, a, ...","[2, 4, 56, 1693, 584, 8, 4821, 4, 229, 125, 4,..."
1,270263570_3160f360d3.jpg,Two brown dogs fight on the leafy ground .,"[<sos>, two, brown, dogs, fight, on, the, leaf...","[2, 14, 28, 32, 517, 7, 6, 1525, 170, 3]"
2,2053006423_6adf69ca67.jpg,A man in shorts is standing on a rock looking ...,"[<sos>, a, man, in, shorts, is, standing, on, ...","[2, 4, 12, 5, 161, 8, 39, 7, 4, 85, 89, 84, 23..."
3,512101751_05a6d93e19.jpg,a muzzled white dog is running on the grass .,"[<sos>, a, muzzled, white, dog, is, running, o...","[2, 4, 900, 15, 10, 8, 33, 7, 6, 42, 3]"
4,3156406419_38fbd52007.jpg,A person skiing downhill .,"[<sos>, a, person, skiing, downhill, <eos>]","[2, 4, 44, 377, 709, 3]"


In [19]:
transform = transforms.Compose([
    # data type convert to tensor
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    
])

In [20]:
embed_dim = 256
hidden_dim = 512
vocab_size = len(vocab)
num_layers = 2
dropout = 0.5

In [21]:
encoder = Encoder(embed_dim, dropout)
model_load = Decoder(embed_dim, hidden_dim, vocab_size, num_layers, device, encoder, dropout )

In [22]:
model_load.load_state_dict(torch.load('best-model.pt', map_location=device))

<All keys matched successfully>

In [23]:
def predict_caption(model, image, vocab, max_length=30):
    model.eval()
    with torch.no_grad():
        features = model.encoder(image).unsqueeze(1)
        input = features
        hidden = torch.zeros(model.num_layers, 1, model.lstm.hidden_size).to(model.device)
        cell = torch.zeros(model.num_layers, 1, model.lstm.hidden_size).to(model.device)

        caption = []
        for _ in range(max_length):
            output, (hidden, cell) = model.lstm(input, (hidden, cell))
            output = model.linear(output.squeeze(1))
            predicted = output.argmax(1)
            caption.append(predicted.item())
            input = model.dropout(model.embed(predicted)).unsqueeze(1)
            if predicted.item() == vocab['<eos>']:
                break
    pred = vocab.lookup_tokens(caption)      
    return ' '.join(pred[1:-1])

In [24]:
import nltk
from nltk.translate.bleu_score import corpus_bleu
import matplotlib.pyplot as plt

In [58]:
test_data = test[['image', 'caption', 'clean_caption']].values

In [59]:
test_data

array([['2973269132_252bfd0160.jpg',
        'A large wild cat is pursuing a horse across a meadow .',
        list(['<sos>', 'a', 'large', 'wild', 'cat', 'is', 'pursuing', 'a', 'horse', 'across', 'a', 'meadow', '<eos>'])],
       ['270263570_3160f360d3.jpg',
        'Two brown dogs fight on the leafy ground .',
        list(['<sos>', 'two', 'brown', 'dogs', 'fight', 'on', 'the', 'leafy', 'ground', '<eos>'])],
       ['2053006423_6adf69ca67.jpg',
        'A man in shorts is standing on a rock looking out at the view from the hilltop .',
        list(['<sos>', 'a', 'man', 'in', 'shorts', 'is', 'standing', 'on', 'a', 'rock', 'looking', 'out', 'at', 'the', 'view', 'from', 'the', 'hilltop', '<eos>'])],
       ...,
       ['2848895544_6d06210e9d.jpg',
        'Two little boys in uniforms play soccer .',
        list(['<sos>', 'two', 'little', 'boys', 'in', 'uniforms', 'play', 'soccer', '<eos>'])],
       ['431410325_f4916b5460.jpg',
        'A wet brown dog is leaving the water .',
        

In [97]:
def calculate_bleu_score(model, data, vocab, device):
    captioning_corpus = []
    reference_corpus = []
    for img, caption, clean_cap in data:
        img = Image.open(f"./flickr8k/Images/{img}")
        img = transform(img).unsqueeze(0).to(device)

        captioned = predict_caption(model, img, vocab)
        
        captioning_corpus.append(captioned.split())
        reference_corpus.append([clean_cap[1:-1]])
    
    return corpus_bleu(reference_corpus, captioning_corpus)

In [98]:
calculate_bleu_score(model_load, test_data, vocab, device)

0.04511407184799061