# Zenith 0.3.3 Verification

Run all cells in order. After cell 2, runtime will restart automatically.

In [None]:
# 1. GPU Check
!nvidia-smi | head -20
import torch
print(f"\nPyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else "No GPU")

In [None]:
# 2. Install Zenith 0.3.3 (will restart runtime)
!pip uninstall pyzenith -y 2>/dev/null
!pip cache purge 2>/dev/null
!pip install pyzenith==0.3.3 --no-cache-dir -q

# Auto restart to clear cached imports
import os
os.kill(os.getpid(), 9)

In [None]:
# 3. Verify Version
import zenith
print(f"Zenith: {zenith.__version__}")
assert zenith.__version__ == '0.3.3', f"Wrong version: {zenith.__version__}"
print("✓ Version OK")

In [None]:
# 4. Triton Check
from zenith.runtime.triton_kernels import is_available, get_version, get_triton_kernel_map

print(f"Triton: {get_version()}")
print(f"Available: {is_available()}")
if is_available():
    print(f"Kernels: {list(get_triton_kernel_map().keys())}")
    print("✓ Triton OK")

In [None]:
# 5. Triton Fused Kernel Benchmark
from zenith.runtime.triton_kernels import benchmark_fused_linear_gelu

result = benchmark_fused_linear_gelu(M=1024, N=4096, K=1024, runs=50)
print(f"Fused: {result['fused_ms']:.2f} ms")
print(f"Separate: {result['separate_ms']:.2f} ms")
print(f"Speedup: {result['speedup']:.2f}x")

if result['speedup'] > 1.0:
    print("\n✓ FUSED KERNEL FASTER!")
else:
    print("\n⚠ No speedup")

In [None]:
# 6. torch.compile with Zenith
import torch
import time

class MLP(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(512, 1024)
        self.fc2 = torch.nn.Linear(1024, 512)
    def forward(self, x):
        return self.fc2(torch.relu(self.fc1(x)))

model = MLP().cuda()
x = torch.randn(32, 512).cuda()

# Baseline
with torch.no_grad():
    for _ in range(10): model(x)
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(100): model(x)
    torch.cuda.synchronize()
    baseline = (time.perf_counter() - t0) * 1000

# Zenith
compiled = torch.compile(model, backend='zenith')
with torch.no_grad():
    for _ in range(10): compiled(x)
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(100): compiled(x)
    torch.cuda.synchronize()
    zenith_time = (time.perf_counter() - t0) * 1000

print(f"Baseline: {baseline:.2f} ms")
print(f"Zenith: {zenith_time:.2f} ms")
print(f"Speedup: {baseline/zenith_time:.2f}x")

In [None]:
# 7. Summary
print("="*40)
print(f"Zenith {zenith.__version__} on {torch.cuda.get_device_name(0)}")
print("="*40)