## Import Required Libraries

In [None]:
import time
from typing import List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm


## Solution 1: Linear Noise Schedule

In [None]:
class LinearSchedule:
    """
    Linear noise schedule implementation.

    Spaces timesteps uniformly from 0 to T.
    Simple and fast, but may not match optimal noise progression.
    """

    def __init__(self, noise_scheduler, num_inference_steps: int = 50):
        """Initialize linear schedule with uniform timestep spacing."""
        self.noise_scheduler = noise_scheduler
        self.num_steps = num_inference_steps

        # Create linearly spaced timesteps (high to low)
        timesteps = torch.linspace(
            noise_scheduler.num_timesteps - 1,  # Start at T
            0,  # End at 0
            num_inference_steps,
            dtype=torch.long,
        )
        self.timesteps = timesteps.long()

        # Pre-compute variances for each step
        alphas_cumprod_prev = torch.cat(
            [torch.ones(1), noise_scheduler.alphas_cumprod[:-1]]
        )

        self.posterior_variance = (
            noise_scheduler.betas
            * (1.0 - alphas_cumprod_prev)
            / (1.0 - noise_scheduler.alphas_cumprod)
        )
        self.posterior_log_variance_clipped = torch.log(
            self.posterior_variance.clamp(min=1e-20)
        )

        # Coefficients for posterior mean
        self.posterior_mean_coef1 = (
            noise_scheduler.betas
            * torch.sqrt(alphas_cumprod_prev)
            / (1.0 - noise_scheduler.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - alphas_cumprod_prev)
            * torch.sqrt(noise_scheduler.alphas)
            / (1.0 - noise_scheduler.alphas_cumprod)
        )

    def get_timestep(self, step: int) -> int:
        """Get timestep for given step number."""
        return self.timesteps[step]


## Solution 2: Cosine Noise Schedule

In [None]:
class CosineSchedule:
    """
    Cosine noise schedule implementation.

    Uses cosine function to create non-linear timestep spacing.
    Focuses more computational effort on high-noise (difficult) steps.
    Often produces better quality with fewer total steps.
    """

    def __init__(self, noise_scheduler, num_inference_steps: int = 50):
        """Initialize cosine schedule with non-linear timestep spacing."""
        self.noise_scheduler = noise_scheduler
        self.num_steps = num_inference_steps

        # Create cosine-spaced timesteps
        s = 0.008  # Offset parameter
        steps = torch.arange(num_inference_steps + 1, dtype=torch.float32)

        # Cosine schedule formula
        alphas_cumprod = (
            torch.cos((steps / num_inference_steps + s) / (1 + s) * torch.pi * 0.5) ** 2
        )
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]

        # Find nearest actual timesteps for each cosine point
        timesteps = []
        for i in range(num_inference_steps):
            # Target alpha at this step
            target_alpha = alphas_cumprod[i + 1]
            # Find closest timestep with this alpha
            t = torch.argmin(torch.abs(noise_scheduler.alphas_cumprod - target_alpha))
            timesteps.append(t)

        # Reverse (high noise to low noise)
        self.timesteps = torch.tensor(timesteps[::-1], dtype=torch.long)

        # Pre-compute variances (same as linear)
        alphas_cumprod_prev = torch.cat(
            [torch.ones(1), noise_scheduler.alphas_cumprod[:-1]]
        )

        self.posterior_variance = (
            noise_scheduler.betas
            * (1.0 - alphas_cumprod_prev)
            / (1.0 - noise_scheduler.alphas_cumprod)
        )
        self.posterior_log_variance_clipped = torch.log(
            self.posterior_variance.clamp(min=1e-20)
        )

        self.posterior_mean_coef1 = (
            noise_scheduler.betas
            * torch.sqrt(alphas_cumprod_prev)
            / (1.0 - noise_scheduler.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - alphas_cumprod_prev)
            * torch.sqrt(noise_scheduler.alphas)
            / (1.0 - noise_scheduler.alphas_cumprod)
        )

    def get_timestep(self, step: int) -> int:
        """Get timestep for given step number."""
        return self.timesteps[step]


## Solution 3: Reverse Diffusion Step

In [None]:
def reverse_diffusion_step(
    model_output: torch.Tensor,
    timestep: int,
    sample: torch.Tensor,
    noise_scheduler,
    schedule,
) -> torch.Tensor:
    """
    Single reverse diffusion step.

    Takes one step from x_t toward x_{t-1} using the model's noise prediction.

    Args:
        model_output: Predicted noise from U-Net (batch, 1, 28, 28)
        timestep: Current timestep t
        sample: Current noisy image x_t (batch, 1, 28, 28)
        noise_scheduler: Base scheduler with coefficients
        schedule: LinearSchedule or CosineSchedule

    Returns:
        Denoised sample x_{t-1}
    """
    # Get alpha coefficients
    alpha_t = noise_scheduler.alphas_cumprod[timestep]
    alpha_t_prev = (
        noise_scheduler.alphas_cumprod[timestep - 1] if timestep > 0 else torch.ones(1)
    )

    # Predicted original image
    pred_original_sample = (
        sample - torch.sqrt(1 - alpha_t) * model_output
    ) / torch.sqrt(alpha_t)

    # Posterior mean
    coef1 = schedule.posterior_mean_coef1[timestep]
    coef2 = schedule.posterior_mean_coef2[timestep]
    mean = coef1 * pred_original_sample + coef2 * sample

    # Add variance (stochastic sampling)
    variance = schedule.posterior_variance[timestep]
    if variance > 0:
        z = torch.randn_like(sample)
        sample = mean + torch.sqrt(variance) * z
    else:
        sample = mean

    return sample


