In [1]:
# 10. FID Calculation 
# ---------------------------------
# Example 10: Calculating FID (Fréchet Inception Distance)
# ---------------------------------

import numpy as np
from scipy.linalg import sqrtm

def calculate_fid(mu1, sigma1, mu2, sigma2):
    """
    Calculate the Fréchet Inception Distance between two distributions.
    
    Parameters:
        mu1, mu2 : Mean vectors of the distributions
        sigma1, sigma2 : Covariance matrices of the distributions
    
    Returns:
        fid : FID score (float)
    """
    diff = mu1 - mu2
    covmean = sqrtm(sigma1.dot(sigma2))
    fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
    return np.real(fid)  # take real part in case of numerical imaginary part

# 1. Generate example mean vectors and covariance matrices
mu1, sigma1 = np.random.rand(3), np.eye(3)
mu2, sigma2 = np.random.rand(3), np.eye(3)

# 2. Calculate FID score
fid_score = calculate_fid(mu1, sigma1, mu2, sigma2)

# 3. Display results
print("--- FID Calculation Example ---")
print("Mean Vector 1:", mu1)
print("Mean Vector 2:", mu2)
print("FID Score:", fid_score)


--- FID Calculation Example ---
Mean Vector 1: [0.0867146  0.18502884 0.82932985]
Mean Vector 2: [0.31147775 0.28075904 0.29593517]
FID Score: 0.3441926325864379
