In [None]:
import numpy as np 
from pathlib import Path
from tqdm.auto import tqdm
from functools import partial

import torch
import torch.nn.functional as F 
from torch.optim import Adam
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import utils
from models import UNet2DModel

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on {DEVICE}...")

#### Training diffusion (unconditional) model for Fashion MNIST
The input image is 28x28x1 image. The pixel values are first converted into a PyTorch tensor in the range [0,1] from [0,255] (via `transforms.toTensor()`), which is then padded 0s on all sides (thus image size increased to 32x32). Lastly, the image is re-scaled to [-0.5,0.5] which helps more efficient training.

In [None]:
tfm = transforms.Compose([
  transforms.ToTensor(),
  transforms.Pad(2),
  lambda x: x-0.5,
])

bs = 512
dataset = datasets.FashionMNIST(root="./data", download=True, transform=tfm)
dls = DataLoader(dataset, batch_size=bs, shuffle=True)

In [None]:
model = UNet2DModel(in_channels=1, out_channels=1, nfs=(32,64,128,256), num_layers=2)
model = model.to(DEVICE)

Train the UNet model with the Adam optimizer according to the 1cycle learning rate policy (see [paper](https://arxiv.org/abs/1708.07120))

In [None]:
# hyperparameters
lr = 1e-2
epochs = 25
tmax = epochs * len(dls)
optimizer = Adam(model.parameters(), eps=1e-5)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
schedo = sched(optimizer)

save_dir = Path("./weights")
save_dir.mkdir(parents=True, exist_ok=True)

In [None]:
model.train()
train_losses = []
for epoch in range(epochs):
  batch_train_losses = []
  pbar = tqdm(dls, mininterval=2)
  for xb, _ in pbar:
    optimizer.zero_grad()
    xb = xb.to(DEVICE)
    (noised_input, t), target = utils.noisify(xb)
    out = model((noised_input, t))
    loss = F.mse_loss(out, target)
    loss.backward()
    schedo.optimizer.step()
    schedo.step()
    batch_train_losses.append(loss.item())
    pbar.set_description(f"loss {loss.item():.2f}")

  train_losses.extend(batch_train_losses)
  print(f"Epoch {epoch}, loss: {np.mean(train_losses)}")

  # save the model every 4 epochs
  if epoch % 4 == 0 or epoch == int(epochs-1):
    model_path = save_dir / f"emnist_model_{epoch}_bs_{bs}.pth"
    torch.save(model.state_dict(), model_path)
    print(f"saved model at {model_path.absolute().as_posix()}")