In [None]:
import torch
import torchaudio
import torch.nn as nn
from denoising_diffusion_pytorch import Unet, GaussianDiffusion
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import wandb
from datetime import datetime
import numpy as np
from multiprocessing import cpu_count

## Config

In [None]:
DEVICE = 'cuda:0'
DATA = "mnist"
MODEL_PATH = "models/"
EPOCHS = 420
LABEL = None
TIMESTEPS = 1000
INITIAL_DIM = 32
IMAGE_SIZE = (32, 32)
BATCH_SIZE = 200
INITIAL_LR = 1e-5

## WandB Config

In [None]:
wandb.login()

run = wandb.init(
    project="conifdent-diffusion",
    config={
        "Epochs": EPOCHS,
        "Timesteps": TIMESTEPS,
        "Initial Conv Dim": INITIAL_DIM,
        "Image Size": IMAGE_SIZE,
        "Batch Size": BATCH_SIZE
    }
)

## Data

#### MNIST Dataset:

In [None]:
class SpectrogramYESNO(torch.utils.data.Dataset):
    def __init__(self):
        self.dataset = torchaudio.datasets.YESNO(
            root="train_yesno/",
            download=True
            )
        
        self.transform = torch.nn.Sequential(
            torchaudio.transforms.Spectrogram()
        )
        
    def __getitem__(self, idx):
        waveform, sample_rate, label = self.dataset[idx]
        spectrogram = self.transform(waveform)

        print(spectrogram.shape)

        return (
            spectrogram,
            sample_rate,
            label
        )
    
    def __len__(self):
        return len(self.dataset)

In [None]:
if DATA == "mnist":
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(size=IMAGE_SIZE, antialias=True)
        ])

    dataset = datasets.MNIST(
        root="train_mnist/",
        train=True,
        download=True,
        transform=preprocess
        )

#### GTZAN Dataset:

In [None]:
if DATA == "YESNO":
    dataset = SpectrogramYESNO()

#### Filter out classes:

In [None]:
if LABEL is not None:
    idx = dataset.targets == LABEL
    dataset.targets = dataset.targets[idx]
    dataset.data = dataset.data[idx]

#### Set up:

In [None]:
dl = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True
    )

In [None]:
example_image = next(iter(dl))[0][0].numpy()
print(example_image.shape)
plt.imshow(example_image.transpose((1, 2, 0)))
plt.show()

## Model

In [None]:
def get_model():
    model = Unet(
        dim=INITIAL_DIM,
        dim_mults = (1, 2, 4, 8),
        channels=1
        ).to(DEVICE)
    
    diffusion = GaussianDiffusion(
        model,
        image_size=IMAGE_SIZE[0],
        timesteps=TIMESTEPS,
        loss_type='l1'
        ).to(DEVICE)
    
    return diffusion

## Training

In [None]:
best_loss = float("inf")
diffusion = get_model()
models_saved = 0
diffusion.train()
optimizer = torch.optim.AdamW(diffusion.parameters(), INITIAL_LR)

for epoch in range(EPOCHS):
    pbar = tqdm(dl, leave=True, desc=f"Epoch {epoch + 1}/{EPOCHS}", colour="#55D3FF")

    for batch in pbar:
        batch = batch[0].to(DEVICE)
        loss = diffusion(batch)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        logs = {"loss": loss.detach().item()}
        pbar.set_postfix(**logs)
        wandb.log({"loss": loss.detach().item()})

    epoch_loss = loss.detach().item()
    if epoch_loss < best_loss:
        torch.save({"model_state_dict": diffusion.state_dict()}, MODEL_PATH + "model.pt")
        best_loss = epoch_loss
        print("INFO: New model saved.")
        models_saved += 1
        wandb.log({"models_saved": models_saved})

## Inference

In [None]:
diffusion = get_model().to(DEVICE)
checkpoint = torch.load("models/model.pt")
diffusion.load_state_dict(checkpoint["model_state_dict"])
sampled_images = diffusion.sample(batch_size=1)
for image in sampled_images:
    image = image.cpu().detach().numpy()
    image = image.transpose((1,2,0))

    plt.imshow(image)
    plt.show()