# Calculate Fréchet Inception Distance (FID) Score

In [2]:
import tensorflow as tf
from keras import backend as K
import pathlib
import numpy as np
import pickle

#from numpy import cov
#from numpy import trace
#from numpy import iscomplexobj
from scipy.linalg import sqrtm

### FID

The FID score was introduced by Heusel et al. (2017, see: https://arxiv.org/abs/1706.08500 ) in order to improve on the currently established inception score (IS) for the evaluation of image generation DL methods, specifically when applying GAN architectures. FID - in comparison to IS - has the ability to evaluate the quality of generated images by comparing a statistical distribution of a latent representation feature vector based on the InceptionV3 model with the same statistical distribution of the original images.

FID achieves this by using the Fréchet distance between the two distributions representing the real and generated images as follows:

<div align="center">
    <img src="FID.PNG"></img><br>
    <i>source: Heusel et al, 2017, p.11 adapted by Hui, 2018 (see: https://jonathan-hui.medium.com/gan-how-to-measure-gan-performance-64b988c47732)</i>
</div>
<br>
<br>

where:<br>
<pre>
    x = real images<br>
    g = generated images<br>
    µx = mean of the multivariate Gaussian distribution representing the latent vector of real images<br>
    µg = mean of the multivariate Gaussian distribution representing the latent vector of generated images<br>
    Σx = covariance matrix for real images<br>
    Σg = covariance matrix for generated images<br>
    Tr() = the trace() function defined as the sum of diagonal elements of a square matrix<br>
</pre>

Essentially, calculating a FID score for two sets of images requires the following steps:

1. Load <i>Inception V3</i> model
2. Modify <i>Inception V3</i> so that we discard the output layer (classification into image categories), only keeping a max pooling layer representing the latent space vector.
3. Calculate the <i>mean (µ)</i> and <i>variance (Σ)</i> based on the latent space vector of the real (training) images
4. Generate images (here using a GAN)
5. Calculate the <i>mean (µ)</i> and <i>variance (Σ)</i> based on the latent space vector of the generated ("fake") images
6. Calculate FID based on Fréchet Distance between the two statistical distributions

The process for calculating FID applied to this use case is shown in the graphic below (own depiction):

<div align="center">
<img src="FID_process.png"></img>
<i>FID Computational Graph. (Source: own depiction)</i>
</div>
<br>

NOTE: due to out-of-date code provided by the paper's authors, FID score calculation has been implemented from the ground up roughly following the following implementation: https://machinelearningmastery.com/how-to-implement-the-frechet-inception-distance-fid-from-scratch/


### Load Inception V3 and Modify

In [8]:
tf.keras.backend.clear_session()
model = tf.keras.applications.InceptionV3(include_top=False, pooling='avg', input_shape=(299,299,3)) #Inception V3 expects 299 x 299 input

ResourceExhaustedError: OOM when allocating tensor with shape[3,3,80,192] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:RandomUniform]

In [None]:
model.summary()

### Prepare Dataset for Real Images

In [10]:
# load training dataset (full)
IMAGE_SIZE = (299, 299) # here we specify the expected input size of Inception V3 to let image_dataset_from_directory() automatically resize the images
BATCH_SIZE = 64

data_dir = pathlib.Path('/data/input/crops')
imgs = list(data_dir.glob('*.png'))

#we split out a subset of about 10 percent of the samples as scoring 700'000+ images with FID will take a very long time
dataset_train = tf.keras.preprocessing.image_dataset_from_directory(  '/data/input/',
                                                                      image_size=IMAGE_SIZE, 
                                                                      batch_size=BATCH_SIZE, 
                                                                      labels=[1.] * len(imgs), # setting all labels to 1 (for 'real')
                                                                      #label_mode=None, # yields float32 type labels
                                                                      seed=42,
                                                                      validation_split=0.1,
                                                                      subset='validation'
                                                                    )

#NOTE: we do not need to worry about the dtype of our image data as the above will cast it to float32, as expected by Inception V3

Found 726300 files belonging to 1 classes.
Using 72630 files for validation.


### Calculate Distribution Parameters for Real Images (Baseline)

As scoring around 70k (10% of total samples in the validation set above) of the available samples through Inception V3 takes a very long time, the resulting activations are used to calculate the required statistical parameters for FID (mean and covariance matrix) only once and are then persisted (pickled) for further use.

In [14]:
activations = model.predict(dataset_train, batch_size=BATCH_SIZE)
mux, sigmax = activations.mean(axis=0), cov(activations, rowvar=False)
print(mux, sigmax)

[1.0269691  1.5217992  2.3997712  ... 0.48547718 0.23129249 0.59598255] [[ 2.45910410e+00  7.15231822e-02  2.31537616e-01 ... -1.07793145e-01
   2.55754755e-01  7.01196827e-02]
 [ 7.15231822e-02  3.87113008e+00  1.51734897e+00 ...  9.13199350e-01
  -7.01373881e-02  6.90343250e-02]
 [ 2.31537616e-01  1.51734897e+00  4.54826168e+00 ...  3.98259162e-01
  -3.85136469e-02  1.94263442e-01]
 ...
 [-1.07793145e-01  9.13199350e-01  3.98259162e-01 ...  9.08558728e-01
  -3.81071381e-02  3.56840394e-03]
 [ 2.55754755e-01 -7.01373881e-02 -3.85136469e-02 ... -3.81071381e-02
   3.05500014e-01  2.00811970e-02]
 [ 7.01196827e-02  6.90343250e-02  1.94263442e-01 ...  3.56840394e-03
   2.00811970e-02  1.98472330e+00]]


The result of the code above is a 2048 dimensional vector with all means for each latent space dimension (<i>mux</i>) and a 2048 x 2048 dimensional matriax holding the covariances for each dimension pair of the latent space (<i>sigmax</i>). We will serialize these values into a pickle object so we that this time consuming calculation does not have to be repeated.

In [19]:
print(mux.shape)
print(sigmax.shape)

# we serialize these reference values into pickle objects
with open('fid_reference_values', 'wb') as f:
    pickle.dump(mux, f)
    pickle.dump(sigmax, f)

(2048,)
(2048, 2048)


### Define FID function

In [1]:
# calculate frechet inception distance
def calculate_fid(model, images, reference='fid_reference_values'):    # calculate activations for images to compare to established baseline
    act = model.predict(images)
    # calculate mean and covariance statistics
    with open(reference, 'rb') as f:
        mu1 = pickle.load(f)
        sigma1 = pickle.load(f)
        mu2, sigma2 = act.mean(axis=0), np.cov(act, rowvar=False)
        # calculate sum squared difference between means
        ssdiff = np.sum((mu1 - mu2)**2.0)
        # calculate sqrt of product between cov
        covmean = sqrtm(sigma1.dot(sigma2))
        # check and correct imaginary numbers from sqrt
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        # calculate score
        fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    
    return fid

### Compare random generated images (DCGAN) against baseline

We load a random subset of generated images from a previous run into a tensor and run that through the calculate_fid() function

In [20]:
# load generated images
IMAGE_SIZE = (299, 299) # here we specify the expected input size of Inception V3 to let image_dataset_from_directory() automatically resize the images
BATCH_SIZE = 64

data_dir = pathlib.Path('/data/output/images/dwarfgan001')
imgs = list(data_dir.glob('*.png'))

check = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,
                                                                      image_size=IMAGE_SIZE, 
                                                                      batch_size=BATCH_SIZE, 
                                                                      #labels=[0.] * len(imgs), # setting all labels to 0 (for 'fake'), not relevant here
                                                                      #label_mode=None, # yields float32 type labels
                                                                      seed=42,
                                                                      validation_split=0.99, #only 20 images available but split has to be < 1 
                                                                      subset='validation'
                                                                    )

