<a href="https://colab.research.google.com/github/singwang-cn/Neural-Network/blob/master/Simple_Implementation_of_FID.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [84]:
import numpy as np
from scipy.linalg import sqrtm
from keras.applications import inception_v3
import cv2

In [85]:
from keras.datasets import mnist
from keras.datasets import cifar10
from keras.datasets import cifar100

In [87]:
# scale an array of images to a new size
def scale_images(images):
  images_list = list()
  # extend images with 1 channel(e.g. mnist) to 3 channels
  if len(images.shape) < 4:
    images = np.stack((images,)*3, axis=-1)
  for image in images:
    # resize
    new_image = cv2.resize(image, dsize=(299, 299), interpolation = cv2.INTER_CUBIC)
    # store
    images_list.append(new_image)
  return np.asarray(images_list)

In [93]:
# calculate frechet inception distance
# input images must be in ndarray type
def calculate_fid(model, images1, images2):
  # convert integer to floating point values
  images1 = images1.astype('float32')
  images2 = images2.astype('float32')
  # resize images
  images1 = scale_images(images1)
  images2 = scale_images(images2)
  # pre-process images
  images1 = inception_v3.preprocess_input(images1)
  images2 = inception_v3.preprocess_input(images2)
	# calculate activations
  act1 = model.predict(images1)
  act2 = model.predict(images2)
	# calculate mean and covariance statistics
  mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
  mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
	# calculate sum squared difference between means
  ssdiff = np.sum((mu1 - mu2)**2.0)
	# calculate sqrt of product between cov
  covmean = sqrtm(sigma1.dot(sigma2))
	# check and correct imaginary numbers from sqrt
  if np.iscomplexobj(covmean):
    covmean = covmean.real
	# calculate score
  fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
  return fid

In [89]:
# prepare the inception v3 model
incv3 = inception_v3.InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3))

In [90]:
# mnist test
(images1, _), (images2, _) = mnist.load_data()
np.random.shuffle(images1)
images1 = images1[:1000]
np.random.shuffle(images2)
images2 = images2[:1000]
fid = calculate_fid(incv3, images1, images2)
print('FID: %.3f' % fid)

FID: 13.258


In [91]:
# cifar10 test
(images1, _), (images2, _) = cifar10.load_data()
np.random.shuffle(images1)
images1 = images1[:100]
np.random.shuffle(images2)
images2 = images2[:100]
fid = calculate_fid(incv3, images1, images2)
print('FID: %.3f' % fid)

FID: 167.271


In [92]:
# cifar100 test
(images1, _), (images2, _) = cifar100.load_data()
np.random.shuffle(images1)
images1 = images1[:100]
np.random.shuffle(images2)
images2 = images2[:100]
fid = calculate_fid(incv3, images1, images2)
print('FID: %.3f' % fid)

FID: 205.363
