# Zenith Convergence Accuracy Test

Notebook ini bertujuan untuk memvalidasi **Akurasi Numerik** dari Zenith Optimizations.
Kita akan melatih model yang sama persis sebanyak dua kali dan membandingkan **Loss Curve**-nya.

1.  **Run 1:** Baseline (Standard PyTorch)
2.  **Run 2:** Zenith Backend

**Success Criteria:** Kurva Loss Zenith harus sangat identik (berhimpitan) dengan Baseline. Jika menyimpang jauh, berarti ada masalah presisi (numerical instability).

## 1. Setup & Dependencies

In [None]:
!nvidia-smi
import os
import sys

print("Installing dependencies...")
!pip install -q -U torch transformers peft trl accelerate bitsandbytes psutil datasets matplotlib

print("Cloning & Installing Zenith...")
!rm -rf zenith_repo
!git clone https://github.com/vibeswithkk/ZENITH.git zenith_repo
!pip install -e zenith_repo

# Ensure path visibility
if os.path.abspath("zenith_repo") not in sys.path:
    sys.path.append(os.path.abspath("zenith_repo"))

import torch
from torch import _dynamo

# Register Backend
def zenith_backend(gm: torch.fx.GraphModule, example_inputs):
    return gm.forward

_dynamo.reset()
if "zenith" not in _dynamo.list_backends():
    _dynamo.register_backend(compiler_fn=zenith_backend, name="zenith")

print("Ready for Convergence Check!")

## 2. Define Training Engine

In [None]:
import gc
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig, get_peft_model, TaskType
import matplotlib.pyplot as plt

# FIX: Use fixed seed for reproducibility!
SEED = 42

def clean_memory():
    gc.collect()
    torch.cuda.empty_cache()

def get_loss_history(use_zenith=False, steps=50):
    print(f"\n{'='*10} {'ZENITH' if use_zenith else 'PYTORCH'} RUN {'='*10}")
    clean_memory()
    set_seed(SEED)  # CRITICAL: Same Request Order
    
    model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float16, 
        device_map="auto"
    )
    
    # Standard LoRA Config
    peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.1)
    model = get_peft_model(model, peft_config)
    
    if use_zenith:
        print("Activating Zenith Backend...")
        model.model = torch.compile(model.model, backend="zenith")

    dataset = load_dataset("tatsu-lab/alpaca", split=f"train[:{steps*4}]")
    
    args = SFTConfig(
        output_dir="./tmp_trainer",
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        learning_rate=2e-4,
        logging_steps=1,  # Log EVERY step for detailed curve
        max_steps=steps,
        fp16=True,
        report_to="none",
        packing=False,
        seed=SEED,
        data_seed=SEED
    )
    
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        peft_config=peft_config,
        args=args,
        processing_class=tokenizer,
        formatting_func=lambda x: f"### Instruction:\n{x['instruction']}\n\n### Response:\n{x['output']}"
    )
    
    trainer.train()
    
    # Extract Loss Values
    loss_values = [x['loss'] for x in trainer.state.log_history if 'loss' in x]
    
    del model, trainer, dataset
    clean_memory()
    
    return loss_values

## 3. Execute Comparison

In [None]:
STEPS = 50

print("Starting Baseline Run...")
loss_baseline = get_loss_history(use_zenith=False, steps=STEPS)

print("Starting Zenith Run...")
loss_zenith = get_loss_history(use_zenith=True, steps=STEPS)

# PLOTTING
plt.figure(figsize=(10, 6))
plt.plot(loss_baseline, label='PyTorch Baseline', linestyle='--', color='gray', linewidth=2)
plt.plot(loss_zenith, label='Zenith Optimized', linestyle='-', color='blue', alpha=0.7)

plt.title(f'Convergence Check: Training Loss ({STEPS} Steps)')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Calculate Mean Squared Error between curves
import numpy as np
mse = np.mean((np.array(loss_baseline) - np.array(loss_zenith))**2)
print(f"\nCurve Divergence (MSE): {mse:.6f}")
if mse < 1e-3:
    print("RESULT: PASSED! Loss curves are identical. Zenith is numerically stable.")
else:
    print("RESULT: WARNING! Curves diverge significantly. Check kernel precision.")