# Instrument the training of a diffusion model with weights and biases
(There's a full course on DLAI on diffusion models)

In [11]:
from types import SimpleNamespace
from pathlib import Path
from tqdm.notebook import tqdm
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from utilities import *

import wandb

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mthatgardnerone[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Setup

In [3]:
DATA_DIR = Path("data")
SAVE_DIR = Path("data/weights")
SAVE_DIR.mkdir(exist_ok=True, parents=True)
DEVICE = "mps"  # "cuda" for Nvidia GPU, "mps" for Apple Silicon GPU, "cpu" for CPU

config = SimpleNamespace(
    # Hyperparameters
    num_samples=30,

    # Diffusion hyperparameters
    timesteps=500,
    beta1=1e-4,
    beta2=0.02,

    # Network hyperparameters
    n_feat=64,  # hidden dimension features
    n_cfeat=5,  # context vector size
    height=16,  # image size

    # Training hyperparameters
    batch_size=100,
    n_epoch=32,
    lrate=1e-3,
)

### Set up noise scheduler and sampler

In [5]:
# perturb_input adds noise to the input image
# sample_ddpm_context generates images using the ddpm sampler
perturb_input, sample_ddpm_context = setup_ddpm(
    config.beta1,
    config.beta2,
    config.timesteps,
    DEVICE,
)

In [6]:
nn_model = ContextUnet(
    in_channels=3,
    n_feat=config.n_feat,
    n_cfeat=config.n_cfeat,
    height=config.height,
).to(DEVICE)

In [7]:
dataset = CustomDataset.from_np(path=DATA_DIR)
dataloader = DataLoader(
    dataset,
    batch_size=config.batch_size,
    shuffle=True,
)
optim = torch.optim.Adam(nn_model.parameters(), lr=config.lrate)

## Training

In [8]:
# Noise vector
# x_T ~ N(0, 1), sample initial noise
noises = torch.randn(
    config.num_samples,
    3,
    config.height,
    config.height,  # square image
).to(DEVICE)

In [9]:
ctx_vector = F.one_hot(
    torch.tensor([
        0, 0, 0, 0, 0, 0,  # hero
        1, 1, 1, 1, 1, 1,  # non-hero
        2, 2, 2, 2, 2, 2,  # food
        3, 3, 3, 3, 3, 3,  # spell
        4, 4, 4, 4, 4, 4,  # side-facing
    ]),
    5
).to(DEVICE).float()

In [12]:
# Create a W&B run
run = wandb.init(
    project="dlai_sprite_diffusion",
    job_type="train",
    config=config
)

# Pass the config back from W&B
config = wandb.config

# Warning: Training is very slow on CPU (but we've set our device to MPS)
for epoch in tqdm(range(config.n_epoch), leave=True, total=config.n_epoch):
    # Set into training mode
    nn_model.train()
    optim.param_groups[0]["lr"] = config.lrate * (1 - epoch / config.n_epoch)

    pbar = tqdm(dataloader, leave=False)
    for x, c in pbar:
        optim.zero_grad()  # zero the gradients
        x = x.to(DEVICE)
        c = c.to(DEVICE)
        context_mask = torch.bernoulli(torch.zeros(c.shape[0]) + 0.8).to(DEVICE)
        c = c * context_mask.unsqueeze(-1)
        noise = torch.randn_like(x)
        t = torch.randint(1, config.timesteps + 1, (x.shape[0],)).to(DEVICE)
        x_pert = perturb_input(x, t, noise)
        pred_noise = nn_model(x_pert, t / config.timesteps, c=c)
        loss = F.mse_loss(pred_noise, noise)
        loss.backward()
        optim.step()

        wandb.log({
            "loss": loss.item(),
            "lr": optim.param_groups[0]["lr"],
            "epoch": epoch,
        })
        
    # Save the model periodically
    if epoch % 4 == 0 or epoch == config.n_epoch - 1:
        nn_model.eval()
        ckpt_file = SAVE_DIR / f"model_{epoch}.pth"
        torch.save(nn_model.state_dict(), ckpt_file)
        
        artifact_name = f"{wandb.run.id}_model_{epoch}"
        artifact = wandb.Artifact(artifact_name, type="model")
        artifact.add_file(ckpt_file)
        wandb.log_artifact(artifact)
        
        samples, _ = sample_ddpm_context(
            nn_model,
            noises,
            ctx_vector[:config.num_samples],
        )
        
        wandb.log({
            "train_samples": [
                wandb.Image(img) for img in samples.split(1)
            ]
        })
        
# Finish the run
wandb.finish()

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/894 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

VBox(children=(Label(value='51.563 MB of 51.568 MB uploaded\r'), FloatProgress(value=0.9999084191587786, max=1…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
loss,█▆▄▃▃▃▃▃▃▂▂▂▂▃▂▂▂▂▃▂▃▂▁▃▁▂▂▂▂▂▂▂▂▁▁▁▁▂▂▂
lr,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁

0,1
epoch,31.0
loss,0.08597
lr,3e-05
