In [1]:
from train import train
from data import initialize_loader, Flickr8k
from encoder_decoder import ResNetEncoder, Decoder
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import pickle
import random
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

In [3]:
args = AttrDict()
# You can play with the hyperparameters here, but to finish the assignment,
# there is no need to tune the hyperparameters here.
args_dict = {
    "learn_rate": 0.001,
    "batch_size": 32,
    "epochs": 5,
    "log_step": 15,
    "save_epoch": 1,
    "model_path": "models/",
    "load_model": False,
    "encoder_path": "models/encoder-5.ckpt",
    "decoder_path": "models/decoder-5.ckpt",
}
args.update(args_dict)

In [4]:
with open("vocab.pkl", 'rb') as f:
        vocab = pickle.load(f)

In [5]:
e = ResNetEncoder(256)
d = Decoder(len(vocab), 256, 512)

In [6]:
# train(e, d, args)

In [7]:
transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
            ])

train_data = Flickr8k(csv_file="flickr8k/train.csv", root_dir="flickr8k/images", vocab=vocab, transform=transform)
train_loader = initialize_loader(train_data, batch_size=args.batch_size)
val_data = Flickr8k(csv_file="flickr8k/val.csv", root_dir="flickr8k/images", vocab=vocab, transform=transform)

In [8]:
from validation import *
e.load_state_dict(torch.load(args.encoder_path, map_location=torch.device('cpu')))
d.load_state_dict(torch.load(args.decoder_path, map_location=torch.device('cpu')))

<All keys matched successfully>

In [9]:
for i, (imgs, captions, lengths) in enumerate(train_loader):
    imgs = imgs.to(device)
    captions = captions.to(device)
    e.eval()
    d.eval()
    c = caption_image(e, d, imgs[0].unsqueeze(0).to(device), vocab)
    print(c)
    captions = bulk_caption_image(e,d,imgs, vocab)
    l = []
    for i in range(32 // 5):
        l.append(captions[i*5])
    print(l)
    # r = random.randint(0,32)
    # img = imgs[r].unsqueeze(0).to(device)
    # e.eval()
    # d.eval()
    # print(" ".join(caption_image(e, d, imgs[r].unsqueeze(0), vocab)))
    # plt.imshow(imgs[r].permute(1, 2, 0).cpu())
    # sentence = map(lambda x: train_data.vocab.itos[x], captions[r])
    # print(" ".join(sentence))
    break

['a', 'dog', 'with', 'a', 'tennis', 'ball', 'in', 'its', 'mouth', 'is', 'running', 'in', 'a', 'field', '.']
[['a', 'dog', 'with', 'a', 'tennis', 'ball', 'in', 'its', 'mouth', 'is', 'running', 'in', 'a', 'field', '.'], ['a', 'little', 'girl', 'in', 'a', 'pink', 'shirt', 'and', 'a', 'boy', 'in', 'a', 'white', 'shirt', 'and', 'a', 'boy', 'in', 'a', 'blue', 'shirt', 'and', 'white', 'shoes', 'jumping', 'on', 'a', 'trampoline', '.'], ['a', 'baby', 'laughs', 'at', 'a', 'table', '.'], ['a', 'woman', 'with', 'a', 'pink', 'headband', 'and', 'black', 'skirt', 'plays', 'with', 'a', 'hula', '-', 'camera', 'with', 'a', 'beer', '.'], ['a', 'young', 'boy', 'is', 'smiling', 'whilst', 'playing', 'a', 'guitar', '.'], ['a', 'dog', 'running', 'in', 'the', 'snow', '.']]


In [10]:
from validation import *
evaluate_bleu_batch(e, d, vocab, val_data, batch_size=128)

128 128
0.16090047359466553
128 128
0.1599031612277031
128 128
0.15887526671091715
128 128
0.1607203111052513
128 128
0.16056452989578246
128 128
0.16489560157060623
128 128
0.1664470966373171
128 128
0.16793225891888142
128 128
0.17027768823835585
128 128
0.1702870950102806
128 128
0.1697507703846151
128 128
0.16924437632163367
128 128
0.16875636462981886
128 128
0.16763258512531007
128 128
0.1675866295893987


KeyboardInterrupt: 