# Stable Diffusion from Scratch - Demo

This notebook demonstrates the key components of our Stable Diffusion implementation:
1. Noise scheduler visualization
2. VAE encoding/decoding
3. Text encoding with CLIP
4. Diffusion sampling process
5. Full text-to-image pipeline

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Noise Scheduler Visualization

The noise scheduler controls how noise is added during the forward diffusion process
and removed during reverse diffusion (sampling).

In [None]:
from sd.schedulers.ddpm import DDPMScheduler
from sd.schedulers.ddim import DDIMScheduler

# Create schedulers with different beta schedules
scheduler_linear = DDPMScheduler(
    num_train_timesteps=1000,
    beta_start=0.0001,
    beta_end=0.02,
    beta_schedule='linear'
)

scheduler_scaled = DDPMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule='scaled_linear'
)

# Plot beta schedules
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Betas
axes[0].plot(scheduler_linear.betas.numpy(), label='Linear')
axes[0].plot(scheduler_scaled.betas.numpy(), label='Scaled Linear')
axes[0].set_xlabel('Timestep')
axes[0].set_ylabel('Beta')
axes[0].set_title('Beta Schedule')
axes[0].legend()

# Alpha cumulative product (signal preservation)
axes[1].plot(scheduler_linear.alphas_cumprod.numpy(), label='Linear')
axes[1].plot(scheduler_scaled.alphas_cumprod.numpy(), label='Scaled Linear')
axes[1].set_xlabel('Timestep')
axes[1].set_ylabel('Alpha Cumprod')
axes[1].set_title('Signal Preservation')
axes[1].legend()

# SNR (Signal-to-Noise Ratio)
snr_linear = scheduler_linear.alphas_cumprod / (1 - scheduler_linear.alphas_cumprod)
snr_scaled = scheduler_scaled.alphas_cumprod / (1 - scheduler_scaled.alphas_cumprod)
axes[2].semilogy(snr_linear.numpy(), label='Linear')
axes[2].semilogy(snr_scaled.numpy(), label='Scaled Linear')
axes[2].set_xlabel('Timestep')
axes[2].set_ylabel('SNR')
axes[2].set_title('Signal-to-Noise Ratio')
axes[2].legend()

plt.tight_layout()
plt.show()

## 2. Forward Diffusion Visualization

Visualize how an image is progressively noised during forward diffusion.

In [None]:
# Create a simple test image
def create_test_image(size=256):
    """Create a colorful test image."""
    x = np.linspace(-1, 1, size)
    y = np.linspace(-1, 1, size)
    X, Y = np.meshgrid(x, y)
    
    R = (np.sin(X * 3) + 1) / 2
    G = (np.sin(Y * 3) + 1) / 2
    B = (np.sin(X * Y * 3) + 1) / 2
    
    img = np.stack([R, G, B], axis=0)
    return torch.tensor(img, dtype=torch.float32).unsqueeze(0)  # (1, 3, H, W)

test_image = create_test_image(64).to(device)
# Normalize to [-1, 1]
test_image = test_image * 2 - 1

# Add noise at different timesteps
timesteps = [0, 100, 250, 500, 750, 999]
scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='scaled_linear')

fig, axes = plt.subplots(1, len(timesteps), figsize=(3*len(timesteps), 3))

for i, t in enumerate(timesteps):
    noise = torch.randn_like(test_image)
    noisy = scheduler.add_noise(test_image, noise, torch.tensor([t]))
    
    # Convert to displayable format
    img = (noisy[0].cpu().permute(1, 2, 0).numpy() + 1) / 2
    img = np.clip(img, 0, 1)
    
    axes[i].imshow(img)
    axes[i].set_title(f't={t}')
    axes[i].axis('off')

plt.suptitle('Forward Diffusion Process')
plt.tight_layout()
plt.show()

## 3. VAE Encoding and Decoding

The VAE compresses images to a lower-dimensional latent space.

In [None]:
from sd.models.vae import VAE

# Create a small VAE for demo
vae = VAE(
    in_channels=3,
    latent_channels=4,
    block_out_channels=(32, 64),
    layers_per_block=1,
).to(device)

print(f'VAE parameters: {sum(p.numel() for p in vae.parameters()):,}')

# Test encoding/decoding
with torch.no_grad():
    latents = vae.encode(test_image)
    reconstructed = vae.decode(latents)

print(f'Input shape: {test_image.shape}')
print(f'Latent shape: {latents.shape}')
print(f'Compression ratio: {test_image.numel() / latents.numel():.1f}x')

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

axes[0].imshow((test_image[0].cpu().permute(1, 2, 0).numpy() + 1) / 2)
axes[0].set_title('Original')
axes[0].axis('off')

# Show latent channels
latent_vis = latents[0].cpu().numpy()
latent_grid = np.concatenate([latent_vis[i] for i in range(4)], axis=1)
axes[1].imshow(latent_grid, cmap='viridis')
axes[1].set_title('Latent (4 channels)')
axes[1].axis('off')

recon = (reconstructed[0].cpu().permute(1, 2, 0).numpy() + 1) / 2
axes[2].imshow(np.clip(recon, 0, 1))
axes[2].set_title('Reconstructed')
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 4. U-Net Architecture

The U-Net predicts noise from noisy latents, conditioned on timestep and text.

In [None]:
from sd.models.unet import UNet, get_timestep_embedding

