In [6]:
from diffusers import DDPMPipeline
from torchvision import transforms, datasets
from torchvision.models import inception_v3
import torch
import numpy as np
from scipy.linalg import sqrtm
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import os
from PIL import Image

In [9]:
DATASET_DIR = r"D:\data0\lsun\bedroom"
MODEL_ID = "google/ddpm-bedroom-256"
OUTPUT_DIR = "ddpm/images"
DATASET_SUBSET = 0.05

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Image generation

In [4]:
ddpm = DDPMPipeline.from_pretrained(MODEL_ID)
generated_images = [ddpm().images[0] for _ in tqdm(range(9), desc="Generating images")]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Generating images:   0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generating images:  11%|█         | 1/9 [32:18<4:18:24, 1938.09s/it]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generating images:  22%|██▏       | 2/9 [1:04:50<3:47:05, 1946.55s/it]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generating images:  33%|███▎      | 3/9 [1:38:12<3:17:11, 1971.99s/it]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generating images:  44%|████▍     | 4/9 [2:09:36<2:41:25, 1937.11s/it]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generating images:  56%|█████▌    | 5/9 [2:39:45<2:06:03, 1890.97s/it]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generating images:  67%|██████▋   | 6/9 [3:08:43<1:31:56, 1838.84s/it]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generating images:  78%|███████▊  | 7/9 [3:37:42<1:00:12, 1806.22s/it]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generating images:  89%|████████▉ | 8/9 [4:06:37<29:43, 1783.71s/it]  

  0%|          | 0/1000 [00:00<?, ?it/s]

Generating images: 100%|██████████| 9/9 [4:35:34<00:00, 1837.21s/it]


In [5]:
generated_images

[<PIL.Image.Image image mode=RGB size=256x256>,
 <PIL.Image.Image image mode=RGB size=256x256>,
 <PIL.Image.Image image mode=RGB size=256x256>,
 <PIL.Image.Image image mode=RGB size=256x256>,
 <PIL.Image.Image image mode=RGB size=256x256>,
 <PIL.Image.Image image mode=RGB size=256x256>,
 <PIL.Image.Image image mode=RGB size=256x256>,
 <PIL.Image.Image image mode=RGB size=256x256>,
 <PIL.Image.Image image mode=RGB size=256x256>]

In [7]:
output_dir = 'ddpm/images'
os.makedirs(output_dir, exist_ok=True)

for i, img in enumerate(generated_images):
    img.save(os.path.join(output_dir, f'fake_samples_{i+1}.png'))

grid_size = (3, 3)
image_size = generated_images[0].size
grid_image = Image.new('RGB', (grid_size[0] * image_size[0], grid_size[1] * image_size[1]))

for index, img in enumerate(generated_images):
    x = (index % grid_size[0]) * image_size[0]
    y = (index // grid_size[0]) * image_size[1]
    grid_image.paste(img, (x, y))

grid_image_path = os.path.join(output_dir, 'fake_samples.png')
grid_image.save(grid_image_path)

print(f'Images saved in directory: {output_dir}')
print(f'Grid image saved as: {grid_image_path}')

Images saved in directory: ddpm/images
Grid image saved as: ddpm/images\fake_samples.png


## FID calculation

In [11]:
transform = transforms.Compose([
    transforms.Resize((299, 299)),  # Resize to 299x299 for Inception
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.ImageFolder(root=DATASET_DIR, transform=transform)
generated_dataset = datasets.ImageFolder(root='ddpm', transform=transform)

total_len = len(dataset)
final_subset_indices = np.random.choice(total_len, int(DATASET_SUBSET * total_len), replace=False)

subset_dataset = Subset(dataset, final_subset_indices)
dataloader = DataLoader(subset_dataset, batch_size=32, shuffle=True)
generated_dataloader = DataLoader(generated_dataset, batch_size=32, shuffle=True)

In [12]:
def get_features(dataloader, model, device):
    model = model.to(device)
    features = []
    with torch.no_grad():
        for images, _ in tqdm(dataloader, desc="Extracting features"):
            images = images.to(device)
            outputs = model(images)
            features.append(outputs.cpu().numpy())
    features = np.concatenate(features, axis=0)
    return features

In [13]:
inception = inception_v3(pretrained=True, transform_input=False).eval()

In [14]:
real_features = get_features(dataloader, inception, device)

Extracting features: 100%|██████████| 474/474 [07:20<00:00,  1.08it/s]


In [15]:
generated_features = get_features(generated_dataloader, inception, device)

Extracting features: 100%|██████████| 1/1 [00:00<00:00,  3.62it/s]


In [17]:
mu_real = np.mean(real_features, axis=0)
sigma_real = np.cov(real_features, rowvar=False)

mu_gen = np.mean(generated_features, axis=0)
sigma_gen = np.cov(generated_features, rowvar=False)

def calculate_fid(mu1, sigma1, mu2, sigma2):
    ssdiff = np.sum((mu1 - mu2) ** 2.0)
    covmean = sqrtm(sigma1.dot(sigma2))

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

fid_score = calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen)
print('FID score:', fid_score)

FID score: 666.9733674596641
