In [None]:
from train import train
from data import initialize_loader, Flickr8k
from encoder_decoder import ResNetEncoder, Decoder
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import pickle
import random
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

In [None]:
args = AttrDict()
# You can play with the hyperparameters here, but to finish the assignment,
# there is no need to tune the hyperparameters here.
args_dict = {
    "learn_rate": 0.001,
    "batch_size": 32,
    "epochs": 5,
    "log_step": 25,
    "save_epoch": 1,
    "model_path": "models/",
    "load_model": True,
    "encoder_path": "models/encoder-5.ckpt",
    "decoder_path": "models/decoder-5.ckpt",
}
args.update(args_dict)

In [None]:
with open("vocab.pkl", 'rb') as f:
        vocab = pickle.load(f)

In [None]:
e = ResNetEncoder(256)
d = Decoder(len(vocab), 256, 512)

In [None]:
train(e, d, args)

In [None]:
transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
            ])

train_data = Flickr8k(csv_file="flickr8k/train.csv", root_dir="flickr8k/images", vocab=vocab, transform=transform)
train_loader = initialize_loader(train_data, batch_size=args.batch_size)

In [None]:
def caption_image(encoderCNN, decoderRNN, image, vocabulary, max_length=50):
    # FROM https://github.com/aladdinpersson/Machine-Learning-Collection/blob/4bd862577ae445852da1c1603ade344d3eb03679/ML/Pytorch/more_advanced/image_captioning/model.py#L49
    # NEED TO CHECK IF IT MAKES SENSE
    result_caption = []

    with torch.no_grad():
        x = encoderCNN(image).unsqueeze(0)
        states = None

        for _ in range(max_length):
            hiddens, states = decoderRNN.lstm(x, states)
            output = decoderRNN.linear(hiddens.squeeze(0))
            predicted = output.argmax(1)
            result_caption.append(predicted.item())
            x = decoderRNN.embedding(predicted).unsqueeze(0)

            if vocabulary.itos[predicted.item()] == "<eos>":
                break

    return [vocabulary.itos[idx] for idx in result_caption]

for i, (imgs, captions, lengths) in enumerate(train_loader):
    imgs = imgs.to(device)
    captions = captions.to(device)
    r = random.randint(0,32)
    img = imgs[r].unsqueeze(0).to(device)
    e.eval()
    d.eval()
    print(" ".join(caption_image(e, d, imgs[r].unsqueeze(0), vocab)))
    plt.imshow(imgs[r].permute(1, 2, 0).cpu())
    sentence = map(lambda x: train_data.vocab.itos[x], captions[r])
    print(" ".join(sentence))
    break