# Create a small U-Net for demo
unet = UNet(
    in_channels=4,
    out_channels=4,
    model_channels=64,
    num_res_blocks=1,
    attention_resolutions=(4,),
    channel_mult=(1, 2),
    num_heads=4,
    use_text_conditioning=True,
    context_dim=64,
).to(device)

print(f'U-Net parameters: {sum(p.numel() for p in unet.parameters()):,}')

# Test forward pass
batch_size = 2
latent_size = 8
context_dim = 64
seq_len = 10

noisy_latents = torch.randn(batch_size, 4, latent_size, latent_size, device=device)
timesteps = torch.randint(0, 1000, (batch_size,), device=device)
context = torch.randn(batch_size, seq_len, context_dim, device=device)

with torch.no_grad():
    noise_pred = unet(noisy_latents, timesteps, context=context)

print(f'Input shape: {noisy_latents.shape}')
print(f'Output shape: {noise_pred.shape}')
print(f'Timesteps: {timesteps.tolist()}')

## 5. Timestep Embedding Visualization

Sinusoidal embeddings encode timestep information.

In [None]:
# Visualize timestep embeddings
timesteps = torch.arange(0, 1000, 10)
embeddings = get_timestep_embedding(timesteps, 64)

plt.figure(figsize=(12, 6))
plt.imshow(embeddings.numpy().T, aspect='auto', cmap='RdBu')
plt.colorbar(label='Value')
plt.xlabel('Timestep')
plt.ylabel('Embedding Dimension')
plt.title('Sinusoidal Timestep Embeddings')
plt.show()

# Show that nearby timesteps have similar embeddings
from sklearn.metrics.pairwise import cosine_similarity

sim_matrix = cosine_similarity(embeddings.numpy())

plt.figure(figsize=(8, 6))
plt.imshow(sim_matrix, cmap='viridis')
plt.colorbar(label='Cosine Similarity')
plt.xlabel('Timestep')
plt.ylabel('Timestep')
plt.title('Embedding Similarity Matrix')
plt.show()

## 6. DDPM vs DDIM Sampling

Compare stochastic (DDPM) and deterministic (DDIM) sampling.

In [None]:
# Compare sampling trajectories
ddpm = DDPMScheduler(num_train_timesteps=1000)
ddim = DDIMScheduler(num_train_timesteps=1000)

# Set different numbers of inference steps
steps_list = [10, 25, 50, 100]

fig, axes = plt.subplots(1, len(steps_list), figsize=(4*len(steps_list), 3))

for i, num_steps in enumerate(steps_list):
    ddpm.set_timesteps(num_steps)
    ddim.set_timesteps(num_steps)
    
    axes[i].plot(ddpm.timesteps.numpy(), label='DDPM', marker='o', markersize=3)
    axes[i].plot(ddim.timesteps.numpy(), label='DDIM', marker='s', markersize=3)
    axes[i].set_xlabel('Step')
    axes[i].set_ylabel('Timestep')
    axes[i].set_title(f'{num_steps} steps')
    axes[i].legend()

plt.suptitle('Sampling Timesteps')
plt.tight_layout()
plt.show()

## 7. Classifier-Free Guidance

CFG combines conditional and unconditional predictions to improve sample quality.

In [None]:
from sd.guidance.cfg import classifier_free_guidance

# Simulate noise predictions
noise_pred_cond = torch.randn(1, 4, 8, 8)
noise_pred_uncond = torch.randn(1, 4, 8, 8)

# Apply CFG with different guidance scales
guidance_scales = [1.0, 3.0, 7.5, 15.0]

fig, axes = plt.subplots(2, len(guidance_scales), figsize=(3*len(guidance_scales), 6))

for i, scale in enumerate(guidance_scales):
    guided = classifier_free_guidance(noise_pred_cond, noise_pred_uncond, scale)
    
    # Show one channel of conditional prediction
    axes[0, i].imshow(noise_pred_cond[0, 0].numpy(), cmap='RdBu')
    axes[0, i].set_title(f'Conditional')
    axes[0, i].axis('off')
    
    # Show guided prediction
    axes[1, i].imshow(guided[0, 0].numpy(), cmap='RdBu')
    axes[1, i].set_title(f'CFG scale={scale}')
    axes[1, i].axis('off')

plt.suptitle('Classifier-Free Guidance Effect')
plt.tight_layout()
plt.show()

# Show how guidance scale affects magnitude
scales = np.linspace(1, 20, 100)
magnitudes = []

for s in scales:
    guided = classifier_free_guidance(noise_pred_cond, noise_pred_uncond, s)
    magnitudes.append(guided.abs().mean().item())

plt.figure(figsize=(8, 4))
plt.plot(scales, magnitudes)
plt.axvline(7.5, color='r', linestyle='--', label='Typical scale (7.5)')
plt.xlabel('Guidance Scale')
plt.ylabel('Mean Absolute Value')
plt.title('CFG Magnitude vs Guidance Scale')
plt.legend()
plt.grid(True)
plt.show()

## Summary

This demo covered the key components of Stable Diffusion:

1. **Noise Scheduler**: Controls the diffusion process with beta schedules
2. **VAE**: Compresses images to latent space (8x spatial reduction)
3. **U-Net**: Predicts noise conditioned on timestep and text
4. **Timestep Embedding**: Sinusoidal encoding of diffusion timestep
5. **DDPM/DDIM**: Different sampling strategies (stochastic vs deterministic)
6. **Classifier-Free Guidance**: Improves sample quality by combining predictions

For full text-to-image generation, see the inference scripts in `scripts/`.