#  Zenith Inference Speed Test: Tokens Per Second (TPS)

Notebook ini dirancang untuk mengukur kecepatan inferensi (generation speed) antara:
1.  **Baseline:** PyTorch Native `model.generate()`
2.  **Zenith:** Optimized `torch.compile(model, backend='zenith')`

Metrics:
*   **Time to First Token (TTFT):** Latensi awal.
*   **Tokens Per Second (TPS):** Kecepatan total output teks.
*   **Total Inference Time:** Waktu keseluruhan.

## 1. Setup Environment

In [None]:
!nvidia-smi
import os
import sys

print("Installing dependencies...")
!pip install -q -U torch transformers accelerate bitsandbytes psutil matplotlib

print("Cloning & Installing Zenith...")
!rm -rf zenith_repo
!git clone https://github.com/vibeswithkk/ZENITH.git zenith_repo
!pip install -e zenith_repo

if os.path.abspath("zenith_repo") not in sys.path:
    sys.path.append(os.path.abspath("zenith_repo"))

import torch
from torch import _dynamo

# Register Dummy Backend if not present (for test without full kernel build)
def zenith_backend(gm: torch.fx.GraphModule, example_inputs):
    return gm.forward

_dynamo.reset()
if "zenith" not in _dynamo.list_backends():
    _dynamo.register_backend(compiler_fn=zenith_backend, name="zenith")

print("Ready for Inference Testing!")

In [None]:
import time
import gc
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import numpy as np

MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
PROMPT = "The future of Artificial Intelligence is"
MAX_NEW_TOKENS = 100

def load_model():
    print(f"Loading {MODEL_ID}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, 
        torch_dtype=torch.float16, 
        device_map="cuda"
    )
    return model, tokenizer

def benchmark_inference(model, tokenizer, use_zenith=False, runs=5):
    mode = "ZENITH" if use_zenith else "PYTORCH"
    print(f"\n{'='*10} BENCHMARK: {mode} {'='*10}")
    
    if use_zenith:
        print("Compiling model with Zenith backend...")
        # Compile the forward pass of the model
        model = torch.compile(model, backend="zenith")
    
    input_ids = tokenizer(PROMPT, return_tensors="pt").input_ids.cuda()
    
    # Warmup
    print("Warming up... (This compiles the graph if Zenith is on)")
    _ = model.generate(input_ids, max_new_tokens=10, do_sample=False)
    
    latencies = []
    tokens_per_sec = []
    
    print(f"Running {runs} generations...")
    for i in range(runs):
        torch.cuda.synchronize()
        start_time = time.time()
        
        output = model.generate(input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)
        
        torch.cuda.synchronize()
        end_time = time.time()
        
        latency = end_time - start_time
        num_tokens = len(output[0]) - len(input_ids[0])
        tps = num_tokens / latency
        
        latencies.append(latency)
        tokens_per_sec.append(tps)
        print(f"Running {i+1}: {tps:.2f} tokens/sec ({latency:.4f}s)")
        
    avg_tps = np.mean(tokens_per_sec)
    print(f"AVG TPS ({mode}): {avg_tps:.2f}")
    
    # Cleanup
    del output
    gc.collect()
    torch.cuda.empty_cache()
    
    return avg_tps, tokens_per_sec, model

## 2. Run Comparison

In [None]:
# 1. Load Baseline Model
model, tokenizer = load_model()

# 2. Benchmark PyTorch
tps_baseline, _, model = benchmark_inference(model, tokenizer, use_zenith=False)

# 3. Benchmark Zenith (Compile SAME model)
# Note: In real scenarios, we might reload ensuring clean slate, but compile usually handles inplace optimization
tps_zenith, _, _ = benchmark_inference(model, tokenizer, use_zenith=True)

# 4. Results
print(f"\n{'='*40}")
print(f"INFERENCE SCOREBOARD")
print(f"{'='*40}")
print(f"PyTorch: {tps_baseline:.2f} TPS")
print(f"Zenith : {tps_zenith:.2f} TPS")
delta = ((tps_zenith - tps_baseline) / tps_baseline) * 100
print(f"Improvement: {delta:+.2f}%")

# Plot
labels = ['PyTorch', 'Zenith']
values = [tps_baseline, tps_zenith]

plt.figure(figsize=(8, 6))
bars = plt.bar(labels, values, color=['gray', 'blue'])
plt.title('Inference Speed (Tokens Per Second) - Higher is Better')
plt.ylabel('TPS')

for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, yval, f'{yval:.2f}', ha='center', va='bottom')

plt.show()