# AttnGAN evaluation notebook

- quantitative analysis
    - inception score
    - frechet inception distance
    - R-precision
- Interactive image generation from caption
    - from caption in the dataset
    - from caption of your own

## Prerequisites

- Models
    - Generator
    - Text Encoder
    - (Image Encoder)
- Dataset
    - COCO
- Calculation Resource
    - 1 GPU

# Common Procedures

In [None]:
# imports
import torch
import torchvision.transforms as transforms
import numpy as np
import os
import sys
import random
import warnings
warnings.filterwarnings('ignore')

%matplotlib inline
from PIL import Image
import matplotlib.pyplot as plt

sys.path.append(os.pardir)

In [None]:
# device setup
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
assert torch.cuda.device_count() == 1

In [None]:
# config
data_dir = '../../data/COCO'
exp_dir = 'results/AttnGAN/COCO/2019_05_14_17_08'
G_epoch = 50
batch_size = 100
test_num = batch_size * 30

# load config from txt
cfg_list_str = ['image_encoder_path', 'text_encoder_path']
cfg_list_int = ['words_num', 't_dim', 'z_dim', 'c_dim', 'ngf', 'branch_num', 'base_size']
cfg_list_float = []
cfg = {}

print('experimental setting')
with open(os.path.join(exp_dir, 'config.txt'), 'r') as f:
    for line in f:
        print(line[:-1])
        split_line = line.split(' ')
        key = split_line[0][:-1]
        if key in cfg_list_str:
            cfg[key] = split_line[1][:-1]
        if key in cfg_list_int:
            cfg[key] = int(split_line[1])
        if key in cfg_list_float:
            cfg[key] = float(split_line[1])

assert len(cfg.keys()) == len(cfg_list_str) + len(cfg_list_int) + len(cfg_list_float)

In [None]:
# data preparation
from datasets import TextDataset
from datasets import prepare_data

imsize = 299
image_transform = transforms.Compose([
    transforms.Resize(int(imsize * 76 / 64)),
    transforms.RandomCrop(imsize),
    transforms.RandomHorizontalFlip()])
dataset = TextDataset(data_dir, 'val2014', base_size=imsize, branch_num=1, words_num=cfg['words_num'], transform=image_transform)
# dataset = TextDataset(data_dir, 'train2014', base_size=imsize, branch_num=1, words_num=cfg['words_num'], transform=image_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, drop_last=True, shuffle=False, num_workers=4)

In [None]:
# model preparation
from models import CNNEncoder
from models import RNNEncoder
from models import AttentionalGNet
    
image_encoder = CNNEncoder(cfg['t_dim'], download=False).to(device)
image_encoder.load_state_dict(torch.load(cfg['image_encoder_path']))
image_encoder.eval()

text_encoder = RNNEncoder(dataset.n_words, cfg['words_num'], nhidden=cfg['t_dim']).to(device)
text_encoder.load_state_dict(torch.load(cfg['text_encoder_path']))
text_encoder.eval()

G = AttentionalGNet(cfg['z_dim'], cfg['t_dim'], cfg['c_dim'], cfg['ngf'], device, cfg['branch_num']).to(device)
G.load_state_dict(torch.load(os.path.join(exp_dir, 'model', 'G_epoch%d.pth' % (G_epoch))))
G.eval()
print('model load complete')

In [None]:
# create images for scores
gen_imsize = cfg['base_size'] * (2 ** (cfg['branch_num'] - 1))

noise = torch.FloatTensor(batch_size, cfg['z_dim']).to(device)
caps = np.empty((test_num, cfg['words_num']))
real_imgs = np.empty((test_num, 3, imsize, imsize))
imgs = np.empty((test_num, 3, gen_imsize, gen_imsize))
start, end = 0, 0
with torch.no_grad():
    for i, data in enumerate(dataloader):
        if i >= test_num // batch_size:
            break
        start = end
        end = start + batch_size

        noise.data.normal_()
        real_images, captions, cap_lens, _, _ = prepare_data(data, device)
        real_imgs[start:end] = real_images[-1].detach().cpu().numpy()
        caps[start:end] = captions.detach().cpu().numpy()
        words_embs, sent_emb = text_encoder(captions, cap_lens)
        mask = (captions == 0)
        num_words = words_embs.size(2)
        if mask.size(1) > num_words:
            mask = mask[:, :num_words]
        fake_imgs, _, _, _ = G(noise, sent_emb.detach(), words_embs.detach(), mask)
        
        imgs[start:end] = fake_imgs[-1].detach().cpu().numpy()

# inception score

In [None]:
from metrics.inception_score.inception_score import inception_score

