[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ruliana/pytorch-katas/blob/main/dan_1/kata_02_temple_cat_feeding_predictor_unrevised.ipynb)

## 🏮 The Ancient Scroll Unfurls 🏮

THE SACRED HUNGER ORACLE: MASTERING SUKI'S FEEDING PROPHECY

Dan Level: 1 (Temple Sweeper) | Time: 45 minutes | Sacred Arts: Linear Regression, Gradient Descent, Training Loops

## 📜 THE CHALLENGE

In the quiet hours before dawn, as you sweep the temple courtyard, you notice something peculiar. **Suki**, the sacred temple cat, appears at the feeding bowl with uncanny precision - never too early, never too late, always exactly when their hunger reaches the perfect threshold.

**Master Pai-Torch** emerges from the shadows, observing your curious gaze.

*"Ah, young grasshopper, you have discovered the first hidden technique. The wise observe patterns where others see chaos. Tell me, what do you see in the sacred cat's behavior?"*

You share your observations: "Master, Suki always appears when they've been without food for a certain time. It's as if there's a mathematical relationship between hours and hunger."

*Master Pai-Torch nods slowly.* "Indeed. The ancient temple keepers knew this secret - **hunger grows linearly with time**. But the true art lies not in knowing this truth, but in teaching a neural network to discover it. This is your first lesson in the sacred art of **gradient descent**."

He gestures toward the feeding bowl. "Your task is to create a mystical artifact - a neural network that can predict Suki's exact hunger level based on hours since their last meal. Master this, and you will understand the fundamental flow of gradients through the sacred tensor realms."

## 🎯 THE SACRED OBJECTIVES

- [ ] Create a linear neural network that learns the relationship between time and hunger
- [ ] Implement proper gradient descent training with loss computation
- [ ] Achieve convergence with final loss below 50
- [ ] Understand how gradients flow backward through your network
- [ ] Visualize the learned relationship and validate predictions
- [ ] Master the sacred ritual of `optimizer.zero_grad()`

## 🎲 THE SACRED IMPORTS

First, we gather the mystical tools needed for our neural arts:

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from typing import Tuple

# For reproducible mystical experiences
torch.manual_seed(42)
np.random.seed(42)

print("🔮 The sacred tools are ready, grasshopper!")
print(f"PyTorch version: {torch.__version__}")

## 🐱 THE SACRED DATA GENERATION SCROLL

Master Pai-Torch whispers the ancient secret: *"The sacred relationship follows the pattern: `hunger_level = 2.5 * hours_since_last_meal + 20`. When hunger exceeds 70, Suki appears at the bowl."*

In [None]:
def generate_cat_feeding_data(n_observations: int = 100, chaos_level: float = 0.1,
                             sacred_seed: int = 42) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Generate observations of Suki's feeding patterns.

    Ancient wisdom suggests: hunger_level = 2.5 * hours_since_last_meal + 20
    When hunger_level > 70, Suki appears at the food bowl.

    Args:
        n_observations: Number of Suki sightings to simulate
        chaos_level: Amount of feline unpredictability (0.0 = perfectly predictable cat, 1.0 = pure chaos)
        sacred_seed: Ensures consistent randomness for reproducible wisdom

    Returns:
        Tuple of (hours_since_last_meal, hunger_level) as sacred tensors
    """
    torch.manual_seed(sacred_seed)
    np.random.seed(sacred_seed)

    # Suki can go 0-30 hours between meals (she's very dramatic)
    hours_since_meal = torch.rand(n_observations, 1) * 30

    # The sacred relationship known to ancient cat scholars
    base_hunger = 20
    hunger_per_hour = 2.5

    hunger_levels = hunger_per_hour * hours_since_meal.squeeze() + base_hunger

    # Add feline chaos (cats are unpredictable creatures)
    chaos = torch.randn(n_observations) * chaos_level * hunger_levels.std()
    hunger_levels = hunger_levels + chaos

    # Even mystical cats have limits
    hunger_levels = torch.clamp(hunger_levels, 0, 100)

    return hours_since_meal, hunger_levels.unsqueeze(1)

def visualize_cat_wisdom(hours: torch.Tensor, hunger: torch.Tensor,
                        predictions: torch.Tensor = None):
    """Display the sacred patterns of Suki's appetite."""
    plt.figure(figsize=(12, 7))
    plt.scatter(hours.numpy(), hunger.numpy(), alpha=0.6, color='purple',
                label='Suki\'s Actual Hunger Levels')

    if predictions is not None:
        sorted_indices = torch.argsort(hours.squeeze())
        sorted_hours = hours[sorted_indices]
        sorted_predictions = predictions[sorted_indices]
        plt.plot(sorted_hours.numpy(), sorted_predictions.detach().numpy(),
                'gold', linewidth=3, label='Your Mystical Predictions')

    plt.axhline(y=70, color='red', linestyle='--', alpha=0.7,
                label='Sacred Feeding Threshold (Suki Appears!)')
    plt.xlabel('Hours Since Last Meal')
    plt.ylabel('Suki\'s Hunger Level')
    plt.title('🐱 The Mysteries of Temple Cat Appetite 🐱')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.ylim(0, 100)
    plt.show()

