# FID, KID and IMD calculation demo

In [1]:
!pip install piq

Collecting piq
  Downloading piq-0.8.0-py3-none-any.whl.metadata (17 kB)
Downloading piq-0.8.0-py3-none-any.whl (106 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.9/106.9 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: piq
Successfully installed piq-0.8.0


In [2]:
import random
import numpy as np
import torch
from torchvision.datasets import MNIST
from torchvision.datasets import CIFAR10
import torchvision.transforms as T
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader, Dataset
import piq



torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [10]:
from torchvision.datasets import MNIST
from torchvision.datasets import CIFAR10

transform_mnist = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor()
])

transform_cifar = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor()
])

mnist_dataset = MNIST(root='MNIST/raw/train-images-idx3-ubyte',train=False, download=True, transform=transform_mnist)
cifar_dataset = CIFAR10(root='cifar-10-batches-py', train=False, download=True, transform=transform_cifar)

Files already downloaded and verified


In [11]:
from collections import defaultdict

def collator_mnist(batch):
  batch_dict = defaultdict(list)
  for item in batch:
      batch_dict['images'].append(torch.concatenate((item[0], item[0], item[0]), axis=0).unsqueeze(0))
      batch_dict['targets'].append(torch.tensor(item[1]).unsqueeze(0))

  return {'images': torch.concatenate(batch_dict['images'], axis=0), 'targets': torch.concatenate(batch_dict['targets'], axis=0)}

def collator_cifar(batch):
  batch_dict = defaultdict(list)
  for item in batch:
      batch_dict['images'].append((item[0]).unsqueeze(0))
      batch_dict['targets'].append(torch.tensor(item[1]).unsqueeze(0))

  return {'images': torch.concatenate(batch_dict['images'], axis=0), 'targets': torch.concatenate(batch_dict['targets'], axis=0)}

mnist_dataloader = DataLoader(mnist_dataset, shuffle=False, batch_size=128, collate_fn=collator_mnist)
cifar_dataloader = DataLoader(cifar_dataset, shuffle=False, batch_size=128, collate_fn=collator_cifar)

In [12]:
@torch.no_grad()
def demo(x_features, y_features):

    if torch.cuda.is_available():
        # Move to GPU to make computaions faster
        x_features = x_features.cuda()
        y_features = y_features.cuda()

    # Use FID class to compute FID score from image features, pre-extracted from some feature extractor network
    fid: torch.Tensor = piq.FID()(x_features, y_features)
    print(f"FID: {fid:0.4f}")

    # Use KID class to compute KID score from image features, pre-extracted from some feature extractor network:
    kid: torch.Tensor = piq.KID()(x_features, y_features)
    print(f"KID: {kid:0.4f}")

    # Use MSID class to compute MSID score from image features, pre-extracted from some feature extractor network:
    msid: torch.Tensor = piq.MSID(niters=100)(x_features, y_features)
    print(f"MSID: {msid:0.4f}")

In [13]:
# Differrent distances between cifar10 and mnist data distrubutions

cifar_features = piq.FID().compute_feats(cifar_dataloader, device='cuda')
mnist_features = piq.FID().compute_feats(mnist_dataloader, device='cuda')
demo(mnist_features, cifar_features)

FID: 263.5873
KID: 0.2426
MSID: 83.1270
