In [1]:
import numpy as np
import path, os
import glob
import torch
from torchmetrics import IS, KID
from torchvision import transforms as T
from pytorch_pretrained_gans import make_gan
import ast
import matplotlib.pyplot as plt
from PIL import Image

In [2]:
label_file = open("imagenet1000_clsidx_to_labels.txt", "r")
label_dict = ast.literal_eval(label_file.read())

In [3]:
# load pretrain GAN
G = make_gan(gan_type='biggan')  # -> nn.Module

Loading BigGAN model biggan-deep-256 from cache at /Users/pinkaew/.pytorch_pretrained_biggan/e7d036ee69a83e83503c46902552573b27d856eaf6b8b84252a63a715dce7501.aec5caf3e5c5252b8857d6bb8adefa8d1d6092a8ba6c9aaed1e6678f8a46be43


In [4]:
# generate images from BigGAN in class "goldfish"
torch.manual_seed(69)
np.random.seed(69)
batch_size = 50
y = torch.zeros(1000)
y[1] = 1
y = y.repeat(batch_size, 1)
z = G.sample_latent(batch_size=batch_size)  # -> torch.Size([1, 128])
x = G(z=z, y=y)  # -> torch.Size([1, 3, 256, 256])

In [5]:
def show_images(img_tensor, labels):
    for i in range(labels.shape[0]):
        img = img_tensor[i].squeeze().permute(1, 2, 0)
        plt.imshow(img.detach().numpy())
        plt.title("class: {}".format(label_dict[np.argmax(labels[i].detach().numpy())]))
        plt.show()

In [6]:
#show_images(x[:5], y[:5])

In [8]:
# rescale generated images
fake_rgb = 256 * (x + 1)/2
fake_rgb = fake_rgb.type(torch.ByteTensor)

In [9]:
# load real images
base_dir = os.getcwd()
img_dir = base_dir + "/ImageNet/goldfish"

trans = T.Compose([T.Resize((256,256)), T.ToTensor()])
reals = []
for filename in glob.glob(img_dir + "/*.jpg"):
    im = Image.open(filename)
    reals.append(trans(im))
reals = torch.stack(reals)

In [10]:
#show_images(reals[:5], y[:5])

In [11]:
# rescale real images
real_rgb = reals * 256
real_rgb = reals.type(torch.ByteTensor)

In [12]:
# compute IS score
is_ = IS()
is_.reset()
is_.update(fake_rgb)
fake_is_mean, fake_is_std = is_.compute()
is_.reset()
is_.update(real_rgb[:batch_size])
real_is_mean, real_is_std = is_.compute()

print("Fake IS score: {:.5} +- {:.5}".format(fake_is_mean, fake_is_std))
print("Real IS score: {:.5} +- {:.5}".format(real_is_mean, real_is_std))



Fake IS score: 1.0078 +- 0.0042331
Real IS score: 1.0247 +- 0.017933


In [13]:
# compute KID score
kid = KID(subset_size=10)
kid.reset()
kid.update(real_rgb, real=True)
kid.update(fake_rgb, real=False)
kid_mean, kid_std = kid.compute()
print("KID score: {:.5} +- {:.5}".format(kid_mean, kid_std))



KID score: 0.75889 +- 0.027454
