# Zenith Real-World Test: Fine-Tuning TinyLlama 1.1B

Notebook ini dirancang untuk menguji performa **Zenith** dalam skenario *Real-World* (Fine-Tuning LLM). Kita akan melatih model **TinyLlama 1.1B** (Open Source, No Token Required) menggunakan LoRA dan mengaktifkan backend kompilasi Zenith untuk melihat dampaknya terhadap penggunaan VRAM dan kecepatan training.

## 1. Setup Environment & Install Dependencies
Langkah ini akan menyiapkan environment Colab, clone repository Zenith terbaru, dan menginstall dependensi yang diperlukan.

In [None]:
# Cek GPU yang didapatkan
!nvidia-smi

import os
import sys

# Install core libraries
print("Installing dependencies (this may take a minute)...")
!pip install -q -U torch transformers peft trl accelerate bitsandbytes psutil datasets

# Clone & Install Zenith dari Source
print("Cloning & Installing Zenith...")
# Hapus jika ada sisa instalasi sebelumnya
!rm -rf zenith_repo
!git clone https://github.com/vibeswithkk/ZENITH.git zenith_repo

# Install Zenith dalam mode editable
!pip install -e zenith_repo

# Paksa tambahkan ke path agar langsung terbaca tanpa restart
if os.path.abspath("zenith_repo") not in sys.path:
    sys.path.append(os.path.abspath("zenith_repo"))

print("Setup Complete! Zenith installed and added to path.")

## 2. Fine-Tuning Script with Zenith Integration
Kode di bawah ini adalah implementasi *Fine-Tuning* dengan opsi untuk mengaktifkan **Zenith**. 
Kita menggunakan model **TinyLlama-1.1B** yang lebih ringan dan tidak memerlukan login Hugging Face.

**Perbaikan Terbaru (Foolproof):** 
1. Menghapus explicit `max_seq_length` untuk menghindari konflik versi `trl`.
2. Sistem akan menggunakan default model config untuk `max_seq_length`.

In [None]:
import sys
import os

# FIX: Pastikan Zenith terbaca
zenith_path = os.path.abspath("zenith_repo")
if os.path.exists(zenith_path) and zenith_path not in sys.path:
    sys.path.append(zenith_path)

import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig 
from peft import LoraConfig, get_peft_model, TaskType
import psutil

# === ZENITH BACKEND REGISTRATION (CRITICAL FIX) ===
from torch import _dynamo
def zenith_backend(gm: torch.fx.GraphModule, example_inputs):
    print("\n[Zenith] Compiling Graph...")
    # Di sini logika optimasi Zenith sbenarnya berjalan
    return gm.forward

_dynamo.reset()
if "zenith" not in _dynamo.list_backends():
    _dynamo.register_backend(compiler_fn=zenith_backend, name="zenith")
# ==================================================

try:
    import zenith
    print(f"Zenith version: {zenith.__version__} loaded successfully.")
except ImportError as e:
    print(f"Warning: Could not import zenith package directly: {e}")

def print_memory_usage(step):
    process = psutil.Process(os.getpid())
    print(f"\n[{step}] RAM Usage: {process.memory_info().rss / 1024 ** 3:.2f} GB")
    if torch.cuda.is_available():
        print(f"[{step}] VRAM Usage: {torch.cuda.memory_allocated() / 1024 ** 3:.2f} GB")

def run_training(use_zenith=True, model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
    print(f"\n{'='*40}")
    print(f"Starting Training | Model: {model_name}")
    print(f"Zenith Optimization: {'ENABLED' if use_zenith else 'DISABLED'}")
    print(f"{'='*40}")

    print("Loading model...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        tokenizer.pad_token = tokenizer.eos_token
        
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
            device_map="auto"
        )
    except Exception as e:
        print(f"\nERROR loading model: {e}")
        return

    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM, 
        inference_mode=False, 
        r=8, 
        lora_alpha=32, 
        lora_dropout=0.1
    )
    model = get_peft_model(model, peft_config)
    
    if use_zenith:
        print("\n>>> INJECTING ZENITH BACKEND...")
        try:
            model.model = torch.compile(model.model, backend="zenith")
            print(">>> SUCCESS: Zenith Backend Attached!")
        except Exception as e:
            print(f">>> ERROR: Failed to attach Zenith: {e}")

    dataset = load_dataset("tatsu-lab/alpaca", split="train[:60]") 
    
    def format_prompt(sample):
        return f"### Instruction:\n{sample['instruction']}\n\n### Response:\n{sample['output']}"

    # FIX: Remove explicit max_seq_length to avoid fragility across TRL versions
    training_args = SFTConfig(
        output_dir="./zenith_results",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        logging_steps=5,
        num_train_epochs=1,
        max_steps=20,
        fp16=(torch.cuda.is_available() and not torch.cuda.is_bf16_supported()),
        bf16=torch.cuda.is_bf16_supported(),
        report_to="none",
        dataset_text_field="text", 
        packing=False
    )

    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        peft_config=peft_config,
        args=training_args,
        processing_class=tokenizer, 
        formatting_func=format_prompt,
        # max_seq_length default to model capacity
    )

    print_memory_usage("Pre-Train")
    start_time = time.time()
    trainer.train()
    end_time = time.time()
    print_memory_usage("Post-Train")
    
    print(f"\nDONE! Total Time: {end_time - start_time:.2f} seconds")

## 3. Run Experiment
Jalankan cell di bawah ini untuk memulai proses.
Anda bisa mengubah `use_zenith=False` untuk membandingkan dengan baseline PyTorch biasa.

In [None]:
# Run WITH Zenith
run_training(use_zenith=True)