[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ruliana/pytorch-katas/blob/main/dan_1/kata_02_candle_burn_time_predictor_unrevised.ipynb)

## 🏮 The Ancient Scroll Unfurls 🏮

THE MYSTERIES OF ANCIENT PRAYER CANDLES: A MULTI-VARIABLE REVELATION

Dan Level: 1 (Temple Sweeper) | Time: 45 minutes | Sacred Arts: Multi-variable Linear Regression, Feature Scaling, Gradient Descent

## 📜 THE CHALLENGE

The temple's ancient prayer candles burn at different rates depending on wax composition, room temperature, and humidity levels. Each candle is sacred, handcrafted by generations of monks, and understanding their burn patterns is crucial for timing evening prayers. Master Ao-Tougrad approaches you with a cryptic observation: "Young grasshopper, you have learned to predict simple patterns with single variables. But the true mysteries of this temple require understanding multiple flows simultaneously."

"Understanding multiple flows leads to understanding gradient flows," Master Ao-Tougrad whispers before disappearing into the shadows. Your task is to create a multi-variable linear model that can predict candle burn times by analyzing wax composition (hardness percentage), room temperature, and humidity levels. This ancient knowledge will help the temple maintain its sacred rhythms while teaching you the fundamental art of multi-dimensional pattern recognition.

## 🎯 THE SACRED OBJECTIVES

- [ ] Master multi-variable linear regression with PyTorch
- [ ] Understand how multiple input features combine to create predictions
- [ ] Practice gradient descent with multiple parameters
- [ ] Learn to interpret model weights across different features
- [ ] Visualize multi-dimensional relationships in data

In [None]:
# 📦 FIRST CELL - ALL IMPORTS AND CONFIGURATION
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from typing import Tuple
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D

# Set reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Global configuration constants
DEFAULT_CHAOS_LEVEL = 0.1
SACRED_SEED = 42
N_FEATURES = 3  # wax_hardness, temperature, humidity

## 🕯️ THE SACRED CANDLE DATA GENERATION SCROLL

In [None]:
def generate_candle_burn_data(n_candles: int = 100, chaos_level: float = 0.1, 
                              sacred_seed: int = 42) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Generate observations of temple candle burning patterns.
    
    Ancient wisdom reveals the sacred formula:
    burn_time = 2.5 * wax_hardness + 0.3 * temperature - 0.8 * humidity + 45
    
    Where:
    - wax_hardness: 20-80% (harder wax burns longer)
    - temperature: 15-30°C (warmer rooms reduce burn time)
    - humidity: 30-70% (higher humidity reduces burn time)
    - burn_time: in minutes
    
    Args:
        n_candles: Number of candle observations to simulate
        chaos_level: Amount of environmental unpredictability (0.0 = perfect conditions)
        sacred_seed: Ensures consistent randomness for reproducible wisdom
    
    Returns:
        Tuple of (candle_features, burn_times) as sacred tensors
        candle_features shape: (n_candles, 3) [wax_hardness, temperature, humidity]
        burn_times shape: (n_candles, 1)
    """
    torch.manual_seed(sacred_seed)
    
    # Generate temple candle characteristics
    wax_hardness = torch.rand(n_candles) * 60 + 20  # 20-80% hardness
    temperature = torch.rand(n_candles) * 15 + 15   # 15-30°C
    humidity = torch.rand(n_candles) * 40 + 30      # 30-70% humidity
    
    # Combine features into a single tensor
    candle_features = torch.stack([wax_hardness, temperature, humidity], dim=1)
    
    # The ancient formula discovered by temple monks
    true_weights = torch.tensor([2.5, 0.3, -0.8])  # hardness, temp, humidity coefficients
    true_bias = 45.0
    
    # Calculate perfect burn times
    perfect_burn_times = torch.matmul(candle_features, true_weights) + true_bias
    
    # Add environmental chaos (drafts, altitude, spiritual energy fluctuations)
    chaos = torch.randn(n_candles) * chaos_level * perfect_burn_times.std()
    burn_times = perfect_burn_times + chaos
    
    # Even mystical candles have physical limits
    burn_times = torch.clamp(burn_times, 30, 200)  # 30-200 minutes
    
    return candle_features, burn_times.unsqueeze(1)

def visualize_candle_wisdom(features: torch.Tensor, burn_times: torch.Tensor, 
                           predictions: torch.Tensor = None):
    """
    Display the sacred patterns of candle burning across multiple dimensions.
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    feature_names = ['Wax Hardness (%)', 'Temperature (°C)', 'Humidity (%)']
    
    # Plot each feature vs burn time
    for i, feature_name in enumerate(feature_names):
        row, col = i // 2, i % 2
        ax = axes[row, col]
        
        # Actual data
        ax.scatter(features[:, i].numpy(), burn_times.numpy(), 
                  alpha=0.6, color='orange', label='Actual Burn Times')
        
        # Predictions if available
        if predictions is not None:
            ax.scatter(features[:, i].numpy(), predictions.detach().numpy(), 
                      alpha=0.6, color='gold', marker='x', s=50, 
                      label='Predicted Burn Times')
        
        ax.set_xlabel(feature_name)
        ax.set_ylabel('Burn Time (minutes)')
        ax.set_title(f'Temple Candles: {feature_name} vs Burn Time')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    # Feature correlation heatmap
    ax = axes[1, 1]
    feature_data = features.numpy()
    correlation_matrix = np.corrcoef(feature_data.T)
    sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0,
                xticklabels=feature_names, yticklabels=feature_names, ax=ax)
    ax.set_title('Feature Correlations in Temple Candles')
    
    plt.tight_layout()
    plt.show()