# Generate the sacred data
hours, hunger = generate_cat_feeding_data(n_observations=150, chaos_level=0.1)
visualize_cat_wisdom(hours, hunger)

print(f"📊 Generated {len(hours)} observations of Suki's feeding patterns")
print(f"🕐 Hour range: {hours.min():.1f} to {hours.max():.1f}")
print(f"🍽️ Hunger range: {hunger.min():.1f} to {hunger.max():.1f}")

## 💃 FIRST MOVEMENTS: THE HUNGER PREDICTION ARTIFACT

Master Pai-Torch gestures toward an empty scroll. *"Now, young grasshopper, create the sacred artifact that will learn Suki's patterns. A simple linear network - one that transforms hours into hunger through the mystical art of matrix multiplication."*

In [None]:
class CatHungerPredictor(nn.Module):
    """A mystical artifact for understanding feline appetite patterns."""

    def __init__(self, input_features: int = 1):
        super(CatHungerPredictor, self).__init__()
        # TODO: Create the Linear layer
        # Hint: torch.nn.Linear transforms input energy into output wisdom
        # It needs to know: input_features -> output_features (1 for hunger level)
        self.linear = None

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        """Channel your understanding through the mystical network."""
        # TODO: Pass the input through your Linear layer
        # Remember: even cats follow mathematical laws
        return None

# Create your mystical artifact
model = CatHungerPredictor(input_features=1)
print("🏛️ Your hunger prediction artifact has been forged!")
print(f"Model parameters: {sum(p.numel() for p in model.parameters())} sacred numbers")

## 🌊 THE SACRED TRAINING RITUAL

Master Pai-Torch closes his eyes in concentration. *"Now comes the most crucial lesson, grasshopper. The network must learn through the sacred flow of gradients. Each cycle, it makes a prediction, measures its error, and adjusts its understanding. This is the essence of all neural wisdom."*

In [None]:
def train_hunger_predictor(model: nn.Module, features: torch.Tensor, target: torch.Tensor,
                          epochs: int = 1000, learning_rate: float = 0.01) -> list:
    """
    Train the cat hunger prediction model through the sacred ritual of gradient descent.

    Args:
        model: Your mystical hunger predictor
        features: Hours since last meal (input)
        target: Actual hunger levels (what we want to predict)
        epochs: Number of training cycles
        learning_rate: How big steps to take during learning

    Returns:
        List of loss values during training
    """
    # TODO: Choose your loss calculation method
    # Hint: Mean Squared Error is favored by the ancient masters
    # It measures the average squared difference between predictions and truth
    criterion = None

    # TODO: Choose your parameter updating method
    # Hint: SGD (Stochastic Gradient Descent) is the traditional path
    # It needs to know: which parameters to update and how fast to learn
    optimizer = None

    losses = []

    for epoch in range(epochs):
        # TODO: CRITICAL - Clear the gradient spirits from previous cycle
        # Hint: The spirits accumulate if not banished properly
        # Use: optimizer.zero_grad()

        # TODO: Forward pass - get predictions from your model
        # Hint: Pass the features through your model
        predictions = None

        # TODO: Compute the loss
        # Hint: Use your criterion to compare predictions with target
        loss = None

        # TODO: Backward pass - compute gradients
        # Hint: loss.backward() summons the gradient spirits

        # TODO: Update parameters
        # Hint: optimizer.step() applies the gradient wisdom

        losses.append(loss.item())

        # Report progress to Master Pai-Torch
        if (epoch + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
            if loss.item() < 10:
                print("💫 The Gradient Spirits smile upon your progress!")

    return losses

# Begin the sacred training ritual
print("🧘 Beginning the sacred training ritual...")
loss_history = train_hunger_predictor(model, hours, hunger, epochs=1000, learning_rate=0.01)

# Visualize the learning journey
plt.figure(figsize=(10, 6))
plt.plot(loss_history)
plt.title('🌊 The Sacred Loss Journey')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)
plt.show()

