# Zenith Tutorial: ML Optimization Framework

Zenith adalah framework optimasi ML yang bekerja sebagai **pelengkap** (bukan pengganti) framework seperti PyTorch dan JAX.

**Apa yang akan dipelajari:**
1. Instalasi dan setup
2. Zenith + PyTorch
3. Zenith + JAX
4. Fitur advanced (Triton, Auto-tuning)

---
## Part 1: Getting Started

In [None]:
# 1.1 Install Zenith dari GitHub
!rm -rf ZENITH 2>/dev/null
!git clone https://github.com/vibeswithkk/ZENITH.git
%cd ZENITH
!pip install -q -e .

In [None]:
# 1.2 Install dependencies
!pip install -q torch jax jaxlib

In [None]:
# 1.3 Verify installation
import zenith
from zenith import backends
import numpy as np

print(f"Zenith version: {zenith.__version__}")
print(f"CPU Available: {backends.is_cpu_available()}")
print(f"CUDA Available: {backends.is_cuda_available()}")
print("\nSetup complete!")

---
## Part 2: Zenith + PyTorch

Zenith menyediakan Quantization-Aware Training (QAT) untuk mengoptimasi model PyTorch.

In [None]:
# 2.1 Create PyTorch Model
import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    """Simple neural network untuk demo."""
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.fc3(x)

# Create model
model = SimpleNet()
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# 2.2 Apply Zenith QAT
from zenith.optimization.qat import (
    QATConfig,
    FakeQuantize,
    QATTrainer,
    prepare_model_for_qat,
    convert_qat_to_quantized
)

# Get layer weights from PyTorch
layer_weights = {
    'fc1': model.fc1.weight.detach().numpy(),
    'fc2': model.fc2.weight.detach().numpy(),
    'fc3': model.fc3.weight.detach().numpy(),
}

# Create QAT config
config = QATConfig(
    weight_bits=8,
    activation_bits=8,
    symmetric_weights=True,
    per_channel_weights=True
)

# Prepare for QAT
trainer = prepare_model_for_qat(list(layer_weights.keys()), config)
print(f"QAT Trainer created with {len(trainer.modules)} modules")
print(f"Config: {config.weight_bits}-bit weights, {config.activation_bits}-bit activations")

In [None]:
# 2.3 Calibrate QAT with sample data
# Generate calibration data
calibration_data = np.random.randn(100, 784).astype(np.float32)

# Calibrate each layer
for layer_name, weights in layer_weights.items():
    # Simulate activations
    activations = np.random.randn(100, weights.shape[1]).astype(np.float32)
    trainer.calibrate(layer_name, weights, activations)

print("Calibration complete!")

In [None]:
# 2.4 Compare FP32 vs INT8
# Get quantization params
quantized_weights = convert_qat_to_quantized(trainer, layer_weights)

# Calculate size reduction
fp32_size = sum(w.nbytes for w in layer_weights.values())
int8_size = fp32_size / 4  # 8-bit = 1/4 of 32-bit

print(f"\nModel Size Comparison:")
print(f"  FP32: {fp32_size / 1024:.2f} KB")
print(f"  INT8: {int8_size / 1024:.2f} KB")
print(f"  Reduction: {fp32_size / int8_size:.1f}x")

# Calculate quantization error
for name in layer_weights:
    original = layer_weights[name]
    quantized = quantized_weights[name]
    error = np.mean(np.abs(original - quantized))
    snr = 10 * np.log10(np.mean(original**2) / np.mean((original - quantized)**2 + 1e-10))
    print(f"  {name}: MAE={error:.6f}, SNR={snr:.1f}dB")

---
## Part 3: Zenith + JAX

Zenith juga dapat digunakan dengan JAX untuk optimasi fungsi.

In [None]:
# 3.1 Create JAX function
import jax
import jax.numpy as jnp

def jax_mlp(params, x):
    """Simple MLP in JAX."""
    for w, b in params[:-1]:
        x = jax.nn.relu(jnp.dot(x, w) + b)
    w, b = params[-1]
    return jnp.dot(x, w) + b

# Initialize params
key = jax.random.PRNGKey(42)
keys = jax.random.split(key, 3)
params = [
    (jax.random.normal(keys[0], (784, 256)) * 0.01, jnp.zeros(256)),
    (jax.random.normal(keys[1], (256, 128)) * 0.01, jnp.zeros(128)),
    (jax.random.normal(keys[2], (128, 10)) * 0.01, jnp.zeros(10)),
]

