## Image Caption Generation Demo

By: dzryk (discord, https://twitter.com/dzryk, https://github.com/dzryk)

This notebook provides an image captioning demo that goes along with the antarctic-captions repository (https://github.com/dzryk/antarctic-captions)

In [None]:
!nvidia-smi

In [None]:
!git clone https://github.com/dzryk/antarctic-captions.git
%cd antarctic-captions/
!git clone https://github.com/openai/CLIP

In [None]:
!pip3 install gdown
!pip3 install ftfy
!pip3 install transformers
!pip3 install git+https://github.com/PyTorchLightning/pytorch-lightning

In [None]:
# Download models and cache
!wget -m -np -c -U "eye02" -w 2 -R "index.html*" "https://the-eye.eu/public/AI/models/antarctic-captions/"

In [6]:
import argparse
import io
import numpy as np
import torch
import torch.nn as nn
import requests
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F

from CLIP import clip
from PIL import Image
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision.utils import make_grid

import model
import utils

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

In [8]:
# Helper functions
def fetch(url_or_path):
    if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
        r = requests.get(url_or_path)
        r.raise_for_status()
        fd = io.BytesIO()
        fd.write(r.content)
        fd.seek(0)
        return fd
    return open(url_or_path, 'rb')

def load_image(img, preprocess):
    img = Image.open(fetch(img))
    return img, preprocess(img).unsqueeze(0).to(device)

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

def display_grid(imgs):
    reshaped = [F.to_tensor(x.resize((256, 256))) for x in imgs]
    show(make_grid(reshaped))
    
def clip_rescoring(args, net, candidates, x):
    textemb = net.perceiver.encode_text(
        clip.tokenize(candidates).to(args.device)).float()
    textemb /= textemb.norm(dim=-1, keepdim=True)
    similarity = (100.0 * x @ textemb.T).softmax(dim=-1)
    _, indices = similarity[0].topk(args.num_return_sequences)
    return [candidates[idx] for idx in indices[0]]

def loader(args):
    cache = []
    with open(args.textfile) as f:
        for line in f:
            cache.append(line.strip())
    cache_emb = np.load(args.embfile)
    net = utils.load_ckpt(args)
    net.cache = cache
    net.cache_emb = torch.tensor(cache_emb).to(args.device)
    preprocess = clip.load(args.clip_model, jit=False)[1]
    return net, preprocess
    
def caption_image(path, args, net, preprocess):
    captions = []
    img, mat = load_image(path, preprocess)
    table, x = utils.build_table(mat.to(device), 
                          perceiver=net.perceiver,
                          cache=net.cache,
                          cache_emb=net.cache_emb,
                          topk=args.topk,
                          return_images=True)
    table = net.tokenizer.encode(table[0], return_tensors='pt').to(device)
    out = net.model.generate(table,
                             do_sample=args.do_sample,
                             num_beams=args.num_beams,
                             temperature=args.temperature,
                             top_p=args.top_p,
                             num_return_sequences=args.num_return_sequences)
    candidates = []
    for seq in out:
        candidates.append(net.tokenizer.decode(seq, skip_special_tokens=True))
    captions = clip_rescoring(args, net, candidates, x[None,:])
    #for c in captions[:args.display]:
        #print(c)
    display_grid([img])
    return captions

In [9]:
# Settings
filedir='the-eye.eu/public/AI/models/antarctic-captions/'
args = argparse.Namespace(
    ckpt=f'{filedir}/-epoch=05-vloss=2.163.ckpt',
    textfile=f'{filedir}/postcache.txt',
    embfile=f'{filedir}/postcache.npy',
    clip_model='ViT-B/32',
    topk=10,
    num_return_sequences=1000,
    num_beams=1,
    temperature=1.0,
    top_p=1.0,
    display=1000,
    do_sample=True,
    device=device
)

In [10]:
# Load checkpoint and preprocessor
net, preprocess = loader(args)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from os import listdir

#Path to which to write the captions
WRITE_PATH = f'/content/drive/My Drive/sobem/Captions_Test/'

#Iterate over the 10 image subjects
for i in range(1,11):

  #Get path of all images for each subject
  target_path = f'/content/drive/My Drive/sobem/Photos/{i}/'
  imgs = listdir(target_path)

  #Caption each image of the subject
  for image in imgs:
    img = f'{target_path}{image}'
    captions = caption_image(img, args, net, preprocess)

    #Write the captions to file
    write_string = '\n'.join(captions)
    with open(f'{WRITE_PATH}{image[:-4]}.txt','w') as writer:
      writer.write(write_string)
    
    print(write_string)