# Zenith Scientific Test 1: Numerical Stability (Inference)

**Objective:**
Validate that Zenith's optimizations do not cause numerical drift or instability during extended inference runs.

**Note:** Zenith is primarily an INFERENCE optimizer. Training (backprop) is handled by native PyTorch.

**Methodology:**
1.  Run the same input through the model 500 times.
2.  Compare each output to the first output (baseline).
3.  **Monitor:** Max Absolute Difference from baseline at every step.
4.  **Comparison:** Zenith vs PyTorch Native.

**Success Criteria:**
*   No `NaN` or `Inf` values.
*   Output should be deterministic (difference from baseline should be 0 or negligible).

## 1. Environment & Setup

In [None]:
!pip install -U pyzenith torch torchvision matplotlib numpy

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import time
import zenith

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)

set_seed(42)

## 2. Model & Synthetic Data

In [None]:
# Simple ConvNet for stability testing
class StabilityNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Linear(128, 10)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Generate Dummy Data (Fixed input for consistency check)
BATCH_SIZE = 128
INPUT_SHAPE = (BATCH_SIZE, 3, 32, 32)
set_seed(42)
dummy_input = torch.randn(INPUT_SHAPE, device=device)

## 3. The Stability Test Function (Inference Only)

In [None]:
def run_stability_test(use_zenith=False, steps=500):
    set_seed(42)
    model = StabilityNet().to(device)
    model.eval()  # Inference mode

    if use_zenith:
        print("Compiling with Zenith...")
        model = torch.compile(model, backend="zenith")
    else:
        print("Running Native PyTorch...")

    diff_history = []
    output_means = []

    start_time = time.time()
    
    # Get baseline output
    with torch.no_grad():
        baseline_output = model(dummy_input).clone()
    
    for step in range(steps):
        with torch.no_grad():
            output = model(dummy_input)
        
        # Check for NaN/Inf
        if torch.isnan(output).any():
            raise ValueError(f"Output became NaN at step {step}!")
        if torch.isinf(output).any():
            raise ValueError(f"Output became Inf at step {step}!")
        
        # Calculate difference from baseline
        diff = torch.abs(output - baseline_output).max().item()
        diff_history.append(diff)
        output_means.append(output.mean().item())

    duration = time.time() - start_time
    print(f"Done. Duration: {duration:.4f}s")
    print(f"Max Drift from Baseline: {max(diff_history):.9f}")
    
    return diff_history, output_means, duration

In [None]:
# Execute Comparison
STEPS = 500

print("--- Run 1: PyTorch Native ---")
diff_py, means_py, time_py = run_stability_test(use_zenith=False, steps=STEPS)

print("\n--- Run 2: Zenith Optimized ---")
diff_zen, means_zen, time_zen = run_stability_test(use_zenith=True, steps=STEPS)

## 4. Visual Analysis (The Evidence)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(15, 6))

# Plot 1: Output Mean Stability
ax[0].plot(means_py, label="PyTorch", color="gray", alpha=0.7, linestyle="--")
ax[0].plot(means_zen, label="Zenith", color="blue", linewidth=1.5)
ax[0].set_title("Output Mean Stability Over 500 Iterations")
ax[0].set_xlabel("Step")
ax[0].set_ylabel("Output Mean")
ax[0].legend()
ax[0].grid(True, alpha=0.3)

# Plot 2: Drift from Baseline (Should be 0 for deterministic execution)
ax[1].plot(diff_py, label="PyTorch", color="gray", alpha=0.7, linestyle="--")
ax[1].plot(diff_zen, label="Zenith", color="red", linewidth=1.5)
ax[1].set_title("Numerical Drift from Baseline")
ax[1].set_xlabel("Step")
ax[1].set_ylabel("Max Absolute Difference")
ax[1].legend()
ax[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("zenith_stability_chart.png")
plt.show()

print(f"\nTotal Time Comparison: PyTorch ({time_py:.2f}s) vs Zenith ({time_zen:.2f}s)")
speedup = ((time_py - time_zen) / time_py) * 100
print(f"Speedup: {speedup:+.2f}%")

if max(diff_zen) == 0.0:
    print("\nRESULT: Zenith is DETERMINISTIC. No numerical drift detected.")
else:
    print(f"\nRESULT: Max drift detected: {max(diff_zen):.9f}")