print(f"✨ Training complete! Final loss: {loss_history[-1]:.4f}")

## 🔮 REVEALING THE LEARNED WISDOM

Master Pai-Torch nods approvingly. *"Now, grasshopper, let us see what your network has learned. The true test is not just low loss, but understanding the sacred relationship itself."*

In [None]:
# Make predictions with your trained model
model.eval()  # Set to evaluation mode
with torch.no_grad():
    predictions = model(hours)

# Visualize the learned relationship
visualize_cat_wisdom(hours, hunger, predictions)

# Examine the learned parameters
learned_weight = model.linear.weight.item()
learned_bias = model.linear.bias.item()

print(f"🎯 Your network learned:")
print(f"   Weight (hunger per hour): {learned_weight:.3f}")
print(f"   Bias (base hunger): {learned_bias:.3f}")
print(f"\n📚 The ancient truth:")
print(f"   True weight: 2.5")
print(f"   True bias: 20")
print(f"\n🎉 Accuracy: Weight is {abs(learned_weight - 2.5):.3f} away from truth")
print(f"           Bias is {abs(learned_bias - 20):.3f} away from truth")

## ⚡ THE TRIALS OF MASTERY

Master Pai-Torch strokes his beard thoughtfully. *"Before you can advance, you must prove your understanding through the sacred trials."*

In [None]:
def test_your_wisdom(model):
    """Master Pai-Torch's evaluation of your understanding."""
    model.eval()
    
    # Test 1: Shape consistency
    test_features = torch.tensor([[5.0], [10.0], [20.0]])
    with torch.no_grad():
        predictions = model(test_features)
    assert predictions.shape == (3, 1), f"Expected shape (3, 1), got {predictions.shape}"
    print("✅ Shape test passed - your tensors align with the sacred geometry!")

    # Test 2: Parameter validation
    weight = model.linear.weight.item()
    bias = model.linear.bias.item()
    
    assert 2.0 <= weight <= 3.0, f"Weight {weight:.2f} seems off - cats are more predictable!"
    assert 15 <= bias <= 25, f"Bias {bias:.2f} - even well-fed cats have base hunger!"
    print("✅ Parameter test passed - your network understands cat nature!")

    # Test 3: Logical predictions
    with torch.no_grad():
        pred_5h = model(torch.tensor([[5.0]])).item()
        pred_10h = model(torch.tensor([[10.0]])).item()
        pred_20h = model(torch.tensor([[20.0]])).item()
    
    assert pred_5h < pred_10h < pred_20h, "Hunger should increase with time!"
    print("✅ Logic test passed - longer waits mean hungrier cats!")

    print("\n🎉 Master Pai-Torch nods with approval - your understanding grows!")
    print(f"   Predicted hunger after 5 hours: {pred_5h:.1f}")
    print(f"   Predicted hunger after 10 hours: {pred_10h:.1f}")
    print(f"   Predicted hunger after 20 hours: {pred_20h:.1f}")

# Test your wisdom
test_your_wisdom(model)

# Final mastery check
final_loss = loss_history[-1]
print(f"\n📊 MASTERY EVALUATION:")
print(f"   Final Loss: {final_loss:.4f} {'✅' if final_loss < 50 else '❌'} (Target: < 50)")
print(f"   Weight Accuracy: {abs(learned_weight - 2.5):.3f} {'✅' if abs(learned_weight - 2.5) < 0.5 else '❌'} (Target: < 0.5)")
print(f"   Bias Accuracy: {abs(learned_bias - 20):.3f} {'✅' if abs(learned_bias - 20) < 5 else '❌'} (Target: < 5)")

