In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import os
from torch import nn
from torch.optim import Adam
import torchvision
from IPython import display
import matplotlib.pyplot as plt
import numpy as np
from PIL import ImageDraw
from PIL import ImageFont

In [None]:
from dataset import MemeCaptionDataset
from captionmodel import EncoderCNN, DecoderRNN
from torchgan.models import DCGANGenerator
from infer_caption import pred_vec_to_text

from settings import caption_batch_size, workers

to_pil = torchvision.transforms.ToPILImage(mode='RGB')

# Point these to the model checkpoints

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# GAN_PARAMS_TO_LOAD = 'trained_model.model'
GAN_PARAMS_TO_LOAD = 'gan0.model'
ENCODER_PARAMS_TO_LOAD = 'encoder-440.pth'
DECODER_PARAMS_TO_LOAD = 'decoder-440.pth'

GAN_CKPT_PATH = f'./model/{GAN_PARAMS_TO_LOAD}'
ENCODER_CKPT_PATH = f'./caption-model-ckpts/{ENCODER_PARAMS_TO_LOAD}'
DECODER_CKPT_PATH = f'./caption-model-ckpts/{DECODER_PARAMS_TO_LOAD}'

# Load the models for inference

In [None]:
# Load the meme background generator

generator_state_dict = torch.load(
    GAN_CKPT_PATH,
    map_location=torch.device(device)
)['generator']

generator = DCGANGenerator(
    encoding_dims=100,
    out_size=64,
    out_channels=3,
    step_channels=64,
    nonlinearity=nn.LeakyReLU(0.2),
    last_nonlinearity=nn.Tanh(),
).to(device)

generator.load_state_dict(
    state_dict=generator_state_dict
)

In [None]:
# Load the CNN Encoder and RNN Decoder
dataset = MemeCaptionDataset()

data_loader = iter(torch.utils.data.DataLoader(
    dataset,
    batch_size=caption_batch_size,
    shuffle=True, 
    num_workers=workers
))

vocab_size = len(dataset.itos)

encoder = EncoderCNN().to(device)
decoder = DecoderRNN(
    embed_size=1024, 
    hidden_size=1024, 
    vocab_size=vocab_size
).to(device)

encoder.load_state_dict(
    torch.load(ENCODER_CKPT_PATH, map_location=torch.device(device))
)

decoder.load_state_dict(
    torch.load(DECODER_CKPT_PATH, map_location=torch.device(device))
)
encoder.eval()
decoder.eval()

# Generate Some Memes!

In [None]:
from infer_full import create_meme

memes_to_generate = 10
memes = []
for _ in range(memes_to_generate):
    meme = create_meme(
        encoder=encoder,
        decoder=decoder,
        generator=generator,
        data_loader=data_loader,
        device=device,
        dataset=dataset
    )
    memes.append(np.array(meme))
    
stacked_memes_1 = np.concatenate(memes[:memes_to_generate//2], axis=1)
stacked_memes_2 = np.concatenate(memes[memes_to_generate//2:], axis=1)
memes = np.concatenate([stacked_memes_1, stacked_memes_2], axis=0)
plt.axis('off')
plt.imshow(np.asarray(memes), interpolation=None)
plt.show()