In [21]:
import torch
import matplotlib.pyplot as plt
import numpy as np 
import argparse
import pickle 
import os
from torchvision import transforms 
from build_vocab import Vocabulary
from model import EncoderCNN, DecoderRNN
from PIL import Image
import os

In [30]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vocab_path = "vocab.pkl"
encoder_path ="final_models/encoder-3-3000.ckpt"
decoder_path = "final_models/decoder-3-3000.ckpt"
test_data_path = "/datasets/COCO-2017/test2017/"

In [31]:
def load_image(image_path, transform=None):
    image = Image.open(image_path)
    image = image.resize([224, 224], Image.LANCZOS)
    
    if transform is not None:
        image = transform(image).unsqueeze(0)
    
    return image

In [32]:
def predict(embed_size,hidden_size,num_layers,image_paths):
    # Image net parameters 
    transform = transforms.Compose([
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])
    
    # Load vocabulary wrapper
    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build models
    encoder = EncoderCNN(embed_size).eval()  # eval mode (batchnorm uses moving mean/variance)
    decoder = DecoderRNN(embed_size, hidden_size, len(vocab), num_layers)
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    # Load the trained model parameters
    encoder.load_state_dict(torch.load(encoder_path))
    decoder.load_state_dict(torch.load(decoder_path))

    # Prepare an image
    for image_path in image_paths:
        image = load_image(image_path, transform)
        image_tensor = image.to(device)

        # Generate an caption from the image
        feature = encoder(image_tensor)
        sampled_ids = decoder.sample(feature)
        sampled_ids = sampled_ids[0].cpu().numpy()          # (1, max_seq_length) -> (max_seq_length)

        # Convert word_ids to words
        sampled_caption = []
        for word_id in sampled_ids:
            word = vocab.idx2word[word_id]
            sampled_caption.append(word)
            if word == '<end>':
                break
        sentence = ' '.join(sampled_caption)

        # Print out the image and the generated caption
        print (sentence)
        image = Image.open(image_path)
        plt.imshow(np.asarray(image))
    
# if __name__ == '__main__':
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--image', type=str, required=True, help='input image for generating caption')
#     parser.add_argument('--encoder_path', type=str, default='exp1/encoder-5-3000.ckpt', help='path for trained encoder')
#     parser.add_argument('--decoder_path', type=str, default='exp1/decoder-5-3000.ckpt', help='path for trained decoder')
#     parser.add_argument('--vocab_path', type=str, default='data/vocab.pkl', help='path for vocabulary wrapper')
    
#     # Model parameters (should be same as paramters in train.py)
#     parser.add_argument('--embed_size', type=int , default=256, help='dimension of word embedding vectors')
#     parser.add_argument('--hidden_size', type=int , default=512, help='dimension of lstm hidden states')
#     parser.add_argument('--num_layers', type=int , default=1, help='number of layers in lstm')
#     args = parser.parse_args()
#     main(args)

In [35]:
list_imgs = os.listdir(test_data_path)
image_paths = list_imgs[0:10]
predict(embed_size=256,hidden_size=512,num_layers=2,image_paths=image_paths)

RuntimeError: Error(s) in loading state_dict for DecoderRNN:
	Missing key(s) in state_dict: "lstm.weight_ih_l1", "lstm.weight_hh_l1", "lstm.bias_ih_l1", "lstm.bias_hh_l1". 
	size mismatch for lstm.weight_ih_l0: copying a param with shape torch.Size([2048, 256]) from checkpoint, the shape in current model is torch.Size([512, 256]).
	size mismatch for lstm.weight_hh_l0: copying a param with shape torch.Size([2048, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for lstm.bias_ih_l0: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for lstm.bias_hh_l0: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).