In [6]:
import numpy as np
import torch
from models import *
from fid import get_fid

from torch.utils.data import TensorDataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Resize, Normalize

In [2]:
G_models = ['mnist_dcgan/G_10.pth', 'mnist_bigan/G_10.pth', 'mnist_logan_b/G_10.pth']
E_models = ['mnist_dcgan/Epost_1.pth', 'mnist_bigan/Epost_1.pth', 'mnist_logan_b/E_10.pth']

In [3]:
num_images = 200

In [4]:
dataset = MNIST(root='.', download=True, transform=Compose([Resize(32), ToTensor(), Normalize((0.5,),(0.5,))]))
dataloader = DataLoader(dataset, batch_size=num_images, shuffle=True)

In [5]:
for (G_file, E_file) in zip(G_models, E_models):
    G = torch.load(G_file).to('cuda:1').eval()
    E = torch.load(E_file).to('cuda:1').eval()
    
    images1 = np.zeros((num_images, 3, 32, 32))
    images2 = np.zeros((num_images, 3, 32, 32))
    
    real = ((next(iter(dataloader))[0]+1)*127).squeeze().numpy()
    images1[:,0,:,:] = real.copy()
    images1[:,1,:,:] = real.copy()
    images1[:,2,:,:] = real.copy()
    images1 = images1.astype(np.uint8)
    
    real = next(iter(dataloader))[0]
    latents = E(real.to('cuda:1')).detach()
    recon = ((G(latents).cpu().detach()+1)*127).squeeze().numpy()
    images2[:,0,:,:] = recon.copy()
    images2[:,1,:,:] = recon.copy()
    images2[:,2,:,:] = recon.copy()
    images2 = images2.astype(np.uint8)
    
    print(G_file, get_fid(images1, images2))

Calculating FID with 200 images from each distribution
FID calculation time: 31.013670 s
mnist_dcgan/G_10.pth 71.681656
Calculating FID with 200 images from each distribution
FID calculation time: 30.933452 s
mnist_bigan/G_10.pth 431.01556
Calculating FID with 200 images from each distribution
FID calculation time: 30.719359 s
mnist_logan_b/G_10.pth 40.24572