if final_loss < 50 and abs(learned_weight - 2.5) < 0.5 and abs(learned_bias - 20) < 5:
    print("\n🏆 CONGRATULATIONS! You have mastered the first sacred art!")
    print("🐱 Suki purrs approvingly - your network understands cat wisdom!")
else:
    print("\n🔄 Not quite there yet, grasshopper. Review your training ritual.")

## 🌸 THE FOUR PATHS OF MASTERY: PROGRESSIVE EXTENSIONS

Master Pai-Torch gestures toward four different paths leading from the courtyard. *"Your foundation is strong, but true mastery requires walking different paths. Each will teach you new aspects of the sacred arts."*

### Extension 1: Cook Oh-Pai-Timizer's Batch Cooking Wisdom
*"A good cook knows that batch size affects the final dish!"*

Cook Oh-Pai-Timizer bustles over, wooden spoon in hand.

*"Ah, grasshopper! I see you've mastered feeding one cat at a time. But what happens when you need to predict hunger for multiple cats simultaneously? In my kitchen, efficiency comes from preparing multiple servings at once!"*

**NEW CONCEPTS:** Batch processing, tensor shapes, vectorized operations  
**DIFFICULTY:** +15% (still Dan 1, but with batches)

In [None]:
def generate_multi_cat_data(n_cats: int = 5, observations_per_cat: int = 50, 
                           sacred_seed: int = 42) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Generate feeding data for multiple temple cats at once.
    
    Returns:
        Tuple of (batch_hours, batch_hunger_levels)
        Shape: (n_cats * observations_per_cat, 1) for both tensors
    """
    torch.manual_seed(sacred_seed)
    
    all_hours = []
    all_hunger = []
    
    for cat_id in range(n_cats):
        # Each cat has slightly different eating patterns
        cat_chaos = 0.05 + (cat_id * 0.02)  # Some cats are more predictable
        hours, hunger = generate_cat_feeding_data(
            n_observations=observations_per_cat, 
            chaos_level=cat_chaos,
            sacred_seed=sacred_seed + cat_id
        )
        all_hours.append(hours)
        all_hunger.append(hunger)
    
    # Combine all cats into one large batch
    batch_hours = torch.cat(all_hours, dim=0)
    batch_hunger = torch.cat(all_hunger, dim=0)
    
    return batch_hours, batch_hunger

# Generate multi-cat data
batch_hours, batch_hunger = generate_multi_cat_data(n_cats=5, observations_per_cat=50)

print(f"🐱 Generated data for 5 temple cats")
print(f"📊 Batch shape: {batch_hours.shape[0]} total observations")

# Test your existing model on batch data
model.eval()
with torch.no_grad():
    batch_predictions = model(batch_hours)

# Visualize batch predictions
visualize_cat_wisdom(batch_hours, batch_hunger, batch_predictions)

# Calculate batch performance
batch_loss = nn.MSELoss()(batch_predictions, batch_hunger)
print(f"\n🍽️ Batch prediction loss: {batch_loss.item():.4f}")
print(f"🎯 SUCCESS: Your model processes multiple cats simultaneously!")
print(f"   Batch processing is a fundamental skill for efficient neural networks.")

### Extension 2: He-Ao-World's Measurement Mix-up
*"These old eyes sometimes read the measuring scrolls incorrectly..."*

He-Ao-World shuffles over, looking apologetic.

*"Oh dear! I was recording Suki's feeding times and... well, I might have mixed up some of the measurements. Some are in minutes instead of hours, and others might be twice what they should be. The data looks a bit... chaotic now."*

**NEW CONCEPTS:** Data normalization, feature scaling, handling inconsistent units  
**DIFFICULTY:** +25% (still Dan 1, but messier data)

In [None]:
def create_messy_data(clean_hours: torch.Tensor, clean_hunger: torch.Tensor,
                     contamination_rate: float = 0.3) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    He-Ao-World's 'accidents' create measurement inconsistencies.
    
    Args:
        contamination_rate: Fraction of data to corrupt
    
    Returns:
        Tuple of (messy_hours, messy_hunger)
    """
    messy_hours = clean_hours.clone()
    messy_hunger = clean_hunger.clone()
    
    n_contaminated = int(len(messy_hours) * contamination_rate)
    contaminated_indices = torch.randperm(len(messy_hours))[:n_contaminated]
    
    for idx in contaminated_indices:
        accident_type = torch.randint(0, 3, (1,)).item()
        
        if accident_type == 0:
            # Measured in minutes instead of hours
            messy_hours[idx] = messy_hours[idx] * 60
        elif accident_type == 1:
            # Double measurement error
            messy_hours[idx] = messy_hours[idx] * 2
        else:
            # Random extra noise
            messy_hours[idx] = messy_hours[idx] + torch.randn(1) * 5
    
    return messy_hours, messy_hunger

