# Zenith x PyTorch: Machine Learning Demo

## Zenith as a Complementary Framework, Not a Replacement

This notebook demonstrates how **Zenith works alongside PyTorch** to enhance your ML workflow:

1. **Train with PyTorch** - Use familiar PyTorch APIs for model definition and training
2. **Optimize with Zenith** - Apply Zenith's torch.compile backend for faster inference
3. **Export with Zenith** - Convert to ONNX for production deployment
4. **Benchmark** - Compare performance before and after Zenith optimization

---

## 1. Setup

In [None]:
# Install dependencies
!pip install git+https://github.com/vibeswithkk/ZENITH.git -q
!pip install onnx onnxruntime onnxscript -q

print("Setup complete!")
print("IMPORTANT: Please restart runtime now (Runtime > Restart runtime)")
print("Then run cells starting from the next cell (cell 2)")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time
import numpy as np

# Zenith imports
import zenith
import zenith.torch as ztorch
from zenith.backends import is_cuda_available, get_device

print(f"PyTorch version: {torch.__version__}")
print(f"Zenith version: {zenith.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

---
## 2. Define Model (Pure PyTorch)

We use standard PyTorch to define our model - **no Zenith code here**.

In [None]:
class MNISTClassifier(nn.Module):
    """A simple CNN for MNIST classification - Pure PyTorch."""
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))  # 28x28 -> 14x14
        x = self.pool(self.relu(self.conv2(x)))  # 14x14 -> 7x7
        x = x.view(-1, 64 * 7 * 7)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MNISTClassifier().to(device)

print(f"Model created on: {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

---
## 3. Load Data (Pure PyTorch)

Standard PyTorch data loading.

In [None]:
# Data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load MNIST
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

print(f"Training samples: {len(train_dataset):,}")
print(f"Test samples: {len(test_dataset):,}")

---
## 4. Train Model (Pure PyTorch)

Training loop using standard PyTorch - **Zenith NOT involved in training**.

In [None]:
def train_epoch(model, loader, optimizer, criterion):
    """Standard PyTorch training loop."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
    
    return total_loss / len(loader), 100. * correct / total


def evaluate(model, loader, criterion):
    """Evaluate model accuracy."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += criterion(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
    
    return total_loss / len(loader), 100. * correct / total

In [None]:
# Training configuration
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 3  # Quick demo

print("Training with Pure PyTorch...")
print("=" * 50)

train_start = time.time()

for epoch in range(1, epochs + 1):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion)
    test_loss, test_acc = evaluate(model, test_loader, criterion)
    
    print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, "
          f"Test Loss={test_loss:.4f}, Test Acc={test_acc:.2f}%")

train_time = time.time() - train_start
print("=" * 50)
print(f"Training complete in {train_time:.2f}s")
print(f"Final Test Accuracy: {test_acc:.2f}%")

---
## 5. Benchmark: PyTorch Native Inference

Measure baseline inference speed.

In [None]:
def benchmark_inference(model_fn, test_loader, num_batches=50, warmup=5):
    """Benchmark inference speed."""
    model.eval()
    
    # Warmup
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            if i >= warmup:
                break
            data = data.to(device)
            _ = model_fn(data)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    # Benchmark
    times = []
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            if i >= num_batches:
                break
            data = data.to(device)
            
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            start = time.perf_counter()
            
            _ = model_fn(data)
            
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            end = time.perf_counter()
            
            times.append((end - start) * 1000)  # ms
    
    return np.mean(times), np.std(times), np.min(times), np.max(times)

# Baseline: PyTorch native
print("Benchmarking PyTorch Native Inference...")
native_mean, native_std, native_min, native_max = benchmark_inference(
    lambda x: model(x), test_loader
)
print(f"PyTorch Native: {native_mean:.3f} +/- {native_std:.3f} ms/batch")
print(f"  Range: [{native_min:.3f}, {native_max:.3f}] ms")

---
## 6. Apply Zenith Optimization

Now we use **Zenith's torch.compile backend** to optimize inference.

**Key Point**: The model was trained with PyTorch. Zenith only optimizes the compiled version.

In [None]:
print("=" * 60)
print("APPLYING ZENITH OPTIMIZATION")
print("=" * 60)

# Method 1: Using torch.compile with Zenith backend
if ztorch.has_torch_compile():
    print("\n[Method 1] torch.compile with Zenith backend...")
    
    # Create Zenith backend
    zenith_backend = ztorch.create_backend(
        target="cuda" if torch.cuda.is_available() else "cpu",
        precision="fp32"
    )
    
    # Compile model with Zenith
    compiled_model = torch.compile(model, backend=zenith_backend)
    print("Model compiled with Zenith backend")
    
    # Verify accuracy is preserved
    _, compiled_acc = evaluate(compiled_model, test_loader, criterion)
    print(f"Compiled Model Accuracy: {compiled_acc:.2f}%")
else:
    print("torch.compile not available, using eager mode")
    compiled_model = model

In [None]:
# Method 2: Using @ztorch.compile decorator
print("\n[Method 2] Using @ztorch.compile decorator...")