# Generate and visualize the sacred candle data
candle_features, burn_times = generate_candle_burn_data(n_candles=150, chaos_level=0.15)
print(f"Generated data for {len(candle_features)} temple candles")
print(f"Feature tensor shape: {candle_features.shape}")
print(f"Burn times tensor shape: {burn_times.shape}")
print(f"\nFeature statistics:")
print(f"Wax hardness: {candle_features[:, 0].min():.1f}% - {candle_features[:, 0].max():.1f}%")
print(f"Temperature: {candle_features[:, 1].min():.1f}°C - {candle_features[:, 1].max():.1f}°C")
print(f"Humidity: {candle_features[:, 2].min():.1f}% - {candle_features[:, 2].max():.1f}%")
print(f"Burn times: {burn_times.min():.1f} - {burn_times.max():.1f} minutes")

visualize_candle_wisdom(candle_features, burn_times)

## 🕯️ FIRST MOVEMENTS - THE MULTI-VARIABLE PREDICTOR

In [None]:
class CandleBurnPredictor(nn.Module):
    """A mystical artifact for understanding multi-dimensional candle burning patterns."""
    
    def __init__(self, input_features: int = 3):
        super(CandleBurnPredictor, self).__init__()
        # TODO: Create a Linear layer that can handle multiple input features
        # Hint: input_features=3 (wax, temperature, humidity) -> output_features=1 (burn time)
        self.linear = None
    
    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """Channel your understanding through the multi-dimensional network."""
        # TODO: Pass the multi-feature input through your Linear layer
        # Remember: this handles multiple features automatically!
        return None

def train_multi_variable(model: nn.Module, features: torch.Tensor, target: torch.Tensor,
                        epochs: int = 1000, learning_rate: float = 0.01) -> list:
    """
    Train the multi-variable candle burn prediction model.
    
    Returns:
        List of loss values during training
    """
    # TODO: Choose your loss calculation method
    # Hint: MSE Loss works well for multi-variable regression too!
    criterion = None
    
    # TODO: Choose your parameter updating method
    # Hint: SGD will update ALL parameters (weights + bias) automatically
    optimizer = None
    
    losses = []
    
    for epoch in range(epochs):
        # TODO: CRITICAL - Clear the gradient spirits from previous cycle
        # Hint: With multiple features, this is even more important!
        
        # TODO: Forward pass - get predictions from all features
        predictions = None
        
        # TODO: Compute the loss
        loss = None
        
        # TODO: Backward pass - compute gradients for all parameters
        
        # TODO: Update all parameters (weights and bias)
        
        losses.append(loss.item())
        
        # Report progress to Master Ao-Tougrad
        if (epoch + 1) % 200 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
            if loss.item() < 20:
                print("🌟 Master Ao-Tougrad nods approvingly at your multi-dimensional wisdom!")
    
    return losses

