## <center>Exploring Generative Capabilities of Diffusion-Based Deep Generative Models <br><br> COMP3547 Deep Learning Assignment 2022/2023</center>

<hr>

### Work based on several sources. Below are the most important attributions:

* [2020, Denoising Diffusion Probabilistic Models (Ho, Jain, Abbeel)](https://arxiv.org/pdf/2006.11239.pdf)

* [2021, Improved Denoising Diffusion Probabilistic Models (Nichol, Dhariwal)](https://proceedings.mlr.press/v139/nichol21a/nichol21a.pdf)

* [2021, Diffusion Models Beat GANs on Image Synthesis (Dhariwal, Nichol)](https://arxiv.org/pdf/2105.05233v4.pdf)

<br>

* [2019, Generative Modeling by Estimating Gradients of the Data Distribution (Song, Ermon)](https://proceedings.neurips.cc/paper/2019/file/3001ef257407d5a371a96dcd947c7d93-Paper.pdf)

* [2020, Score-based Generative Modeling Through Stochastic Differential Equations (Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole)](https://arxiv.org/pdf/2011.13456.pdf)

* [2020, Improved Techniques for Training Score-based Generative Models (Song, Ermon)](https://proceedings.neurips.cc/paper/2020/file/92c3b916311a5517d9290576e3ea37ad-Paper.pdf)

* [2020, Denoising Diffusion Implicit Models (Song, Meng, Ermon)](https://arxiv.org/pdf/2010.02502.pdf)


* [2022, High-Resolution Image Synthesis with (Rombach, Blattmann, Lorenz, Esser, Ommer)](https://arxiv.org/pdf/2112.10752.pdf)

* [2022, Diffusion Models: A Comprehensive Survey of Methods and Applications (Yang et al.)](https://arxiv.org/abs/2209.00796)


* [2022, How Much is Enough? A Study on Diffusion Times in Score-based Generative Models (Franzese et al.)](https://arxiv.org/abs/2206.05173)

* [https://github.com/dome272/Diffusion-Models-pytorch (Apache License 2.0)](https://github.com/dome272/Diffusion-Models-pytorch)

* [https://colab.research.google.com/drive/1sjy9odlSSy0RBVgMTgP7s99NXsqglsUL?usp=sharing (MIT License)](https://colab.research.google.com/drive/1sjy9odlSSy0RBVgMTgP7s99NXsqglsUL?usp=sharing)

* [https://github.com/lucidrains/denoising-diffusion-pytorch (MIT License)](https://github.com/lucidrains/denoising-diffusion-pytorch)

* [https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/annotated_diffusion.ipynb#scrollTo=3a159023 (No license information)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/annotated_diffusion.ipynb#scrollTo=3a159023)

* [https://github.com/labmlai (MIT License)](https://github.com/labmlai)

* [https://github.com/heejkoo/Awesome-Diffusion-Models (MIT License)](https://github.com/heejkoo/Awesome-Diffusion-Models)

* [https://github.com/yang-song/score_sde (Apache License 2.0)](https://github.com/yang-song/score_sde)

* [https://colab.research.google.com/drive/120kYYBOVa1i0TD85RjlEkFjaWDxSFUx3?usp=sharing#scrollTo=XCR6m0HjWGVV (Apache License 2.0)](https://colab.research.google.com/drive/120kYYBOVa1i0TD85RjlEkFjaWDxSFUx3?usp=sharing#scrollTo=XCR6m0HjWGVV)

<hr>

In [None]:
import model
import diffusion
import dataset
import config

import tqdm
import torch

torch.manual_seed(config.SEED)
torch.cuda.manual_seed_all(config.SEED)

In [None]:
data_loader = dataset.get_data(config.DATASET_NAME, config.BATCH_SIZE)

model = model.EpsilonTheta(
    channels=config.CHANNELS,
    feature_map_size=config.FEATURE_MAP_SIZE,
    groups_number=config.GROUPS_NUMBER,
    heads_number=config.HEADS_NUMBER,
    blocks_number=config.BLOCKS_NUMBER,
).to(config.DEVICE)

diffusion = diffusion.DenoisingDiffusion(
    epsilon_theta_model=model,
    beta_initial=config.BETA_INITIAL,
    beta_final=config.BETA_FINAL,
    T=config.T,
    device=config.DEVICE
)

optimiser = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

In [None]:
if config.CHECKPOINT_FILE is not None:
    checkpoint = torch.load(config.CHECKPOINT_FILE)
    epoch_start = checkpoint["epoch"]
    loss = checkpoint["loss"]
    losses = checkpoint["losses"]
    model.load_state_dict(checkpoint["model_state_dict"])
    optimiser.load_state_dict(checkpoint["optimiser_state_dict"])
else:
    epoch_start = 1
    losses = []
    
if epoch_start > config.EPOCHS:
    raise ValueError("Invalid number of epochs. Please choose a number greater than the number of epochs already trained.")

    
for epoch in range(epoch_start, config.EPOCHS + 1):
    print(f"EPOCH {epoch}/{config.EPOCHS}")
    epoch_loss = 0.0
    
    for batch in tqdm.tqdm(data_loader, desc="Processing batches"):
        
        # Load a batch of data and assign it to device
        x_0 = batch[0].to(config.DEVICE)
        
        # Get random t for each sample in the batch
        t = torch.randint(0, config.T, (config.BATCH_SIZE,), device=config.DEVICE, dtype=torch.long)
        
        # Sample noise from the Normal Distribution 
        epsilon = torch.randn_like(x_0)
           
        # Sample $x_t$ for $q(x_t|x_0)$
        x_t = diffusion.forward_diffusion(x_0, t, epsilon)
        
        # Get \{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
        epsilon_theta = diffusion.epsilon_theta_model(x_t, t)
        
        # Compute the loss
        loss = torch.functional.F.mse_loss(epsilon, epsilon_theta)
        epoch_loss += loss.item()
        
        # Backpropagate and step the optimiser
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        
    # Collect losses for plottings
    epoch_loss /= len(data_loader)
    losses.append(epoch_loss)
            
    # Sample new data
    with torch.no_grad():

        # x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})
        x_T = torch.randn(
            [64, config.CHANNELS, config.IMAGE_SIZE, config.IMAGE_SIZE], 
            device=config.DEVICE
        )

        # Remove noise in T steps
        for t_ in tqdm(range(0, config.ET_T), desc="Denoising timesteps"):
            t = config.ET_T - t_ - 1

            # Sample {p_\theta}(x_{t-1}|x_t) 
            x_T = diffusion.reverse_diffusion(x_T, x_T.new_full((64, ), t, dtype=torch.long))

        # Display the data
        model.diffusion.display_images(x_T)
        
    # Print mean loss at the end of epoch
    print(f"Epoch mean loss: {epoch_loss}")
    print("-" * 100)

    # Save progress (checkpoint) every `config.CHECKPOINT_FREQUENCY` epochs
    if epoch % config.CHECKPOINT_FREQUENCY == 0:
         torch.save({
                "epoch": epoch + 1,
                "loss": loss,
                "losses": losses,
                "model_state_dict": model.state_dict(),
                "optimiser_state_dict": optimiser.state_dict()
            }, f"ddpm_{config.DATASET_NAME}_checkpoint_epoch_{epoch}.pt")