# Diffusion on MNIST: predicting the noise

Steps:

1. Train UNet to predict noise given noisified image
2. Train UNet to predict noise given noisified image AND the noise level used

Sampling is so finnicky!!11!!!1

## References

* fastai 2022 / 2023 course part II:
    * [notebook 26](https://github.com/fastai/course22p2/blob/master/nbs/26_diffusion_unet.ipynb)
    * [lesson 19](https://course.fast.ai/Lessons/lesson19.html)

## Setup

In [None]:
# TODO: include all digits
# TODO: include the digit itself (beyond the image) as information to pass to the model (especially for sampling)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import typing as T
from pathlib import Path

import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torchinfo
import tqdm
from sklearn.datasets import fetch_openml
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader

import random_neural_net_models.convolution_lecun1990 as conv_lecun1990
import random_neural_net_models.data as rnnm_data
import random_neural_net_models.learner as rnnm_learner
import random_neural_net_models.losses as rnnm_losses
import random_neural_net_models.telemetry as telemetry
import random_neural_net_models.unet as unet
import random_neural_net_models.unet_with_noise as unet_with_noise
import random_neural_net_models.utils as utils

logger = utils.get_logger("nb")
sns.set_theme()

In [None]:
DO_OVERFITTING_ONLY = True

In [None]:
mnist = fetch_openml("mnist_784", version=1, cache=True, parser="auto")

Setting seeds

In [None]:
utils.make_deterministic(42)

Getting device

In [None]:
device = utils.get_device()
device

In [None]:
X = mnist["data"]
y = mnist["target"]
X.shape, y.shape

Selecting a few images to overfit on (limiting to the number 5)

In [None]:
n0 = 32
n1 = 1_000
is_5 = y == "5"
X0, y0 = X.loc[is_5].iloc[:n0], y.loc[is_5].iloc[:n0]
X1, y1 = X.loc[is_5].iloc[n0 : n1 + n0], y.loc[is_5].iloc[n0 : n0 + n1]
X0.shape, X1.shape

## Defining dataset and dataloader

In [None]:
ds_train = rnnm_data.MNISTDatasetWithLabels(
    X0, y0, one_hot=False, add_channel=False
)
ds_valid = rnnm_data.MNISTDatasetWithLabels(
    X1, y1, one_hot=False, add_channel=False
)

In [None]:
ds_train[0][0].dtype

In [None]:
img, label = ds_train[0]
plt.imshow(img, cmap="gray", origin="upper")
plt.title(f"Label: {label}")
plt.axis("off")
plt.tight_layout()

In [None]:
img, label = ds_valid[0]
plt.imshow(img, cmap="gray", origin="upper")
plt.title(f"Label: {label}")
plt.axis("off")
plt.tight_layout()

applying noise based on 
```python
def noisify(x0):
    device = x0.device
    sig = (torch.randn([len(x0)])*1.2-1.2).exp().to(x0).reshape(-1,1,1,1)
    noise = torch.randn_like(x0, device=device)
    c_skip,c_out,c_in = scalings(sig)
    noised_input = x0 + noise*sig
    target = (x0-c_skip*noised_input)/c_out
    return (noised_input*c_in,sig.squeeze()),target
```
from https://github.com/fastai/course22p2/blob/master/nbs/26_diffusion_unet.ipynb

defining a dataloader

In [None]:
batch_size = n0
dl_train = DataLoader(
    ds_train,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=unet_with_noise.mnist_noisy_collate_train,
)
dl_valid = DataLoader(
    ds_valid,
    batch_size=500,
    shuffle=False,
    collate_fn=unet_with_noise.mnist_noisy_collate_train,
)

inspecting noise levels

In [None]:
noise_levels = []
for _ in range(10):
    for b in dl_train:
        noise_levels.append(b.noise_level.detach())


fig, ax = plt.subplots(figsize=(7, 3))
sns.histplot(torch.concat(noise_levels).numpy(), ax=ax)
ax.set(xlabel="Noise level", ylabel="Count")
plt.tight_layout()

inspecting the noisified images

In [None]:
b: unet_with_noise.MNISTNoisyDataTrain = next(iter(dl_train))

In [None]:
ix_img = 0
noisy_input_image = b.noisy_image[ix_img].cpu()
target_noise = b.target_noise[ix_img].cpu()
noise_level = b.noise_level[ix_img].cpu()
noisy_input_image.shape, target_noise.shape, noise_level.shape

In [None]:
c_skip, c_out, c_in = unet_with_noise.get_cs(noise_level)
denoised_image = unet_with_noise.get_denoised_images(
    noisy_input_image, target_noise, noise_level
)

print(f"noise level: {noise_levels[ix_img]}")

fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(10, 7))
ax = axs[0]
ax.imshow(noisy_input_image, cmap="gray")
ax.set_title("Noisy input image")
ax.axis("off")
ax = axs[1]
ax.imshow(target_noise, cmap="gray")
ax.set_title("Target noise")
ax.axis("off")
ax = axs[2]
ax.imshow(denoised_image, cmap="gray")
ax.set_title("Denoised image")
ax.axis("off")
plt.show()

In [None]:
unet_with_noise.compare_input_noise_and_denoised_image(
    noisy_input_image, target_noise, denoised_image
)

## overfitting digit 5

defining the model

In [None]:
model = unet_with_noise.UNetModelTensordict(
    in_channels=1,
    out_channels=1,
    list_num_features=(
        8,
        16,
    ),
    num_layers=2,
)

In [None]:
n_epochs = 10_000
lr = 1e-1
optimizer = optim.SGD(model.parameters(), lr=lr)

loss = unet_with_noise.MSELossMNISTNoisy()
save_dir = Path("./models")

loss_callback = rnnm_learner.TrainLossCallback()
activations_callback = rnnm_learner.TrainActivationsCallback(
    every_n=100,
    max_depth_search=4,
    name_patterns=(".*act.*",),
)
gradients_callback = rnnm_learner.TrainGradientsCallback(
    every_n=100,
    max_depth_search=4,
    name_patterns=(r".*conv\d", r".*convs\.[25]$", r".*idconv$"),
)
parameters_callback = rnnm_learner.TrainParametersCallback(
    every_n=100,
    max_depth_search=4,
    name_patterns=(r".*conv\d", r".*convs\.[25]$", r".*idconv$"),
)


callbacks = [
    loss_callback,
    activations_callback,
    gradients_callback,
    parameters_callback,
]

learner = rnnm_learner.Learner(
    model,
    optimizer,
    loss,
    callbacks=callbacks,
    save_dir=save_dir,
    device=device,
)

In [None]:
lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 10, 100)

