In [None]:
# -*- coding: utf-8 -*-
"""Gemma3_(4B)_DPO.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3_(4B).ipynb
"""

# Commented out IPython magic to ensure Python compatibility.
# %%capture
# import os
# if "COLAB_" not in "".join(os.environ.keys()):
#     print("Not in Colab. Assuming Unsloth is already installed.")
#     print("If not, please install with: pip install unsloth[colab-new]")
# else:
#     # Do this only in Colab notebooks! Otherwise use pip install unsloth
#     print("Installing Unsloth and dependencies for Colab...")
#     !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
#     !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
#     !pip install --no-deps "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
#     print("Installation complete.")

from unsloth import FastLanguageModel # Changed from FastModel for DPO consistency
import torch
from trl import DPOTrainer, DPOConfig
from datasets import load_dataset
from unsloth.chat_templates import get_chat_template

max_seq_length = 2048 # Choose any for long context!
# Use "unsloth" for LoRA optimization to fit 2x larger batch sizes!
use_gradient_checkpointing = "unsloth" # True or "unsloth" for very long context

# Load Gemma-3 model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/gemma-3-4b-it", # Using the instruction-tuned base
    max_seq_length = max_seq_length,
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    # token = "hf_...", # use one if using gated models
    # You can add dtype=torch.bfloat16 if your GPU supports it for potentially faster training
)

# Add LoRA adapters
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # LoRA rank. 8-64 is common. DPO template used 64. Let's try 16.
    lora_alpha = 32, # Recommended lora_alpha = 2 * r
    lora_dropout = 0.05, # Supports any, but = 0 is optimized for some Unsloth features
    bias = "none",    # Supports any, but = "none" is optimized
    # Target modules for Gemma-3. These are typical for Llama-like architectures.
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    use_gradient_checkpointing = use_gradient_checkpointing,
    random_state = 3407,
    max_seq_length = max_seq_length,
)

# Set up chat template for Gemma-3
# DPOTrainer will use this to format prompts and responses
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma-3",
    # map_eos_token = True, # Gemma-3 tokenizer has an EOS token, so this might not be needed
)
# It's good practice to set pad_token if it's not already set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


# --- Data Prep for DPO ---
# DPO requires a dataset with 'prompt', 'chosen', and 'rejected' columns.
# Replace this with your actual DPO dataset.
# Example using a small, standard DPO dataset:
# For a real run, you'd use a larger, more relevant dataset.
# Popular options:
#   - "HuggingFaceH4/ultrafeedback_binarized-preferences-cleaned" (large, high quality)
#   - "argilla/ultrafeedback-binarized-preferences-cleaned-ja" (Japanese version)
#   - "trl-internal-testing/hh-rlhf-trl-style" (small, for testing)

# For demonstration, let's create a tiny dummy dataset
# In a real scenario, you would load your DPO dataset:
# dpo_dataset = load_dataset("your_username/your_dpo_dataset_name", split="train")
# Or from a local file:
# dpo_dataset = load_dataset("json", data_files="path/to/your/dpo_data.jsonl", split="train")

# Using a small example dataset for demonstration
# This dataset has 'prompt', 'chosen', 'rejected' fields
USE_LOCAL_DATASET = True

if USE_LOCAL_DATASET:
    # --- Data Prep for DPO -------------------------------------------------------
    # set to the dpo pairs dataset that was generated by auto_unslop
    dpo_file = "./results/[experiment_id]/dpo_pairs_dataset.jsonl"
    print(f"Loading DPO dataset from {dpo_file}")
    dpo_dataset = load_dataset("json", data_files=str(dpo_file), split="train")

    # quick sanity-check – drop any rows that somehow lack the three fields
    req_cols = {"prompt", "chosen", "rejected"}
    before   = len(dpo_dataset)
    dpo_dataset = dpo_dataset.filter(lambda x: all(col in x and x[col] for col in req_cols))
    after    = len(dpo_dataset)
    if after == 0:
        raise ValueError("All rows were filtered out – check dataset contents.")
    if after < before:
        print(f"Filtered out {before - after} malformed rows; {after} remain.")

    print(f"DPO dataset ready with {after} samples.")