def normalize_feeding_data(hours_since_meal: torch.Tensor, 
                          hunger_levels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, dict]:
    """
    Clean and normalize the feeding data to handle measurement inconsistencies.
    
    Returns:
        Tuple of (normalized_hours, normalized_hunger, normalization_params)
    """
    # Calculate normalization parameters
    hours_mean = hours_since_meal.mean()
    hours_std = hours_since_meal.std()
    hunger_mean = hunger_levels.mean()
    hunger_std = hunger_levels.std()
    
    # Normalize (subtract mean, divide by std)
    normalized_hours = (hours_since_meal - hours_mean) / hours_std
    normalized_hunger = (hunger_levels - hunger_mean) / hunger_std
    
    # Store parameters for denormalization later
    params = {
        'hours_mean': hours_mean,
        'hours_std': hours_std,
        'hunger_mean': hunger_mean,
        'hunger_std': hunger_std
    }
    
    return normalized_hours, normalized_hunger, params

# Create messy data
messy_hours, messy_hunger = create_messy_data(hours, hunger, contamination_rate=0.3)

print("🤦 He-Ao-World's measurement mishaps:")
print(f"   Original hour range: {hours.min():.1f} to {hours.max():.1f}")
print(f"   Messy hour range: {messy_hours.min():.1f} to {messy_hours.max():.1f}")

# Visualize the chaos
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.scatter(hours.numpy(), hunger.numpy(), alpha=0.6, color='blue', label='Clean Data')
plt.title('Original Clean Data')
plt.xlabel('Hours')
plt.ylabel('Hunger')
plt.legend()

plt.subplot(1, 2, 2)
plt.scatter(messy_hours.numpy(), messy_hunger.numpy(), alpha=0.6, color='red', label='Messy Data')
plt.title('He-Ao-World\'s Messy Data')
plt.xlabel('Hours (inconsistent units!)')
plt.ylabel('Hunger')
plt.legend()
plt.tight_layout()
plt.show()

# Normalize the messy data
norm_hours, norm_hunger, norm_params = normalize_feeding_data(messy_hours, messy_hunger)

# Train a new model on normalized data
normalized_model = CatHungerPredictor(input_features=1)
normalized_losses = train_hunger_predictor(normalized_model, norm_hours, norm_hunger, 
                                         epochs=800, learning_rate=0.01)

print(f"\n🧹 Normalized training complete!")
print(f"   Final loss with normalization: {normalized_losses[-1]:.4f}")
print(f"   Original messy data would have loss: {nn.MSELoss()(model(messy_hours), messy_hunger).item():.4f}")
print(f"\n🎯 SUCCESS: Normalization makes messy data trainable!")
print(f"   This is a crucial skill for handling real-world data inconsistencies.")

### Extension 3: Master Pai-Torch's Patience Teaching
*"The eager student trains too quickly and learns too little."*

Master Pai-Torch sits in contemplative silence.

*"Young grasshopper, I observe your training ritual rushes like a mountain stream. But wisdom comes to those who vary their pace. Sometimes we must step boldly, sometimes cautiously, sometimes we must rest entirely."*

**NEW CONCEPTS:** Learning rate scheduling, early stopping, training patience  
**DIFFICULTY:** +35% (still Dan 1, but smarter training)

