In [1]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from keras.datasets.mnist import load_data

from unet import UNet
from diffusion_model import DiffusionModel
import matplotlib.pyplot as plt

import imageio

In [2]:
(trainX, trainy), (testX, testy) = load_data()
trainX = np.float32(trainX) / 255.
testX = np.float32(testX) / 255.

def sample_batch(batch_size, device):
    indices = torch.randperm(trainX.shape[0])[:batch_size]
    data = torch.from_numpy(trainX[indices]).unsqueeze(1).to(device)
    return torch.nn.functional.interpolate(data, 32)

In [3]:
device = 'cuda'
model = torch.load('model_paper2_epoch_39999').to(device)
diffusion_model = DiffusionModel(1000, model, device)

In [4]:
@torch.no_grad()
def sampling(self, n_samples=1, image_channels=1, img_size=(32, 32), use_tqdm=True):

    
    x = torch.randn((n_samples, image_channels, img_size[0], img_size[1]), 
                     device=self.device)

    all_x = [x]
    progress_bar = tqdm if use_tqdm else lambda x : x
    for t in progress_bar(range(self.T, 0, -1)):
        z = torch.randn_like(x) if t > 1 else torch.zeros_like(x)

        t = torch.ones(n_samples, dtype=torch.long, device=self.device) * t 

        beta_t = self.beta[t-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        alpha_t = self.alpha[t-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        alpha_bar_t = self.alpha_bar[t-1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

        mean = 1 / torch.sqrt(alpha_t) * (x - ((1 - alpha_t) / torch.sqrt(
            1 - alpha_bar_t)) * self.function_approximator(x, t-1))
        sigma = torch.sqrt(beta_t)
        x = mean + sigma * z
        all_x.append(x)

    return all_x

In [5]:
imgs = sampling(diffusion_model, n_samples=10)

100%|██████████| 1000/1000 [00:17<00:00, 56.15it/s]


In [51]:
indices = (1.1**np.linspace(0, 73, 80, dtype=int)).astype(int)
indices = np.array(sorted(list(set(list(indices))))).clip(1, 1000)
indices = 1001 - indices
indices = sorted(indices)
idx = 1

In [53]:
imgs_np = [(e[idx].clip(0, 1).squeeze(0).cpu().numpy() * 255).astype(np.uint8) for e in imgs]
imgs_np = np.array(imgs_np)[indices]

In [54]:
imgs_np.shape

(58, 32, 32)

In [55]:
imageio.mimsave('movie.gif', imgs_np)

In [56]:
imageio.mimsave('movie.mp4', imgs_np)