else:
    dpo_dataset = load_dataset("trl-internal-testing/hh-rlhf-trl-style", split="train[:1%]") # Take a small slice for quick demo
    # Filter out examples where prompt, chosen, or rejected are None or empty
    dpo_dataset = dpo_dataset.filter(
        lambda x: x["prompt"] is not None and x["chosen"] is not None and x["rejected"] is not None and \
                  len(x["prompt"]) > 0 and len(x["chosen"]) > 0 and len(x["rejected"]) > 0
    )
    if len(dpo_dataset) == 0:
        raise ValueError("Filtered dataset is empty. Check the original data or filtering logic.")
    print(f"Loaded DPO dataset with {len(dpo_dataset)} samples.")
    print("First DPO sample:")
    print(dpo_dataset[0])


# --- Train the model with DPO ---
# Note: For DPO, you typically don't need a separate eval_dataset during training,
# but it can be useful for monitoring.
# The `ref_model` is set to None for LoRA, DPOTrainer will handle creating a reference.

# Ensure the dataset is not empty
if len(dpo_dataset) == 0:
    raise ValueError("DPO dataset is empty. Please provide a valid dataset.")

dpo_trainer = DPOTrainer(
    model = model,
    ref_model = None, # Handled automatically by DPOTrainer for LoRA
    train_dataset = dpo_dataset,
    # eval_dataset = YOUR_EVAL_DPO_DATASET_HERE, # Optional
    tokenizer = tokenizer,
    args = DPOConfig(
        per_device_train_batch_size = 1, # Adjust based on your GPU memory
        gradient_accumulation_steps = 4, # Effective batch size = 1 * 4 = 4
        warmup_ratio = 0.1, # Or warmup_steps
        num_train_epochs = 1, # For a quick demo. Set to 1-3 for a real run.
        # max_steps = 60, # Alternatively, use max_steps for a fixed number of steps
        learning_rate = 5e-5, # Common DPO learning rate
        logging_steps = 10,
        optim = "adamw_8bit", # Unsloth optimizes this
        seed = 42,
        output_dir = "outputs_gemma3_4b_dpo",
        max_length = max_seq_length,         # Max length of combined prompt + response
        max_prompt_length = max_seq_length // 2, # Max length of prompt
        beta = 0.1, # DPO beta parameter
        report_to = "none", # "wandb" or "tensorboard"
        lr_scheduler_type = "linear",
        # bf16 = True, # Set to True if your GPU supports bfloat16 and you loaded model with bfloat16
        # fp16 = False, # Set to True for mixed precision if bf16 is not available (and not using 4bit)
    ),
)

print("Starting DPO training...")
trainer_stats = dpo_trainer.train()
print("DPO training finished.")


# --- Show current memory stats (optional) ---
if torch.cuda.is_available():
    gpu_stats = torch.cuda.get_device_properties(0)
    start_gpu_memory = round(torch.cuda.memory_reserved() / 1024 / 1024 / 1024, 3) # Corrected: use memory_reserved
    max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
    print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
    print(f"{start_gpu_memory} GB of memory reserved after training.") # This will be peak if called after train

    # Show final memory and time stats
    if hasattr(trainer_stats, 'metrics') and 'train_runtime' in trainer_stats.metrics:
        used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
        # Note: start_gpu_memory was captured *after* model loading.
        # A more accurate "used_memory_for_lora" would require capturing memory before model load and after.
        # For simplicity, we'll just show peak reserved.
        used_percentage = round(used_memory / max_memory * 100, 3)
        print(f"{trainer_stats.metrics['train_runtime']:.2f} seconds used for training.")
        print(f"{trainer_stats.metrics['train_runtime']/60:.2f} minutes used for training.")
        print(f"Peak reserved memory = {used_memory} GB.")
        print(f"Peak reserved memory % of max memory = {used_percentage} %.")
else:
    print("CUDA not available. Memory stats not shown.")


# --- Inference after DPO ---
# For inference, make sure the tokenizer has the chat template for Gemma-3
# The model is already LoRA-adapted.

# Reload tokenizer with chat template if needed (should be already set)
# tokenizer = get_chat_template(
#     tokenizer,
#     chat_template = "gemma-3",
# )

