# Lab 6: Conditional Generation - From Random to Controllable
**Course: Diffusion Models: Theory and Applications**  
**Duration: 90 minutes**  
**Team Size: 2 students (same teams from Labs 1-5)**

---

## Learning Objectives
By the end of this lab, students will be able to:
1. **Implement** class-conditional diffusion models with embedding injection
2. **Build** classifier guidance systems for external steering during sampling
3. **Create** classifier-free guidance (CFG) for modern conditional generation
4. **Construct** U-Net modifications for conditioning at multiple scales
5. **Analyze** trade-offs between conditioning approaches in terms of quality, speed, and flexibility
6. **Deploy** conditional generation systems for practical creative applications

---

## Lab Setup and Conditional Generation Framework

### Part 1: Team Setup & Conditional Generation Mission (10 minutes)

In [None]:
# Conditional generation implementation setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
from torch.distributions import Normal
from typing import Tuple, Dict, List, Optional, Union
import time
from dataclasses import dataclass
from collections import defaultdict

# Set seeds for reproducible conditional generation
torch.manual_seed(42)
np.random.seed(42)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Conditional generation experiments on: {device}")

@dataclass
class ConditionalConfig:
    """Configuration for conditional generation experiments"""
    T: int = 50
    beta_start: float = 1e-4
    beta_end: float = 2e-2
    img_size: int = 32
    channels: int = 1
    num_classes: int = 4
    dropout_prob: float = 0.1  # For classifier-free guidance training
    
    def __post_init__(self):
        # Compute noise schedule
        self.betas = torch.linspace(self.beta_start, self.beta_end, self.T).to(device)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat([torch.ones(1).to(device), self.alphas_cumprod[:-1]])

# Create conditional configuration
config = ConditionalConfig(T=50, img_size=8, channels=1, num_classes=4)
print(f"Conditional config: T={config.T}, classes={config.num_classes}, image_size={config.img_size}x{config.img_size}")

# Create synthetic dataset with class labels for conditioning experiments
def create_conditional_dataset(n_samples_per_class: int = 50) -> Tuple[torch.Tensor, torch.Tensor]:
    """Create 2D dataset with clear class structure for conditional generation"""
    all_data = []
    all_labels = []
    
    # Class 0: Circle in top-right
    t = torch.linspace(0, 2*math.pi, n_samples_per_class)
    x = 2 + 0.8 * torch.cos(t) + 0.1 * torch.randn(n_samples_per_class)
    y = 2 + 0.8 * torch.sin(t) + 0.1 * torch.randn(n_samples_per_class)
    all_data.append(torch.stack([x, y], dim=1))
    all_labels.extend([0] * n_samples_per_class)
    
    # Class 1: Square in top-left
    x = -2 + 1.6 * torch.rand(n_samples_per_class) - 0.8 + 0.1 * torch.randn(n_samples_per_class)
    y = 2 + 1.6 * torch.rand(n_samples_per_class) - 0.8 + 0.1 * torch.randn(n_samples_per_class)
    all_data.append(torch.stack([x, y], dim=1))
    all_labels.extend([1] * n_samples_per_class)
    
    # Class 2: Triangle in bottom-left
    for i in range(n_samples_per_class):
        # Random point in triangle
        r1, r2 = torch.rand(2)
        if r1 + r2 > 1:
            r1, r2 = 1 - r1, 1 - r2
        x = -2 + r1 * 1.6
        y = -2 + r2 * 1.6
        all_data.append(torch.tensor([[x + 0.1 * torch.randn(1), y + 0.1 * torch.randn(1)]]))
        all_labels.append(2)
    
    # Class 3: Line in bottom-right
    x = 1.2 + 1.6 * torch.rand(n_samples_per_class) + 0.1 * torch.randn(n_samples_per_class)
    y = -2.5 + 0.1 * x + 0.2 * torch.randn(n_samples_per_class)
    all_data.append(torch.stack([x, y], dim=1))
    all_labels.extend([3] * n_samples_per_class)
    
    # Combine all data
    if len(all_data) > 3:
        # Handle the triangle case
        triangle_data = torch.cat([item for item in all_data if len(item.shape) == 2 and item.shape[0] == 1])
        other_data = [item for item in all_data if len(item.shape) == 2 and item.shape[0] > 1]
        other_data.append(triangle_data)
        data = torch.cat(other_data, dim=0)
    else:
        data = torch.cat(all_data, dim=0)
    
    labels = torch.tensor(all_labels)
    
    return data.to(device), labels.to(device)

# Create dataset
train_data, train_labels = create_conditional_dataset(60)
print(f"Dataset shape: {train_data.shape}, Labels shape: {train_labels.shape}")

# Visualize dataset with class colors
plt.figure(figsize=(10, 8))
colors = ['red', 'blue', 'green', 'orange']
class_names = ['Circle', 'Square', 'Triangle', 'Line']

for class_id in range(config.num_classes):
    class_mask = train_labels == class_id
    class_data = train_data[class_mask]
    plt.scatter(class_data[:, 0].cpu(), class_data[:, 1].cpu(), 
               c=colors[class_id], label=f'Class {class_id}: {class_names[class_id]}', 
               alpha=0.7, s=50)

plt.title('Conditional Dataset: 4 Distinct Classes')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.show()

print("✓ Conditional dataset created with clear class structure")

---

## Part 2: Understanding Unconditional Limitations (10 minutes)

### Task 2.1: Demonstrate Unconditional Generation Problems

**Your Mission**: Show why unconditional generation is insufficient for practical applications.

In [None]:
class UnconditionalLimitations:
    """
    Demonstrate the fundamental limitations of unconditional diffusion models.
    This motivates the need for conditional generation approaches.
    """
    
    def __init__(self, unconditional_model: nn.Module, config: ConditionalConfig):
        self.model = unconditional_model
        self.config = config
        
    def demonstrate_random_generation_problem(self, n_samples: int = 100):
        """
        Show how unconditional generation produces random samples from the full distribution
        """
        print("=== Unconditional Generation Problem ===\n")
        
        # Simulate unconditional sampling (we'll use a simple approach for demonstration)
        print("Generating unconditional samples...")
        
        # Create a simple mixture of the training data for demonstration
        samples = []
        for _ in range(n_samples):
            # Randomly pick a data point and add some noise (simulating unconditional generation)
            idx = torch.randint(0, len(train_data), (1,))
            base_sample = train_data[idx]
            noise = 0.3 * torch.randn_like(base_sample)
            sample = base_sample + noise
            samples.append(sample)
        
        unconditional_samples = torch.cat(samples, dim=0)
        
        # Visualize the problem
        self.visualize_unconditional_problem(unconditional_samples)
        
        return unconditional_samples
    
    def visualize_unconditional_problem(self, unconditional_samples: torch.Tensor):
        """Visualize why unconditional generation is problematic"""
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Original data with classes
        colors = ['red', 'blue', 'green', 'orange']
        class_names = ['Circle', 'Square', 'Triangle', 'Line']
        
        for class_id in range(config.num_classes):
            class_mask = train_labels == class_id
            class_data = train_data[class_mask]
            axes[0].scatter(class_data[:, 0].cpu(), class_data[:, 1].cpu(), 
                           c=colors[class_id], label=f'{class_names[class_id]}', 
                           alpha=0.7, s=50)
        
        axes[0].set_title('Original Data\n(Clear Class Structure)')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        axes[0].axis('equal')
        
        # Unconditional samples (mixed)
        axes[1].scatter(unconditional_samples[:, 0].cpu(), unconditional_samples[:, 1].cpu(), 
                       c='gray', alpha=0.6, s=30)
        axes[1].set_title('Unconditional Generation\n(Random Mix)')
        axes[1].grid(True, alpha=0.3)
        axes[1].axis('equal')
        
        # What we want: conditional control
        axes[2].text(0.5, 0.7, 'WHAT WE WANT:', ha='center', va='center', 
                    fontsize=14, fontweight='bold', transform=axes[2].transAxes)
        axes[2].text(0.5, 0.5, '"Generate a Circle"', ha='center', va='center', 
                    fontsize=12, color='red', transform=axes[2].transAxes)
        axes[2].text(0.5, 0.4, '"Generate a Square"', ha='center', va='center', 
                    fontsize=12, color='blue', transform=axes[2].transAxes)
        axes[2].text(0.5, 0.3, '"Generate a Triangle"', ha='center', va='center', 
                    fontsize=12, color='green', transform=axes[2].transAxes)
        axes[2].text(0.5, 0.1, 'CONTROLLED GENERATION!', ha='center', va='center', 
                    fontsize=14, fontweight='bold', color='purple', transform=axes[2].transAxes)
        axes[2].set_title('Desired: Conditional Control')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        print("The Problem:")
        print("• Unconditional generation gives random samples from the full distribution")
        print("• No control over what specific content is generated")
        print("• Must generate many samples to find desired content")
        print("• Wastes computation and provides poor user experience")
        print("\nThe Solution: Replace p(x) with p(x|y) where y is our condition!")
    
    def analyze_generation_efficiency(self, target_class: int = 0, max_attempts: int = 50):
        """
        Analyze how many unconditional samples we need to get desired class
        """
        print(f"\n=== Efficiency Analysis: Finding Class {target_class} ===\n")
        
        # Simulate trying to find samples of target class through unconditional generation
        attempts = 0
        target_found = 0
        target_samples = []
        
        # Simple classifier to determine which class a sample belongs to
        def classify_sample(sample):
            """Simple distance-based classifier for our synthetic data"""
            # Class centers (approximate)
            centers = torch.tensor([
                [2.0, 2.0],   # Circle
                [-2.0, 2.0],  # Square  
                [-2.0, -2.0], # Triangle
                [2.0, -2.5]   # Line
            ]).to(device)
            
            distances = torch.norm(sample - centers, dim=1)
            return torch.argmin(distances).item()
        
        while attempts < max_attempts and target_found < 10:
            # Generate unconditional sample
            idx = torch.randint(0, len(train_data), (1,))
            base_sample = train_data[idx]
            noise = 0.3 * torch.randn_like(base_sample)
            sample = base_sample + noise
            
            attempts += 1
            
            # Check if it's the target class
            predicted_class = classify_sample(sample.squeeze())
            if predicted_class == target_class:
                target_found += 1
                target_samples.append(sample)
        
        success_rate = target_found / attempts
        print(f"Attempts: {attempts}")
        print(f"Target class samples found: {target_found}")
        print(f"Success rate: {success_rate:.2%}")
        print(f"Average attempts per target sample: {attempts/max(target_found, 1):.1f}")
        
        print(f"\n💡 Insight: Unconditional generation is inefficient!")
        print(f"   With conditional generation, we could generate target samples directly!")
        
        return success_rate