print("JAX MLP created!")
print(f"Params: {sum(w.size + b.size for w, b in params):,}")

In [None]:
# 3.2 Apply Zenith QAT to JAX weights
# Extract weights
jax_weights = {
    f'layer_{i}': np.array(w) for i, (w, b) in enumerate(params)
}

# Create QAT trainer
jax_trainer = prepare_model_for_qat(list(jax_weights.keys()), config)

# Calibrate
for layer_name, weights in jax_weights.items():
    activations = np.random.randn(100, weights.shape[0]).astype(np.float32)
    jax_trainer.calibrate(layer_name, weights, activations)

# Convert
jax_quantized = convert_qat_to_quantized(jax_trainer, jax_weights)

print("\nJAX weights quantized!")
for name in jax_weights:
    original = jax_weights[name]
    quantized = jax_quantized[name]
    error = np.mean(np.abs(original - quantized))
    print(f"  {name}: MAE={error:.6f}")

In [None]:
# 3.3 Benchmark JAX inference
import time

# JIT compile
jax_mlp_jit = jax.jit(jax_mlp)

# Test data
test_input = jax.random.normal(jax.random.PRNGKey(0), (32, 784))

# Warmup
_ = jax_mlp_jit(params, test_input)

# Benchmark
start = time.perf_counter()
for _ in range(100):
    _ = jax_mlp_jit(params, test_input)
elapsed = (time.perf_counter() - start) * 1000

print(f"\nJAX Inference Performance:")
print(f"  100 iterations: {elapsed:.2f} ms")
print(f"  Per iteration: {elapsed/100:.3f} ms")
print(f"  Throughput: {100*32/(elapsed/1000):.0f} samples/sec")

---
## Part 4: Advanced Features

In [None]:
# 4.1 Triton Mock Deployment
from zenith.serving.triton_client import MockTritonClient, InferenceInput, ModelMetadata

# Create mock Triton server
client = MockTritonClient("localhost:8000")

# Register model with custom handler
def inference_handler(inputs):
    """Model inference logic."""
    x = inputs[0].data
    # Simple forward pass simulation
    return {"output": np.tanh(x @ np.random.randn(x.shape[-1], 10).astype(np.float32))}

client.register_model(
    "my_model",
    metadata=ModelMetadata(name="my_model", platform="python", versions=["1"]),
    handler=inference_handler
)

# Test inference
test_input = np.random.randn(1, 784).astype(np.float32)
result = client.infer("my_model", [InferenceInput(name="input", data=test_input)])

print(f"Triton Server Status: {client.is_server_ready()}")
print(f"Inference Success: {result.success}")
print(f"Output shape: {result.outputs[0].data.shape}")
print(f"Latency: {result.latency_ms:.3f} ms")

In [None]:
# 4.2 Auto-tuning
from zenith.optimization.autotuner import (
    KernelAutotuner,
    SearchSpace,
    TuningConfig
)

# Define search space
space = SearchSpace()
space.define("block_size", [16, 32, 64, 128])
space.define("num_warps", [2, 4, 8])

# Define benchmark function
def benchmark_kernel(config):
    """Simulate kernel benchmark."""
    # Lower is better
    return 1000 / (config["block_size"] * config["num_warps"] ** 0.5)

# Create autotuner
autotuner = KernelAutotuner(search_space=space)

# Run tuning
config = TuningConfig(op_type="matmul", input_shape=[1024, 1024])
result = autotuner.tune(config, benchmark_kernel, max_trials=12)

print(f"\nAuto-tuning Result:")
print(f"  Best config: {result.best_config}")
print(f"  Best time: {result.best_time:.2f}")
print(f"  Trials: {result.num_trials}")

In [None]:
# 4.3 Load Testing
import sys
sys.path.insert(0, './tests/integration')
from triton_load_test import run_mock_load_test

# Run load test
result = run_mock_load_test(
    model_name="load_test",
    num_requests=100,
    concurrent_workers=10,
    verbose=True
)

---
## Summary

Anda telah mempelajari:

| Feature | Status |
|---------|--------|
| Zenith Install | Verified |
| PyTorch + QAT | Verified |
| JAX + QAT | Verified |
| Triton Deployment | Verified |
| Auto-tuning | Verified |
| Load Testing | Verified |

### Next Steps
- Deploy model ke production dengan Docker/Kubernetes
- Gunakan real Triton server
- Explore more optimization passes