## Solution 4: Complete Sampling Loop

In [None]:
def sample_with_schedule(
    model: nn.Module,
    noise_scheduler,
    schedule,
    batch_size: int = 4,
    device: torch.device = torch.device("cpu"),
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
    """
    Complete reverse diffusion sampling loop.

    Process:
    1. Initialize with pure Gaussian noise
    2. Iterate from high noise to low noise
    3. At each step:
       - Predict noise with model
       - Take reverse step
       - Store for visualization
    4. Return final images and trajectory

    Args:
        model: Trained U-Net denoiser
        noise_scheduler: Base scheduler
        schedule: LinearSchedule or CosineSchedule
        batch_size: Number of images to generate
        device: torch.device (cuda/mps/cpu)

    Returns:
        samples: Generated images (batch, 1, 28, 28)
        trajectory: List of intermediate samples
    """
    # Initialize with pure Gaussian noise
    sample = torch.randn(batch_size, 1, 28, 28, device=device)
    trajectory = []

    model.eval()
    with torch.no_grad():
        # Iterate through timesteps (high noise to low noise)
        schedule_name = "Linear" if isinstance(schedule, LinearSchedule) else "Cosine"

        for i, t in enumerate(
            tqdm(schedule.timesteps, desc=f"Sampling ({schedule_name})")
        ):
            # Prepare timestep tensor
            t_tensor = torch.full(
                (batch_size,),
                t.item() if isinstance(t, torch.Tensor) else t,
                device=device,
                dtype=torch.long,
            )

            # Predict noise with U-Net
            noise_pred = model(sample, t_tensor)

            # Take reverse diffusion step
            sample = reverse_diffusion_step(
                noise_pred, t, sample, noise_scheduler, schedule
            )

            # Store intermediate results
            if i % max(1, len(schedule.timesteps) // 5) == 0:
                trajectory.append(sample.cpu().clone())

    return sample, trajectory


## Solution 5: Fidelity Metrics

In [None]:
def compute_sample_variance(samples: torch.Tensor) -> float:
    """
    Compute variance of generated samples across batch.

    Low variance → all samples are similar (potential mode collapse)
    High variance → diverse samples (good exploration)

    Args:
        samples: (batch, channels, H, W)

    Returns:
        Variance value
    """
    # Flatten to (batch, -1)
    flat_samples = samples.reshape(samples.shape[0], -1)

    # Compute variance across batch
    sample_mean = flat_samples.mean(dim=0, keepdim=True)
    variance = ((flat_samples - sample_mean) ** 2).mean()

    return variance.item()


def compare_schedules(model, noise_scheduler, device, batch_size=16):
    """
    Compare Linear vs Cosine sampling schedules.
    """
    results = {}

    for schedule_class, schedule_name in [
        (LinearSchedule, "Linear"),
        (CosineSchedule, "Cosine"),
    ]:
        print(f"\nTesting {schedule_name} Schedule...")

        # Create schedule
        schedule = schedule_class(noise_scheduler, num_inference_steps=50)

        # Sample
        start_time = time.time()
        samples, trajectory = sample_with_schedule(
            model, noise_scheduler, schedule, batch_size, device
        )
        sampling_time = time.time() - start_time

        # Compute metrics
        variance = compute_sample_variance(samples)

        results[schedule_name] = {
            "samples": samples,
            "trajectory": trajectory,
            "time": sampling_time,
            "variance": variance,
        }

        print(f"  Time: {sampling_time:.2f}s")
        print(f"  Variance: {variance:.4f}")

    return results


## Solution 6: Visualization and Analysis

In [None]:
def visualize_sampling_comparison(results):
    """
    Visualize samples and metrics from both schedules.
    """
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))

    for idx, (schedule_name, data) in enumerate(results.items()):
        samples = (data["samples"][:8] + 1) / 2  # Denormalize

        for i, sample in enumerate(samples):
            ax = axes[idx, i]
            ax.imshow(sample.squeeze().numpy(), cmap="gray")
            ax.axis("off")

        # Add title on left
        axes[idx, 0].text(
            -1.5,
            0.5,
            schedule_name,
            transform=axes[idx, 0].transAxes,
            fontsize=12,
            fontweight="bold",
            ha="right",
            va="center",
        )

    plt.suptitle(
        "Sampling Schedule Comparison: Generated MNIST Digits",
        fontsize=12,
        fontweight="bold",
    )
    plt.tight_layout()
    plt.show()


def print_schedule_analysis(results):
    """
    Print detailed analysis of schedule comparison.
    """
    print("\n" + "=" * 80)
    print("SAMPLING SCHEDULE COMPARISON RESULTS")
    print("=" * 80)

    print("\n1. SPEED COMPARISON:")
    for name, data in results.items():
        print(f"   {name}: {data['time']:.3f} seconds")

    print("\n2. DIVERSITY (Sample Variance):")
    for name, data in results.items():
        print(f"   {name}: {data['variance']:.6f}")

    print("\n3. RECOMMENDATIONS:")
    print("   • Linear Schedule: Simple, predictable, good baseline")
    print("   • Cosine Schedule: Often higher quality, focuses on hard steps")
    print("   • Choice depends on quality vs speed tradeoff")
