In [1]:
from train import train
from data import initialize_loader, Flickr8k
from encoder_decoder import ResNetEncoder, Decoder, DecoderWithAttention, ResNetAttentionEncoder
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import pickle
import random
import os
from utils import *
from validation import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\Ayush\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

In [3]:
args = AttrDict()
# You can play with the hyperparameters here, but to finish the assignment,
# there is no need to tune the hyperparameters here.
args_dict = {
    "embed_size": 256,
    "hidden_size": 512,
    "encoder_dim": 512, # MUST MATCH THE RESNET ENCODER OUTPUT
    "attention_dim": 128,
    "learn_rate": 0.001,
    "batch_size": 32,
    "epochs": 10,
    "log_step": 25,
    "save_epoch": 1,
    "model_path": "models/",
    "load_model": False,
    "encoder_path": "models/encoder-attention-7.ckpt",
    "decoder_path": "models/decoder-attention-7.ckpt",
    "model_type": "attention",
    "early_stopping_patience": 2,
    "early_stopping_metric": "bleu",
    "finetune_attention": True
}
args.update(args_dict)

In [4]:
train(args)

cuda
Epoch [1/10], Step [0/1012], Loss: 9.0763
Epoch [1/10], Step [25/1012], Loss: 4.9860
Epoch [1/10], Step [50/1012], Loss: 4.4838
Epoch [1/10], Step [75/1012], Loss: 4.3801
Epoch [1/10], Step [100/1012], Loss: 4.0152
Epoch [1/10], Step [125/1012], Loss: 3.9713
Epoch [1/10], Step [150/1012], Loss: 4.1160
Epoch [1/10], Step [175/1012], Loss: 3.6588
Epoch [1/10], Step [200/1012], Loss: 3.4512
Epoch [1/10], Step [225/1012], Loss: 3.5083
Epoch [1/10], Step [250/1012], Loss: 3.3980
Epoch [1/10], Step [275/1012], Loss: 3.5858
Epoch [1/10], Step [300/1012], Loss: 3.6223
Epoch [1/10], Step [325/1012], Loss: 3.5001
Epoch [1/10], Step [350/1012], Loss: 3.5167
Epoch [1/10], Step [375/1012], Loss: 3.3607
Epoch [1/10], Step [400/1012], Loss: 3.6197
Epoch [1/10], Step [425/1012], Loss: 3.1537
Epoch [1/10], Step [450/1012], Loss: 3.4472
Epoch [1/10], Step [475/1012], Loss: 3.1947
Epoch [1/10], Step [500/1012], Loss: 3.2218
Epoch [1/10], Step [525/1012], Loss: 3.3452
Epoch [1/10], Step [550/1012], L

In [5]:
with open("vocab.pkl", 'rb') as f:
        vocab = pickle.load(f)
        
transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
            ])

train_data = Flickr8k(csv_file="flickr8k/train.csv", root_dir="flickr8k/images", vocab=vocab, transform=transform)
train_loader = initialize_loader(train_data, batch_size=args.batch_size)
val_data = Flickr8k(csv_file="flickr8k/val.csv",
                        root_dir="flickr8k/images", vocab=vocab, transform=transform)
val_loader = initialize_loader(val_data, batch_size=args.batch_size)

In [6]:
if args.model_type == "attention":
    e = ResNetAttentionEncoder(args.embed_size)
    d = DecoderWithAttention(len(
        vocab), args.embed_size, args.hidden_size, args.encoder_dim, args.attention_dim)
else:
    e = ResNetEncoder(args.embed_size)
    d = Decoder(len(vocab), args.embed_size, args.hidden_size)

e.load_state_dict(torch.load(args.encoder_path))
d.load_state_dict(torch.load(args.decoder_path))

e.to(device)
d.to(device)

rand_num = random.randint(0, 1000)
img = val_data[rand_num][0]
if args.model_type == "attention":
    caps, alphas = get_caption_attention(e, d, img, vocab)
    plot_attention(img, caps, alphas)
else:
    caps = get_caption_lstm(e, d, img, vocab)

FileNotFoundError: [Errno 2] No such file or directory: 'models/encoder-attention-7.ckpt'

In [7]:
validation_bleu1(e, d, vocab, val_data, attention=(True if args.model_type == "attention" else False))

0.6074705719947815