def analyze_learned_patterns(model: CandleBurnPredictor):
    """Examine what the model learned about candle burning."""
    weights = model.linear.weight.data.squeeze()
    bias = model.linear.bias.data.item()
    
    feature_names = ['Wax Hardness', 'Temperature', 'Humidity']
    
    print("\n🔥 LEARNED CANDLE BURNING WISDOM:")
    print(f"Base burn time (bias): {bias:.2f} minutes")
    print("\nFeature importance:")
    for i, (name, weight) in enumerate(zip(feature_names, weights)):
        effect = "increases" if weight > 0 else "decreases"
        print(f"  {name}: {weight:.3f} (each unit {effect} burn time)")
    
    # Compare to true values
    true_weights = torch.tensor([2.5, 0.3, -0.8])
    true_bias = 45.0
    
    print("\n🎯 COMPARISON TO ANCIENT WISDOM:")
    print(f"True bias: {true_bias:.2f}, Learned bias: {bias:.2f}")
    for i, (name, true_w, learned_w) in enumerate(zip(feature_names, true_weights, weights)):
        print(f"  {name}: True={true_w:.3f}, Learned={learned_w:.3f}")

# Create and examine the model structure
model = CandleBurnPredictor(input_features=3)
print("Model structure:")
print(model)
print(f"\nNumber of parameters: {sum(p.numel() for p in model.parameters())}")

## ⚡ THE TRIALS OF MASTERY

### Trial 1: Basic Multi-Variable Mastery
- [ ] Loss decreases consistently across all features
- [ ] Final loss below 30 (the candles burn predictably)
- [ ] Model weights approximately: [2.5, 0.3, -0.8] (±0.3 each)
- [ ] Model bias approximately 45 (±5)
- [ ] Predictions align well with actual burn times across all features

In [None]:
# Train your multi-variable model
print("🕯️ Beginning the sacred multi-variable training ritual...")
losses = train_multi_variable(model, candle_features, burn_times, epochs=1500, learning_rate=0.001)

# Analyze what the model learned
analyze_learned_patterns(model)

# Generate predictions and visualize
with torch.no_grad():
    predictions = model(candle_features)

print(f"\n📊 FINAL PERFORMANCE:")
print(f"Final loss: {losses[-1]:.4f}")
print(f"Mean absolute error: {torch.mean(torch.abs(predictions - burn_times)):.2f} minutes")

# Visualize results
visualize_candle_wisdom(candle_features, burn_times, predictions)

# Plot training progress
plt.figure(figsize=(10, 6))
plt.plot(losses, color='orange', linewidth=2)
plt.title('Multi-Variable Training Progress: Loss Over Time')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)
plt.show()

### Trial 2: Understanding Test

In [None]:
def test_your_wisdom(model):
    """Master Ao-Tougrad's evaluation of your multi-dimensional understanding."""
    # Test with specific candle configurations
    test_features = torch.tensor([
        [50.0, 20.0, 50.0],  # Medium hardness, cool room, medium humidity
        [80.0, 15.0, 30.0],  # Hard wax, cold room, low humidity (should burn longest)
        [20.0, 30.0, 70.0]   # Soft wax, warm room, high humidity (should burn shortest)
    ])
    
    predictions = model(test_features)
    
    # Shape validation
    assert predictions.shape == (3, 1), f"Expected shape (3, 1), got {predictions.shape}"
    
    # Parameter validation
    weights = model.linear.weight.data.squeeze()
    bias = model.linear.bias.data.item()
    
    # Check if learned weights are close to true values
    expected_weights = torch.tensor([2.5, 0.3, -0.8])
    expected_bias = 45.0
    
    for i, (learned, expected) in enumerate(zip(weights, expected_weights)):
        assert abs(learned - expected) < 0.5, f"Weight {i} = {learned:.3f}, expected ~{expected:.3f}"
    
    assert abs(bias - expected_bias) < 8, f"Bias = {bias:.2f}, expected ~{expected_bias:.2f}"
    
    # Logical validation - hardest wax in cold, dry conditions should burn longest
    longest_burn = predictions[1].item()  # Hard wax, cold, dry
    shortest_burn = predictions[2].item()  # Soft wax, warm, humid
    
    assert longest_burn > shortest_burn, "Hard wax in cold, dry conditions should burn longer!"
    
    print("🎉 Master Ao-Tougrad emerges from the shadows with approval!")
    print("   'You have grasped the flow of multiple gradients, young grasshopper.'")
    print(f"\n📊 Test predictions:")
    print(f"  Medium conditions: {predictions[0].item():.1f} minutes")
    print(f"  Optimal conditions: {predictions[1].item():.1f} minutes")
    print(f"  Poor conditions: {predictions[2].item():.1f} minutes")