In [None]:
def patient_training_ritual(model: nn.Module, features: torch.Tensor, target: torch.Tensor,
                           epochs: int = 2000, initial_lr: float = 0.1, 
                           patience: int = 100) -> Tuple[list, bool]:
    """
    Train with patience and adaptive learning rate.
    
    Args:
        patience: Stop training if loss doesn't improve for this many epochs
        initial_lr: Starting learning rate
    
    Returns:
        Tuple of (loss_history, stopped_early)
    """
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=initial_lr)
    
    # Learning rate scheduler - reduces LR when loss plateaus
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=50, verbose=True
    )
    
    losses = []
    best_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        predictions = model(features)
        loss = criterion(predictions, target)
        loss.backward()
        optimizer.step()
        
        current_loss = loss.item()
        losses.append(current_loss)
        
        # Update learning rate based on loss
        scheduler.step(current_loss)
        
        # Check for improvement
        if current_loss < best_loss:
            best_loss = current_loss
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Early stopping
        if patience_counter >= patience:
            print(f"\n⏰ Early stopping at epoch {epoch+1} - no improvement for {patience} epochs")
            print(f"   Best loss achieved: {best_loss:.4f}")
            return losses, True
        
        # Progress reporting
        if (epoch + 1) % 200 == 0:
            current_lr = optimizer.param_groups[0]['lr']
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {current_loss:.4f}, LR: {current_lr:.6f}')
    
    return losses, False

# Compare patient vs rushed training
print("🏃 Training with rush (original method):")
rushed_model = CatHungerPredictor(input_features=1)
rushed_losses = train_hunger_predictor(rushed_model, hours, hunger, epochs=1000, learning_rate=0.01)

print("\n🧘 Training with patience (adaptive method):")
patient_model = CatHungerPredictor(input_features=1)
patient_losses, stopped_early = patient_training_ritual(patient_model, hours, hunger,
                                                       epochs=2000, initial_lr=0.1, patience=100)

# Compare results
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(rushed_losses, label='Rushed Training', color='red')
plt.plot(patient_losses, label='Patient Training', color='blue')
plt.title('Training Comparison')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
# Show final predictions
rushed_model.eval()
patient_model.eval()
with torch.no_grad():
    rushed_pred = rushed_model(hours)
    patient_pred = patient_model(hours)

plt.scatter(hours.numpy(), hunger.numpy(), alpha=0.4, color='gray', label='True Data')
plt.scatter(hours.numpy(), rushed_pred.numpy(), alpha=0.6, color='red', label='Rushed', s=10)
plt.scatter(hours.numpy(), patient_pred.numpy(), alpha=0.6, color='blue', label='Patient', s=10)
plt.title('Prediction Comparison')
plt.xlabel('Hours')
plt.ylabel('Hunger')
plt.legend()
plt.tight_layout()
plt.show()

print(f"\n📊 PATIENCE COMPARISON:")
print(f"   Rushed final loss: {rushed_losses[-1]:.4f}")
print(f"   Patient final loss: {patient_losses[-1]:.4f}")
print(f"   Epochs used: {len(rushed_losses)} vs {len(patient_losses)}")
print(f"   Stopped early: {stopped_early}")
print(f"\n🎯 SUCCESS: Patient training often achieves better results!")
print(f"   Learning rate adaptation and early stopping are essential skills.")

### Extension 4: Suki's Feeding Threshold Mystery
*"Understanding when the cat appears is as important as predicting hunger."*

Suki sits majestically, then meows once.

Master Pai-Torch translates: *"The sacred cat says your linear wisdom is sound, but the true test is knowing when hunger becomes action. At what point does prediction become decision?"*

**NEW CONCEPTS:** Threshold analysis, decision boundaries, model interpretation  
**DIFFICULTY:** +45% (still Dan 1, but thinking beyond prediction)

In [None]:
def analyze_feeding_threshold(model: nn.Module, features: torch.Tensor, target: torch.Tensor,
                            threshold_candidates: list = [60, 65, 70, 75, 80]) -> dict:
    """
    Analyze how well your model predicts when Suki will actually appear.
    
    Returns:
        Dictionary of {threshold: accuracy_score}
    """
    model.eval()
    with torch.no_grad():
        predictions = model(features)
    
    threshold_results = {}
    
    for threshold in threshold_candidates:
        # Model's binary prediction: will Suki appear?
        pred_appears = (predictions > threshold).float()
        
        # Ground truth: does Suki actually appear?
        true_appears = (target > threshold).float()
        
        # Calculate accuracy
        correct_predictions = (pred_appears == true_appears).float()
        accuracy = correct_predictions.mean().item()
        
        threshold_results[threshold] = accuracy
    
    return threshold_results

