In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import argparse
import pickle
import os
import glob
from torchvision import transforms
from build_vocab import Vocabulary
from model import EncoderCNNWithAttention, DecoderRNNWithAttention
from PIL import Image

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def load_image(image_path, transform=None):
    image = Image.open(image_path).convert('RGB')
    image = image.resize([224, 224], Image.LANCZOS)

    if transform is not None:
        image = transform(image).unsqueeze(0)

    return image

In [4]:
image_path = 'data/images/elephants.png'
encoder_path = 'models/encoder-with-attention-3-3000.ckpt'
decoder_path = 'models/decoder-with-attention-3-3000.ckpt'
vocab_path = 'data/vocab.pkl'
embed_size = 512
hidden_size = 512
encoded_size = 512
num_layers = 1
pixel_num = 16

In [5]:
def main(image, encoder_path, decoder_path, vocab_path, embed_size, hidden_size, num_layers):
    # Image preprocessing
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])

    # Load vocabulary wrapper
    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build models
    # eval mode (batchnorm uses moving mean/variance)
    encoder = EncoderCNNWithAttention(embed_size).eval()
    decoder = DecoderRNNWithAttention(
        embed_size,
        hidden_size,
        len(vocab),
        num_layers, encoded_size, device).eval()
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    # Load the trained model parameters
    try:
        if torch.cuda.is_available():
            encoder.load_state_dict(torch.load(encoder_path))
        else:
            encoder.load_state_dict(
                torch.load(
                    encoder_path,
                    map_location=torch.device('cpu')))
    except BaseException as e:
        print(e)
    try:
        if torch.cuda.is_available():
            decoder.load_state_dict(torch.load(decoder_path))
        else:
            decoder.load_state_dict(
                torch.load(
                    decoder_path,
                    map_location=torch.device('cpu')))
    except BaseException as e:
        print(e)

    # Prepare an image
    image = load_image(image, transform)
    image_tensor = image.to(device)

    # Generate an caption from the image
    with torch.no_grad():
        feature = encoder(image_tensor)
        sampled_ids = decoder.sample(feature)
    # (1, max_seq_length) -> (max_seq_length)
    sampled_ids = sampled_ids[0].cpu().numpy()
    print(sampled_ids)

    # Convert word_ids to words
    sampled_caption = []
    for word_id in sampled_ids:
        word = vocab.idx2word[word_id]
        sampled_caption.append(word)
        if word == '<end>':
            break
    sentence = ' '.join(sampled_caption)
    return sentence

In [6]:
image_path = 'data/images/elephants.png'
sentence = main(image_path, encoder_path, decoder_path, vocab_path, embed_size, hidden_size, num_layers)

  image = image.resize([224, 224], Image.LANCZOS)


torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
torch.Size([1, 1, 512])
torch.Size([1, 512])
[  4 371 131   4   4   4   4   4  19  19  19  19  19  19  19  19  19  19
  19  19]
