In [3]:
from __future__ import annotations

from pathlib import Path

%env KERAS_BACKEND=torch

import torch
from torchvision import datasets, transforms
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input
from skimage.transform import resize
import scipy
import numpy as np
from tqdm.auto import tqdm

env: KERAS_BACKEND=torch


In [4]:
try:
    from google.colab import drive
    import os

    drive.mount("/content/drive")
    os.chdir("/content/drive/MyDrive/Colab Notebooks")
except ImportError:
    print("Running locally.")

assert Path("ddpm_models").exists(), "Couldn't find model folder"

Mounted at /content/drive


In [5]:
torch.manual_seed(0)

# Rather than treating MNIST images as discrete objects, as done in Ho et al 2020,
# we here treat them as continuous input data, by dequantizing the pixel values (adding noise to the input data)
# Also note that we map the 0..255 pixel values to [-1, 1], and that we process the 28x28 pixel values as a flattened 784 tensor.
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Lambda(
            lambda x: x + torch.rand(x.shape) / 255
        ),  # Dequantize pixel values
        transforms.Lambda(lambda x: (x - 0.5) * 2.0),  # Map from [0,1] -> [-1, -1]
        transforms.Lambda(lambda x: x.flatten()),
    ]
)

dataset_test = datasets.MNIST(
    "./mnist_data", download=True, train=False, transform=transform
)

In [6]:
def hacky_resize(data: torch.Tensor, output_shape=(75, 75, 3)):
    def resize_(a: np.ndarray) -> np.ndarray:
        return resize(a.reshape(28, 28, 1), output_shape=output_shape)

    transformed_np = np.apply_along_axis(
        resize_, 1, data.cpu().numpy().reshape((-1, 28 * 28))
    )
    return torch.from_numpy(transformed_np)

In [7]:
# Select device
if torch.cuda.is_available():
    device = "cuda:0"
elif torch.mps.is_available():
    device = "mps:0"
else:
    device = "cpu"

print(f"Running on {device}")

Running on cuda:0


In [8]:
inc_model = InceptionV3(include_top=False, pooling="avg", input_shape=(75, 75, 3))


@torch.no_grad()
def run_inc_model(imgs: torch.Tensor) -> np.ndarray:
    return inc_model.predict(
        preprocess_input(hacky_resize(imgs)), verbose=False, batch_size=100
    )

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m87910968/87910968[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 0us/step


In [9]:
def calc_fid(imgs_or_act_1: torch.Tensor | np.ndarray, imgs_2: torch.Tensor):
    if isinstance(imgs_or_act_1, np.ndarray):
        act_1 = imgs_or_act_1
    else:
        act_1 = run_inc_model(imgs_or_act_1)
    act_2 = run_inc_model(imgs_2)
    mean_1, cov_1 = act_1.mean(axis=0), np.cov(act_1, rowvar=False)
    mean_2, cov_2 = act_2.mean(axis=0), np.cov(act_2, rowvar=False)
    mean_diff = np.sum((mean_1 - mean_2) ** 2.0)
    cov_prod = np.real(scipy.linalg.sqrtm(cov_1 @ cov_2))
    return round(mean_diff + np.trace(cov_1 + cov_2 - 2 * cov_prod), 2)


ex_d1, ex_d2 = dataset_test.data[:2050], dataset_test.data[2050:4100]
print(calc_fid(run_inc_model(ex_d1), ex_d1)) # -0.0
print(calc_fid(ex_d1, ex_d2)) # 1.76

-0.0
1.76


In [10]:
n_samples = 10_000
dataset_activations = run_inc_model(dataset_test.data[:n_samples])

In [13]:
bar = tqdm(list(Path("ddpm_models").glob("*.pt")))
for model_path in bar:
    bar.set_postfix(model=model_path)
    model = torch.load(model_path, weights_only=False, map_location=device)
    samples = model.sample((n_samples, 28 * 28))
    fid = calc_fid(dataset_activations, samples)
    print(model_path, ":", fid)

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?timestep/s]

ddpm_models/base_e.pt : 569.07


  0%|          | 0/1000 [00:00<?, ?timestep/s]

ddpm_models/base_u.pt : 1676.23


  0%|          | 0/1000 [00:00<?, ?timestep/s]

ddpm_models/base_x0.pt : 432.83


  0%|          | 0/1000 [00:00<?, ?timestep/s]

ddpm_models/variance_low-discrepency.pt : 593.39


  0%|          | 0/1000 [00:00<?, ?timestep/s]

ddpm_models/variance_importance-batch.pt : 633.86


  0%|          | 0/1000 [00:00<?, ?timestep/s]

ddpm_models/variance_importance-sampling.pt : 596.85


ddpm_models/base_e.pt : 569.07

ddpm_models/base_u.pt : 1676.23

ddpm_models/base_x0.pt : 432.83

ddpm_models/variance_low-discrepency.pt : 593.39

ddpm_models/variance_importance-batch.pt : 633.86

ddpm_models/variance_importance-sampling.pt : 596.85