learner.find_learning_rate(
    dl_train, n_epochs=200, lr_find_callback=lr_find_callback
)

In [None]:
lr_find_callback.plot(yscale="log")

In [None]:
lr = 0.1
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=lr,
    epochs=n_epochs,
    steps_per_epoch=len(dl_train),
)
scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)
learner.update_callback(scheduler_callback)

In [None]:
learner.fit(
    dl_train,
    n_epochs=n_epochs,
    # dataloader_valid=dl_valid
)

In [None]:
loss_callback.plot()

inspecting predictions

In [None]:
preds = learner.predict(dl_train)
preds[0, :5, :5]

In [None]:
preds = preds.detach().cpu()  # .numpy()
preds[0, :3, :5]

In [None]:
input_images, target_noises, noise_levels = [], [], []

for b in dl_train:
    input_images.append(b.noisy_image)
    target_noises.append(b.target_noise)
    noise_levels.append(b.noise_level)

input_images = torch.cat(input_images)
target_noises = torch.cat(target_noises)
noise_levels = torch.cat(noise_levels)
input_images.shape, target_noises.shape, noise_levels.shape

In [None]:
ix_img = 1
noisy_input_image = input_images[ix_img].cpu()
pred_noise = preds[ix_img]
target_noise = target_noises[ix_img].cpu()
noise_level = noise_levels[ix_img].cpu()

denoised_image = unet_with_noise.get_denoised_images(
    noisy_input_image, target_noise, noise_level
)
pred_denoised_image = unet_with_noise.get_denoised_images(
    noisy_input_image, pred_noise, noise_level
)

