In [1]:
from data import initialize_loader, Flickr8k
from encoder_decoder import ResNetEncoder, Decoder, ResNetAttentionEncoder, DecoderWithAttention
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import pickle
import random
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

In [3]:
args = AttrDict()
# You can play with the hyperparameters here, but to finish the assignment,
# there is no need to tune the hyperparameters here.
args_dict = {
    "learn_rate": 0.001,
    "batch_size": 32,
    "epochs": 5,
    "log_step": 15,
    "save_epoch": 1,
    "model_path": "models/",
    "load_model": False,
    "encoder_path": "models/encoder-attention-5.ckpt",
    "decoder_path": "models/decoder-attention-5.ckpt",
}
args.update(args_dict)

In [4]:
with open("vocab.pkl", 'rb') as f:
        vocab = pickle.load(f)

In [5]:
# e = ResNetEncoder(256)
# d = Decoder(len(vocab), 256, 512)
e = ResNetAttentionEncoder(256)
d = DecoderWithAttention(len(vocab), 256, 512, 512, 128)

In [6]:
# train(e, d, args)

In [7]:
transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
            ])

train_data = Flickr8k(csv_file="flickr8k/train.csv", root_dir="flickr8k/images", vocab=vocab, transform=transform)
train_loader = initialize_loader(train_data, batch_size=args.batch_size)
val_data = Flickr8k(csv_file="flickr8k/val.csv", root_dir="flickr8k/images", vocab=vocab, transform=transform)

In [8]:
from validation import *
e.load_state_dict(torch.load(args.encoder_path, map_location=torch.device('cpu')))
d.load_state_dict(torch.load(args.decoder_path, map_location=torch.device('cpu')))

<All keys matched successfully>

In [9]:
# for i, (imgs, captions, lengths) in enumerate(train_loader):
#     imgs = imgs.to(device)
#     captions = captions.to(device)
#     e.eval()
#     d.eval()
#     c = caption_image(e, d, imgs[0].unsqueeze(0).to(device), vocab)
#     print(c)
#     captions = bulk_caption_image(e,d,imgs, vocab)
#     l = []
#     for i in range(32 // 5):
#         l.append(captions[i*5])
#     print(l)
    # r = random.randint(0,32)
    # img = imgs[r].unsqueeze(0).to(device)
    # e.eval()
    # d.eval()
    # print(" ".join(caption_image(e, d, imgs[r].unsqueeze(0), vocab)))
    # plt.imshow(imgs[r].permute(1, 2, 0).cpu())
    # sentence = map(lambda x: train_data.vocab.itos[x], captions[r])
    # print(" ".join(sentence))
    # break

In [10]:
from validation import *
# x = p()
# print(x)
# bleu_score = evaluate_bleu_batch(e, d, vocab, val_data, 
#                 attention=True,
#                 maxn_gram=1,
#                 batch_size=128)
# c,r = get_captions_and_references(e, d, vocab, val_data, 
#                 attention=True,
#                 batch_size=128)

In [11]:
# for i, (imgs, captions, lengths) in enumerate(train_loader):
#     imgs = imgs.to(device)
#     with torch.no_grad():
#         features = e(imgs)
#         caps = d.generate_caption(features, vocab=vocab)
#         break


In [12]:
# bleu_score

In [13]:
# print(len(c), len(r))
# print(c[10], r[10])

In [14]:
# bleus = validation([e], [d], vocab, val_data, bleu_max=4, attention=True)
bleus = validation_bleu3(e, d, vocab, val_data, attention=True)

56,
        0.05710256],
       [0.04303322, 0.0453335 , 0.04716038, ..., 0.00947117, 0.00931784,
        0.01250342],
       ...,
       [0.03889818, 0.04174523, 0.04647877, ..., 0.00377324, 0.10699219,
        0.07076762],
       [0.02735622, 0.01556078, 0.02026512, ..., 0.02177152, 0.0059887 ,
        0.00577473],
       [0.11323255, 0.10237764, 0.11904323, ..., 0.0039934 , 0.00433815,
        0.00621169]], dtype=float32), array([[0.06043086, 0.03857721, 0.01603036, ..., 0.05745499, 0.05192371,
        0.04487402],
       [0.02543471, 0.01445676, 0.03011228, ..., 0.0174533 , 0.05034891,
        0.08675915],
       [0.08412878, 0.06264436, 0.04929078, ..., 0.00510818, 0.00589486,
        0.0109731 ],
       ...,
       [0.0579655 , 0.06067324, 0.06442319, ..., 0.00311828, 0.05409901,
        0.13499984],
       [0.04751624, 0.01466253, 0.0198857 , ..., 0.02091633, 0.0036224 ,
        0.0031801 ],
       [0.04360093, 0.03704317, 0.08604947, ..., 0.01123418, 0.0122946 ,
        0.02156

KeyboardInterrupt: 

In [18]:
print(bleus)

[[0.5746132135391235], [0.37304043769836426], [0.24396073818206787], [0.16179923713207245]]
