# Zenith Scientific Test 3: The Dynamic Shape Torture Test

**Objective:**
Real-world AI workloads (like Chatbots) have variable input lengths. Static compilers often fail here, triggering slow "re-compilation" for every new shape.
This test verifies if Zenith handles **Dynamic Shapes** gracefully.

**Methodology:**
1.  Define a Linear Layer model.
2.  Run an inference loop where `batch_size` is fixed (32), but `seq_len` changes randomly (between 100 and 1000) **every single iteration**.
3.  **Metric:** Measure latency of each step. Watch for "Spikes".

**Success Criteria:**
*   **No compiling pauses:** The first few steps might be slow (warmup), but subsequent steps must be fast regardless of shape changes.
*   **Linear Scaling:** Latency should increase linearly with sequence length, not exponentially.

In [None]:
!pip install -q -U pyzenith torch numpy matplotlib

import torch
import torch.nn as nn
import time
import random
import numpy as np
import matplotlib.pyplot as plt
import zenith

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

In [None]:
# A simple scalable model
class DynamicNet(nn.Module):
    def __init__(self, hidden_dim=4096):
        super().__init__()
        self.layer1 = nn.Linear(hidden_dim, hidden_dim)
        self.act = nn.GELU()
        self.layer2 = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        return self.layer2(self.act(self.layer1(x)))

model = DynamicNet().to(device)
# Mark as dynamic? Some compilers need hints. 
# Zenith aims to be automatic, so we provide no hints.

In [None]:
def run_dynamic_test(model_fn, iterations=50):
    latencies = []
    shapes = []
    
    print("Starting Dynamic Loop...")
    for i in range(iterations):
        # Random Sequence Length between 128 and 1024
        seq_len = random.randint(128, 1024)
        batch_size = 32
        
        x = torch.randn(batch_size, seq_len, 4096, device=device)
        
        torch.cuda.synchronize()
        start = time.time()
        
        with torch.no_grad():
            out = model_fn(x)
            
        torch.cuda.synchronize()
        # Record time in ms
        latencies.append((time.time() - start) * 1000)
        shapes.append(seq_len)
        
        if i % 10 == 0:
            print(f"Step {i}: SeqLen={seq_len} -> {latencies[-1]:.2f}ms")
            
    return shapes, latencies

In [None]:
# 1. PyTorch Eager (Baseline)
# Eager mode naturally handles dynamic shapes perfectly.
print("--- Baseline: PyTorch Eager ---")
shapes, lat_py = run_dynamic_test(model)

In [None]:
# 2. Zenith Optimized
print("\n--- Zenith Optimized ---")
# We compile once. 
# If Zenith supports dynamic shapes, it won't recompile on every new shape.
opt_model = torch.compile(model, backend="zenith")

# Warmup with a fixed shape
dummy = torch.randn(32, 512, 4096, device=device)
opt_model(dummy)
print("Warmup Done.")

_, lat_zen = run_dynamic_test(opt_model)

In [None]:
# Visualization
plt.figure(figsize=(10, 6))
plt.scatter(shapes, lat_py, alpha=0.6, label="PyTorch Eager", color="gray")
plt.scatter(shapes, lat_zen, alpha=0.8, label="Zenith", color="blue")

plt.title("Dynamic Shape Performance: SeqLen vs Latency")
plt.xlabel("Sequence Length (Input Size)")
plt.ylabel("Latency (ms)")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig("zenith_dynamic_shape.png")
plt.show()

print("Interpretation: Points should form a clean line. Outliers high above the line indicate re-compilation events.")