# Zenith Comprehensive Verification

**Purpose:** Verify Zenith on GPU with TensorFlow, JAX, and CUDA kernels

**Requirements:**
- Google Colab with GPU runtime
- Change Runtime > GPU (T4 or better)

## 1. Environment Setup

In [None]:
# Check GPU availability
!nvidia-smi
print("\n" + "="*50)
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Install Zenith
!pip install pyzenith -q
!pip install tensorflow jax jaxlib -q

import zenith
print(f"Zenith Version: {zenith.__version__}")

## 2. PyTorch + CUDA Verification

In [None]:
import torch
import zenith
from zenith.torch import compile as ztorch_compile
import time

# Simple model
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(512, 1024)
        self.fc2 = torch.nn.Linear(1024, 512)
        self.relu = torch.nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)

model = SimpleModel().cuda()
x = torch.randn(32, 512).cuda()

# Baseline
with torch.no_grad():
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(100):
        _ = model(x)
    torch.cuda.synchronize()
    baseline_time = time.perf_counter() - start

print(f"Baseline: {baseline_time*1000:.2f} ms")

# With torch.compile + Zenith backend
try:
    compiled = torch.compile(model, backend='zenith')
    with torch.no_grad():
        torch.cuda.synchronize()
        start = time.perf_counter()
        for _ in range(100):
            _ = compiled(x)
        torch.cuda.synchronize()
        zenith_time = time.perf_counter() - start
    
    print(f"Zenith: {zenith_time*1000:.2f} ms")
    print(f"Speedup: {baseline_time/zenith_time:.2f}x")
    print("\n✓ PyTorch + CUDA: VERIFIED")
except Exception as e:
    print(f"✗ PyTorch + CUDA: {e}")

## 3. CUDA Kernels Verification

In [None]:
# Check Zenith native CUDA kernels
print("=== CUDA Kernel Registry ===")

try:
    from zenith.runtime import KernelRegistry
    
    registry = KernelRegistry()
    registry.initialize()
    
    print(f"Initialized: {registry.is_initialized}")
    print(f"Available kernels:")
    
    for op in ['MatMul', 'ReLU', 'GELU', 'Softmax', 'LayerNorm']:
        kernel = registry.get_kernel(op)
        status = '✓' if kernel else '✗'
        print(f"  {status} {op}")
    
    print("\n✓ CUDA Kernels: VERIFIED")
except Exception as e:
    print(f"✗ CUDA Kernels: {e}")

## 4. TensorFlow Adapter Verification

In [None]:
print("=== TensorFlow Adapter ===")

try:
    import tensorflow as tf
    from zenith.adapters import TensorFlowAdapter
    
    print(f"TensorFlow: {tf.__version__}")
    print(f"GPU Devices: {tf.config.list_physical_devices('GPU')}")
    
    # Create simple Keras model
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    model.build((None, 64))
    
    # Test adapter
    adapter = TensorFlowAdapter()
    print(f"Adapter available: {adapter.is_available}")
    
    # Convert to GraphIR
    sample_input = tf.random.normal((1, 64))
    graph_ir = adapter.from_model(model, sample_input)
    
    print(f"GraphIR nodes: {graph_ir.num_nodes()}")
    print("\n✓ TensorFlow Adapter: VERIFIED")
except Exception as e:
    print(f"✗ TensorFlow Adapter: {e}")

## 5. JAX Adapter Verification

In [None]:
print("=== JAX Adapter ===")

try:
    import jax
    import jax.numpy as jnp
    from zenith.adapters import JAXAdapter
    
    print(f"JAX: {jax.__version__}")
    print(f"JAX Devices: {jax.devices()}")
    
    # Simple JAX function
    def simple_mlp(x):
        x = jnp.tanh(x @ jnp.ones((64, 128)))
        x = jnp.tanh(x @ jnp.ones((128, 10)))
        return x
    
    # Test adapter
    adapter = JAXAdapter()
    print(f"Adapter available: {adapter.is_available}")
    
    # Convert to GraphIR
    sample_input = jnp.ones((1, 64))
    graph_ir = adapter.from_model(simple_mlp, sample_input)
    
    print(f"GraphIR nodes: {graph_ir.num_nodes()}")
    print("\n✓ JAX Adapter: VERIFIED")
except Exception as e:
    print(f"✗ JAX Adapter: {e}")

## 6. Benchmark on GPU

In [None]:
print("=== GPU Benchmark ===")

try:
    from benchmarks.mlperf_suite import (
        ZenithBenchmark, BenchmarkConfig, generate_results_table
    )
    import torch
    
    # Model
    model = torch.nn.Sequential(
        torch.nn.Linear(256, 512),
        torch.nn.ReLU(),
        torch.nn.Linear(512, 256),
    ).cuda()
    
    def model_fn(x):
        with torch.no_grad():
            return model(x)
    
    def input_gen(batch_size, seq_len):
        return torch.randn(batch_size, 256).cuda()
    
    config = BenchmarkConfig(
        model_name='MLP-256',
        batch_sizes=[1, 8, 32],
        num_warmup=10,
        num_runs=100,
        scenario='single-stream'
    )
    
    benchmark = ZenithBenchmark(device='cuda')
    results = benchmark.run(config, model_fn, input_gen)
    
    print(generate_results_table(results))
    print("\n✓ GPU Benchmark: VERIFIED")
except Exception as e:
    print(f"✗ GPU Benchmark: {e}")

## 7. Summary

In [None]:
print("="*50)
print("ZENITH VERIFICATION SUMMARY")
print("="*50)
print()
print(f"Zenith Version: {zenith.__version__}")
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print()
print("Components Verified:")
print("  - PyTorch Integration")
print("  - CUDA Kernels")
print("  - TensorFlow Adapter")
print("  - JAX Adapter")
print("  - MLPerf Benchmark Suite")
print()
print("="*50)