In [2]:
pip install torch torchvision pillow numpy scipy torchmetrics[image] torch-fidelity

Collecting torch-fidelity
  Using cached torch_fidelity-0.3.0-py3-none-any.whl.metadata (2.0 kB)
Collecting torchmetrics[image]
  Using cached torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Using cached nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Using cached nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3

In [None]:
from PIL import Image
import os
import numpy as np
import torch
from torchvision.transforms import functional as F
import matplotlib.pyplot as plt

# below codes will compute the fid score according to the two image sets given
real_images_folder = ""
generated_images_folder = ""
real_images_paths = sorted([os.path.join(real_images_folder, x) for x in os.listdir(real_images_folder)])
generated_images_paths = sorted([os.path.join(generated_images_folder, x) for x in os.listdir(generated_images_folder)])

real_images = [np.array(Image.open(path).convert("RGB")) for path in real_images_paths]
generated_images = [np.array(Image.open(path).convert("RGB")) for path in generated_images_paths]

def preprocess_image(image):
    image = torch.tensor(image).unsqueeze(0)
    image = image.permute(0, 3, 1, 2) / 255.0
    
    # you can also use crop or change resolution
    return F.resize(image, (512, 512))

real_images_preprocessed = torch.cat([preprocess_image(image) for image in real_images])
generated_images_preprocessed = torch.cat([preprocess_image(image) for image in generated_images])

In [10]:
from torchmetrics.image.fid import FrechetInceptionDistance

fid = FrechetInceptionDistance(normalize=True)
fid.update(real_images_preprocessed, real=True)
fid.update(generated_images_preprocessed, real=False)

print(f"FID: {float(fid.compute())}")

FID: 183.2335662841797