# Example prompt
messages = [{
    "role": "user",
    "content": "What are the pros and cons of pineapple on pizza?"
}]

# Apply chat template for generation
# For DPO, the model has learned from preferences, so it should generate better responses.
# The prompt format should match what it saw during training (user turn).
# `add_generation_prompt=True` adds the necessary tokens to signal the model to start generating.
# For Gemma-3, this means it will end with `<start_of_turn>model\n`
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize = True,
    add_generation_prompt = True, # Crucial for generation
    return_tensors = "pt",
).to("cuda" if torch.cuda.is_available() else "cpu")

from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer, skip_prompt = True)

print("\n--- Generating DPO model response (streaming) ---")
_ = model.generate(
    inputs,
    max_new_tokens = 128, # Increase for longer outputs!
    # Recommended Gemma-3 settings!
    temperature = 0.7, # Slightly lower temperature for more focused output after DPO
    top_p = 0.95,
    top_k = 64,
    streamer = text_streamer,
    pad_token_id = tokenizer.eos_token_id # Important for generation
)
print("\n--- End of DPO model response ---")

# --- Saving the DPO-finetuned model (LoRA adapters) ---
dpo_model_save_path = "gemma-3-4b-dpo-lora"
dpo_trainer.save_model(dpo_model_save_path) # Saves LoRA adapters
tokenizer.save_pretrained(dpo_model_save_path)
print(f"DPO LoRA adapters and tokenizer saved to ./{dpo_model_save_path}")

# To load the LoRA adapters later for inference:
if False: # Set to True to test loading
    from unsloth import FastLanguageModel
    loaded_model, loaded_tokenizer = FastLanguageModel.from_pretrained(
        model_name = dpo_model_save_path, # Path to your saved LoRA adapters
        max_seq_length = max_seq_length,
        load_in_4bit = True,
    )
    # Now `loaded_model` is ready for inference.
    # Ensure tokenizer has chat template
    loaded_tokenizer = get_chat_template(
        loaded_tokenizer,
        chat_template = "gemma-3",
    )
    if loaded_tokenizer.pad_token is None:
        loaded_tokenizer.pad_token = loaded_tokenizer.eos_token

    print("\n--- Generating response from loaded DPO LoRA model (streaming) ---")
    _ = loaded_model.generate(
        inputs, # Using the same inputs as before
        max_new_tokens = 128,
        temperature = 0.7, top_p = 0.95, top_k = 64,
        streamer = TextStreamer(loaded_tokenizer, skip_prompt = True),
        pad_token_id = loaded_tokenizer.eos_token_id
    )
    print("\n--- End of loaded DPO model response ---")


# --- Saving to float16 for VLLM or other deployments (merged model) ---
if False: # Change to True to save merged finetune!
    # Merges LoRA adapters into the base model and saves
    # This creates a standalone model, not just adapters
    merged_model_path = "gemma-3-4b-dpo-merged"
    model.save_pretrained_merged(merged_model_path, tokenizer, save_method = "merged_16bit")
    print(f"Merged 16-bit DPO model saved to ./{merged_model_path}")
    # For GGUF, you can then convert this merged model or use Unsloth's direct GGUF saving
    # model.push_to_hub_merged("YOUR_HF_USERNAME/gemma-3-4b-dpo-merged", tokenizer, save_method = "merged_16bit", token = "YOUR_HF_TOKEN")

# --- GGUF / llama.cpp Conversion ---
if False: # Change to True to save to GGUF
    # Saves the LoRA model directly to GGUF by first merging.
    # Quantization types: "Q8_0", "F16", "BF16", "Q4_K_M", "Q5_K_M" etc.
    gguf_model_path = "gemma-3-4b-dpo-gguf"
    model.save_pretrained_gguf(gguf_model_path, tokenizer, quantization_method = "q8_0")
    print(f"GGUF (Q8_0) DPO model saved to ./{gguf_model_path}.gguf") # Unsloth adds .gguf
    # model.push_to_hub_gguf("YOUR_HF_USERNAME/gemma-3-4b-dpo-gguf", tokenizer, quantization_method = "q8_0", token = "YOUR_HF_TOKEN")

print("\nScript finished. Remember to replace placeholder DPO dataset with your actual data.")