# Simple unconditional model simulation
class SimpleUnconditionalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(3, 64),  # 2D data + 1 time dimension
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 2)
        )
    
    def forward(self, x, t):
        t_embed = t.float().unsqueeze(-1) / config.T
        input_with_time = torch.cat([x, t_embed], dim=-1)
        return self.net(input_with_time)

# Test unconditional limitations
unconditional_model = SimpleUnconditionalModel().to(device)
limitations_demo = UnconditionalLimitations(unconditional_model, config)

# Demonstrate the problems
unconditional_samples = limitations_demo.demonstrate_random_generation_problem(n_samples=80)
efficiency = limitations_demo.analyze_generation_efficiency(target_class=0, max_attempts=30)

---

## Part 3: Class-Conditional Diffusion Implementation (25 minutes)

### Task 3.1: Implement Class-Conditional U-Net

**Your Mission**: Build the simplest form of conditional generation using class embeddings.

In [None]:
class ClassConditionalUNet(nn.Module):
    """
    Simple U-Net-style network with class conditioning.
    This demonstrates the basic approach to conditional generation.
    """
    
    def __init__(self, data_dim: int = 2, num_classes: int = 4, embed_dim: int = 64):
        super().__init__()
        self.data_dim = data_dim
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        
        # Class embedding table
        self.class_embedding = nn.Embedding(num_classes, embed_dim)
        
        # Time embedding (simple version)
        self.time_mlp = nn.Sequential(
            nn.Linear(1, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
        
        # Main network - simplified U-Net structure
        self.encoder = nn.Sequential(
            nn.Linear(data_dim + embed_dim, 128),  # data + combined embedding
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(128 + embed_dim, 128),  # features + combined embedding
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, data_dim)
        )
    
    def forward(self, x: torch.Tensor, class_labels: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with class conditioning
        
        Args:
            x: Noisy data [batch_size, data_dim]
            class_labels: Class labels [batch_size] or None for unconditional
            t: Timesteps [batch_size]
            
        Returns:
            Predicted noise [batch_size, data_dim]
        """
        batch_size = x.shape[0]
        
        # Handle timestep
        if t.dim() == 0:
            t = t.repeat(batch_size)
        t_embed = self.time_mlp(t.float().unsqueeze(-1) / config.T)
        
        # Handle class conditioning
        if class_labels is not None:
            class_embed = self.class_embedding(class_labels)
            # Combine time and class embeddings
            combined_embed = t_embed + class_embed
        else:
            # Unconditional case
            combined_embed = t_embed
        
        # Encoder path
        x_with_embed = torch.cat([x, combined_embed], dim=-1)
        encoded = self.encoder(x_with_embed)
        
        # Decoder path
        decoder_input = torch.cat([encoded, combined_embed], dim=-1)
        output = self.decoder(decoder_input)
        
        return output

class ClassConditionalTrainer:
    """
    Training system for class-conditional diffusion models.
    """
    
    def __init__(self, model: ClassConditionalUNet, config: ConditionalConfig):
        self.model = model
        self.config = config
        self.model.train()
        
    def add_noise(self, x_start: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """Add noise according to the forward process"""
        if noise is None:
            noise = torch.randn_like(x_start)
        
        alpha_cumprod_t = self.config.alphas_cumprod[t]
        if alpha_cumprod_t.dim() == 0:
            alpha_cumprod_t = alpha_cumprod_t.unsqueeze(0)
        if alpha_cumprod_t.shape[0] != x_start.shape[0]:
            alpha_cumprod_t = alpha_cumprod_t.repeat(x_start.shape[0])
        
        alpha_cumprod_t = alpha_cumprod_t.view(-1, 1)
        
        x_noisy = torch.sqrt(alpha_cumprod_t) * x_start + torch.sqrt(1 - alpha_cumprod_t) * noise
        return x_noisy, noise
    
    def training_step(self, x_batch: torch.Tensor, class_batch: torch.Tensor) -> Dict[str, float]:
        """
        TODO: Implement class-conditional training step
        
        Implement the class-conditional diffusion training objective:
        L = E[||ε - ε_θ(x_t, y, t)||²]
        
        Steps:
        1. Sample random timesteps t for each item in batch
        2. Sample noise ε from N(0, I)
        3. Create noisy data x_t using forward process
        4. Predict noise using model with class conditioning
        5. Compute MSE loss between predicted and actual noise
        6. Return loss and metrics
        
        Args:
            x_batch: Clean data batch [batch_size, data_dim]
            class_batch: Class labels [batch_size]
            
        Returns:
            Dictionary with 'loss' and other metrics
        """
        # TODO: Your implementation here
        # Hint: Use self.add_noise() and self.model()
        # Remember to sample timesteps uniformly from [0, T-1]
        pass
    
    def train_epoch(self, data: torch.Tensor, labels: torch.Tensor, 
                   batch_size: int = 32, lr: float = 1e-3) -> List[float]:
        """Train for one epoch"""
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        losses = []
        
        n_batches = len(data) // batch_size
        
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, len(data))
            
            x_batch = data[start_idx:end_idx]
            class_batch = labels[start_idx:end_idx]
            
            optimizer.zero_grad()
            
            try:
                result = self.training_step(x_batch, class_batch)
                if result and 'loss' in result:
                    loss = result['loss']
                    loss.backward()
                    optimizer.step()
                    losses.append(loss.item())
                else:
                    print(f"Batch {i}: Implement training_step() method")
                    break
            except Exception as e:
                print(f"Batch {i}: Error in training - {e}")
                break
        
        return losses

class ClassConditionalSampler:
    """
    Sampling system for class-conditional diffusion models.
    """
    
    def __init__(self, model: ClassConditionalUNet, config: ConditionalConfig):
        self.model = model
        self.config = config
        self.model.eval()
    
    def ddim_step(self, x_t: torch.Tensor, class_labels: torch.Tensor, t: int, s: int) -> torch.Tensor:
        """
        TODO: Implement class-conditional DDIM sampling step
        
        Perform one step of class-conditional DDIM sampling.
        
        Steps:
        1. Predict noise using the model with class conditioning
        2. Predict clean data x_0 from current state and predicted noise
        3. Compute the DDIM update: x_s = √ᾱ_s * x̂_0 + √(1-ᾱ_s) * ε̂
        4. Return the updated state
        
        Args:
            x_t: Current noisy state [batch_size, data_dim]
            class_labels: Target class labels [batch_size]
            t: Current timestep
            s: Target timestep (s < t)
            
        Returns:
            x_s: Updated state [batch_size, data_dim]
        """
        # TODO: Your implementation here
        # Hint: Use self.model() for noise prediction
        # Remember the DDIM formulas from Lab 5
        pass
    
    def sample_class_conditional(self, class_labels: torch.Tensor, num_steps: int = 20) -> torch.Tensor:
        """
        TODO: Implement complete class-conditional sampling
        
        Generate samples conditioned on specific class labels.
        
        Steps:
        1. Create timestep schedule (uniform spacing)
        2. Initialize x_T from pure noise
        3. For each timestep pair (t, s): apply ddim_step
        4. Return final samples
        
        Args:
            class_labels: Desired class labels [batch_size]
            num_steps: Number of sampling steps
            
        Returns:
            Generated samples [batch_size, data_dim]
        """
        # TODO: Your implementation here
        # Hint: Create uniform timestep schedule like in Lab 5
        # Use ddim_step() for each sampling step
        pass
    
    def create_timestep_schedule(self, num_steps: int) -> List[int]:
        """Create uniform timestep schedule"""
        step_size = self.config.T // num_steps
        timesteps = list(range(self.config.T - 1, -1, -step_size))
        timesteps.append(0)
        return timesteps[:num_steps + 1]
    
    def demonstrate_class_conditional_generation(self, samples_per_class: int = 15, num_steps: int = 20):
        """Demonstrate conditional generation for all classes"""
        print("=== Class-Conditional Generation Demo ===\n")
        
        all_samples = []
        all_class_labels = []
        
        for class_id in range(self.config.num_classes):
            print(f"Generating class {class_id} samples...")
            
            # Create class labels for this class
            class_labels = torch.full((samples_per_class,), class_id, dtype=torch.long).to(device)
            
            try:
                # Generate samples
                samples = self.sample_class_conditional(class_labels, num_steps)
                
                if samples is not None:
                    all_samples.append(samples)
                    all_class_labels.extend([class_id] * samples_per_class)
                    print(f"  ✓ Generated {samples_per_class} samples for class {class_id}")
                else:
                    print(f"  ❌ Implement sample_class_conditional() method")
                    return None, None
            except Exception as e:
                print(f"  ❌ Error generating class {class_id}: {e}")
                return None, None
        
        if all_samples:
            generated_samples = torch.cat(all_samples, dim=0)
            generated_labels = torch.tensor(all_class_labels)
            
            self.visualize_conditional_results(generated_samples, generated_labels)
            return generated_samples, generated_labels
        
        return None, None
    
    def visualize_conditional_results(self, samples: torch.Tensor, labels: torch.Tensor):
        """Visualize class-conditional generation results"""
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        colors = ['red', 'blue', 'green', 'orange']
        class_names = ['Circle', 'Square', 'Triangle', 'Line']
        
        # Original training data
        for class_id in range(self.config.num_classes):
            class_mask = train_labels == class_id
            class_data = train_data[class_mask]
            axes[0].scatter(class_data[:, 0].cpu(), class_data[:, 1].cpu(), 
                           c=colors[class_id], label=f'{class_names[class_id]}', 
                           alpha=0.7, s=50)
        
        axes[0].set_title('Original Training Data')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        axes[0].axis('equal')
        
        # Generated conditional samples
        for class_id in range(self.config.num_classes):
            class_mask = labels == class_id
            if class_mask.any():
                class_samples = samples[class_mask]
                axes[1].scatter(class_samples[:, 0].cpu(), class_samples[:, 1].cpu(), 
                               c=colors[class_id], label=f'{class_names[class_id]}', 
                               alpha=0.7, s=50)
        
        axes[1].set_title('Generated Conditional Samples')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        axes[1].axis('equal')
        
        # Individual class generations
        axes[2].text(0.5, 0.8, 'Class-Conditional Success!', ha='center', va='center', 
                    fontsize=14, fontweight='bold', transform=axes[2].transAxes)
        axes[2].text(0.5, 0.6, '✓ Generate specific classes on demand', ha='center', va='center', 
                    fontsize=11, transform=axes[2].transAxes)
        axes[2].text(0.5, 0.5, '✓ No random sampling required', ha='center', va='center', 
                    fontsize=11, transform=axes[2].transAxes)
        axes[2].text(0.5, 0.4, '✓ Direct control over generation', ha='center', va='center', 
                    fontsize=11, transform=axes[2].transAxes)
        axes[2].text(0.5, 0.2, 'Simple & Effective!', ha='center', va='center', 
                    fontsize=12, fontweight='bold', color='green', transform=axes[2].transAxes)
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.show()

# Create and test class-conditional model
print("Creating class-conditional model...")
class_conditional_model = ClassConditionalUNet(data_dim=2, num_classes=config.num_classes).to(device)

# Test training system
trainer = ClassConditionalTrainer(class_conditional_model, config)

# Quick training demonstration (just a few steps for this lab)
print("\nQuick training demonstration...")
losses = trainer.train_epoch(train_data, train_labels, batch_size=16, lr=1e-3)
if losses:
    print(f"Training losses: {losses[:5]}...")  # Show first few losses
    
    # Test sampling
    sampler = ClassConditionalSampler(class_conditional_model, config)
    generated_samples, generated_labels = sampler.demonstrate_class_conditional_generation(
        samples_per_class=12, num_steps=20)
else:
    print("❌ Implement training_step() to proceed with sampling")

### Task 3.2: Analyze Class-Conditional Performance

In [None]:
class ClassConditionalAnalyzer:
    """
    Analyze the performance and characteristics of class-conditional generation.
    """
    
    def __init__(self, sampler: ClassConditionalSampler):
        self.sampler = sampler
        
    def analyze_class_separation(self, samples_per_class: int = 20):
        """Analyze how well the model separates different classes"""
        print("=== Class Separation Analysis ===\n")
        
        # Generate samples for each class
        class_samples = {}
        
        for class_id in range(config.num_classes):
            class_labels = torch.full((samples_per_class,), class_id, dtype=torch.long).to(device)
            
            try:
                samples = self.sampler.sample_class_conditional(class_labels, num_steps=20)
                if samples is not None:
                    class_samples[class_id] = samples
                    print(f"✓ Generated class {class_id} samples")
                else:
                    print(f"❌ Implement sampling methods first")
                    return
            except Exception as e:
                print(f"❌ Error generating class {class_id}: {e}")
                return
        
        # Compute class centers and inter-class distances
        self.compute_class_statistics(class_samples)
        
    def compute_class_statistics(self, class_samples: Dict[int, torch.Tensor]):
        """Compute statistics about class separation"""
        class_centers = {}
        class_spreads = {}
        
        print("Class Statistics:")
        for class_id, samples in class_samples.items():
            center = samples.mean(dim=0)
            spread = samples.std(dim=0).mean()  # Average std across dimensions
            
            class_centers[class_id] = center
            class_spreads[class_id] = spread
            
            print(f"  Class {class_id}: Center=({center[0]:.2f}, {center[1]:.2f}), Spread={spread:.2f}")
        
        # Compute inter-class distances
        print("\nInter-class distances:")
        for i in range(config.num_classes):
            for j in range(i+1, config.num_classes):
                distance = torch.norm(class_centers[i] - class_centers[j]).item()
                print(f"  Class {i} ↔ Class {j}: {distance:.2f}")
        
        return class_centers, class_spreads
    
    def test_class_consistency(self, num_runs: int = 3):
        """Test if the same class produces consistent samples across runs"""
        print("=== Class Consistency Test ===\n")
        
        target_class = 0  # Test with class 0
        samples_per_run = 10
        all_runs = []
        
        for run in range(num_runs):
            class_labels = torch.full((samples_per_run,), target_class, dtype=torch.long).to(device)
            
            try:
                samples = self.sampler.sample_class_conditional(class_labels, num_steps=20)
                if samples is not None:
                    all_runs.append(samples)
                    print(f"Run {run+1}: Generated {len(samples)} samples")
                else:
                    print(f"❌ Implement sampling methods first")
                    return
            except Exception as e:
                print(f"❌ Error in run {run+1}: {e}")
                return
        
        # Analyze consistency across runs
        run_centers = [samples.mean(dim=0) for samples in all_runs]
        
        print(f"\nConsistency analysis for class {target_class}:")
        for i, center in enumerate(run_centers):
            print(f"  Run {i+1} center: ({center[0]:.2f}, {center[1]:.2f})")
        
        # Compute variance across runs
        center_variance = torch.stack(run_centers).var(dim=0).mean().item()
        print(f"  Center variance across runs: {center_variance:.4f}")
        
        if center_variance < 0.1:
            print("✓ Good consistency - similar centers across runs")
        else:
            print("⚠️  High variance - class generation may be unstable")

# Test class-conditional analyzer (uncomment after implementing sampling methods)
# analyzer = ClassConditionalAnalyzer(sampler)
# analyzer.analyze_class_separation(samples_per_class=15)
# analyzer.test_class_consistency(num_runs=3)

---

## Part 4: Classifier Guidance Implementation (25 minutes)

### Task 4.1: Build Noise-Aware Classifier

**Your Mission**: Implement a classifier that works on noisy data for classifier guidance.

In [None]:
class NoiseAwareClassifier(nn.Module):
    """
    Classifier that can work on noisy data at different timesteps.
    This is essential for classifier guidance during sampling.
    """
    
    def __init__(self, data_dim: int = 2, num_classes: int = 4, embed_dim: int = 64):
        super().__init__()
        self.data_dim = data_dim
        self.num_classes = num_classes
        
        # Time embedding for noise level awareness
        self.time_mlp = nn.Sequential(
            nn.Linear(1, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
        
        # Main classifier network
        self.classifier = nn.Sequential(
            nn.Linear(data_dim + embed_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
    
    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Classify noisy data at timestep t
        
        Args:
            x: Noisy data [batch_size, data_dim]
            t: Timesteps [batch_size]
            
        Returns:
            Class logits [batch_size, num_classes]
        """
        batch_size = x.shape[0]
        
        # Handle timestep
        if t.dim() == 0:
            t = t.repeat(batch_size)
        t_embed = self.time_mlp(t.float().unsqueeze(-1) / config.T)
        
        # Combine input with time embedding
        x_with_time = torch.cat([x, t_embed], dim=-1)
        logits = self.classifier(x_with_time)
        
        return logits

class ClassifierTrainer:
    """
    Training system for noise-aware classifier.
    """
    
    def __init__(self, classifier: NoiseAwareClassifier, config: ConditionalConfig):
        self.classifier = classifier
        self.config = config
        
    def add_noise(self, x_start: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """Add noise for classifier training on noisy data"""
        noise = torch.randn_like(x_start)
        alpha_cumprod_t = self.config.alphas_cumprod[t]
        
        if alpha_cumprod_t.dim() == 0:
            alpha_cumprod_t = alpha_cumprod_t.unsqueeze(0)
        if alpha_cumprod_t.shape[0] != x_start.shape[0]:
            alpha_cumprod_t = alpha_cumprod_t.repeat(x_start.shape[0])
        
        alpha_cumprod_t = alpha_cumprod_t.view(-1, 1)
        
        x_noisy = torch.sqrt(alpha_cumprod_t) * x_start + torch.sqrt(1 - alpha_cumprod_t) * noise
        return x_noisy
    
    def train_classifier(self, data: torch.Tensor, labels: torch.Tensor, 
                        epochs: int = 5, batch_size: int = 32, lr: float = 1e-3):
        """Train the noise-aware classifier"""
        optimizer = torch.optim.Adam(self.classifier.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        
        self.classifier.train()
        
        print("Training noise-aware classifier...")
        
        for epoch in range(epochs):
            epoch_losses = []
            n_batches = len(data) // batch_size
            
            for i in range(n_batches):
                start_idx = i * batch_size
                end_idx = min((i + 1) * batch_size, len(data))
                
                x_batch = data[start_idx:end_idx]
                y_batch = labels[start_idx:end_idx]
                
                # Sample random timesteps
                t_batch = torch.randint(0, self.config.T, (len(x_batch),)).to(device)
                
                # Add noise according to timestep
                x_noisy = self.add_noise(x_batch, t_batch)
                
                # Classify noisy data
                optimizer.zero_grad()
                logits = self.classifier(x_noisy, t_batch)
                loss = criterion(logits, y_batch)
                loss.backward()
                optimizer.step()
                
                epoch_losses.append(loss.item())
            
            avg_loss = np.mean(epoch_losses)
            print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")
        
        self.classifier.eval()
        print("✓ Classifier training completed")

class ClassifierGuidanceSampler:
    """
    Sampling system using classifier guidance.
    """
    
    def __init__(self, unconditional_model: nn.Module, classifier: NoiseAwareClassifier, 
                 config: ConditionalConfig):
        self.unconditional_model = unconditional_model
        self.classifier = classifier
        self.config = config
        
        self.unconditional_model.eval()
        self.classifier.eval()
    
    def compute_classifier_gradient(self, x_t: torch.Tensor, target_class: int, t: int) -> torch.Tensor:
        """
        TODO: Implement classifier gradient computation
        
        Compute ∇_x log p(y|x_t) for classifier guidance.
        
        Steps:
        1. Enable gradients for x_t
        2. Get classifier logits for x_t at timestep t
        3. Compute log probability for target class
        4. Compute gradient with respect to x_t
        5. Return the gradient
        
        Args:
            x_t: Current noisy state [batch_size, data_dim]
            target_class: Target class for guidance
            t: Current timestep
            
        Returns:
            Classifier gradient [batch_size, data_dim]
        """
        # TODO: Your implementation here
        # Hint: Use x_t.requires_grad_(True) and torch.autograd.grad()
        # Remember to handle the log probability correctly
        pass
    
    def classifier_guided_step(self, x_t: torch.Tensor, target_class: int, t: int, s: int, 
                             guidance_scale: float = 1.0) -> torch.Tensor:
        """
        TODO: Implement classifier-guided sampling step
        
        Apply classifier guidance to modify the noise prediction:
        ε̃ = ε_θ(x_t, t) - ω√(1-ᾱ_t) ∇_x log p(y|x_t)
        
        Steps:
        1. Get unconditional noise prediction
        2. Compute classifier gradient
        3. Apply guidance: modify noise prediction using gradient
        4. Use modified noise for DDIM update
        5. Return updated state
        
        Args:
            x_t: Current noisy state
            target_class: Target class for guidance
            t: Current timestep
            s: Target timestep (s < t)
            guidance_scale: Strength of classifier guidance (ω)
            
        Returns:
            x_s: Updated state with classifier guidance
        """
        # TODO: Your implementation here
        # Hint: Combine unconditional prediction with classifier gradient
        # Use DDIM update formula from previous labs
        pass
    
    def sample_with_classifier_guidance(self, target_class: int, num_samples: int = 10, 
                                      num_steps: int = 20, guidance_scale: float = 2.0) -> torch.Tensor:
        """
        TODO: Implement complete classifier-guided sampling
        
        Generate samples using classifier guidance.
        
        Steps:
        1. Create timestep schedule
        2. Initialize x_T from pure noise
        3. For each timestep pair: apply classifier_guided_step
        4. Return final samples
        
        Args:
            target_class: Class to generate
            num_samples: Number of samples to generate
            num_steps: Number of sampling steps
            guidance_scale: Strength of guidance
            
        Returns:
            Generated samples [num_samples, data_dim]
        """
        # TODO: Your implementation here
        # Hint: Similar to class-conditional sampling but use classifier_guided_step
        pass
    
    def create_timestep_schedule(self, num_steps: int) -> List[int]:
        """Create uniform timestep schedule"""
        step_size = self.config.T // num_steps
        timesteps = list(range(self.config.T - 1, -1, -step_size))
        timesteps.append(0)
        return timesteps[:num_steps + 1]
    
    def demonstrate_classifier_guidance(self, guidance_scales: List[float] = [0.0, 1.0, 2.0, 5.0]):
        """Demonstrate classifier guidance with different scales"""
        print("=== Classifier Guidance Demonstration ===\n")
        
        target_class = 0  # Generate circles
        samples_per_scale = 15
        
        results = {}
        
        for scale in guidance_scales:
            print(f"Testing guidance scale ω = {scale}")
            
            try:
                samples = self.sample_with_classifier_guidance(
                    target_class=target_class,
                    num_samples=samples_per_scale,
                    num_steps=20,
                    guidance_scale=scale
                )
                
                if samples is not None:
                    results[scale] = samples
                    print(f"  ✓ Generated {len(samples)} samples")
                else:
                    print(f"  ❌ Implement classifier guidance methods first")
                    return
            except Exception as e:
                print(f"  ❌ Error with scale {scale}: {e}")
                return
        
        if results:
            self.visualize_guidance_effects(results, target_class)
    
    def visualize_guidance_effects(self, results: Dict[float, torch.Tensor], target_class: int):
        """Visualize the effect of different guidance scales"""
        n_scales = len(results)
        fig, axes = plt.subplots(1, n_scales, figsize=(4*n_scales, 4))
        
        if n_scales == 1:
            axes = [axes]
        
        colors = ['red', 'blue', 'green', 'orange']
        class_names = ['Circle', 'Square', 'Triangle', 'Line']
        
        for i, (scale, samples) in enumerate(results.items()):
            # Plot generated samples
            axes[i].scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), 
                           c=colors[target_class], alpha=0.7, s=50, 
                           label=f'Generated {class_names[target_class]}')
            
            # Plot reference data for comparison
            class_mask = train_labels == target_class
            ref_data = train_data[class_mask]
            axes[i].scatter(ref_data[:, 0].cpu(), ref_data[:, 1].cpu(), 
                           c='gray', alpha=0.3, s=20, label='Reference')
            
            axes[i].set_title(f'Guidance Scale ω = {scale}')
            axes[i].legend()
            axes[i].grid(True, alpha=0.3)
            axes[i].axis('equal')
        
        plt.suptitle(f'Classifier Guidance Effect on {class_names[target_class]} Generation')
        plt.tight_layout()
        plt.show()

# Create unconditional model for classifier guidance
class SimpleUnconditionalDiffusion(nn.Module):
    """Simple unconditional model for classifier guidance demo"""
    def __init__(self, data_dim: int = 2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(data_dim + 64, 128),  # data + time embedding
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, data_dim)
        )
        
        self.time_mlp = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 64)
        )
    
    def forward(self, x, t):
        if t.dim() == 0:
            t = t.repeat(x.shape[0])
        t_embed = self.time_mlp(t.float().unsqueeze(-1) / config.T)
        return self.net(torch.cat([x, t_embed], dim=-1))

