In [2]:
import torch
from torchvision.models import inception_v3
import torchvision.transforms as transforms
import numpy as np
from scipy.linalg import sqrtm
import glob
from PIL import Image

def calculate_fid(model, images1, images2):
    # Function to calculate Frechet Inception Distance (FID) score
    def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
        covmean = sqrtm(sigma1.dot(sigma2), disp=False)[0]
        return np.square(mu1 - mu2).sum() + np.trace(sigma1 + sigma2 - 2 * covmean)

    # Preprocess images
    preprocess = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    images1 = torch.stack([preprocess(img) for img in images1])
    images2 = torch.stack([preprocess(img) for img in images2])

    # Get predictions
    with torch.no_grad():
        pred1 = model(images1)
        pred2 = model(images2)

    # Calculate mean and covariance
    mu1, sigma1 = pred1.mean(0).numpy(), np.cov(pred1.numpy(), rowvar=False)
    mu2, sigma2 = pred2.mean(0).numpy(), np.cov(pred2.numpy(), rowvar=False)

    # Calculate FID
    fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)

    return fid

def load_images(folder_path, file_ext):
    images = []
    for filename in glob.glob(f'{folder_path}/*.{file_ext}'):  # Adjust file extension as needed
        with open(filename, 'rb') as f:
            image = Image.open(f).convert('RGB')
            images.append(image)
    return images[:-1]

# Load InceptionV3 model
inception_model = inception_v3(pretrained=True).eval()

# Load your datasets
real_images = load_images('./test', 'jpg')  # Update with the path to real images
fake_images = load_images('./model_outptus/diffusion', 'jpeg')  # Update with the path to generated images

# Calculate FID Score
fid_score_diff = calculate_fid(inception_model, real_images, fake_images)
print(f'FID score for diffusion: {fid_score_diff}')

# Gans
fake_images = load_images('./model_outptus/gan', 'png')  # Update with the path to generated images

# Calculate FID Score
fid_score_gan = calculate_fid(inception_model, real_images, fake_images)
print(f'FID score for Gan: {fid_score_gan}')


FID score for diffusion: 52
FID score for Gan: 97