def visualize_decision_boundary(model: nn.Module, features: torch.Tensor, 
                               target: torch.Tensor, best_threshold: float):
    """
    Show where your model draws the line between "hungry" and "will appear"
    """
    model.eval()
    with torch.no_grad():
        predictions = model(features)
    
    # Create decision regions
    pred_appears = predictions > best_threshold
    true_appears = target > best_threshold
    
    # Classify predictions
    true_positives = (pred_appears & true_appears).squeeze()
    false_positives = (pred_appears & ~true_appears).squeeze()
    true_negatives = (~pred_appears & ~true_appears).squeeze()
    false_negatives = (~pred_appears & true_appears).squeeze()
    
    plt.figure(figsize=(14, 8))
    
    # Plot the different prediction types
    plt.scatter(features[true_positives].numpy(), target[true_positives].numpy(), 
               color='green', alpha=0.7, label='True Positive (Correctly predicted appearance)', s=50)
    plt.scatter(features[false_positives].numpy(), target[false_positives].numpy(), 
               color='orange', alpha=0.7, label='False Positive (Predicted appearance wrongly)', s=50)
    plt.scatter(features[true_negatives].numpy(), target[true_negatives].numpy(), 
               color='lightblue', alpha=0.7, label='True Negative (Correctly predicted no appearance)', s=50)
    plt.scatter(features[false_negatives].numpy(), target[false_negatives].numpy(), 
               color='red', alpha=0.7, label='False Negative (Missed appearance)', s=50)
    
    # Plot model predictions as a line
    sorted_indices = torch.argsort(features.squeeze())
    sorted_features = features[sorted_indices]
    sorted_predictions = predictions[sorted_indices]
    plt.plot(sorted_features.numpy(), sorted_predictions.numpy(), 'black', linewidth=2, label='Model Predictions')
    
    # Plot decision threshold
    plt.axhline(y=best_threshold, color='red', linestyle='--', linewidth=2, 
                label=f'Decision Threshold = {best_threshold}')
    
    plt.xlabel('Hours Since Last Meal')
    plt.ylabel('Hunger Level')
    plt.title('🐱 Decision Boundary Analysis: When Will Suki Appear?')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

# Analyze different thresholds
threshold_results = analyze_feeding_threshold(model, hours, hunger, 
                                            threshold_candidates=[60, 65, 70, 75, 80])

# Find the best threshold
best_threshold = max(threshold_results.keys(), key=lambda k: threshold_results[k])
best_accuracy = threshold_results[best_threshold]

print("🎯 THRESHOLD ANALYSIS RESULTS:")
for threshold, accuracy in threshold_results.items():
    marker = "⭐" if threshold == best_threshold else "  "
    print(f"{marker} Threshold {threshold}: {accuracy:.3f} accuracy")

print(f"\n🏆 Best threshold: {best_threshold} with {best_accuracy:.3f} accuracy")

# Visualize the decision boundary
visualize_decision_boundary(model, hours, hunger, best_threshold)

# Calculate detailed metrics
model.eval()
with torch.no_grad():
    predictions = model(hours)
    
pred_appears = predictions > best_threshold
true_appears = hunger > best_threshold

true_positives = (pred_appears & true_appears).sum().item()
false_positives = (pred_appears & ~true_appears).sum().item()
true_negatives = (~pred_appears & ~true_appears).sum().item()
false_negatives = (~pred_appears & true_appears).sum().item()

precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0

print(f"\n📊 DETAILED METRICS:")
print(f"   Precision: {precision:.3f} (When we predict Suki will appear, how often are we right?)")
print(f"   Recall: {recall:.3f} (When Suki actually appears, how often do we catch it?)")
print(f"   True Positives: {true_positives} (Correctly predicted appearances)")
print(f"   False Positives: {false_positives} (False alarms)")
print(f"   True Negatives: {true_negatives} (Correctly predicted no appearance)")
print(f"   False Negatives: {false_negatives} (Missed appearances)")