# Create and train noise-aware classifier
print("Creating noise-aware classifier...")
noise_aware_classifier = NoiseAwareClassifier(data_dim=2, num_classes=config.num_classes).to(device)

classifier_trainer = ClassifierTrainer(noise_aware_classifier, config)
classifier_trainer.train_classifier(train_data, train_labels, epochs=3, batch_size=16)

# Create unconditional model and classifier guidance sampler
unconditional_model = SimpleUnconditionalDiffusion().to(device)
guidance_sampler = ClassifierGuidanceSampler(unconditional_model, noise_aware_classifier, config)

# Demonstrate classifier guidance (uncomment after implementing TODOs)
# guidance_sampler.demonstrate_classifier_guidance([0.0, 1.0, 2.0, 4.0])

### Task 4.2: Analyze Classifier Guidance Performance

In [None]:
class ClassifierGuidanceAnalyzer:
    """
    Analyze the performance characteristics of classifier guidance.
    """
    
    def __init__(self, guidance_sampler: ClassifierGuidanceSampler):
        self.guidance_sampler = guidance_sampler
        
    def analyze_guidance_strength(self, target_class: int = 0):
        """Analyze how guidance strength affects generation quality and diversity"""
        print("=== Guidance Strength Analysis ===\n")
        
        guidance_scales = [0.0, 0.5, 1.0, 2.0, 5.0, 10.0]
        samples_per_scale = 20
        
        results = {}
        
        for scale in guidance_scales:
            print(f"Testing guidance scale {scale}...")
            
            try:
                samples = self.guidance_sampler.sample_with_classifier_guidance(
                    target_class=target_class,
                    num_samples=samples_per_scale,
                    num_steps=20,
                    guidance_scale=scale
                )
                
                if samples is not None:
                    # Compute quality metrics
                    center = samples.mean(dim=0)
                    spread = samples.std(dim=0).mean().item()
                    
                    results[scale] = {
                        'samples': samples,
                        'center': center,
                        'spread': spread
                    }
                    
                    print(f"  Center: ({center[0]:.2f}, {center[1]:.2f}), Spread: {spread:.3f}")
                else:
                    print("  ❌ Implement guidance methods first")
                    return
            except Exception as e:
                print(f"  Error with scale {scale}: {e}")
                return
        
        if results:
            self.plot_guidance_analysis(results, target_class)
        
        return results
    
    def plot_guidance_analysis(self, results: Dict, target_class: int):
        """Plot guidance strength analysis results"""
        scales = list(results.keys())
        spreads = [results[s]['spread'] for s in scales]
        
        # Plot spread vs guidance scale
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        plt.plot(scales, spreads, 'bo-', linewidth=2, markersize=8)
        plt.xlabel('Guidance Scale ω')
        plt.ylabel('Sample Spread (Diversity)')
        plt.title('Diversity vs Guidance Strength')
        plt.grid(True, alpha=0.3)
        
        # Plot sample evolution
        plt.subplot(1, 2, 2)
        colors = plt.cm.viridis(np.linspace(0, 1, len(scales)))
        
        for i, (scale, result) in enumerate(results.items()):
            samples = result['samples']
            plt.scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), 
                       c=[colors[i]], alpha=0.6, s=30, label=f'ω={scale}')
        
        plt.xlabel('X')
        plt.ylabel('Y')
        plt.title('Sample Distribution vs Guidance')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.axis('equal')
        
        plt.tight_layout()
        plt.show()
        
        print("\nKey Insights:")
        print("• Higher guidance → Lower diversity (more focused)")
        print("• ω=0: Unconditional generation (high diversity)")
        print("• ω>5: Risk of mode collapse or artifacts")
        print("• Optimal ω depends on application needs")
    
    def compare_with_class_conditional(self, target_class: int = 0):
        """Compare classifier guidance with class-conditional generation"""
        print("=== Classifier vs Class-Conditional Comparison ===\n")
        
        samples_per_method = 25
        
        try:
            # Classifier guidance samples
            guidance_samples = self.guidance_sampler.sample_with_classifier_guidance(
                target_class=target_class,
                num_samples=samples_per_method,
                num_steps=20,
                guidance_scale=2.0
            )
            
            if guidance_samples is None:
                print("❌ Implement classifier guidance methods first")
                return
                
        except Exception as e:
            print(f"❌ Error with classifier guidance: {e}")
            return
        
        # Note: This would compare with class-conditional if implemented
        print("Classifier guidance samples generated successfully")
        print("(To compare with class-conditional, implement Part 3 methods)")
        
        return guidance_samples

