<a href="https://colab.research.google.com/github/tocom242242/notebooks/blob/master/metrics/FID.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
import tensorflow as tf
import cv2
import numpy as np
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input

# 特徴量抽出用のInceptionモデルを読み込む
model = InceptionV3(weights='imagenet', include_top=False, pooling="avg")

In [10]:
# CIFAR-10データセットを読み込む
(_, _), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

# データの正規化
test_images = test_images.astype('float32') / 255.0

In [11]:
# クラスごとにデータを分ける
class_images = [[] for _ in range(10)]
for image, label in zip(test_images, test_labels):
    class_images[label[0]].append(image)

class_images = [np.array(images) for images in class_images]

In [12]:
np.array(class_images).shape

(10, 1000, 32, 32, 3)

In [13]:
from scipy.linalg import sqrtm
# Inceptionモデルに入力できるサイズにリサイズする
target_size = (224, 224)
def resize(imgs):
    resized_imgs = []
    for img in imgs:
        resized = cv2.resize(img, target_size)
        resized_imgs.append(resized)
    return np.array(resized_imgs)


# fidを計算する
def calc_fid(model, imgs1, imgs2):
    # 特徴量の抽出
    f1 = model.predict(preprocess_input(imgs1))
    f2 = model.predict(preprocess_input(imgs2))
    # 平均を求める
    f1_mean = np.mean(f1,axis=0)
    f2_mean = np.mean(f2,axis=0)
    # 平均の差を求める
    diff = f1_mean - f2_mean
    # 共分散行列を求める
    f1_sigma = np.cov(f1, rowvar=False)
    f2_sigma = np.cov(f2, rowvar=False)
    # 共分散行列の積を取り平方根を計算する
    sqrt_cov_dotted = sqrtm(f1_sigma.dot(f2_sigma))
    # 虚数が含まれる場合には実数のみ用いる
    if np.iscomplexobj(sqrt_cov_dotted):
        sqrt_cov_dotted = sqrt_cov_dotted.real
    fid = np.sum(diff**2.0) + np.trace(f1_sigma+f2_sigma - 2.0*sqrt_cov_dotted)
    return fid

# 同一クラスのfid
resized_imgs1 = resize(class_images[0][:100])
resized_imgs2 = resize(class_images[0][:100])

fid = calc_fid(model, resized_imgs1, resized_imgs2)
print("fid1:",fid)

# 別クラスのfid
resized_imgs1 = resize(class_images[0][:100])
resized_imgs2 = resize(class_images[1][:100])

fid = calc_fid(model, resized_imgs1, resized_imgs2)
print("fid2:",fid)


fid1: -4.0880048937223173e-07
fid2: 0.5700966585587528
