In [1]:
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os
import time
import argparse
import numpy as np
import torch as t
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as tr
from torch.utils.data import DataLoader, Dataset
import numpy as np

import torch
from torch.utils.data import TensorDataset

from utils import get_train_test

os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [3]:
def norm_ip(img, min, max):
    temp = t.clamp(img, min=min, max=max)
    temp = (temp + -min) / (max - min + 1e-5)
    return temp

In [4]:
device = t.device('cuda' if t.cuda.is_available() else 'cpu')


In [5]:
args = argparse.Namespace()
args.dataset = 'cifar10'
args.data_root = '../../data'
args.seed = 1
data_loader, test_loader = get_train_test(args)


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [6]:
from data import Cifar10
test_dataset = Cifar10(args, full=True, noise=False)

test_dataloader = DataLoader(test_dataset, batch_size=100, num_workers=0, shuffle=True, drop_last=False)

test_ims = []

def rescale_im(im):
    return np.clip(im * 256, 0, 255).astype(np.uint8)

for data_corrupt, data, label_gt in test_dataloader:
    if args.dataset in ['celeba128', 'img32']:
        data = data_corrupt.numpy().transpose(0, 2, 3, 1)
    else:
        data = data.numpy()
    test_ims.extend(list(rescale_im(data)))
    if (args.dataset == "imagenet" or 'img' in args.dataset) and len(test_ims) > 60000:
        test_ims = test_ims[:60000]
        break
        
print(np.array(test_ims).shape)

Files already downloaded and verified
Files already downloaded and verified
(60000, 32, 32, 3)


In [7]:
test_imgs = np.array(test_ims).transpose(0, 3, 1, 2)

In [8]:
train_images = []
for data, y, _  in data_loader:
    data = norm_ip(data, -1, 1).numpy()
    train_images.extend(list(rescale_im(data)))
    if (args.dataset == "imagenet" or 'img' in args.dataset) and len(train_images) > 60000:
        train_images = train_images[:60000]
        break
        
train_images = np.array(train_images)

In [9]:
test_images = []
for data, y, _ in test_loader:
    data = norm_ip(data, -1, 1).numpy()
    test_images.extend(list(rescale_im(data)))
    if (args.dataset == "imagenet" or 'img' in args.dataset) and len(test_images) > 60000:
        test_images = test_images[:60000]
        break
test_images = np.array(test_images)

In [10]:
test_imgs.shape

(60000, 3, 32, 32)

In [11]:
print(train_images.shape, test_images.shape)

(50000, 3, 32, 32) (10000, 3, 32, 32)


# call the function of TF

In [12]:
from fid_score import *

# benchmarks

cifar10 train set, test set,  train_set[:10000]

In [13]:
start = time.time()
fid = calculate_fid_score(test_images, test_imgs)
print("FID of score {} takes {}s".format(fid, time.time() - start))

# tensorflow version's results:
# Obtained fid value of 2.203647314570105
# FID of score 2.203647314570105 takes 166.0104057788849s

FID of score 5.124681832706301 takes 118.20916604995728s


In [18]:
start = time.time()
fid = calculate_fid_score(train_images, test_imgs)
print("FID of score {} takes {}s".format(fid, time.time() - start))

# tensorflow version's results:
# Obtained fid value of 0.08768652773858321
# FID of score 0.08768652773858321 takes 254.37319445610046s

FID of score 0.20264188262535754 takes 168.97182512283325s


In [14]:
start = time.time()
fid = calculate_fid_score(train_images[:10000], test_imgs)
print("FID of score {} takes {}s".format(fid, time.time() - start))

# tensorflow version's results:
# Obtained fid value of 2.192182121402425
# FID of score 2.192182121402425 takes 167.54650378227234s

FID of score 4.858539173947975 takes 113.88631391525269s


## A generated dataset

In [15]:
feed_imgs = np.load('feed_imgs.npy')
feed_imgs2 = feed_imgs.transpose(0, 3, 1, 2)

In [17]:
calculate_fid_score(feed_imgs2, test_imgs)

# tensorflow's result
# Obtained fid value of 16.04265464806963
# FID of score 16.04265464806963 takes 168.163743019104s

9.871485188506995