# Zenith vs PyTorch: The Benchmark Arena

Notebook ini adalah arena pengujian **Head-to-Head** untuk membandingkan performa training antara:
1.  **Baseline:** PyTorch Native (Standard)
2.  **Challenger:** PyTorch + **Zenith Backend**

Metric yang diukur:
*   **Training Speed:** Waktu total untuk N steps.
*   **VRAM Usage:** Penggunaan memori GPU rata-rata dan puncak.
*   **Startup Overhead:** Waktu kompilasi awal.

## 1. Setup Environment
Instalasi Zenith dan library pendukung.

In [None]:
!nvidia-smi
import os
import sys

print("Installing dependencies...")
!pip install -q -U torch transformers peft trl accelerate bitsandbytes psutil datasets matplotlib

print("Cloning & Installing Zenith...")
!rm -rf zenith_repo
!git clone https://github.com/vibeswithkk/ZENITH.git zenith_repo
!pip install -e zenith_repo

# Force path update
if os.path.abspath("zenith_repo") not in sys.path:
    sys.path.append(os.path.abspath("zenith_repo"))

print("Ready for Battle!")

## 2. Benchmark Engine Definition
Di sini kita mendefinisikan fungsi benchmark yang bersih dan adil. Setiap ronde akan membersihkan memori GPU agar hasil tidak bias.

In [None]:
import time
import torch
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig, get_peft_model, TaskType
import psutil
import matplotlib.pyplot as plt

# === ZENITH REGISTRATION ===
from torch import _dynamo
def zenith_backend(gm: torch.fx.GraphModule, example_inputs):
    # Pass-through for integration testing, or actual optimization logic if implemented
    return gm.forward

_dynamo.reset()
if "zenith" not in _dynamo.list_backends():
    _dynamo.register_backend(compiler_fn=zenith_backend, name="zenith")
# ===========================

def clean_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def run_round(use_zenith, steps=30, model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
    mode_name = "ZENITH" if use_zenith else "PYTORCH (BASELINE)"
    print(f"\n{'='*20} ROUND START: {mode_name} {'='*20}")
    
    clean_memory()
    
    # Load Model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
        device_map="auto"
    )
    
    # Apply LoRA
    peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
    model = get_peft_model(model, peft_config)
    
    # Apply Optimization
    if use_zenith:
        print("Activating Zenith Compilation...")
        try:
            model.model = torch.compile(model.model, backend="zenith")
        except Exception as e:
            print(f"Zenith Failed: {e}")
            return None

    # Dataset
    dataset = load_dataset("tatsu-lab/alpaca", split=f"train[:{steps*2}]")
    def format_prompt(sample):
        return f"### Instruction:\n{sample['instruction']}\n\n### Response:\n{sample['output']}"

    # Trainer Config
    args = SFTConfig(
        output_dir=f"./results_{mode_name}",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        logging_steps=10,
        max_steps=steps,
        fp16=(torch.cuda.is_available() and not torch.cuda.is_bf16_supported()),
        bf16=torch.cuda.is_bf16_supported(),
        report_to="none",
        packing=False
    )
    
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        peft_config=peft_config,
        args=args,
        processing_class=tokenizer,
        formatting_func=format_prompt
    )
    
    # WARMUP (Compile overhead happens here)
    print(" warmup...")
    trainer.train(resume_from_checkpoint=False)
    
    # METRICS
    train_metrics = trainer.state.log_history[-1]
    peak_mem = torch.cuda.max_memory_allocated() / 1024**3
    
    results = {
        "mode": mode_name,
        "total_time": trainer.state.log_history[-1].get('train_runtime', 0),
        "steps_per_sec": trainer.state.log_history[-1].get('train_samples_per_second', 0),
        "peak_vram_gb": peak_mem
    }
    
    print(f"ROUND FINISHED: {round(results['total_time'], 2)}s | VRAM: {round(peak_mem, 2)} GB")
    
    del model, trainer, dataset
    clean_memory()
    return results

## 3. Run The Fight!
Jalankan cell ini untuk memulai pertarungan.

In [None]:
STEPS = 50

print("Round 1: Baseline (PyTorch)...")
res_baseline = run_round(use_zenith=False, steps=STEPS)

print("\nRound 2: Challenger (Zenith)...")
res_zenith = run_round(use_zenith=True, steps=STEPS)

# --- REPORT ---
print(f"\n{'='*40}")
print(f"FINAL SCOREBOARD ({STEPS} Steps)")
print(f"{'='*40}")
print(f"{'Metric':<20} | {'PyTorch':<10} | {'Zenith':<10} | {'Delta'}")
print("-"*60)

t_base = res_baseline['total_time']
t_zen = res_zenith['total_time']
v_base = res_baseline['peak_vram_gb']
v_zen = res_zenith['peak_vram_gb']

diff_time = ((t_base - t_zen) / t_base) * 100
diff_vram = ((v_base - v_zen) / v_base) * 100

print(f"{'Time (s)':<20} | {t_base:<10.2f} | {t_zen:<10.2f} | {diff_time:+.2f}% {'(Faster)' if diff_time > 0 else ''}")
print(f"{'Peak VRAM (GB)':<20} | {v_base:<10.2f} | {v_zen:<10.2f} | {diff_vram:+.2f}% {'(Lighter)' if diff_vram > 0 else ''}")

# VISUALIZATION
labels = ['PyTorch', 'Zenith']
times = [t_base, t_zen]
vrams = [v_base, v_zen]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.bar(labels, times, color=['gray', 'blue'])
ax1.set_title('Training Time (Lower is Better)')
ax1.set_ylabel('Seconds')

ax2.bar(labels, vrams, color=['gray', 'green'])
ax2.set_title('Peak VRAM (Lower is Better)')
ax2.set_ylabel('GB')

plt.tight_layout()
plt.show()