print(f"noise level: {noise_levels[ix_img]}")

unet_with_noise.compare_input_noise_and_denoised_image(
    noisy_input_image,
    target_noise,
    denoised_image,
    title=f"Target noise (noise level: {noise_levels[ix_img]:.4f})",
)
unet_with_noise.compare_input_noise_and_denoised_image(
    noisy_input_image,
    pred_noise,
    pred_denoised_image,
    title=f"Predicted noise (noise level: {noise_levels[ix_img]:.4f})",
)

## overfitting digit 5 - including the noise level as input

visualizing the noise embedding

In [None]:
noise = torch.linspace(-10, 10, 100)
emb = unet_with_noise.get_noise_level_embedding(noise, 8 * 4, max_period=1000)
print(emb.T.shape)
plt.imshow(emb.T)
plt.xlabel("Noise level")
plt.ylabel("Embedding")
plt.grid(False)
plt.tight_layout()

defining the model

In [None]:
model = unet_with_noise.NoisyUNetModelTensordict(
    in_channels=1,
    out_channels=1,
    list_num_features=(
        8,
        16,
    ),
    num_layers=2,
)

In [None]:
n_epochs = 2_000
lr = 1e-1
optimizer = optim.SGD(model.parameters(), lr=lr)  # , momentum=1e-3

loss = unet_with_noise.MSELossMNISTNoisy()
save_dir = Path("./models")

loss_callback = rnnm_learner.TrainLossCallback()
activations_callback = rnnm_learner.TrainActivationsCallback(
    every_n=100,
    max_depth_search=4,
    name_patterns=(".*act.*",),
)
gradients_callback = rnnm_learner.TrainGradientsCallback(
    every_n=100,
    max_depth_search=4,
    name_patterns=(r".*conv\d", r".*convs\.[25]$", r".*idconv$"),
)
parameters_callback = rnnm_learner.TrainParametersCallback(
    every_n=100,
    max_depth_search=4,
    name_patterns=(r".*conv\d", r".*convs\.[25]$", r".*idconv$"),
)

callbacks = [
    loss_callback,
    activations_callback,
    gradients_callback,
    parameters_callback,
]

learner = rnnm_learner.Learner(
    model,
    optimizer,
    loss,
    callbacks=callbacks,
    save_dir=save_dir,
    device=device,
)

In [None]:
lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 10, 100)

learner.find_learning_rate(
    dl_train, n_epochs=200, lr_find_callback=lr_find_callback
)

In [None]:
lr_find_callback.plot(yscale="log")

In [None]:
lr = 0.1
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=lr,
    epochs=n_epochs,
    steps_per_epoch=len(dl_train),
)
scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)
learner.update_callback(scheduler_callback)

In [None]:
learner.fit(dl_train, n_epochs=n_epochs)

In [None]:
loss_callback.plot()

In [None]:
preds = learner.predict(dl_train)
preds[0, :5, :5]

In [None]:
preds = preds.detach().cpu()  # .numpy()
preds[0, :3, :5]

In [None]:
input_images, target_noises, noise_levels = [], [], []

for b in dl_train:
    input_images.append(b.noisy_image)
    target_noises.append(b.target_noise)
    noise_levels.append(b.noise_level)

input_images = torch.cat(input_images)
target_noises = torch.cat(target_noises)
noise_levels = torch.cat(noise_levels)
input_images.shape, target_noises.shape, noise_levels.shape

In [None]:
ix_img = 1
noisy_input_image = input_images[ix_img].cpu()
pred_noise = preds[ix_img]
target_noise = target_noises[ix_img].cpu()
noise_level = noise_levels[ix_img].cpu()

denoised_image = unet_with_noise.get_denoised_images(
    noisy_input_image, target_noise, noise_level
)
pred_denoised_image = unet_with_noise.get_denoised_images(
    noisy_input_image, pred_noise, noise_level
)

print(f"noise level: {noise_levels[ix_img]}")

