# Zenith v0.3.0: The Real-World Benchmark

This is the **Real World Test** for **Zenith v0.3.0**.
Unlike previous tests that used a *dummy backend*, this notebook will:
1.  Install the actual `pyzenith` library.
2.  Let Zenith handle backend registration internally.
3.  Measure the pure performance of the kernel optimizations present in v0.3.0.

**Objective:**
*   **Training:** Fine-Tune TinyLlama (Alpaca Dataset) - Speed & VRAM.
*   **Inference:** Text Generation - Tokens Per Second (TPS).

**Hardware Recom:** T4 GPU (Google Colab Standard).

## 1. Environment Setup

In [None]:
!nvidia-smi

print("Installing dependencies (Transformers, PEFT, TRL)...")
!pip install -q -U torch transformers peft trl accelerate bitsandbytes psutil datasets matplotlib

print("Installing ZENITH v0.3.0...")
# Option 1: Install from PyPI (If available)
!pip install -U pyzenith

# Option 2: Fallback to Github (If PyPI is pending)
# !pip install git+https://github.com/vibeswithkk/ZENITH.git

import torch
print(f"PyTorch Version: {torch.__version__}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Verify Zenith Integration (The Moment of Truth)
Here we check if `import zenith` successfully registers the backend automatically.

In [None]:
import zenith
from torch import _dynamo

print("Checking Registered Backends...")
backends = _dynamo.list_backends()
print(f"Available Backends: {backends}")

if "zenith" in backends:
    print("SUCCESS: Zenith Backend is registered natively!")
else:
    print("WARNING: Zenith Backend NOT found. Did v0.3.0 install correctly?")
    # Stop execution if Zenith is missing in a Real World test
    raise RuntimeError("Zenith Backend not found! Cannot proceed with Real World Benchmark.")

## 3. Training Benchmark (Fine-Tuning)

In [None]:
import time
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig, get_peft_model, TaskType
import matplotlib.pyplot as plt

SEED = 42
STEPS = 50

def clean_memory():
    gc.collect()
    torch.cuda.empty_cache()

def run_training(use_zenith=False):
    mode = "ZENITH" if use_zenith else "PYTORCH"
    print(f"\n{'='*10} STARTING: {mode} {'='*10}")
    clean_memory()
    set_seed(SEED)
    
    model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float16, 
        device_map="auto"
    )
    
    peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1)
    model = get_peft_model(model, peft_config)
    
    # === THE REAL WORLD INTEGRATION ===
    # No dummy functions here. Just pure torch.compile.
    if use_zenith:
        print("Activating Zenith Optimization (Native)...")
        model.model = torch.compile(model.model, backend="zenith")
    # ==================================

    dataset = load_dataset("tatsu-lab/alpaca", split=f"train[:{STEPS*4}]")
    
    args = SFTConfig(
        output_dir=f"./tmp_{mode}",
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        learning_rate=2e-4,
        max_steps=STEPS,
        fp16=True,
        report_to="none",
        packing=False,
    )
    
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        peft_config=peft_config,
        args=args,
        processing_class=tokenizer,
        formatting_func=lambda x: f"### Instruction:\n{x['instruction']}\n\n### Response:\n{x['output']}"
    )
    
    torch.cuda.synchronize()
    start_t = time.time()
    trainer.train()
    torch.cuda.synchronize()
    end_t = time.time()
    
    total_time = end_t - start_t
    max_mem = torch.cuda.max_memory_allocated() / 1e9
    
    print(f"DONE {mode}. Time: {total_time:.2f}s | VRAM: {max_mem:.2f}GB")
    
    del model, trainer, dataset
    clean_memory()
    return total_time, max_mem

In [None]:
# Run Benchmarks
time_py, mem_py = run_training(use_zenith=False)
time_zen, mem_zen = run_training(use_zenith=True)

## 4. Inference Benchmark (TPS)

In [None]:
import numpy as np

def run_inference(use_zenith=False, runs=5):
    mode = "ZENITH" if use_zenith else "PYTORCH"
    print(f"\n{'='*10} INFERENCE: {mode} {'='*10}")
    clean_memory()
    
    tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
    model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="cuda")
    
    if use_zenith:
        print("Compiling with Zenith...")
        model = torch.compile(model, backend="zenith")
        
    input_ids = tokenizer("The future of AI is", return_tensors="pt").input_ids.cuda()
    
    # Warmup
    model.generate(input_ids, max_new_tokens=10)
    
    tps_list = []
    for _ in range(runs):
        torch.cuda.synchronize()
        start = time.time()
        out = model.generate(input_ids, max_new_tokens=100)
        torch.cuda.synchronize()
        lat = time.time() - start
        tps = 100 / lat
        tps_list.append(tps)
        print(f"Run: {tps:.2f} TPS")
        
    return np.mean(tps_list)

In [None]:
tps_py = run_inference(use_zenith=False)
tps_zen = run_inference(use_zenith=True)

## 5. Final Report

In [None]:
print(f"\n{'='*40}")
print("ZENITH v0.3.0 REAL WORLD RESULTS")
print(f"{'='*40}")
print(f"Training Time : PyTorch {time_py:.2f}s | Zenith {time_zen:.2f}s | Speedup: {((time_py-time_zen)/time_py)*100:+.2f}%")
print(f"Inference TPS : PyTorch {tps_py:.2f}  | Zenith {tps_zen:.2f}  | Speedup: {((tps_zen-tps_py)/tps_py)*100:+.2f}%")

# Simple Bar Chart
labels = ['Training Time (s)', 'Inference Speed (TPS)']
py_vals = [time_py, tps_py]
zen_vals = [time_zen, tps_zen]

x = np.arange(len(labels))
width = 0.35

fig, ax = plt.subplots(1, 2, figsize=(12, 5))

# Chart 1: Training Time (Lower is better)
ax[0].bar(['PyTorch', 'Zenith'], [time_py, time_zen], color=['gray', 'blue'])
ax[0].set_title('Training Time (Lower is Better)')
ax[0].set_ylabel('Seconds')

# Chart 2: Inference TPS (Higher is better)
ax[1].bar(['PyTorch', 'Zenith'], [tps_py, tps_zen], color=['gray', 'orange'])
ax[1].set_title('Inference TPS (Higher is Better)')
ax[1].set_ylabel('Tokens / Sec')

plt.show()