In [None]:
from tqdm import tqdm 

import os
import gc

import numpy as np

import torch
import torch.nn as nn

from torch.optim import Adam

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid

from unet import Unet
from scheduler import get_schedules

In [None]:
torch.manual_seed(42)
np.random.seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
n_T = 1000
betas = [1e-4, 0.02]

In [None]:
if not os.path.isdir('results'):
    os.makedirs('results')

In [None]:
unet = Unet().to(device)

# pre-compute schedules
schedules = get_schedules(betas[0], betas[1], n_T)
schedules = {key: val.to(device) for key, val in schedules.items()}  # add all tensors on device

In [None]:
loaded_state_dict = torch.load(f"saved_models/unet_mnist.pth", map_location=torch.device(device))
unet.load_state_dict(loaded_state_dict)
unet.eval()


In [None]:
n_samples = 8

# Step 1
x_T = torch.randn(n_samples, 1, 28, 28).to(device)

ones = torch.ones(n_samples).to(device)

noise = torch.randn(n_samples, 1, 28, 28).to(device)

# Step 2
x_i = x_T
for i in tqdm(range(n_T, 0, -1)):
    gc.collect()
    torch.cuda.empty_cache()

    # Step 3
    noise.normal_()
    z = noise

    # Step 4
    t = ones
    eps = unet(x_i, t)

    x_i = schedules["one_over_sqrt_a"][i] * (x_i - eps * schedules["inv_alpha_over_sqrt_inv_abar"][i]) + schedules["sqrt_beta"][i] * z
    
    del eps, t

In [None]:
x_hat = x_i
x_hat = x_hat.detach().cpu().numpy().reshape(-1, 28, 28)

import matplotlib.pyplot as plt

for i, img in enumerate(x_hat):
    plt.subplot(2, 4, i+1)
    plt.imshow(img, cmap='binary')
    plt.axis('off')

plt.draw()