# Zenith Scientific Test 1: Gradient Stability

**Objective:**
Validate that Zenith's optimizations (Graph Capture, Fusion) do not cause numerical instability (Gradient Explosion/Vanishing) during extended training runs.

**Methodology:**
1.  Train a simple ResNet-like model on dummy data for 1000 steps.
2.  Use a relatively high Learning Rate (`1e-3`) to induce stress.
3.  **Monitor:** Gradient Norm (`L2 Norm`) of the model parameters at every step.
4.  **Comparison:** Zenith vs PyTorch Native.

**Success Criteria:**
*   No `NaN` or `Inf` values.
*   Gradient Norm curve of Zenith should be comparable (smoothness/magnitude) to PyTorch.


## 1. Environment & Setup

In [None]:
!pip install -U pyzenith torch torchvision matplotlib numpy

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import time
import zenith

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)

set_seed(42)

## 2. Model & Synthetic Data
We use a synthetic workload to isolate compiler behavior from data noise.

In [None]:
# Simple ConvNet for stability testing
class StabilityNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Linear(128, 10)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Generate Dummy Data
BATCH_SIZE = 128
INPUT_SHAPE = (BATCH_SIZE, 3, 32, 32)
dummy_input = torch.randn(INPUT_SHAPE, device=device)
dummy_target = torch.randint(0, 10, (BATCH_SIZE,), device=device)

## 3. The Stress Test Function

In [None]:
def run_stability_test(use_zenith=False, steps=500):
    set_seed(42)
    model = StabilityNet().to(device)
    optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
    criterion = nn.CrossEntropyLoss()

    if use_zenith:
        print("Compiling with Zenith...")
        model = torch.compile(model, backend="zenith")
    else:
        print("Running Native PyTorch...")

    loss_history = []
    grad_norm_history = []

    start_time = time.time()
    
    for step in range(steps):
        optimizer.zero_grad()
        output = model(dummy_input)
        loss = criterion(output, dummy_target)
        
        # Check for Explosion (NaN)
        if torch.isnan(loss):
            raise ValueError(f"Loss became NaN at step {step}!")

        loss.backward()
        
        # CAPTURE GRADIENT NORM
        total_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        
        grad_norm_history.append(total_norm)
        loss_history.append(loss.item())
        
        optimizer.step()

    duration = time.time() - start_time
    print(f"Done. Duration: {duration:.4f}s")
    
    return loss_history, grad_norm_history, duration

In [None]:
# Execute Comparison
STEPS = 500

print("--- Run 1: PyTorch Native ---")
loss_py, grad_py, time_py = run_stability_test(use_zenith=False, steps=STEPS)

print("\n--- Run 2: Zenith Optimized ---")
loss_zen, grad_zen, time_zen = run_stability_test(use_zenith=True, steps=STEPS)

## 4. Visual Analysis (The Evidence)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(15, 6))

# Plot 1: Loss Curve
ax[0].plot(loss_py, label="PyTorch", color="gray", alpha=0.7, linestyle="--")
ax[0].plot(loss_zen, label="Zenith", color="blue", linewidth=1.5)
ax[0].set_title("Training Loss Stability")
ax[0].set_xlabel("Step")
ax[0].set_ylabel("Loss")
ax[0].legend()
ax[0].grid(True, alpha=0.3)

# Plot 2: Gradient Norm (Crucial for exploding gradients)
ax[1].plot(grad_py, label="PyTorch", color="gray", alpha=0.7, linestyle="--")
ax[1].plot(grad_zen, label="Zenith", color="red", linewidth=1.5)
ax[1].set_title("Gradient Norm Monitoring")
ax[1].set_xlabel("Step")
ax[1].set_ylabel("L2 Norm of Gradients")
ax[1].legend()
ax[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("zenith_stability_chart.png")
plt.show()

print(f"\nTotal Time Comparison: PyTorch ({time_py:.2f}s) vs Zenith ({time_zen:.2f}s)")
print(f"Max Gradient Norm: PyTorch ({max(grad_py):.4f}) vs Zenith ({max(grad_zen):.4f})")