# Test classifier guidance analyzer (uncomment after implementing guidance methods)
# guidance_analyzer = ClassifierGuidanceAnalyzer(guidance_sampler)
# guidance_results = guidance_analyzer.analyze_guidance_strength(target_class=0)
# guidance_analyzer.compare_with_class_conditional(target_class=0)

---

## Part 5: Classifier-Free Guidance Implementation (20 minutes)

### Task 5.1: Implement CFG Training and Sampling

**Your Mission**: Build the modern standard for conditional generation - classifier-free guidance.

In [None]:
class ClassifierFreeUNet(nn.Module):
    """
    U-Net that supports both conditional and unconditional generation.
    This is the key to classifier-free guidance.
    """
    
    def __init__(self, data_dim: int = 2, num_classes: int = 4, embed_dim: int = 64):
        super().__init__()
        self.data_dim = data_dim
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        
        # Class embedding table (with null token for unconditional)
        self.class_embedding = nn.Embedding(num_classes + 1, embed_dim)  # +1 for null token
        self.null_token = num_classes  # Use last index as null token
        
        # Time embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(1, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )
        
        # Main network
        self.encoder = nn.Sequential(
            nn.Linear(data_dim + embed_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(128 + embed_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, data_dim)
        )
    
    def forward(self, x: torch.Tensor, class_labels: Optional[torch.Tensor], t: torch.Tensor) -> torch.Tensor:
        """
        Forward pass supporting both conditional and unconditional generation
        
        Args:
            x: Noisy data [batch_size, data_dim]
            class_labels: Class labels [batch_size] or None for unconditional
            t: Timesteps [batch_size]
        """
        batch_size = x.shape[0]
        
        # Handle timestep
        if t.dim() == 0:
            t = t.repeat(batch_size)
        t_embed = self.time_mlp(t.float().unsqueeze(-1) / config.T)
        
        # Handle class conditioning
        if class_labels is not None:
            class_embed = self.class_embedding(class_labels)
        else:
            # Use null token for unconditional
            null_labels = torch.full((batch_size,), self.null_token, dtype=torch.long).to(device)
            class_embed = self.class_embedding(null_labels)
        
        # Combine embeddings
        combined_embed = t_embed + class_embed
        
        # Forward pass
        x_with_embed = torch.cat([x, combined_embed], dim=-1)
        encoded = self.encoder(x_with_embed)
        decoder_input = torch.cat([encoded, combined_embed], dim=-1)
        output = self.decoder(decoder_input)
        
        return output

class CFGTrainer:
    """
    Training system for classifier-free guidance.
    """
    
    def __init__(self, model: ClassifierFreeUNet, config: ConditionalConfig):
        self.model = model
        self.config = config
        self.model.train()
        
    def add_noise(self, x_start: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Add noise according to forward process"""
        noise = torch.randn_like(x_start)
        alpha_cumprod_t = self.config.alphas_cumprod[t]
        
        if alpha_cumprod_t.dim() == 0:
            alpha_cumprod_t = alpha_cumprod_t.unsqueeze(0)
        if alpha_cumprod_t.shape[0] != x_start.shape[0]:
            alpha_cumprod_t = alpha_cumprod_t.repeat(x_start.shape[0])
        
        alpha_cumprod_t = alpha_cumprod_t.view(-1, 1)
        
        x_noisy = torch.sqrt(alpha_cumprod_t) * x_start + torch.sqrt(1 - alpha_cumprod_t) * noise
        return x_noisy, noise
    
    def cfg_training_step(self, x_batch: torch.Tensor, class_batch: torch.Tensor) -> Dict[str, float]:
        """
        TODO: Implement classifier-free guidance training step
        
        The key insight: train one model to do both conditional and unconditional generation
        by randomly dropping the class condition during training.
        
        Steps:
        1. Sample random timesteps for each item in batch
        2. Sample noise and create noisy data
        3. Apply conditioning dropout: randomly set some class_labels to None
        4. Predict noise using model (handles both conditional and unconditional)
        5. Compute MSE loss between predicted and actual noise
        6. Return loss and metrics
        
        Args:
            x_batch: Clean data [batch_size, data_dim]  
            class_batch: Class labels [batch_size]
            
        Returns:
            Dictionary with loss and metrics
        """
        # TODO: Your implementation here
        # Hint: Use random dropout with probability config.dropout_prob
        # When dropping condition, pass None as class_labels to model
        pass
    
    def train_cfg_epoch(self, data: torch.Tensor, labels: torch.Tensor, 
                       batch_size: int = 32, lr: float = 1e-3) -> List[float]:
        """Train one epoch with CFG objective"""
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        losses = []
        
        n_batches = len(data) // batch_size
        
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, len(data))
            
            x_batch = data[start_idx:end_idx]
            class_batch = labels[start_idx:end_idx]
            
            optimizer.zero_grad()
            
            try:
                result = self.cfg_training_step(x_batch, class_batch)
                if result and 'loss' in result:
                    loss = result['loss']
                    loss.backward()
                    optimizer.step()
                    losses.append(loss.item())
                else:
                    print(f"Batch {i}: Implement cfg_training_step() method")
                    break
            except Exception as e:
                print(f"Batch {i}: Error in CFG training - {e}")
                break
        
        return losses

class CFGSampler:
    """
    Sampling system using classifier-free guidance.
    """
    
    def __init__(self, model: ClassifierFreeUNet, config: ConditionalConfig):
        self.model = model
        self.config = config
        self.model.eval()
    
    def cfg_step(self, x_t: torch.Tensor, class_labels: torch.Tensor, t: int, s: int, 
                 guidance_scale: float = 1.0) -> torch.Tensor:
        """
        TODO: Implement classifier-free guidance sampling step
        
        The CFG magic: ε̃ = (1+ω)ε_cond - ωε_uncond
        
        Steps:
        1. Get conditional noise prediction: ε_cond = model(x_t, class_labels, t)
        2. Get unconditional noise prediction: ε_uncond = model(x_t, None, t)  
        3. Apply CFG formula: ε̃ = (1+ω)ε_cond - ωε_uncond
        4. Use guided noise prediction for DDIM update
        5. Return updated state
        
        Args:
            x_t: Current noisy state [batch_size, data_dim]
            class_labels: Target class labels [batch_size]
            t: Current timestep
            s: Target timestep (s < t)
            guidance_scale: CFG guidance scale (ω)
            
        Returns:
            x_s: Updated state with CFG guidance
        """
        # TODO: Your implementation here
        # Hint: Call model twice - once with class_labels, once with None
        # Apply CFG formula and then use DDIM update
        pass
    
    def sample_cfg(self, class_labels: torch.Tensor, num_steps: int = 20, 
                   guidance_scale: float = 2.0) -> torch.Tensor:
        """
        TODO: Implement complete CFG sampling
        
        Generate samples using classifier-free guidance.
        
        Steps:
        1. Create timestep schedule
        2. Initialize x_T from pure noise
        3. For each timestep pair: apply cfg_step
        4. Return final samples
        
        Args:
            class_labels: Target class labels [batch_size]
            num_steps: Number of sampling steps
            guidance_scale: CFG guidance scale
            
        Returns:
            Generated samples [batch_size, data_dim]
        """
        # TODO: Your implementation here
        # Hint: Similar to previous samplers but use cfg_step
        pass
    
    def create_timestep_schedule(self, num_steps: int) -> List[int]:
        """Create uniform timestep schedule"""
        step_size = self.config.T // num_steps
        timesteps = list(range(self.config.T - 1, -1, -step_size))
        timesteps.append(0)
        return timesteps[:num_steps + 1]
    
    def demonstrate_cfg_generation(self, guidance_scales: List[float] = [0.0, 1.0, 2.0, 5.0]):
        """Demonstrate CFG with different guidance scales"""
        print("=== Classifier-Free Guidance Demonstration ===\n")
        
        target_class = 0
        samples_per_scale = 15
        
        results = {}
        
        for scale in guidance_scales:
            print(f"Testing CFG scale ω = {scale}")
            
            class_labels = torch.full((samples_per_scale,), target_class, dtype=torch.long).to(device)
            
            try:
                samples = self.sample_cfg(class_labels, num_steps=20, guidance_scale=scale)
                
                if samples is not None:
                    results[scale] = samples
                    print(f"  ✓ Generated {len(samples)} samples")
                else:
                    print(f"  ❌ Implement CFG sampling methods first")
                    return
            except Exception as e:
                print(f"  ❌ Error with scale {scale}: {e}")
                return
        
        if results:
            self.visualize_cfg_effects(results, target_class)
    
    def visualize_cfg_effects(self, results: Dict[float, torch.Tensor], target_class: int):
        """Visualize CFG effects across different scales"""
        n_scales = len(results)
        fig, axes = plt.subplots(1, n_scales, figsize=(4*n_scales, 4))
        
        if n_scales == 1:
            axes = [axes]
        
        colors = ['red', 'blue', 'green', 'orange']
        class_names = ['Circle', 'Square', 'Triangle', 'Line']
        
        for i, (scale, samples) in enumerate(results.items()):
            # Plot generated samples
            axes[i].scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), 
                           c=colors[target_class], alpha=0.7, s=50)
            
            # Plot reference data
            class_mask = train_labels == target_class
            ref_data = train_data[class_mask]
            axes[i].scatter(ref_data[:, 0].cpu(), ref_data[:, 1].cpu(), 
                           c='gray', alpha=0.3, s=20, label='Reference')
            
            # Determine behavior
            if scale == 0.0:
                behavior = "Unconditional"
            elif scale < 2.0:
                behavior = "Weak Guidance"
            elif scale < 5.0:
                behavior = "Strong Guidance"
            else:
                behavior = "Very Strong"
            
            axes[i].set_title(f'CFG ω = {scale}\n{behavior}')
            axes[i].grid(True, alpha=0.3)
            axes[i].axis('equal')
        
        plt.suptitle(f'Classifier-Free Guidance: {class_names[target_class]} Generation')
        plt.tight_layout()
        plt.show()

# Create and train CFG model
print("Creating classifier-free guidance model...")
cfg_model = ClassifierFreeUNet(data_dim=2, num_classes=config.num_classes).to(device)

# Test CFG training
cfg_trainer = CFGTrainer(cfg_model, config)

print("\nCFG training demonstration...")
cfg_losses = cfg_trainer.train_cfg_epoch(train_data, train_labels, batch_size=16, lr=1e-3)
if cfg_losses:
    print(f"CFG training losses: {cfg_losses[:5]}...")
    
    # Test CFG sampling
    cfg_sampler = CFGSampler(cfg_model, config)
    cfg_sampler.demonstrate_cfg_generation([0.0, 1.0, 2.0, 4.0])
else:
    print("❌ Implement cfg_training_step() to proceed")

### Task 5.2: CFG vs Other Methods Comparison

In [None]:
class ComprehensiveConditionalComparison:
    """
    Compare all three conditional generation approaches:
    1. Class-conditional
    2. Classifier guidance  
    3. Classifier-free guidance
    """
    
    def __init__(self, class_conditional_sampler: ClassConditionalSampler,
                 classifier_guidance_sampler: ClassifierGuidanceSampler,
                 cfg_sampler: CFGSampler):
        self.class_conditional = class_conditional_sampler
        self.classifier_guidance = classifier_guidance_sampler
        self.cfg = cfg_sampler
    
    def compare_all_methods(self, target_class: int = 0, samples_per_method: int = 20):
        """Compare all three conditioning approaches"""
        print("=== Comprehensive Conditional Generation Comparison ===\n")
        
        results = {}
        timings = {}
        
        # Test class-conditional
        print("1. Testing Class-Conditional...")
        start_time = time.time()
        try:
            class_labels = torch.full((samples_per_method,), target_class, dtype=torch.long).to(device)
            cc_samples = self.class_conditional.sample_class_conditional(class_labels, num_steps=20)
            cc_time = time.time() - start_time
            
            if cc_samples is not None:
                results['Class-Conditional'] = cc_samples
                timings['Class-Conditional'] = cc_time
                print(f"   ✓ Generated in {cc_time:.3f}s")
            else:
                print("   ❌ Class-conditional not implemented")
        except Exception as e:
            print(f"   ❌ Class-conditional error: {e}")
        
        # Test classifier guidance
        print("2. Testing Classifier Guidance...")
        start_time = time.time()
        try:
            cg_samples = self.classifier_guidance.sample_with_classifier_guidance(
                target_class=target_class, num_samples=samples_per_method, 
                num_steps=20, guidance_scale=2.0)
            cg_time = time.time() - start_time
            
            if cg_samples is not None:
                results['Classifier Guidance'] = cg_samples
                timings['Classifier Guidance'] = cg_time
                print(f"   ✓ Generated in {cg_time:.3f}s")
            else:
                print("   ❌ Classifier guidance not implemented")
        except Exception as e:
            print(f"   ❌ Classifier guidance error: {e}")
        
        # Test CFG
        print("3. Testing Classifier-Free Guidance...")
        start_time = time.time()
        try:
            class_labels = torch.full((samples_per_method,), target_class, dtype=torch.long).to(device)
            cfg_samples = self.cfg.sample_cfg(class_labels, num_steps=20, guidance_scale=2.0)
            cfg_time = time.time() - start_time
            
            if cfg_samples is not None:
                results['CFG'] = cfg_samples
                timings['CFG'] = cfg_time
                print(f"   ✓ Generated in {cfg_time:.3f}s")
            else:
                print("   ❌ CFG not implemented")
        except Exception as e:
            print(f"   ❌ CFG error: {e}")
        
        if results:
            self.visualize_method_comparison(results, target_class)
            self.analyze_method_characteristics(results, timings)
        
        return results, timings
    
    def visualize_method_comparison(self, results: Dict[str, torch.Tensor], target_class: int):
        """Visualize comparison between methods"""
        n_methods = len(results)
        fig, axes = plt.subplots(1, n_methods + 1, figsize=(4*(n_methods+1), 4))
        
        colors = ['red', 'blue', 'green', 'orange']
        class_names = ['Circle', 'Square', 'Triangle', 'Line']
        
        # Reference data
        class_mask = train_labels == target_class
        ref_data = train_data[class_mask]
        axes[0].scatter(ref_data[:, 0].cpu(), ref_data[:, 1].cpu(), 
                       c=colors[target_class], alpha=0.7, s=50)
        axes[0].set_title('Reference\nTraining Data')
        axes[0].grid(True, alpha=0.3)
        axes[0].axis('equal')
        
        # Generated samples from each method
        for i, (method, samples) in enumerate(results.items()):
            axes[i+1].scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), 
                             c=colors[target_class], alpha=0.7, s=50)
            axes[i+1].scatter(ref_data[:, 0].cpu(), ref_data[:, 1].cpu(), 
                             c='gray', alpha=0.2, s=10)
            axes[i+1].set_title(f'{method}\nGeneration')
            axes[i+1].grid(True, alpha=0.3)
            axes[i+1].axis('equal')
        
        plt.suptitle(f'Conditional Generation Methods: {class_names[target_class]}')
        plt.tight_layout()
        plt.show()
    
    def analyze_method_characteristics(self, results: Dict[str, torch.Tensor], 
                                     timings: Dict[str, float]):
        """Analyze characteristics of each method"""
        print("\n=== Method Analysis ===\n")
        
        characteristics = {}
        
        for method, samples in results.items():
            # Compute basic statistics
            center = samples.mean(dim=0)
            spread = samples.std(dim=0).mean().item()
            
            characteristics[method] = {
                'center': center,
                'spread': spread,
                'time': timings.get(method, 0)
            }
            
            print(f"{method}:")
            print(f"  Center: ({center[0]:.2f}, {center[1]:.2f})")
            print(f"  Spread: {spread:.3f}")
            print(f"  Time: {timings.get(method, 0):.3f}s")
            print()
        
        # Create comparison chart
        self.plot_method_comparison_chart(characteristics)
        
        return characteristics
    
    def plot_method_comparison_chart(self, characteristics: Dict):
        """Plot comparison chart of method characteristics"""
        methods = list(characteristics.keys())
        spreads = [characteristics[m]['spread'] for m in methods]
        times = [characteristics[m]['time'] for m in methods]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Diversity comparison
        bars1 = ax1.bar(methods, spreads, color=['lightblue', 'lightcoral', 'lightgreen'])
        ax1.set_ylabel('Sample Spread (Diversity)')
        ax1.set_title('Generation Diversity')
        ax1.tick_params(axis='x', rotation=45)
        
        for bar, spread in zip(bars1, spreads):
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height + 0.001,
                    f'{spread:.3f}', ha='center', va='bottom')
        
        # Speed comparison
        bars2 = ax2.bar(methods, times, color=['lightblue', 'lightcoral', 'lightgreen'])
        ax2.set_ylabel('Generation Time (s)')
        ax2.set_title('Generation Speed')
        ax2.tick_params(axis='x', rotation=45)
        
        for bar, time_val in zip(bars2, times):
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width()/2., height + 0.001,
                    f'{time_val:.3f}s', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()
    
    def summarize_trade_offs(self):
        """Summarize the trade-offs between methods"""
        print("=== Method Trade-offs Summary ===\n")
        
        trade_offs = {
            "Class-Conditional": {
                "Pros": ["✓ Simple implementation", "✓ Fast sampling", "✓ Reliable results", "✓ Memory efficient"],
                "Cons": ["❌ Fixed classes only", "❌ Limited flexibility", "❌ Requires retraining for new classes"],
                "Best for": "Simple applications with fixed, known classes"
            },
            "Classifier Guidance": {
                "Pros": ["✓ Works with any pretrained model", "✓ Modular approach", "✓ Strong control"],
                "Cons": ["❌ Requires separate classifier", "❌ Slower (extra gradients)", "❌ Complex implementation"],
                "Best for": "Research and experimentation with existing models"
            },
            "Classifier-Free Guidance": {
                "Pros": ["✓ No separate classifier", "✓ Modern standard", "✓ Excellent text conditioning", "✓ Flexible control"],
                "Cons": ["❌ Requires retraining", "❌ 2x forward passes", "❌ More complex training"],
                "Best for": "Production systems, text-to-image, modern applications"
            }
        }
        
        for method, info in trade_offs.items():
            print(f"🔹 {method}:")
            print(f"   Pros: {', '.join(info['Pros'])}")
            print(f"   Cons: {', '.join(info['Cons'])}")
            print(f"   Best for: {info['Best for']}")
            print()
        
        print("🏆 Winner for modern applications: Classifier-Free Guidance")
        print("   • Powers Stable Diffusion, DALL-E, Midjourney")
        print("   • Best balance of quality, flexibility, and practicality")

# Create comprehensive comparison (uncomment after implementing all methods)
# Note: This will only work if all three methods are implemented
try:
    if 'sampler' in locals() and 'guidance_sampler' in locals() and 'cfg_sampler' in locals():
        comprehensive_comparison = ComprehensiveConditionalComparison(
            sampler, guidance_sampler, cfg_sampler)
        
        comparison_results, comparison_timings = comprehensive_comparison.compare_all_methods(
            target_class=0, samples_per_method=15)
        
        comprehensive_comparison.summarize_trade_offs()
    else:
        print("Implement all three methods to run comprehensive comparison")
except:
    print("Implement all three sampling methods to run comprehensive comparison")

---

## Part 6: Advanced Conditioning and Real-World Applications (10 minutes)

### Task 6.1: Multi-Scale Conditioning and Modern Techniques

**Your Mission**: Explore advanced conditioning techniques used in production systems.

In [None]:
class AdvancedConditioningDemo:
    """
    Demonstrate advanced conditioning techniques used in modern diffusion models.
    """
    
    def __init__(self, cfg_model: ClassifierFreeUNet, config: ConditionalConfig):
        self.model = cfg_model
        self.config = config
        
    def demonstrate_guidance_interpolation(self, class_a: int = 0, class_b: int = 1, 
                                         num_interpolation_steps: int = 5):
        """Demonstrate interpolation between different class conditions"""
        print("=== Guidance Interpolation Demo ===\n")
        
        # This would require more sophisticated implementation
        # For now, we'll demonstrate the concept
        
        print(f"Interpolating between class {class_a} and class {class_b}")
        print("In advanced systems, this enables:")
        print("• Smooth transitions between conditions")
        print("• Creative control over generation")
        print("• Morphing between different styles/classes")
        
        # Simulate interpolation visualization
        fig, axes = plt.subplots(1, num_interpolation_steps, figsize=(3*num_interpolation_steps, 3))
        
        colors = ['red', 'blue', 'green', 'orange']
        class_names = ['Circle', 'Square', 'Triangle', 'Line']
        
        for i in range(num_interpolation_steps):
            # Interpolation weight
            weight = i / (num_interpolation_steps - 1)
            
            # Simulate interpolated samples (in real implementation, you'd blend class embeddings)
            if i < num_interpolation_steps // 2:
                color = colors[class_a]
                title = f'{class_names[class_a]}\n({1-weight:.1f})'
            else:
                color = colors[class_b] 
                title = f'{class_names[class_b]}\n({weight:.1f})'
            
            # Mock interpolated data
            center_a = torch.tensor([2.0, 2.0])  # Circle center
            center_b = torch.tensor([-2.0, 2.0])  # Square center
            interpolated_center = (1 - weight) * center_a + weight * center_b
            
            mock_samples = interpolated_center.unsqueeze(0) + 0.3 * torch.randn(10, 2)
            
            axes[i].scatter(mock_samples[:, 0], mock_samples[:, 1], c=color, alpha=0.7, s=50)
            axes[i].set_title(title)
            axes[i].grid(True, alpha=0.3)
            axes[i].axis('equal')
            axes[i].set_xlim(-3, 3)
            axes[i].set_ylim(1, 3)
        
        plt.suptitle('Class Interpolation (Conceptual)')
        plt.tight_layout()
        plt.show()
    
    def demonstrate_hierarchical_conditioning(self):
        """Demonstrate multi-level conditioning concepts"""
        print("=== Hierarchical Conditioning Demo ===\n")
        
        conditioning_hierarchy = {
            "Global Control": {
                "Level": "High-level semantics",
                "Examples": ["Class labels", "Style tokens", "Global attributes"],
                "Effect": "Overall generation direction"
            },
            "Regional Control": {
                "Level": "Spatial conditioning", 
                "Examples": ["Bounding boxes", "Segmentation maps", "Spatial layouts"],
                "Effect": "Where objects appear"
            },
            "Local Control": {
                "Level": "Fine-grained details",
                "Examples": ["Edge maps", "Depth maps", "Texture guidance"],
                "Effect": "Surface details and textures"
            }
        }
        
        print("Modern diffusion models support multiple conditioning levels:")
        for level, info in conditioning_hierarchy.items():
            print(f"\n🔹 {level}:")
            print(f"   Level: {info['Level']}")
            print(f"   Examples: {', '.join(info['Examples'])}")
            print(f"   Effect: {info['Effect']}")
        
        # Visualize hierarchy concept
        fig, ax = plt.subplots(1, 1, figsize=(10, 6))
        
        levels = list(conditioning_hierarchy.keys())
        y_positions = [2, 1, 0]
        colors = ['lightblue', 'lightgreen', 'lightyellow']
        
        for i, (level, y_pos, color) in enumerate(zip(levels, y_positions, colors)):
            # Draw level box
            rect = plt.Rectangle((0.1, y_pos-0.3), 0.8, 0.6, 
                               facecolor=color, edgecolor='black', alpha=0.7)
            ax.add_patch(rect)
            
            # Add text
            ax.text(0.5, y_pos, level, ha='center', va='center', 
                   fontsize=12, fontweight='bold')
            
            # Add arrows between levels
            if i < len(levels) - 1:
                ax.arrow(0.5, y_pos-0.4, 0, -0.2, head_width=0.05, head_length=0.05,
                        fc='gray', ec='gray')
        
        ax.set_xlim(0, 1)
        ax.set_ylim(-0.5, 2.5)
        ax.set_title('Hierarchical Conditioning in Modern Diffusion Models')
        ax.axis('off')
        
        plt.tight_layout()
        plt.show()
    
    def demonstrate_production_considerations(self):
        """Demonstrate practical considerations for production deployment"""
        print("=== Production Deployment Considerations ===\n")
        
        considerations = {
            "Performance Optimization": [
                "• Cached embeddings for repeated conditions",
                "• Batch processing for multiple samples", 
                "• Mixed precision training and inference",
                "• Model distillation for faster inference"
            ],
            "Quality Control": [
                "• Validation metrics for conditioning adherence",
                "• Safety filters for inappropriate content",
                "• Consistency checks across guidance scales",
                "• A/B testing for optimal hyperparameters"
            ],
            "User Experience": [
                "• Intuitive guidance scale recommendations",
                "• Progressive generation with early previews",
                "• Fallback mechanisms for failed generations",
                "• Real-time feedback for interactive applications"
            ],
            "Scalability": [
                "• Distributed inference for high throughput",
                "• Load balancing across multiple models",
                "• Efficient memory management for large batches",
                "• Auto-scaling based on demand"
            ]
        }
        
        for category, items in considerations.items():
            print(f"🔹 {category}:")
            for item in items:
                print(f"   {item}")
            print()
        
        print("💡 Key Insight: Production systems require extensive engineering beyond the core algorithm!")
    
    def showcase_modern_applications(self):
        """Showcase real-world applications of conditional diffusion"""
        print("=== Modern Applications Showcase ===\n")
        
        applications = {
            "Creative AI Tools": {
                "Examples": ["Stable Diffusion", "Midjourney", "DALL-E"],
                "Key Features": ["Text-to-image", "Style transfer", "Inpainting"],
                "Impact": "Democratized AI art creation"
            },
            "Content Creation": {
                "Examples": ["Marketing materials", "Game assets", "Product visualization"],
                "Key Features": ["Brand consistency", "Rapid iteration", "Custom styles"],
                "Impact": "Accelerated creative workflows"
            },
            "Scientific Applications": {
                "Examples": ["Medical imaging", "Material design", "Drug discovery"],
                "Key Features": ["Conditional synthesis", "Data augmentation", "Property control"],
                "Impact": "Enhanced research capabilities"
            },
            "Interactive Media": {
                "Examples": ["Video games", "AR/VR", "Real-time avatars"],
                "Key Features": ["Real-time generation", "User control", "Adaptive content"],
                "Impact": "Personalized experiences"
            }
        }
        
        print("Conditional diffusion has revolutionized multiple industries:")
        for domain, info in applications.items():
            print(f"\n🚀 {domain}:")
            print(f"   Examples: {', '.join(info['Examples'])}")
            print(f"   Key Features: {', '.join(info['Key Features'])}")
            print(f"   Impact: {info['Impact']}")
        
        print(f"\n🌟 The Future:")
        print(f"   • Multimodal conditioning (text + image + audio)")
        print(f"   • Real-time interactive generation")
        print(f"   • Personalized AI creative assistants")
        print(f"   • Domain-specific specialized models")

# Demonstrate advanced techniques
advanced_demo = AdvancedConditioningDemo(cfg_model, config)

# Run demonstrations
advanced_demo.demonstrate_guidance_interpolation(class_a=0, class_b=1, num_interpolation_steps=5)
advanced_demo.demonstrate_hierarchical_conditioning()
advanced_demo.demonstrate_production_considerations()
advanced_demo.showcase_modern_applications()

---

## Part 7: Integration and Final Validation (5 minutes)

### Task 7.1: Complete System Integration and Validation

**Your Mission**: Integrate all conditional generation approaches and validate the complete system.

In [None]:
def comprehensive_conditional_validation():
    """
    Final validation of the complete conditional generation system
    """
    print("=== Comprehensive Conditional Generation Validation ===\n")
    
    validation_results = {
        'class_conditional_implemented': False,
        'classifier_guidance_implemented': False,
        'cfg_implemented': False,
        'understanding_demonstrated': False
    }
    
    # Test 1: Class-Conditional Implementation
    print("1. Validating Class-Conditional Implementation...")
    try:
        if 'trainer' in locals():
            test_result = trainer.training_step(train_data[:5], train_labels[:5])
            if test_result and 'loss' in test_result:
                validation_results['class_conditional_implemented'] = True
                print("   ✓ Class-conditional training functional")
            else:
                print("   ❌ Implement training_step() method")
        else:
            print("   ❌ Class-conditional trainer not created")
    except Exception as e:
        print(f"   ❌ Class-conditional error: {e}")
    
    # Test 2: Classifier Guidance Implementation
    print("2. Validating Classifier Guidance Implementation...")
    try:
        if 'guidance_sampler' in locals():
            test_gradient = guidance_sampler.compute_classifier_gradient(
                torch.randn(3, 2).to(device), target_class=0, t=10)
            if test_gradient is not None:
                validation_results['classifier_guidance_implemented'] = True
                print("   ✓ Classifier guidance functional")
            else:
                print("   ❌ Implement compute_classifier_gradient() method")
        else:
            print("   ❌ Classifier guidance sampler not created")
    except Exception as e:
        print(f"   ❌ Classifier guidance error: {e}")
    
    # Test 3: CFG Implementation
    print("3. Validating Classifier-Free Guidance Implementation...")
    try:
        if 'cfg_trainer' in locals():
            test_result = cfg_trainer.cfg_training_step(train_data[:5], train_labels[:5])
            if test_result and 'loss' in test_result:
                validation_results['cfg_implemented'] = True
                print("   ✓ CFG training functional")
            else:
                print("   ❌ Implement cfg_training_step() method")
        else:
            print("   ❌ CFG trainer not created")
    except Exception as e:
        print(f"   ❌ CFG error: {e}")
    
    # Test 4: Understanding Demonstration
    print("4. Validating Conceptual Understanding...")
    try:
        # Check if student ran the limitation demos
        if 'limitations_demo' in locals():
            validation_results['understanding_demonstrated'] = True
            print("   ✓ Unconditional limitations demonstrated")
        else:
            print("   ❌ Run unconditional limitations demo")
    except Exception as e:
        print(f"   ❌ Understanding demo error: {e}")
    
    # Overall assessment
    print("\n" + "="*60)
    print("FINAL VALIDATION RESULTS:")
    print("="*60)
    
    for component, status in validation_results.items():
        status_str = "✓ PASS" if status else "❌ FAIL"
        component_name = component.replace('_', ' ').title()
        print(f"{component_name}: {status_str}")
    
    overall_pass = sum(validation_results.values()) >= 2
    print(f"\nOverall System Status: {'✓ FUNCTIONAL' if overall_pass else '❌ NEEDS WORK'}")
    
    if overall_pass:
        print("\n🎉 Congratulations! Your conditional generation system is working!")
        print("You've implemented the core techniques that power modern AI art tools!")
    else:
        print("\n🔧 Keep working on the TODO implementations.")
        print("Focus on the training_step() and sampling methods first.")
    
    return validation_results

def demonstrate_conditional_generation_journey():
    """
    Demonstrate the complete journey from unconditional to conditional generation
    """
    print("=== Conditional Generation Journey ===\n")
    
    journey_stages = [
        "1. 🎲 Problem: Unconditional generation gives random results",
        "2. 🏷️  Solution 1: Class-conditional - simple embedding approach",
        "3. 🧭 Solution 2: Classifier guidance - external steering",
        "4. ⚡ Solution 3: Classifier-free guidance - modern breakthrough",
        "5. 🚀 Advanced: Hierarchical and multi-modal conditioning",
        "6. 🏭 Production: Real-world deployment considerations"
    ]
    
    print("Your journey through conditional diffusion generation:")
    for stage in journey_stages:
        print(f"  {stage}")
    
    print(f"\n💡 Key insights achieved:")
    print(f"   • Conditional generation solves the control problem")
    print(f"   • Multiple approaches with different trade-offs")
    print(f"   • CFG became the modern standard (Stable Diffusion, DALL-E)")
    print(f"   • Production requires extensive engineering beyond algorithms")
    
    print(f"\n🌟 What this enables:")
    print(f"   • Text-to-image generation (DALL-E, Stable Diffusion)")
    print(f"   • Creative AI tools (Midjourney, Playground AI)")
    print(f"   • Style transfer and artistic applications")
    print(f"   • Controllable content creation")
    
    # Create journey visualization
    fig, ax = plt.subplots(1, 1, figsize=(14, 8))
    
    stages = ["Problem", "Class-Cond", "Classifier", "CFG", "Advanced", "Production"]
    stage_colors = ['red', 'lightblue', 'orange', 'lightgreen', 'purple', 'gold']
    
    # Draw timeline
    y_pos = 0.5
    stage_width = 0.12
    
    for i, (stage, color) in enumerate(zip(stages, stage_colors)):
        x_pos = 0.1 + i * 0.15
        
        # Stage circle
        circle = plt.Circle((x_pos, y_pos), stage_width/2, 
                           facecolor=color, edgecolor='black', alpha=0.8)
        ax.add_patch(circle)
        ax.text(x_pos, y_pos, stage, ha='center', va='center', 
               fontsize=9, weight='bold')
        
        # Arrow to next stage
        if i < len(stages) - 1:
            ax.arrow(x_pos + stage_width/2, y_pos, 
                    0.15 - stage_width, 0, 
                    head_width=0.02, head_length=0.01, 
                    fc='gray', ec='gray')
    
    # Add evolution annotations
    evolutions = [
        (0.175, 0.7, "Random\nGeneration"),
        (0.325, 0.3, "Simple\nControl"),
        (0.475, 0.7, "External\nSteering"),
        (0.625, 0.3, "Unified\nModel"),
        (0.775, 0.7, "Multi-Modal\nControl"),
        (0.925, 0.3, "Real-World\nSystems")
    ]
    
    for x, y, text in evolutions:
        ax.text(x, y, text, ha='center', va='center', fontsize=8,
               bbox=dict(boxstyle="round,pad=0.3", facecolor='lightyellow', alpha=0.8))
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_title('Evolution of Conditional Diffusion Generation', fontsize=16, weight='bold')
    ax.axis('off')
    
    plt.tight_layout()
    plt.show()

def create_final_summary():
    """Create a comprehensive summary of conditional generation approaches"""
    print("=== Final Summary: Conditional Generation Mastery ===\n")
    
    # Method comparison table
    methods_summary = {
        "Approach": ["Class-Conditional", "Classifier Guidance", "Classifier-Free"],
        "Training": ["Conditional only", "Separate classifier", "Joint training"],
        "Sampling": ["1 forward pass", "2 passes + gradients", "2 forward passes"],
        "Flexibility": ["Fixed classes", "Any classifier", "Learned conditions"],
        "Speed": ["Fastest", "Slowest", "Medium"],
        "Quality": ["Good", "Excellent", "Excellent"],
        "Modern Use": ["Limited", "Research", "Production Standard"]
    }
    
    print("📊 Method Comparison Summary:")
    print("-" * 80)
    
    # Print table header
    print(f"{'Aspect':<15} {'Class-Cond':<12} {'Classifier':<12} {'CFG':<12}")
    print("-" * 80)
    
    # Print table rows
    aspects = ["Training", "Sampling", "Flexibility", "Speed", "Quality", "Modern Use"]
    for i, aspect in enumerate(aspects):
        row = f"{aspect:<15} {methods_summary[list(methods_summary.keys())[0]][i+1]:<12} "
        row += f"{methods_summary[list(methods_summary.keys())[0]][i+1]:<12} "
        row += f"{methods_summary[list(methods_summary.keys())[0]][i+1]:<12}"
        print(row)
    
    print("-" * 80)
    
    print(f"\n🏆 Winner for Modern Applications: Classifier-Free Guidance")
    print(f"   • Powers all major text-to-image models")
    print(f"   • Best balance of quality, flexibility, and practicality")
    print(f"   • Enables complex multi-modal conditioning")
    
    print(f"\n🔮 Future Directions:")
    print(f"   • Real-time interactive generation")
    print(f"   • Multimodal conditioning (text + image + audio + 3D)")
    print(f"   • Personalized models that adapt to user preferences")
    print(f"   • Specialized domain models (medical, scientific, creative)")
    
    print(f"\n🎯 Key Takeaways:")
    print(f"   • Conditional generation solves the control problem in diffusion models")
    print(f"   • Different approaches offer different trade-offs")
    print(f"   • CFG has become the foundation of modern AI creativity tools")
    print(f"   • Implementation details matter for production deployment")

# Run final validation and demonstrations
validation_results = comprehensive_conditional_validation()
demonstrate_conditional_generation_journey()
create_final_summary()

print("\n" + "="*80)
print("🎓 LAB 6 COMPLETE: Conditional Generation Mastery Achieved!")
print("="*80)
print("You've learned the techniques that power modern AI art and content creation!")
print("From DALL-E to Stable Diffusion, you now understand the core algorithms.")
print("Ready to build the next generation of creative AI tools? 🚀")

---

## Implementation Checklist

### Core Conditional Functions (Students Implement):

**✅ Essential TODOs:**
- [ ] `training_step()` (Class-Conditional) - Basic conditional training objective
- [ ] `ddim_step()` (Class-Conditional) - Class-conditional DDIM sampling step
- [ ] `sample_class_conditional()` - Complete class-conditional generation
- [ ] `compute_classifier_gradient()` - Gradient computation for classifier guidance
- [ ] `classifier_guided_step()` - Single step with classifier guidance
- [ ] `sample_with_classifier_guidance()` - Complete classifier-guided sampling
- [ ] `cfg_training_step()` - CFG training with conditioning dropout
- [ ] `cfg_step()` - CFG sampling step with guidance formula
- [ ] `sample_cfg()` - Complete classifier-free guidance sampling

**✅ Provided Starter Code:**
- [ ] All U-Net architectures and model definitions
- [ ] Noise-aware classifier implementation
- [ ] Complete visualization and analysis frameworks
- [ ] Comparison and benchmarking systems
- [ ] Advanced conditioning demonstrations
- [ ] Production considerations and deployment guidance

---

## Submission Requirements

### What to Submit

Submit your completed Jupyter notebook (.ipynb file) with:

**✅ Class-Conditional Implementation:**
- Complete training objective with class embedding injection
- DDIM sampling modified for class conditioning
- Demonstration of controlled class generation

**✅ Classifier Guidance Implementation:**
- Noise-aware classifier training and gradient computation
- Classifier-guided sampling with adjustable guidance scales
- Analysis of guidance strength effects on quality and diversity

**✅ Classifier-Free Guidance Implementation:**
- Joint training for conditional and unconditional generation
- CFG sampling with the guidance formula implementation
- Comparison of different guidance scales and their effects

**✅ Comprehensive Analysis:**
- Comparison between all three conditioning approaches
- Trade-off analysis: speed vs quality vs flexibility
- Understanding of when to use each method

**✅ Advanced Understanding:**
- Demonstration of unconditional generation limitations
- Analysis of production deployment considerations
- Connection between mathematical formulations and practical implementations

**✅ Documentation and Insights:**
- Clear explanations of implementation choices
- Discussion of real-world applications and impact
- Understanding of how these techniques power modern AI art tools

---

## Quick Reference: Key Implementation Formulas

### For Implementation Reference:

**Class-Conditional Training:**

In [None]:
# Modified training objective with class conditioning
# L = E[||ε - ε_θ(x_t, y, t)||²]
predicted_noise = model(x_noisy, class_labels, t)
loss = F.mse_loss(predicted_noise, actual_noise)

**Classifier Guidance Formula:**

In [None]:
# ε̃ = ε_θ(x_t, t) - ω√(1-ᾱ_t) ∇_x log p(y|x_t)
unconditional_noise = model(x_t, t)
classifier_grad = compute_gradient(classifier, x_t, target_class, t)
guided_noise = unconditional_noise - guidance_scale * math.sqrt(1 - alpha_cumprod_t) * classifier_grad

**CFG Training with Dropout:**

In [None]:
# Randomly drop conditioning during training
if torch.rand(1) < dropout_prob:
    class_labels = None  # Unconditional training
predicted_noise = model(x_noisy, class_labels, t)

**CFG Sampling Formula:**

In [None]:
# ε̃ = (1+ω)ε_cond - ωε_uncond
conditional_noise = model(x_t, class_labels, t)
unconditional_noise = model(x_t, None, t)
guided_noise = (1 + guidance_scale) * conditional_noise - guidance_scale * unconditional_noise

---

## Common Implementation Issues & Solutions

### Debugging Tips:

**Class Embedding Issues:**
- Ensure class labels are within valid range [0, num_classes-1]
- Check that embedding dimensions match other embeddings
- Verify proper combination of time and class embeddings

**Guidance Scale Problems:**
- Start with guidance_scale=1.0 for debugging
- Scale=0.0 should give unconditional results
- Very high scales (>10) may cause artifacts or instability

**CFG Training Issues:**
- Verify dropout is applied correctly during training only
- Check that model handles None class_labels properly
- Ensure null token is properly defined and used

**Sampling Convergence:**
- Reduce number of steps for faster debugging
- Check that noise predictions are reasonable (not NaN/inf)
- Verify timestep scheduling is correct

---

