In [None]:
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import wandb

In [None]:
DEVICE = 'cuda:0'
EPOCHS = 50
LABEL = 1
TIMESTEPS = 1000
INITIAL_DIM = 64
IMAGE_SIZE = (64, 64)
BATCH_SIZE = 50

In [None]:
wandb.login()

run = wandb.init(
    project="conifdent-diffusion",
    config={
        "Epochs": EPOCHS,
        "Timesteps": TIMESTEPS,
        "Initial Conv Dim": INITIAL_DIM,
        "Image Size": IMAGE_SIZE,
        "Batch Size": BATCH_SIZE
    }
)

In [None]:
model = Unet(
    dim=INITIAL_DIM,
    dim_mults = (1, 2, 4, 8),
    channels=1
    ).to(DEVICE)

In [None]:
diffusion = GaussianDiffusion(
    model,
    image_size=IMAGE_SIZE[0],
    timesteps=TIMESTEPS,
    loss_type='l1'
    ).to(DEVICE)

In [None]:
trainer = Trainer(
    diffusion,
    'mnist_jpg/',
    train_batch_size = 32,
    train_lr = 8e-5,
    train_num_steps = 46875,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = True,                       # turn on mixed precision
    calculate_fid = True,              # whether to calculate fid during training
    save_and_sample_every=1875
)

trainer.train()