## <center>Exploring Generative Capabilities of Diffusion-Based Deep Generative Models <br><br> COMP3547 Deep Learning Assignment 2022/2023</center>

<hr>

### Work based on several sources. Below are the most important attributions:

* [2020, Denoising Diffusion Probabilistic Models (Ho, Jain, Abbeel)](https://arxiv.org/pdf/2006.11239.pdf)

* [2021, Improved Denoising Diffusion Probabilistic Models (Nichol, Dhariwal)](https://proceedings.mlr.press/v139/nichol21a/nichol21a.pdf)

* [2021, Diffusion Models Beat GANs on Image Synthesis (Dhariwal, Nichol)](https://arxiv.org/pdf/2105.05233v4.pdf)

<br>

* [2019, Generative Modeling by Estimating Gradients of the Data Distribution (Song, Ermon)](https://proceedings.neurips.cc/paper/2019/file/3001ef257407d5a371a96dcd947c7d93-Paper.pdf)

* [2020, Score-based Generative Modeling Through Stochastic Differential Equations (Song, Sohl-Dickstein, Kingma, Kumar, Ermon, Poole)](https://arxiv.org/pdf/2011.13456.pdf)

* [2020, Improved Techniques for Training Score-based Generative Models (Song, Ermon)](https://proceedings.neurips.cc/paper/2020/file/92c3b916311a5517d9290576e3ea37ad-Paper.pdf)

* [2020, Denoising Diffusion Implicit Models (Song, Meng, Ermon)](https://arxiv.org/pdf/2010.02502.pdf)


* [2022, High-Resolution Image Synthesis with (Rombach, Blattmann, Lorenz, Esser, Ommer)](https://arxiv.org/pdf/2112.10752.pdf)

* [2022, Diffusion Models: A Comprehensive Survey of Methods and Applications (Yang et al.)](https://arxiv.org/abs/2209.00796)


* [2022, How Much is Enough? A Study on Diffusion Times in Score-based Generative Models (Franzese et al.)](https://arxiv.org/abs/2206.05173)

* [https://github.com/dome272/Diffusion-Models-pytorch (Apache License 2.0)](https://github.com/dome272/Diffusion-Models-pytorch)

* [https://colab.research.google.com/drive/1sjy9odlSSy0RBVgMTgP7s99NXsqglsUL?usp=sharing (MIT License)](https://colab.research.google.com/drive/1sjy9odlSSy0RBVgMTgP7s99NXsqglsUL?usp=sharing)

* [https://github.com/lucidrains/denoising-diffusion-pytorch (MIT License)](https://github.com/lucidrains/denoising-diffusion-pytorch)

* [https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/annotated_diffusion.ipynb#scrollTo=3a159023 (No license information)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/annotated_diffusion.ipynb#scrollTo=3a159023)

* [https://github.com/labmlai (MIT License)](https://github.com/labmlai)

* [https://github.com/heejkoo/Awesome-Diffusion-Models (MIT License)](https://github.com/heejkoo/Awesome-Diffusion-Models)

* [https://github.com/yang-song/score_sde (Apache License 2.0)](https://github.com/yang-song/score_sde)

* [https://colab.research.google.com/drive/120kYYBOVa1i0TD85RjlEkFjaWDxSFUx3?usp=sharing#scrollTo=XCR6m0HjWGVV (Apache License 2.0)](https://colab.research.google.com/drive/120kYYBOVa1i0TD85RjlEkFjaWDxSFUx3?usp=sharing#scrollTo=XCR6m0HjWGVV)

<hr>

In [1]:
import model
import training
import sampling
import dataset
import config

import tqdm
import torch

torch.manual_seed(config.SEED)
torch.cuda.manual_seed_all(config.SEED)

In [2]:
data_loader = dataset.get_data(config.DATASET_NAME, config.BATCH_SIZE)

model = model.ScoreMatchingModel(
    batch_size=config.BATCH_SIZE,
    channels=config.CHANNELS,
    image_size=config.IMAGE_SIZE,
    dimensions=config.DIMENSIONS,
    embedding_size=config.EMBEDDING_SIZE,
    groups_number=config.GROUPS_NUMBER,
    epsilon=config.EPSILON,
    sigma=config.SIGMA,
    scale=config.SCALE,
    T=config.T,
    device=config.DEVICE,
)

optimiser = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

Files already downloaded and verified


In [3]:
if config.CHECKPOINT_FILE is not None:
    checkpoint = torch.load(config.CHECKPOINT_FILE)
    epoch_start = checkpoint["epoch"]
    loss = checkpoint["loss"]
    losses = checkpoint["losses"]
    model.load_state_dict(checkpoint["model_state_dict"])
    optimiser.load_state_dict(checkpoint["optimiser_state_dict"])
else:
    epoch_start = 1
    losses = []


if epoch_start > config.EPOCHS:
    raise ValueError("Invalid number of epochs. Please choose a number greater than the number of epochs already trained.")

if config.SDE_SAMPLING_MODE not in config.SDE_SAMPLING_MODES:
    raise ValueError('Invalid sde_sampling_mode. Please choose between "euler_maruyama_only" and "langevin_mcmc_and_euler_maruyama"')


# MAIN TRAINING AND SAMPLING LOOP
for epoch in range(epoch_start, config.EPOCHS + 1):
    print(f"EPOCH {epoch}/{config.EPOCHS}")
    epoch_loss = 0

    # TRAINING
    for batch in tqdm.tqdm(data_loader, desc="Processing batches"):
        
        # Load a batch of data and assign to device
        x = batch[0].to(config.DEVICE)
                
        # Get random t for each sample in the batch
        t = torch.rand(config.BATCH_SIZE, device=config.DEVICE) * (1.0 - config.EPSILON) + config.EPSILON
        
        # Calculate standard deviation
        standard_deviation = model.marginal_probability_std(t)
        
        # Generate random z
        z = torch.randn_like(x)
        
        # Generate noised image
        x_noised = x + z * standard_deviation[:, None, None, None]
        
        # Calculate score
        score = model(x_noised, t)
        
        # Calculate loss
        loss = torch.mean(torch.sum((score * standard_deviation[:, None, None, None] + z)**2, dim=(1,2,3)))
        epoch_loss += loss.item()
        
        # Backpropagate
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        
    # Collect losses for plotting
    epoch_loss /= len(data_loader)
    losses.append(epoch_loss)

    # SAMPLING
    if config.SAMPLING_TYPE == "SDE":
        generated_images = model.sample_with_SDE(
            sde_sampling_mode=config.SDE_SAMPLING_MODE,
            signal_to_noise_ratio=config.SIGNAL_TO_NOISE_RATIO
        )
    
    elif config.SAMPLING_TYPE == "ODE":        
        generated_images = model.sample_with_ODE(
            ode_error_tolerance=config.ODE_ERROR_TOLERANCE,
            z = z
        )
    else:
        raise ValueError('Invalid sampling type. Please choose between "SDE" and "ODE"')
    
    # Display new images
    model.display_images(generated_images)
        
    # Print mean loss at the end of epoch
    print(f"Epoch mean loss: {epoch_loss}")

    # Save progress (checkoint) every `config.CHECKPOINT_FREQUENCY` epochs
    if epoch % config.CHECKPOINT_FREQUENCY == 0:
            torch.save({
                    "epoch": epoch,
                    "loss": loss,
                    "losses": losses,
                    "model_state_dict": model.state_dict(),
                    "optimiser_state_dict": optimiser.state_dict()
                }, f"score_matching_LATEST_{config.DATASET_NAME}_checkpoint_epoch_{epoch}.pt")

EPOCH 1/10000


Processing batches:   0%|          | 0/781 [00:00<?, ?it/s]


RuntimeError: Placeholder storage has not been allocated on MPS device!