In [2]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from datasets.captioning.dataset import *

from utils import load_checkpoint
from models.captioning import CaptionNet

import config

import torch
import torch.nn as nn
import torch.nn.functional as F

import random
import os

from models.compress import HyperpriorWrapper

%matplotlib inline

In [3]:
config.DEVICE

'cuda'

In [34]:
network = CaptionNet(n_tokens, pad_ix=pad_ix, cnn_feature_size=512 * 49,
                     cnn_in_channels=192, cnn_out_channels=512, pool=2).to(config.DEVICE)
optimizer = torch.optim.Adam(network.parameters(), lr=1e-4, weight_decay=1e-4)

In [37]:
load_checkpoint("checkpoints/captions_compressed_enc_512_2.pth.tar", network, optimizer, 1e-4)

=> Loading checkpoint


In [38]:
# from models.beheaded_inception3 import beheaded_inception_v3
# inception = beheaded_inception_v3().train(False)

In [39]:
compressor = HyperpriorWrapper(1, pretrained=True).eval().to(config.DEVICE)

In [40]:
def generate_caption(image, caption_prefix=("#START#",), t=1, sample=True, max_len=100):

    assert (
        isinstance(image, np.ndarray)
        and np.max(image) <= 1
        and np.min(image) >= 0
        and image.shape[-1] == 3
    )

    with torch.no_grad():
        image = torch.tensor(image.transpose([2, 0, 1]), dtype=torch.float32).to(
            config.DEVICE
        )

        # vectors_8x8, vectors_neck, logits = inception(image[None])
        compressed = compressor.compress(image[None])
        vectors_neck = compressor.entropy_decode(compressed["strings"], compressed["shape"])

        caption_prefix = list(caption_prefix)

        for _ in range(max_len):

            prefix_ix = as_matrix([caption_prefix])
            prefix_ix = torch.tensor(prefix_ix, dtype=torch.int64)
            next_word_logits = network.forward(
                vectors_neck.to(config.DEVICE), prefix_ix.to(config.DEVICE)
            )[0, -1]
            next_word_probs = (
                F.softmax(next_word_logits, dim=-1).data.detach().cpu().numpy()
            )
            next_word_probs = next_word_probs.clip(0)
            assert len(next_word_probs.shape) == 1, "probs must be one-dimensional"
            next_word_probs = next_word_probs**t / np.sum(
                next_word_probs**t
            )  # apply temperature

            if sample:
                next_word = np.random.choice(vocab, p=next_word_probs)
            else:
                next_word = vocab[np.argmax(next_word_probs)]

            caption_prefix.append(next_word)

            if next_word == "#END#":
                break

    return caption_prefix

In [None]:
# sample image
# path = 'datasets/image_captioning/flickr30k_images/flickr30k_images/'
path = "datasets/captioning/coco2017/"
image = random.choice(os.listdir(path))
img = plt.imread(os.path.join(path, image))
img = Image.fromarray(img).resize((256, 256))
img = np.array(img).astype("float32") / 255.0
plt.imshow(img)
for i in range(10):
    print(" ".join(generate_caption(img, t=5.0)[1:-1]))

a man with a tennis racket in a baseball game .
a man is standing on a bench in front of a building .
a man in a suit and a man holding a frisbee .
a man is sitting on a bench in a room .
a man is standing in the air with a skateboard .
a man riding a skateboard in the water .
a white cat sitting on a table with a laptop .
a bathroom with a toilet and a sink .
