# Minimal DDPM Denoising Example

This notebook shows how to train and test a Diffusion Model (DDPM) using [lucidrains’ denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch) for image denoising.

- We use a *fake* dataset (random images) just for a quick demo.
- We do a very short training, then demonstrate *partial inference* to remove noise.

In [None]:
!pip install denoising_diffusion_pytorch torchvision

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt
import numpy as np

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', device)

## 1. Create a tiny (fake) dataset
We'll use `torchvision.datasets.FakeData` just for illustration.
You can replace this with your *real* image dataset.

In [None]:
# Our images will be 3-channel (RGB-like), 64x64
image_size = 64
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor()
])

train_data = datasets.FakeData(
    size=256,  # just 256 random images for quick demo
    image_size=(3, image_size, image_size),
    transform=transform
)

train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
len(train_loader)

## 2. Define U-Net + DDPM
We set fewer timesteps (50) for faster experimentation. You’ll likely want ~1000 for real applications.

We also define a small U-Net for demonstration.

In [None]:
model = Unet(
    dim=32,                 # base channel dimension (tiny for speed)
    dim_mults=(1, 2),       # how it expands at deeper layers
    channels=3              # RGB
).to(device)

diffusion = GaussianDiffusion(
    model,
    image_size=image_size,
    timesteps=50,           # fewer steps for quick demo
    sampling_timesteps=50,  # same as timesteps here
    loss_type='l2'          # can be 'l1', 'l2', or 'huber'
).to(device)

# We won't use the built-in Trainer class here, so we can demonstrate the custom partial inference later.
# But you could also do: trainer = Trainer(diffusion, train_data, ...)
print('Model and diffusion created.')

## 3. Short Training Loop
We’ll do just a few hundred steps so we can finish quickly. *Real training would require many thousands of steps.*

In [None]:
optimizer = torch.optim.Adam(diffusion.model.parameters(), lr=1e-4)

num_training_steps = 300  # Just 300 iterations!
step = 0

diffusion.train()
while step < num_training_steps:
    for batch, _ in train_loader:
        batch = batch.to(device)
        loss = diffusion(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        step += 1
        if step % 50 == 0:
            print(f"Step {step} / {num_training_steps}, Loss: {loss.item():.4f}")
        if step >= num_training_steps:
            break

print('Training done!')

## 4. Demonstration of Partial Inference (Denoising)
We’ll create a random image from our dataset, manually add noise to it, and pretend it corresponds to some middle diffusion step. Then we’ll call a simple *custom partial denoising function* to show how you might do real-world denoising.

In [None]:
def show_tensor_image(tensor_img, title=""):
    # tensor_img: (C, H, W)
    img_np = tensor_img.permute(1,2,0).detach().cpu().numpy()
    img_np = np.clip(img_np, 0, 1)
    plt.imshow(img_np)
    plt.axis('off')
    plt.title(title)
    plt.show()

In [None]:
from denoising_diffusion_pytorch.denoising_diffusion_pytorch import extract

@torch.no_grad()
def partial_denoise(diffusion_model, x_t, t_start):
    """
    x_t: a noisy image at diffusion step t_start (Tensor, shape [B, C, H, W])
    t_start: integer step where 0 <= t_start < timesteps
    """
    model = diffusion_model.model
    betas = diffusion_model.betas
    alphas = 1. - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    posterior_variance = diffusion_model.posterior_variance

    model.eval()
    device = x_t.device
    for i in reversed(range(t_start+1)):
        t_tensor = torch.tensor([i], device=device, dtype=torch.long).expand(x_t.shape[0])
        # Predict noise
        pred_noise = model(x_t, t_tensor)

        beta_t = extract(betas, t_tensor, x_t.shape)
        alpha_t = 1. - beta_t
        alpha_bar_t = extract(alphas_cumprod, t_tensor, x_t.shape)

        # Estimate x_0 via the standard DDPM formula
        sqrt_recip_alpha_t = 1. / torch.sqrt(alpha_t)
        x_0_pred = sqrt_recip_alpha_t * x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * pred_noise

        if i > 0:
            # Compute the mean of x_{t-1}
            alpha_bar_t1 = extract(alphas_cumprod, t_tensor-1, x_t.shape)
            mean_pred = ( torch.sqrt(alpha_bar_t1) * x_0_pred
                         + torch.sqrt(1 - alpha_bar_t1) * pred_noise )

            # Add random noise based on posterior variance for sampling
            posterior_var_t = extract(posterior_variance, t_tensor, x_t.shape)
            noise = torch.randn_like(x_t)
            x_t = mean_pred + torch.sqrt(posterior_var_t) * noise
        else:
            # At step 0, no more noise needed
            x_t = x_0_pred

    return x_t


### 4.1 Create a "Noisy" Image
Pick a real (random) image from the dataset, artificially diffuse it to step ~25 (out of 50), and then see if partial denoising recovers a clean image.

In [None]:
# Grab one sample from the dataset
sample_img, _ = next(iter(train_loader))
sample_img = sample_img[0:1].to(device)  # take the first image in the batch

# We'll pretend we are at step t=25 (the midpoint)
t_start = 25

# Use the library's internal method to generate x_t from x_0
with torch.no_grad():
    noisy_img = diffusion.q_sample(sample_img, t=torch.tensor([t_start], device=device))

print('Original clean image:')
show_tensor_image(sample_img[0], title="Clean (x_0)")
print('Noisy image at step t=25:')
show_tensor_image(noisy_img[0], title=f"x_{t_start}")

### 4.2 Denoise from Step 25 → 0
Now we apply our `partial_denoise` function.

In [None]:
with torch.no_grad():
    denoised_img = partial_denoise(diffusion, noisy_img, t_start=t_start)

print('Denoised image (x_0 predicted):')
show_tensor_image(denoised_img[0], title="Denoised x_0_pred")

You’ll likely see that the denoised image is still quite *blobby*, because we trained for only a few steps on random data. With a real dataset and longer training, you’ll see a much clearer recovery of details!

## Next Steps
- Replace `FakeData` with a *real* image folder or custom dataset.
- Increase `timesteps` (e.g. 1000) for better results.
- Train for many more iterations (tens or hundreds of thousands) until convergence.
- Adjust the U-Net architecture (bigger `dim`, more `dim_mults`) for higher-quality denoising.
- Use advanced sampling methods (like DDIM or DPM-Solver) for faster inference.