Found 220 files belonging to 2 classes.
Using 217 files for validation.


In [30]:
result = calculate_fid(model, check)
print(f'We see that the result is quite large with a FID score of: {round(result,2)}. A perfect imitation would score a FID score close to 0.')

We see that the result is quite large with a FID score of: 272565.03. A perfect imitation would score a FID score close to 0.


### Calculate FID for Real Images

In [6]:
# load real images
IMAGE_SIZE = (299, 299) # here we specify the expected input size of Inception V3 to let image_dataset_from_directory() automatically resize the images
BATCH_SIZE = 32

data_dir = pathlib.Path('/data/input')
imgs = list(data_dir.glob('*.png'))

check = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,
                                                              image_size=IMAGE_SIZE, 
                                                              batch_size=BATCH_SIZE, 
                                                              #labels=[0.] * len(imgs), # setting all labels to 0 (for 'fake'), not relevant here
                                                              #label_mode=None, # yields float32 type labels
                                                              seed=42,
                                                              validation_split=0.025, #only 2.5% of 700'000 images as reference 
                                                              subset='validation'
                                                            )

Found 726300 files belonging to 1 classes.
Using 18157 files for validation.


In [9]:
result = calculate_fid(model, check)

In [10]:
print(f'Here we see a much lower FID score of: {round(result,2)}. Due to the variety of pictures, a score of 0 is unlikely.')

Here we see a much lower FID score of: 45.25. Due to the variety of pictures, a score of 0 is unlikely.


### Compare WGAN-GP RUN02 Images to Baseline

In [4]:
# load real images
IMAGE_SIZE = (299, 299) # here we specify the expected input size of Inception V3 to let image_dataset_from_directory() automatically resize the images
BATCH_SIZE = 32

data_dir = pathlib.Path('/data/output/images/dwarfganWGANGPR02/')
imgs = list(data_dir.glob('*.png'))

check = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,
                                                              image_size=IMAGE_SIZE, 
                                                              batch_size=BATCH_SIZE, 
                                                              #labels=[0.] * len(imgs), # setting all labels to 0 (for 'fake'), not relevant here
                                                              #label_mode=None, # yields float32 type labels
                                                              seed=42,
                                                              #validation_split=0.025, #only 2.5% of 700'000 images as reference 
                                                              #subset='validation'
                                                            )

Found 120 files belonging to 1 classes.


In [5]:
result = calculate_fid(model, check)

NameError: name 'model' is not defined

In [None]:
print(f'Here we see a much lower FID score of: {round(result,2)}. Due to the variety of pictures, a score of 0 is unlikely.')