# Run the wisdom test
test_your_wisdom(model)

## 🌸 THE FOUR PATHS OF MASTERY: PROGRESSIVE EXTENSIONS

### Extension 1: Cook Oh-Pai-Timizer's Feature Scaling Wisdom
*"In cooking, balancing flavors requires understanding their relative strengths!"*

*Cook Oh-Pai-Timizer approaches with measuring spoons*

"Ah, grasshopper! I see your model learns well, but notice how wax hardness ranges from 20-80 while temperature only goes 15-30? It's like adding a tablespoon of salt versus a teaspoon of pepper - the amounts are different but both affect the final dish! Your gradient descent might struggle because some features dominate others simply due to their scale."

**NEW CONCEPTS**: Feature normalization, standardization, gradient descent stability  
**DIFFICULTY**: +15% (still Dan 1, but with preprocessing)

In [None]:
def normalize_features(features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Normalize features to have zero mean and unit variance.
    
    Returns:
        Tuple of (normalized_features, feature_means, feature_stds)
    """
    # TODO: Implement feature normalization
    # Hint: normalized = (features - mean) / std
    # Remember: Store means and stds for denormalizing predictions later!
    
    feature_means = None
    feature_stds = None
    normalized_features = None
    
    return normalized_features, feature_means, feature_stds

def compare_training_with_without_normalization():
    """Compare training speed and stability with and without feature normalization."""
    
    # Train without normalization (original data)
    model_raw = CandleBurnPredictor(input_features=3)
    print("Training without normalization...")
    losses_raw = train_multi_variable(model_raw, candle_features, burn_times, 
                                     epochs=1000, learning_rate=0.001)
    
    # Train with normalization
    normalized_features, means, stds = normalize_features(candle_features)
    model_norm = CandleBurnPredictor(input_features=3)
    print("\nTraining with normalization...")
    losses_norm = train_multi_variable(model_norm, normalized_features, burn_times, 
                                      epochs=1000, learning_rate=0.01)  # Can use higher LR!
    
    # Compare results
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(losses_raw, label='Without Normalization', color='red', alpha=0.7)
    plt.plot(losses_norm, label='With Normalization', color='blue', alpha=0.7)
    plt.title('Training Comparison: Loss Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    plt.plot(losses_raw[-200:], label='Without Normalization', color='red', alpha=0.7)
    plt.plot(losses_norm[-200:], label='With Normalization', color='blue', alpha=0.7)
    plt.title('Final 200 Epochs (Convergence Detail)')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n🍜 Cook Oh-Pai-Timizer's Analysis:")
    print(f"  Final loss without normalization: {losses_raw[-1]:.4f}")
    print(f"  Final loss with normalization: {losses_norm[-1]:.4f}")
    print(f"  Improvement: {((losses_raw[-1] - losses_norm[-1]) / losses_raw[-1] * 100):.1f}%")

# TRIAL: Compare normalized vs non-normalized training
# SUCCESS: Normalized training converges faster and more stably
compare_training_with_without_normalization()

### Extension 2: He-Ao-World's Measurement Mishap
*"These old eyes have been recording measurements for decades, but..."*

*He-Ao-World shuffles over looking particularly apologetic*

"Oh dear! I've been helping record the candle data, but I'm afraid I've made some... inconsistencies. Some temperature readings are in Fahrenheit instead of Celsius, some humidity measurements might be absolute instead of relative, and I may have double-counted some wax hardness values. The data is messier now, but perhaps this teaches us about real-world conditions?"

**NEW CONCEPTS**: Outlier detection, robust training, data validation  
**DIFFICULTY**: +25% (still Dan 1, but with noisy data)

In [None]:
def introduce_measurement_chaos(features: torch.Tensor, burn_times: torch.Tensor, 
                               chaos_probability: float = 0.2) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Introduce realistic measurement errors that He-Ao-World might cause.
    
    Args:
        chaos_probability: Fraction of measurements that contain errors
    
    Returns:
        Tuple of (corrupted_features, corrupted_burn_times)
    """
    corrupted_features = features.clone()
    corrupted_burn_times = burn_times.clone()
    
    n_samples = len(features)
    n_corrupted = int(n_samples * chaos_probability)
    
    # Randomly select samples to corrupt
    corrupted_indices = torch.randperm(n_samples)[:n_corrupted]
    
    for idx in corrupted_indices:
        # Different types of measurement errors He-Ao-World might make
        error_type = torch.randint(0, 4, (1,)).item()
        
        if error_type == 0:  # Temperature in Fahrenheit instead of Celsius
            celsius_temp = corrupted_features[idx, 1]
            fahrenheit_temp = celsius_temp * 9/5 + 32
            corrupted_features[idx, 1] = fahrenheit_temp
        elif error_type == 1:  # Double-recorded wax hardness
            corrupted_features[idx, 0] *= 2
        elif error_type == 2:  # Humidity as absolute instead of relative
            corrupted_features[idx, 2] *= 0.3  # Simulate conversion error
        else:  # Burn time measurement error
            corrupted_burn_times[idx] *= torch.normal(1.0, 0.3, (1,))  # ±30% error
    
    return corrupted_features, corrupted_burn_times

def detect_outliers(features: torch.Tensor, burn_times: torch.Tensor, 
                   threshold: float = 3.0) -> torch.Tensor:
    """
    Detect outliers using the Z-score method.
    
    Returns:
        Boolean tensor indicating which samples are outliers
    """
    # TODO: Implement outlier detection
    # Hint: Calculate Z-scores for each feature and burn time
    # Hint: Z-score = (value - mean) / std
    # A sample is an outlier if ANY feature has |Z-score| > threshold
    
    outliers = torch.zeros(len(features), dtype=torch.bool)
    
    # Check each feature column
    for i in range(features.shape[1]):
        # TODO: Calculate Z-scores for feature i
        # TODO: Mark samples with |Z-score| > threshold as outliers
        pass
    
    # TODO: Also check burn times for outliers
    
    return outliers

def robust_training_comparison():
    """Compare training on clean vs corrupted data, with and without outlier removal."""
    
    # Generate corrupted data
    corrupted_features, corrupted_burn_times = introduce_measurement_chaos(
        candle_features, burn_times, chaos_probability=0.15
    )
    
    # Detect outliers
    outliers = detect_outliers(corrupted_features, corrupted_burn_times)
    clean_mask = ~outliers
    
    print(f"🔍 He-Ao-World's Measurement Analysis:")
    print(f"  Original samples: {len(candle_features)}")
    print(f"  Detected outliers: {outliers.sum().item()}")
    print(f"  Clean samples remaining: {clean_mask.sum().item()}")
    
    # Train three models: clean, corrupted, and cleaned
    models = {
        'Original Clean': (candle_features, burn_times),
        'With Corruption': (corrupted_features, corrupted_burn_times),
        'Outliers Removed': (corrupted_features[clean_mask], corrupted_burn_times[clean_mask])
    }
    
    results = {}
    
    for name, (feats, targets) in models.items():
        model = CandleBurnPredictor(input_features=3)
        print(f"\nTraining on {name} data...")
        losses = train_multi_variable(model, feats, targets, epochs=1000, learning_rate=0.001)
        results[name] = {'model': model, 'losses': losses}
    
    # Visualize training comparison
    plt.figure(figsize=(12, 8))
    
    plt.subplot(2, 2, 1)
    for name, result in results.items():
        plt.plot(result['losses'], label=name, alpha=0.8)
    plt.title('Training Loss Comparison')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Show data distribution comparison
    plt.subplot(2, 2, 2)
    plt.hist(candle_features[:, 1].numpy(), alpha=0.5, label='Original Temp', bins=20)
    plt.hist(corrupted_features[:, 1].numpy(), alpha=0.5, label='Corrupted Temp', bins=20)
    plt.title('Temperature Distribution: Before/After Corruption')
    plt.xlabel('Temperature')
    plt.ylabel('Frequency')
    plt.legend()
    
    plt.subplot(2, 2, 3)
    plt.scatter(candle_features[:, 0], burn_times, alpha=0.5, label='Original', s=20)
    plt.scatter(corrupted_features[outliers, 0], corrupted_burn_times[outliers], 
               alpha=0.8, label='Outliers', s=30, color='red', marker='x')
    plt.title('Outliers in Wax Hardness vs Burn Time')
    plt.xlabel('Wax Hardness')
    plt.ylabel('Burn Time')
    plt.legend()
    
    plt.subplot(2, 2, 4)
    final_losses = [result['losses'][-1] for result in results.values()]
    plt.bar(results.keys(), final_losses, color=['green', 'red', 'blue'], alpha=0.7)
    plt.title('Final Training Loss Comparison')
    plt.ylabel('Final Loss')
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n🧹 He-Ao-World's Wisdom:")
    print(f"  'Sometimes the best way to clean data is to recognize what doesn't belong.'")
    for name, result in results.items():
        print(f"  {name}: Final loss = {result['losses'][-1]:.4f}")

# TRIAL: Handle corrupted measurements and outliers
# SUCCESS: Model trained on cleaned data performs better than corrupted data
robust_training_comparison()

### Extension 3: Master Pai-Torch's Gradient Wisdom
*"The path of learning is not always straight, young grasshopper."*

*Master Pai-Torch materializes beside you in contemplative silence*

"I observe that you train your networks with steady, unchanging steps. But consider the mountain climber - they take bold strides on gentle slopes, careful steps on steep terrain, and pause to rest when tired. The wise student learns to adjust their pace based on the terrain of the loss landscape."

**NEW CONCEPTS**: Learning rate scheduling, momentum, adaptive optimization  
**DIFFICULTY**: +35% (still Dan 1, but with advanced optimization)

In [None]:
def advanced_optimization_comparison():
    """Compare different optimization strategies for multi-variable regression."""
    
    # Different optimization strategies
    optimizers_config = {
        'SGD (Basic)': {'optimizer': 'sgd', 'lr': 0.001, 'momentum': 0},
        'SGD + Momentum': {'optimizer': 'sgd', 'lr': 0.001, 'momentum': 0.9},
        'Adam (Adaptive)': {'optimizer': 'adam', 'lr': 0.01},
        'SGD + Schedule': {'optimizer': 'sgd_schedule', 'lr': 0.01, 'momentum': 0.9}
    }
    
    results = {}
    
    for name, config in optimizers_config.items():
        print(f"\nTraining with {name}...")
        model = CandleBurnPredictor(input_features=3)
        
        # TODO: Create different optimizers based on config
        if config['optimizer'] == 'sgd':
            optimizer = None  # TODO: optim.SGD with momentum
        elif config['optimizer'] == 'adam':
            optimizer = None  # TODO: optim.Adam
        elif config['optimizer'] == 'sgd_schedule':
            optimizer = None  # TODO: optim.SGD with momentum
            scheduler = None  # TODO: optim.lr_scheduler.StepLR
        
        # Training loop with advanced optimization
        criterion = nn.MSELoss()
        losses = []
        learning_rates = []
        
        for epoch in range(1000):
            # TODO: Standard training step
            optimizer.zero_grad()
            predictions = model(candle_features)
            loss = criterion(predictions, burn_times)
            loss.backward()
            optimizer.step()
            
            # TODO: Update learning rate if using scheduler
            if config['optimizer'] == 'sgd_schedule':
                pass  # TODO: scheduler.step()
            
            losses.append(loss.item())
            learning_rates.append(optimizer.param_groups[0]['lr'])
            
            if (epoch + 1) % 200 == 0:
                print(f'Epoch [{epoch+1}/1000], Loss: {loss.item():.4f}, LR: {optimizer.param_groups[0]["lr"]:.6f}')
        
        results[name] = {
            'model': model,
            'losses': losses,
            'learning_rates': learning_rates
        }
    
    # Visualize optimization comparison
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss curves
    ax = axes[0, 0]
    for name, result in results.items():
        ax.plot(result['losses'], label=name, alpha=0.8)
    ax.set_title('Training Loss Comparison')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Learning rate schedules
    ax = axes[0, 1]
    for name, result in results.items():
        ax.plot(result['learning_rates'], label=name, alpha=0.8)
    ax.set_title('Learning Rate Over Time')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Learning Rate')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Final loss comparison
    ax = axes[1, 0]
    final_losses = [result['losses'][-1] for result in results.values()]
    bars = ax.bar(results.keys(), final_losses, color=['red', 'orange', 'green', 'blue'], alpha=0.7)
    ax.set_title('Final Loss Comparison')
    ax.set_ylabel('Final Loss')
    ax.tick_params(axis='x', rotation=45)
    
    # Convergence speed (epochs to reach 90% of final loss)
    ax = axes[1, 1]
    convergence_epochs = []
    for name, result in results.items():
        final_loss = result['losses'][-1]
        target_loss = final_loss * 1.1  # 10% above final loss
        convergence_epoch = next((i for i, loss in enumerate(result['losses']) if loss <= target_loss), 999)
        convergence_epochs.append(convergence_epoch)
    
    ax.bar(results.keys(), convergence_epochs, color=['red', 'orange', 'green', 'blue'], alpha=0.7)
    ax.set_title('Convergence Speed (Epochs to 90% of Final Loss)')
    ax.set_ylabel('Epochs')
    ax.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n🧙 Master Pai-Torch's Analysis:")
    print(f"  'Each optimization path teaches different lessons about the gradient landscape.'")
    for name, result in results.items():
        print(f"  {name}: Final loss = {result['losses'][-1]:.4f}, Convergence = {convergence_epochs[list(results.keys()).index(name)]} epochs")

# TRIAL: Compare different optimization strategies
# SUCCESS: Understand how different optimizers affect training dynamics
advanced_optimization_comparison()

### Extension 4: Suki's Multi-Dimensional Purring Oracle
*"The temple cat understands patterns that span multiple dimensions."*

*Suki sits majestically, then performs an elaborate sequence of meows*

*Master Pai-Torch translates: "The sacred cat says your linear wisdom across multiple dimensions is sound, but true understanding comes from seeing how all variables dance together in harmony. Can you predict not just burn time, but also understand which combination of conditions creates the most sacred flames?"*

**NEW CONCEPTS**: Multi-dimensional visualization, feature interaction analysis, model interpretation  
**DIFFICULTY**: +45% (still Dan 1, but with deep analysis)

In [None]:
def analyze_feature_interactions(model: CandleBurnPredictor):
    """Analyze how different feature combinations affect candle burn time."""
    
    # Create a grid of feature combinations
    wax_range = torch.linspace(20, 80, 20)
    temp_range = torch.linspace(15, 30, 20)
    humidity_range = torch.linspace(30, 70, 20)
    
    # TODO: Generate predictions for all combinations
    # Hint: Use torch.meshgrid to create all combinations
    # Hint: Reshape to create a feature matrix for prediction
    
    # Find optimal conditions
    print("🔍 Suki's Multi-Dimensional Analysis:")
    
    # TODO: Find the feature combination that gives maximum burn time
    # TODO: Find the feature combination that gives minimum burn time
    # TODO: Analyze how each feature affects the others
    
    pass

def create_sacred_flame_predictor(model: CandleBurnPredictor):
    """Create an interactive tool to predict candle burn time for any combination."""
    
    def predict_burn_time(wax_hardness: float, temperature: float, humidity: float) -> float:
        """Predict burn time for specific candle conditions."""
        features = torch.tensor([[wax_hardness, temperature, humidity]])
        with torch.no_grad():
            prediction = model(features)
        return prediction.item()
    
    print("🕯️ SACRED FLAME PREDICTOR")
    print("Enter candle conditions to predict burn time:")
    
    # Test some interesting combinations
    test_conditions = [
        (80, 15, 30, "Optimal for long meditation"),
        (20, 30, 70, "Quick evening prayers"),
        (50, 22, 50, "Balanced conditions"),
        (70, 20, 40, "Winter temple setting"),
        (40, 28, 60, "Summer temple setting")
    ]
    
    for wax, temp, hum, description in test_conditions:
        burn_time = predict_burn_time(wax, temp, hum)
        print(f"  {description}: {burn_time:.1f} minutes")
        print(f"    (Wax: {wax}%, Temp: {temp}°C, Humidity: {hum}%)")
    
    return predict_burn_time

def visualize_multi_dimensional_wisdom(model: CandleBurnPredictor):
    """Create advanced visualizations of the multi-dimensional relationship."""
    
    # Create 3D visualization of feature interactions
    fig = plt.figure(figsize=(15, 12))
    
    # 3D scatter plot of actual data
    ax1 = fig.add_subplot(221, projection='3d')
    scatter = ax1.scatter(candle_features[:, 0], candle_features[:, 1], candle_features[:, 2], 
                         c=burn_times.squeeze(), cmap='viridis', alpha=0.6)
    ax1.set_xlabel('Wax Hardness (%)')
    ax1.set_ylabel('Temperature (°C)')
    ax1.set_zlabel('Humidity (%)')
    ax1.set_title('3D Candle Data (Color = Burn Time)')
    plt.colorbar(scatter, ax=ax1, shrink=0.5)
    
    # Heatmap of burn time vs two features (fixing third)
    ax2 = fig.add_subplot(222)
    
    # TODO: Create a heatmap showing burn time for different wax/temperature combinations
    # Fix humidity at 50% and vary wax hardness and temperature
    
    # Feature importance visualization
    ax3 = fig.add_subplot(223)
    weights = model.linear.weight.data.squeeze().abs()
    feature_names = ['Wax Hardness', 'Temperature', 'Humidity']
    bars = ax3.bar(feature_names, weights, color=['brown', 'red', 'blue'], alpha=0.7)
    ax3.set_title('Feature Importance (Absolute Weight Values)')
    ax3.set_ylabel('Absolute Weight')
    
    # Residual analysis
    ax4 = fig.add_subplot(224)
    with torch.no_grad():
        predictions = model(candle_features)
    residuals = (predictions - burn_times).squeeze()
    ax4.scatter(predictions.squeeze(), residuals, alpha=0.6)
    ax4.axhline(y=0, color='red', linestyle='--', alpha=0.7)
    ax4.set_xlabel('Predicted Burn Time')
    ax4.set_ylabel('Residuals (Predicted - Actual)')
    ax4.set_title('Residual Analysis')
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\n🐱 Suki's Final Wisdom:")
    print("  *Purrs approvingly while demonstrating multi-dimensional understanding*")
    print(f"  Model R² score: {1 - (residuals.var() / burn_times.var()):.3f}")
    print(f"  Mean absolute error: {torch.mean(torch.abs(residuals)):.2f} minutes")

# TRIAL: Analyze multi-dimensional feature relationships
# SUCCESS: Understand how all features work together to predict burn time
analyze_feature_interactions(model)
predictor = create_sacred_flame_predictor(model)
visualize_multi_dimensional_wisdom(model)

# MASTERY: Create your own candle condition and predict its burn time
print("\n🎓 MASTERY CHALLENGE:")
print("Create your own candle condition and predict its burn time using the predictor function!")
print("Example: predictor(60, 25, 45)")

# Test your understanding
custom_prediction = predictor(60, 25, 45)
print(f"Your custom candle (60% wax, 25°C, 45% humidity): {custom_prediction:.1f} minutes")

## 🔥 CORRECTING YOUR FORM: A STANCE IMBALANCE

Master Pai-Torch observes your multi-variable training ritual with a careful eye. "Your eager mind grasps the complexity of multiple dimensions, grasshopper, but I sense a disturbance in your gradient flow. Your stance wavers when handling the sacred multi-dimensional tensors."

A previous disciple left this flawed multi-variable training ritual. The form has become unsteady across multiple dimensions - can you restore proper technique?

In [None]:
def unsteady_multi_variable_training(model, features, target, epochs=1000):
    """This multi-variable training stance has lost its balance - your form needs correction! 🥋"""
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    
    for epoch in range(epochs):
        # Forward pass with multiple features
        predictions = model(features)
        loss = criterion(predictions, target)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        if epoch % 200 == 0:
            print(f'Epoch {epoch}: Loss = {loss.item():.4f}')
    
    return model

# Test the flawed training - notice how the loss behaves strangely
print("🚨 Testing the unsteady training ritual...")
flawed_model = CandleBurnPredictor(input_features=3)
flawed_model = unsteady_multi_variable_training(flawed_model, candle_features, burn_times)

print("\n🧙 Master Pai-Torch's Guidance:")
print("'The undisciplined mind accumulates old thoughts across all dimensions,'")
print("'just as the untrained gradient accumulates old directions from multiple features.'")
print("\n🔍 DEBUGGING CHALLENGE:")
print("Can you spot the critical error in this multi-variable training ritual?")
print("HINT: The Gradient Spirits from ALL features are not being properly dismissed between cycles")
print("HINT: In multi-variable regression, accumulated gradients affect ALL weight updates")
print("\n💡 SOLUTION: Add the missing gradient clearing step and observe the difference!")