print(f"\n🎯 SUCCESS CRITERIA:")
print(f"   Accuracy > 0.8: {'✅' if best_accuracy > 0.8 else '❌'} ({best_accuracy:.3f})")
print(f"   Precision > 0.7: {'✅' if precision > 0.7 else '❌'} ({precision:.3f})")
print(f"   Recall > 0.7: {'✅' if recall > 0.7 else '❌'} ({recall:.3f})")

if best_accuracy > 0.8 and precision > 0.7 and recall > 0.7:
    print("\n🏆 MASTERY ACHIEVED!")
    print("🐱 You understand that prediction and decision-making are different skills!")
    print("🎯 You've learned to optimize thresholds for practical applications!")
else:
    print("\n🔄 Good progress! Understanding decision boundaries is advanced knowledge.")
    print("   Try adjusting your model training or exploring different thresholds.")

## 🔥 CORRECTING YOUR FORM: A STANCE IMBALANCE

Master Pai-Torch observes your training ritual with a careful eye. *"Your eager mind races ahead of your disciplined form, grasshopper. See how your gradient flow stance wavers?"*

A previous disciple left this flawed training ritual. Your form has become unsteady - can you restore proper technique?

In [None]:
def unsteady_training(model: nn.Module, features: torch.Tensor, target: torch.Tensor, 
                     epochs: int = 1000) -> list:
    """This training stance has lost its balance - your form needs correction! 🥋"""
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    
    losses = []
    
    for epoch in range(epochs):
        # Forward pass
        predictions = model(features)
        loss = criterion(predictions, target)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        if epoch % 100 == 0:
            print(f'Epoch {epoch}: Loss = {loss.item():.4f}')
    
    return losses

# Test the flawed training
print("🥋 Testing the previous disciple's flawed training technique...")
flawed_model = CatHungerPredictor(input_features=1)
flawed_losses = unsteady_training(flawed_model, hours, hunger, epochs=500)

# Visualize the problem
plt.figure(figsize=(10, 6))
plt.plot(flawed_losses, color='red', linewidth=2, label='Flawed Training')
plt.title('🔥 The Flawed Training Ritual')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print(f"\n🚨 DEBUGGING CHALLENGE:")
print(f"   The loss seems to be {'climbing' if flawed_losses[-1] > flawed_losses[0] else 'unstable'}!")
print(f"   Final loss: {flawed_losses[-1]:.4f}")
print(f"   This is much worse than your proper training: {loss_history[-1]:.4f}")

print(f"\n🧐 MASTER'S WISDOM:")
print(f"   'The undisciplined mind accumulates old thoughts,'")
print(f"   'just as the untrained gradient accumulates old directions.'")
print(f"\n🔍 HINT: Look carefully at the training loop.")
print(f"   What crucial step is missing between epochs?")
print(f"   The Gradient Spirits are not being properly dismissed!")

# The solution (commented out - let students figure it out)
# THE MISSING LINE: optimizer.zero_grad() before loss.backward()
# Without this, gradients accumulate across epochs, causing instability

print(f"\n🎯 YOUR MISSION:")
print(f"   Can you spot the critical error in the 'unsteady_training' function?")
print(f"   Fix it and compare the training curves!")
print(f"   Understanding this error is crucial for all future neural network training.")

## 🏮 THE SACRED COMPLETION

Master Pai-Torch bows deeply as you complete the trials.

*"Well done, grasshopper. You have learned the first sacred art - the Linear Transformation. You understand how gradients flow backward through your network, updating the sacred parameters that encode wisdom. This foundation will serve you well in all future trials."*

**🎓 WHAT YOU HAVE MASTERED:**
- Linear neural networks and forward passes
- Gradient descent and backpropagation
- The sacred ritual of `optimizer.zero_grad()`
- Loss functions and convergence
- Batch processing and data normalization
- Learning rate scheduling and early stopping
- Decision boundaries and threshold optimization
- Debugging gradient accumulation problems

**🔮 LOOKING AHEAD:**
In your next trials, you will learn about multiple layers, activation functions, and more complex architectures. But remember - all neural wisdom builds upon these linear foundations.

**🐱 SUKI'S FINAL WISDOM:**
*Meow.* (Translation: "The journey of a thousand neural networks begins with a single linear layer.")

**Continue your journey with Dan 2 to learn the arts of the Temple Guardian!**