In [5]:
import torch
import numpy as np

from skimage.io import imread, imsave
from tqdm.auto import trange, tqdm
from torchvision.datasets import MNIST
from pytorch_fid import fid_score

from data_generator import DataGenerator
from default_mnist_config import create_default_mnist_config
from diffusion import DiffusionRunner
from models.classifier import ResNet, ResidualBlock, ConditionalResNet

from matplotlib import pyplot as plt

import os

from torchvision.transforms import Compose, Resize

#### Определим папку с настоящими картинками

In [6]:
def create_dir(path: str):
    if not os.path.exists(path):
        os.makedirs(path)

In [7]:
create_dir('./real_images_MNIST')

real_dataset = MNIST(root='../data', download=True, train=True, transform=Compose([Resize((32, 32))]))
for idx, (image_mnist, label) in enumerate(tqdm(real_dataset, total=len(real_dataset))):
    image = np.array(image_mnist)
    imsave("./real_images_MNIST/{}.png".format(idx), image)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 78723091.16it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 118877030.25it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 85547188.95it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 20116714.64it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw






  0%|          | 0/60000 [00:00<?, ?it/s]

#### Определим папку для синтетических картинок и сгенерируем 60к картинок

In [13]:
uncond_diff = DiffusionRunner(create_default_mnist_config(), eval=True)

P.S.: изменил TOTAL_IMAGES_COUNT с 60.000 до 10.000, так как не хватало времени (в чате разрешили)

In [14]:
create_dir('../uncond_mnist')

TOTAL_IMAGES_COUNT = 10_000
BATCH_SIZE = 200
NUM_ITERS = TOTAL_IMAGES_COUNT // BATCH_SIZE

global_idx = 0
for idx in trange(NUM_ITERS):
    images: torch.Tensor = uncond_diff.sample_images(batch_size=BATCH_SIZE).cpu()
    images = images.permute(0, 2, 3, 1).data.numpy().astype(np.uint8)

    for i in range(len(images)):
        imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
        global_idx += 1

  0%|          | 0/50 [00:00<?, ?it/s]

  imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])


In [17]:
fid_value = fid_score.calculate_fid_given_paths(
    paths=['./real_images_MNIST', '../uncond_mnist'],
    batch_size=200,
    device='cuda:0',
    dims=2048
)
fid_value

Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 196MB/s]
100%|██████████| 300/300 [03:54<00:00,  1.28it/s]
100%|██████████| 50/50 [00:40<00:00,  1.25it/s]


132.33983047380514

> Какой фид получился? Сравните FID для безусловной генерации и для условной. Сгенерируйте для каждого класса по 6к картинок и посчитайте FID между реальными и условно сгенерированными картинками.

Фид равен 132.34. Вычислим для условной генерации

In [18]:
classifier_args = {
    "block": ResidualBlock,
    "layers": [2, 2, 2, 2]
}
noisy_classifier = ConditionalResNet(**classifier_args)
noisy_classifier.to('cuda:0')

noisy_classifier.load_state_dict(torch.load('./ddpm_checkpoints/classifier.pth'))

conditional_diffusion = DiffusionRunner(create_default_mnist_config(), eval=True)
conditional_diffusion.set_classifier(noisy_classifier, T=0.1)

In [20]:
create_dir('../conditional_mnist')

TOTAL_IMAGES_COUNT = 60_000
BATCH_SIZE = 6_000
NUM_ITERS = TOTAL_IMAGES_COUNT // BATCH_SIZE

global_idx = 0
for idx in trange(NUM_ITERS):
    labels = (idx % 10)*torch.ones(100).long().to('cuda:0')

    images: torch.Tensor = conditional_diffusion.sample_images(batch_size=BATCH_SIZE, labels=labels).cpu()
    images = images.permute(0, 2, 3, 1).data.numpy().astype(np.uint8)

    for i in range(len(images)):
        imsave(os.path.join('../conditional_mnist', f'{global_idx}.png'), images[i])
        global_idx += 1

  0%|          | 0/10 [00:00<?, ?it/s]

  imsave(os.path.join('../conditional_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('../conditional_mnist', f'{global_idx}.png'), images[i])
  imsave(os.path.join('../conditional_mnist', f'{global_idx}.png'), images[i])


In [21]:
fid_value = fid_score.calculate_fid_given_paths(
    paths=['./real_images_MNIST', '../conditional_mnist'],
    batch_size=200,
    device='cuda:0',
    dims=2048
)
fid_value

100%|██████████| 300/300 [03:55<00:00,  1.27it/s]
100%|██████████| 5/5 [00:03<00:00,  1.26it/s]


141.26082344168532

Фид равен 141.26. Он больше, чем для безусловной генерации