inception_value, _ = inception_score(imgs, device)
print('inception score: %.3f' % (inception_value))

# dset_inception_value, _ = inception_score(real_imgs, device)
# print('inception score of the dataset: %.3f' % (dset_inception_value))

# frechet inception distance

In [None]:
from metrics.fid.fid_score import fid

fid_value = fid(device).calculate_score(real_imgs, imgs)
print('fid score: %.3f' % (fid_value))

# R-precision

In [None]:
from metrics.R_precision.r_precision import r_precision

r_precision_value = r_precision(dataset, G, cfg['z_dim'], image_encoder, text_encoder, device, 100)
print('R-precision: %.3f' % (r_precision_value))

# save eval result to file

In [None]:
with open(os.path.join(exp_dir, 'eval_result%d.txt' % (G_epoch)), 'w') as f:
    print('inception_score: %.3f' % (inception_value))
    f.write('inception_score: %.3f\n' % (inception_value))
    print('fid: %.3f' % (fid_value))
    f.write('fid: %.3f\n' % (fid_value))
    print('R-precision: %.3f\n' % (r_precision_value))
    f.write('R-precision: %.3f\n' % (r_precision_value))

# Image Generation

In [None]:
import utils
from nltk.tokenize import RegexpTokenizer
unloader = transforms.ToPILImage()  # reconvert into PIL image

def show_tensor(tensor):
    image = tensor.add_(1).div_(2).cpu().clone()  # we clone the tensor to not do changes on it
    image = image.squeeze(0)      # remove the fake batch dimension
    image = unloader(image)
    plt.imshow(image)
    
def show_np_arr(arr):
    plt.imshow(arr)
    
def caption_convert(caption):
    cap = caption.replace("\ufffd\ufffd", " ")
    tokenizer = RegexpTokenizer(r'\w+')
    tokens = tokenizer.tokenize(cap.lower())
    caption_new = []
    for t in tokens:
        t = t.encode('ascii', 'ignore').decode('ascii')
        if len(t) > 0:
            caption_new.append(dataset.wordtoix[t])
    cap_len = len(caption_new)
    assert cap_len <= cfg['words_num']
    return np.asarray(caption_new).astype('int64'), cap_len

# Image Generation with Caption in the dataset

In [None]:
data_iter = iter(dataloader)
data = next(data_iter)
_, captions, cap_lens, _, _ = prepare_data(data, device)
noise = torch.FloatTensor(batch_size, cfg['z_dim']).normal_().to(device)
with torch.no_grad():
    words_embs, sent_emb = text_encoder(captions, cap_lens)
    words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
    mask = (captions == 0)
    num_words = words_embs.size(2)
    if mask.size(1) > num_words:
        mask = mask[:, :num_words]

    fake_imgs, attention_maps, _, _ = G(noise, sent_emb, words_embs, mask)
    img = fake_imgs[len(attention_maps)].detach().cpu()
    lr_img = fake_imgs[len(attention_maps) - 1].detach().cpu()
    attn_maps = attention_maps[-1]
    att_sze = attn_maps.size(2)
    
img_set, _ = utils.build_super_images(img, captions, dataset.ixtoword, attn_maps, att_sze, batch_size, cfg['words_num'], lr_imgs=lr_img)

plt.figure(figsize=(10,10), dpi=500)
show_np_arr(img_set)

# Interactive Image Generation with your Caption

In [None]:
caption = 'a cat is lying on the desk'

In [None]:
cap, cap_len = caption_convert(caption)
cap = torch.from_numpy(cap).unsqueeze(0).to(device)
cap_len = torch.tensor(cap_len).unsqueeze(0).to(device)
noise = torch.FloatTensor(1, cfg['z_dim']).data.normal_().to(device)

with torch.no_grad():
    words_embs, sent_emb = text_encoder(cap, cap_len)
    words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
    mask = (cap == 0)
    num_words = words_embs.size(2)
    if mask.size(1) > num_words:
        mask = mask[:, :num_words]

    fake_imgs, attention_maps, _, _ = G(noise, sent_emb, words_embs, mask)
    assert len(attention_maps) >= 1
    img = fake_imgs[len(attention_maps)].detach().cpu()
    lr_img = fake_imgs[len(attention_maps) - 1].detach().cpu()
    attn_maps = attention_maps[-1]
    att_sze = attn_maps.size(2)
    
img_set, _ = utils.build_super_images(img, cap, dataset.ixtoword, attn_maps, att_sze,
                                      batch_size, cfg['words_num'], lr_imgs=lr_img, nvis=1)

print('given caption:', caption)
plt.figure()
show_tensor(fake_imgs[-1])

plt.figure(figsize=(10,10),dpi=500)
show_np_arr(img_set)