unet_with_noise.compare_input_noise_and_denoised_image(
    noisy_input_image,
    target_noise,
    denoised_image,
    title=f"Target noise (noise level: {noise_levels[ix_img]:.4f})",
)
unet_with_noise.compare_input_noise_and_denoised_image(
    noisy_input_image,
    pred_noise,
    pred_denoised_image,
    title=f"Predicted noise (noise level: {noise_levels[ix_img]:.4f})",
)

noise levels based on
```python
def sigmas_karras(n, sigma_min=0.01, sigma_max=80., rho=7.):
    ramp = torch.linspace(0, 1, n)
    min_inv_rho = sigma_min**(1/rho)
    max_inv_rho = sigma_max**(1/rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho-max_inv_rho))**rho
    return torch.cat([sigmas, tensor([0.])]).cuda()
```

In [None]:
def get_noise_level(
    n: int, max_noise_level: float = 15.0, d: float = 2.5
) -> torch.Tensor:
    return torch.tensor([max_noise_level / (d**i) for i in range(n)])


max_noise_level = 2.0
n_levels = 20
rho = 7.0
d = 3.0
noise_levels = get_noise_level(n_levels, max_noise_level=max_noise_level, d=d)

sns.scatterplot(x=range(len(noise_levels)), y=noise_levels);

In [None]:
n_samples = 5
generative_sig = torch.tensor([max_noise_level for _ in range(n_samples)])
sampled_noise = unet_with_noise.draw_img_noise_given_noise_level(
    generative_sig.reshape(-1, 1, 1),
    images_shape=(generative_sig.shape[0], 28, 28),
)
_, _, c_in = unet_with_noise.get_cs(generative_sig.reshape(-1, 1, 1))
sampled_noise.shape

In [None]:
fig, axs = plt.subplots(figsize=(7, 3), nrows=2)
ax = axs[0]
sns.histplot(x=sampled_noise.flatten(), ax=ax)
ax.set(xlabel="unscaled Pixel value", ylabel="Count")
ax = axs[1]
sns.histplot(x=(sampled_noise * c_in).flatten(), ax=ax)
ax.set(xlabel="scaled Pixel value", ylabel="Count")
plt.tight_layout()

denoising based on 
```python
def denoise(model, x, sig):
    sig = sig[None]
    c_skip,c_out,c_in = scalings(sig)
    return model((x*c_in, sig))*c_out + x*c_skip
    
def sample_lms(model, steps=100, order=4, sigma_max=80.):
    preds = []
    x = torch.randn(sz).cuda()*sigma_max
    sigs = sigmas_karras(steps, sigma_max=sigma_max)
    ds = []
    for i in progress_bar(range(len(sigs)-1)):
        sig = sigs[i]
        denoised = denoise(model, x, sig)
        d = (x-denoised)/sig
        ds.append(d)
        if len(ds) > order: ds.pop(0)
        cur_order = min(i+1, order)
        coeffs = [linear_multistep_coeff(cur_order, sigs, i, j) for j in range(cur_order)]
        x = x + sum(coeff*d for coeff, d in zip(coeffs, reversed(ds)))
        preds.append(x)
    return preds
```

In [None]:
noise_preds, denoised_preds = unet_with_noise.denoise_with_model(
    learner, sampled_noise.float(), noise_levels
)

In [None]:
len(noise_levels)

In [None]:
ix_denoise = 19

for ix_img in range(n_samples):
    if ix_denoise == 0:
        noisy_input_image = sampled_noise[ix_img].cpu() * c_in[0]
    else:
        noisy_input_image = denoised_preds[ix_denoise - 1][ix_img].cpu()
    predicted_noise = noise_preds[ix_denoise][ix_img].cpu()
    denoised_image = denoised_preds[ix_denoise][ix_img].cpu()

    noise_level = noise_levels[ix_denoise].cpu()
    unet_with_noise.compare_input_noise_and_denoised_image(
        noisy_input_image,
        predicted_noise,
        denoised_image,
        title=f"{ix_img=}: Predicted noise (noise level: {noise_level:.4f} (sigma_max: {max_noise_level:.4f}))",
    )

In [None]:
# TODO: implement ResBlocks with attention as in https://github.com/fastai/course22p2/blob/master/nbs/28_diffusion-attn-cond.ipynb