In [None]:
from tqdm import tqdm 

import os

import numpy as np

import torch
import torch.nn as nn

from torch.optim import Adam

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid

from unet import Unet
from scheduler import get_schedules

In [None]:
# Hyper parameters

epochs = 1
batch_size = 64
lr = 1e-5

n_T = 1000
betas = [1e-4, 0.02]

In [None]:
torch.manual_seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Import dataset
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
dataset = MNIST("./data", train=True, download=True, transform=transform,
)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
# Initialize model, loss function, and optimizer
unet = Unet().to(device)
loss_fn = nn.MSELoss()
optim = Adam(unet.parameters(), lr=lr)

# pre-compute schedules
schedules = get_schedules(betas[0], betas[1], n_T)
schedules = {key: val.to(device) for key, val in schedules.items()}  # add all tensors on device

In [None]:
# Create results directory if it does not exist
if not os.path.isdir('results'):
    os.makedirs('results')

if not os.path.isdir('saved_models'):
    os.makedirs('saved_models')

In [None]:
# Sampling function (used to log the results)
def sample(model, n_T, n_samples, sample_shape, schedules):

    # Step 1
    x_T = torch.randn(n_samples, *sample_shape).to(device)

    # Step 2
    x_i = x_T
    for i in tqdm(range(n_T, 0, -1)):
        # Step 3
        z = torch.randn(n_samples, *sample_shape).to(device) if i > 1 else 0
        # Step 4
        ts = torch.tensor(i / n_T).repeat(n_samples,).to(device)
        eps = model(x_i, ts)
        x_i = schedules["one_over_sqrt_a"][i] * (x_i - eps * schedules["inv_alpha_over_sqrt_inv_abar"][i]) + schedules["sqrt_beta"][i] * z


    # Step 6
    x = x_i
    return x

In [None]:
# Training loop
losses = []
for epoch in range(epochs):
    print(f"Epoch {epoch} : ")

    # Set unet in training mode
    unet.train()

    pbar = tqdm(dataloader)
    for x, _ in pbar:
        
        x = x.to(device)
        x = x.view(-1, 1, 28, 28)

        # Step 3
        timesteps = torch.randint(1, n_T + 1, (x.shape[0],)).to(device)

        # Step 4
        eps = torch.randn_like(x)

        # Step 5
        optim.zero_grad()

        x_t = schedules["sqrt_abar"][timesteps, None, None, None] * x + schedules["sqrt_inv_abar"][timesteps, None, None, None] * eps
        t = timesteps/n_T
        eps_hat = unet(x_t, t)


        loss = loss_fn(eps_hat, eps)
        loss.backward()
        losses.append(loss.item())
        pbar.set_description(f"Avg loss: {np.mean(losses):.4f}")
        optim.step()

    # Set unet in eval mode
    unet.eval()
    with torch.no_grad():
        x_hat = sample(unet, n_T, 8, (1, 28, 28), device, schedules)
        x_comp = torch.cat([x_hat, x[:8]], dim=0)  # compare original and reconstructed examples
        grid = make_grid(x_comp, normalize=True, value_range=(-1, 1), nrow=4)
        save_image(grid, f"results/sample_mnist{epoch}.png")

        # save model
        torch.save(unet.state_dict(), f"saved_models/unet_mnist.pth")