In [16]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.nn.utils.rnn import pack_padded_sequence

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pickle
from build_vocab import Vocabulary

##### Create model

In [23]:
class Encoder(nn.Module):
    def __init__(self, embed_size, backbone, pretrained, fine_tune):
        super(Encoder, self).__init__()
        if backbone == 'densenet201':
            model = models.densenet201(pretrained=pretrained)
            self.num_ftrs = model.classifier.in_features
            self.features = model.features
        
        self.fc = nn.Linear(self.num_ftrs, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
        
        if not fine_tune:
            for param in self.features.parameters():
                param.requires_grad = False
            
    def forward(self, x):
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.bn(self.fc(x))
        
        return features

In [18]:
class Decoder(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.max_seq_length = max_seq_length
        
    def forward(self, features, captions, length):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embed), 1)
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
        hiddens, _ = self.lstm(packed)
        outputs = self.linear(hiddens[0])
        
        return outputs
    
    

In [19]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [20]:
def load_image(image_path, transform=None):
    img = Image.open(image_path).convert('RGB')
    img = Image.resize([224, 224], Image.LANCZOS)
    
    if transform is not None:
        img = transform(img).unsqueeze(0)
        
    return img
    

In [21]:
def main(args):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    
    transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean, std)])
    
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)
        
    encoder = Encoder(args.embed_size, 'densenet201', True, False).eval().to(device)
    decoder = Decoder(args.embed_size, args.hidden_size, len(vocab), args.num_layers).to(device)
    
    encoder.load_state_dict(torch.load(args.encoder_path))
    decoder.load_state_dict(torch.load(args.decoder_path))
    
    img = load_image(args.image, transform)
    