In [1]:
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os
import time
import argparse
import numpy as np
import torch as t
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as tr
from torch.utils.data import DataLoader, Dataset
import numpy as np

from utils import get_train_test

os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [3]:
def norm_ip(img, min, max):
    temp = t.clamp(img, min=min, max=max)
    temp = (temp + -min) / (max - min + 1e-5)
    return temp

In [4]:
device = t.device('cuda' if t.cuda.is_available() else 'cpu')


In [5]:
args = argparse.Namespace()
args.dataset = 'cifar10'
args.data_root = '../../data'
args.seed = 1
data_loader, test_loader = get_train_test(args)


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [9]:
from data import Cifar10
test_dataset = Cifar10(args, full=True, noise=False)

test_dataloader = DataLoader(test_dataset, batch_size=100, num_workers=0, shuffle=True, drop_last=False)

test_ims = []

def rescale_im(im):
    return np.clip(im * 256, 0, 255).astype(np.uint8)

for data_corrupt, data, label_gt in test_dataloader:
    if args.dataset in ['celeba128', 'img32']:
        data = data_corrupt.numpy().transpose(0, 2, 3, 1)
    else:
        data = data.numpy()
    test_ims.extend(list(rescale_im(data)))
    if (args.dataset == "imagenet" or 'img' in args.dataset) and len(test_ims) > 60000:
        test_ims = test_ims[:60000]
        break
        
print(np.array(test_ims).shape)

Files already downloaded and verified
Files already downloaded and verified
(60000, 32, 32, 3)


In [10]:
train_images = []
for data, y, _  in data_loader:
    data = norm_ip(data, -1, 1).numpy().transpose(0, 2, 3, 1)
    train_images.extend(list(rescale_im(data)))
    if (args.dataset == "imagenet" or 'img' in args.dataset) and len(train_images) > 60000:
        train_images = train_images[:60000]
        break

In [11]:
test_images = []
for data, y, _ in test_loader:
    data = norm_ip(data, -1, 1).numpy().transpose(0, 2, 3, 1)
    test_images.extend(list(rescale_im(data)))
    if (args.dataset == "imagenet" or 'img' in args.dataset) and len(test_images) > 60000:
        test_images = test_images[:60000]
        break

In [12]:
np.array(train_images).shape, np.array(test_images).shape

((50000, 32, 32, 3), (10000, 32, 32, 3))

In [13]:
np.array(train_images).max(), np.array(train_images).min()

(255, 0)

# call the function of TF

In [14]:
from inception import get_inception_score
from fid import get_fid_score



Instructions for updating:
Use tf.gfile.GFile.



# benchmarks

cifar10 train set, test set,  train_set[:10000]

In [79]:
start = time.time()
fid = get_fid_score(test_images, test_ims)
print("FID of score {} takes {}s".format(fid, time.time() - start))

Obtained fid value of 2.203647314570105
FID of score 2.203647314570105 takes 166.0104057788849s


In [82]:
start = time.time()
fid = get_fid_score(train_images, test_ims)
print("FID of score {} takes {}s".format(fid, time.time() - start))

Obtained fid value of 0.08768652773858321
FID of score 0.08768652773858321 takes 254.37319445610046s


In [81]:
start = time.time()
fid = get_fid_score(np.array(train_images)[:10000], test_ims)
print("FID of score {} takes {}s".format(fid, time.time() - start))

Obtained fid value of 2.192182121402425
FID of score 2.192182121402425 takes 167.54650378227234s


In [80]:
splits = max(1, len(test_images) // 5000)
start = time.time()
score, std = get_inception_score(test_images, splits=splits)
print("Inception score of {} with std of {} takes {}s".format(score, std, time.time() - start))

iteration 0, total 10000
iteration 3333, total 10000
iteration 6666, total 10000
iteration 9999, total 10000
Inception score of 11.23015022277832 with std of 0.1562943458557129 takes 148.98509669303894s


## A generated dataset

In [15]:
feed_imgs = np.load('feed_imgs.npy')

In [16]:
feed_imgs.shape

(10000, 32, 32, 3)

In [17]:
feed_imgs.max(), feed_imgs.min()

(254.99872, 0.0)

In [18]:
start = time.time()
fid = get_fid_score(feed_imgs, test_ims)
print("FID of score {} takes {}s".format(fid, time.time() - start))

Obtained fid value of 16.04265464806963
FID of score 16.04265464806963 takes 168.163743019104s


In [19]:
splits = max(1, len(feed_imgs) // 5000)
start = time.time()
score, std = get_inception_score(feed_imgs, splits=splits)
print("Inception score of {} with std of {} takes {}s".format(score, std, time.time() - start))

iteration 0, total 10000
iteration 3333, total 10000
iteration 6666, total 10000
iteration 9999, total 10000
Inception score of 7.869165420532227 with std of 0.04828786849975586 takes 153.34843182563782s