@ztorch.compile(target="cuda" if torch.cuda.is_available() else "cpu")
def optimized_forward(x):
    return model(x)

# Test it
sample = next(iter(test_loader))[0].to(device)
output = optimized_forward(sample)
print(f"Decorated function output shape: {output.shape}")

---
## 7. Benchmark: Zenith-Optimized Inference

In [None]:
print("Benchmarking Zenith-Optimized Inference...")
zenith_mean, zenith_std, zenith_min, zenith_max = benchmark_inference(
    lambda x: compiled_model(x), test_loader
)
print(f"Zenith Optimized: {zenith_mean:.3f} +/- {zenith_std:.3f} ms/batch")
print(f"  Range: [{zenith_min:.3f}, {zenith_max:.3f}] ms")

In [None]:
# Performance comparison
print("\n" + "=" * 60)
print("PERFORMANCE COMPARISON")
print("=" * 60)
print(f"\nPyTorch Native:    {native_mean:.3f} ms/batch")
print(f"Zenith Optimized:  {zenith_mean:.3f} ms/batch")

speedup = native_mean / zenith_mean if zenith_mean > 0 else 1.0
improvement = (native_mean - zenith_mean) / native_mean * 100

print(f"\nSpeedup: {speedup:.2f}x")
print(f"Improvement: {improvement:.1f}%")

---
## 8. Export to ONNX

Export the trained model to ONNX for production deployment.

In [None]:
import io
import onnx

print("=" * 60)
print("ONNX EXPORT")
print("=" * 60)

# Sample input for export
sample_input = torch.randn(1, 1, 28, 28)
onnx_path = "/tmp/mnist_classifier.onnx"

# Export using standard PyTorch ONNX export (most reliable)
model_cpu = model.cpu()
model_cpu.eval()

torch.onnx.export(
    model_cpu, 
    sample_input,
    onnx_path,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
    opset_version=14
)

# Get file size
import os
onnx_size = os.path.getsize(onnx_path)
print(f"ONNX model saved to: {onnx_path}")
print(f"ONNX model size: {onnx_size:,} bytes")

# Validate
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("ONNX validation: PASSED")

# Move model back to device
model = model.to(device)

In [None]:
# Verify ONNX model with ONNX Runtime
print("\nVerifying ONNX model with ONNX Runtime...")

import onnxruntime as ort

# Create session
sess = ort.InferenceSession(onnx_path)

# Run inference
test_input = np.random.randn(1, 1, 28, 28).astype(np.float32)
output = sess.run(None, {sess.get_inputs()[0].name: test_input})

print(f"ONNX Runtime output shape: {output[0].shape}")
print("ONNX Runtime inference: PASSED")

---
## 9. Convert to Zenith GraphIR

Zenith's internal representation for advanced optimization.

In [None]:
from zenith.adapters import PyTorchAdapter

print("=" * 60)
print("ZENITH GRAPHIR CONVERSION")
print("=" * 60)

adapter = PyTorchAdapter()
sample = torch.randn(1, 1, 28, 28)

# Convert to GraphIR
graph = adapter.from_model(model.cpu(), sample_input=sample)

print(f"Graph name: {graph.name}")
print(f"Input tensors: {len(graph.inputs)}")
print(f"Output tensors: {len(graph.outputs)}")

# Move model back
model = model.to(device)

---
## 10. Summary: Zenith as PyTorch Companion

In [None]:
print("\n" + "=" * 70)
print("ZENITH x PYTORCH: MACHINE LEARNING DEMO SUMMARY")
print("=" * 70)

print("""
+------------------------------------------------------------------+
|                     WORKFLOW DEMONSTRATION                       |
+------------------------------------------------------------------+

  1. MODEL DEFINITION:   Pure PyTorch (nn.Module)           [CHECK]
  2. DATA LOADING:       Pure PyTorch (DataLoader)          [CHECK]
  3. TRAINING:           Pure PyTorch (optimizer.step())    [CHECK]
  4. OPTIMIZATION:       Zenith torch.compile backend       [CHECK]
  5. ONNX EXPORT:        PyTorch + ONNX                     [CHECK]
  6. GRAPHIR CONVERT:    Zenith PyTorchAdapter              [CHECK]

+------------------------------------------------------------------+
|                     KEY TAKEAWAYS                                |
+------------------------------------------------------------------+

  - Zenith does NOT replace PyTorch for training
  - Zenith ENHANCES inference performance
  - Zenith SIMPLIFIES production deployment (ONNX)
  - Zenith INTEGRATES seamlessly with existing code
  - Your PyTorch knowledge remains 100% applicable

""")

print(f"Model Accuracy:     {test_acc:.2f}%")
print(f"Native Inference:   {native_mean:.3f} ms/batch")
print(f"Zenith Inference:   {zenith_mean:.3f} ms/batch")
print(f"Performance Gain:   {speedup:.2f}x faster")

print("\n" + "=" * 70)
print("Zenith: Your PyTorch Companion for Production ML")
print("=" * 70)