## ⚙️ Environment Reset & Library Reinstallation  

This step resets the Colab environment to remove any conflicting library versions and installs a stable set of dependencies tested for **Unsloth RL/DPO** training.  
It ensures all required frameworks — such as PyTorch, Transformers, TRL, PEFT, and Datasets — are correctly aligned for compatibility and performance.  

By reinstalling specific versions, this configuration prevents runtime errors (e.g., missing attributes or mismatched CUDA builds) and guarantees smooth execution of Unsloth-based reinforcement learning workflows.

In [2]:
# 0) make sure hash-checking isn't forcing pip to fail
import os
os.environ.pop("PIP_REQUIRE_HASHES", None)
print("PIP_REQUIRE_HASHES =", os.environ.get("PIP_REQUIRE_HASHES"))

# 1) Remove conflicting wheels (ok if some aren't installed)
%pip -q uninstall -y torch torchvision torchaudio xformers triton unsloth unsloth_zoo transformers datasets peft accelerate bitsandbytes || true

# 2) Install a known-good combo for Unsloth RL/DPO on Colab
%pip install --upgrade --force-reinstall --no-cache-dir \
  "torch==2.9.0" \
  "torchvision==0.24.0" \
  "xformers==0.0.33.post1" \
  "triton==3.5.0" \
  "unsloth==2025.11.3" \
  "unsloth_zoo==2025.11.4" \
  "transformers==4.57.1" \
  "trl==0.23.0" \
  "accelerate==1.11.0" \
  "peft==0.17.1" \
  "datasets==4.3.0" \
  "bitsandbytes==0.48.2"


PIP_REQUIRE_HASHES = None
[0mCollecting torch==2.9.0
  Downloading torch-2.9.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting torchvision==0.24.0
  Downloading torchvision-0.24.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)
Collecting xformers==0.0.33.post1
  Downloading xformers-0.0.33.post1-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (1.2 kB)
Collecting triton==3.5.0
  Downloading triton-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.7 kB)
Collecting unsloth==2025.11.3
  Downloading unsloth-2025.11.3-py3-none-any.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.8/61.8 kB[0m [31m57.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting unsloth_zoo==2025.11.4
  Downloading unsloth_zoo-2025.11.4-py3-none-any.whl.metadata (32 kB)
Collecting transformers==4.57.1
  Downloading transformers-4.57.1-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4

## ✅ Verifying Library Imports  

This step confirms that all essential libraries are installed correctly and working together.  
By importing **Unsloth** first, we ensure its dependency patches are applied before loading other frameworks like Transformers or TRL.  
A successful printout of library versions indicates that the environment is properly configured and ready for model training and fine-tuning.

In [2]:
# 3) Verify imports (and always import Unsloth FIRST in your notebook)
import unsloth, torch, transformers, trl, datasets, peft
from unsloth import FastLanguageModel
print("OK:", unsloth.__version__, transformers.__version__, datasets.__version__, torch.__version__, trl.__version__)

OK: 2025.11.3 4.57.1 4.3.0 2.9.0+cu128 0.23.0


## 🧠 Initial Setup & Configuration  

This step initializes the core libraries and defines essential parameters for the **Direct Preference Optimization (DPO)** fine-tuning process.  
Unsloth is imported **first** to ensure it safely patches underlying dependencies before other frameworks are loaded.  
A fixed random seed guarantees reproducible training results across runs.  

Key settings include:  
- **POLICY_BASE / REF_BASE:** Base models for the policy (trainable) and reference (frozen) networks.  
- **MAX_LEN:** Maximum sequence length for input prompts.  
- **SUBSET:** Number of samples loaded from the preference dataset for faster training.  
- **MAX_STEPS:** Total training steps in DPO fine-tuning.  
- **MAX_TARGET:** Maximum number of tokens generated in the target output.  
- **DO_MERGE:** Option to merge LoRA adapters into a full-precision model after training (kept `False` for faster runs).  

This configuration sets up a lightweight yet effective training environment suitable for DPO fine-tuning in Colab.

In [3]:
# Import Unsloth FIRST so it patches dependencies safely
import unsloth
from unsloth import FastLanguageModel

# Then other libs
from datasets import load_dataset
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer
from peft import PeftModel
import wandb
import random, torch

SEED = 3407
random.seed(SEED); torch.manual_seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"

# FAST MODE knobs
POLICY_BASE = "HuggingFaceTB/SmolLM2-135M-Instruct"
REF_BASE    = "HuggingFaceTB/SmolLM2-135M-Instruct"
MAX_LEN     = 768      # was 1024
SUBSET      = 1000     # was 3000
MAX_STEPS   = 200      # was 400
MAX_TARGET  = 128      # was 256
DO_MERGE    = False    # keep False for speed; True to create merged fp16 checkpoint

print({"POLICY_BASE": POLICY_BASE, "REF_BASE": REF_BASE, "MAX_LEN": MAX_LEN, "SUBSET": SUBSET, "MAX_STEPS": MAX_STEPS, "MAX_TARGET": MAX_TARGET})


{'POLICY_BASE': 'HuggingFaceTB/SmolLM2-135M-Instruct', 'REF_BASE': 'HuggingFaceTB/SmolLM2-135M-Instruct', 'MAX_LEN': 768, 'SUBSET': 1000, 'MAX_STEPS': 200, 'MAX_TARGET': 128}


## ⚙️ Loading the Policy and Reference Models  

This step loads two models required for **Direct Preference Optimization (DPO)** fine-tuning:  

1. **Policy Model** — the trainable model that learns to generate preferred responses.  
2. **Reference Model** — a frozen model used as a comparison baseline during training.  

Both models are loaded in **4-bit precision** using Unsloth’s `FastLanguageModel` for efficient GPU memory usage.  
The policy model is further wrapped with **LoRA adapters**, enabling parameter-efficient fine-tuning without modifying the base weights.  

Key configurations include:  
- `r`, `lora_alpha`, `lora_dropout` — LoRA hyperparameters controlling adapter size, scaling, and regularization.  
- `target_modules` — specific transformer layers where LoRA adapters are injected.  
- `use_gradient_checkpointing` — saves VRAM during training by recomputing intermediate activations.  
- The reference model’s gradients are disabled to keep it **frozen** throughout training.  

This setup establishes a compact, memory-efficient dual-model architecture ready for preference-based reinforcement learning.

In [4]:
# ---------- Load POLICY (4-bit + LoRA) and REFERENCE (4-bit frozen) ----------
policy, tokenizer = FastLanguageModel.from_pretrained(
    model_name     = POLICY_BASE,
    max_seq_length = MAX_LEN,
    dtype          = None,
    load_in_4bit   = True,
)

policy = FastLanguageModel.get_peft_model(
    policy,
    r=16, lora_alpha=32, lora_dropout=0.05,
    bias="none",
    use_gradient_checkpointing=True,
    random_state=SEED,
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
)
if hasattr(policy, "print_trainable_parameters"):
    policy.print_trainable_parameters()

reference, _ = FastLanguageModel.from_pretrained(
    model_name     = REF_BASE,
    max_seq_length = MAX_LEN,
    dtype          = None,
    load_in_4bit   = True,
)
for p in reference.parameters():
    p.requires_grad_(False)


==((====))==  Unsloth 2025.11.3: Fast Llama patching. Transformers: 4.57.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu128. CUDA: 7.5. CUDA Toolkit: 12.8. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/655 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

HuggingFaceTB/SmolLM2-135M-Instruct does not have a padding token! Will use pad_token = <|endoftext|>.


Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.
Unsloth 2025.11.3 patched 30 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


trainable params: 4,884,480 || all params: 139,399,488 || trainable%: 3.5039
==((====))==  Unsloth 2025.11.3: Fast Llama patching. Transformers: 4.57.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu128. CUDA: 7.5. CUDA Toolkit: 12.8. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
HuggingFaceTB/SmolLM2-135M-Instruct does not have a padding token! Will use pad_token = <|endoftext|>.


## 📚 Loading the Preference Dataset  

This step loads the **UltraFeedback Binarized** dataset from Hugging Face, which is specifically designed for **preference-based fine-tuning**.  
Each data sample includes a *prompt*, a *chosen* (preferred) response, and a *rejected* (less-preferred) response — making it ideal for Direct Preference Optimization (DPO).  

A smaller subset of the dataset is used (`SUBSET`), allowing for faster experimentation and reduced GPU load in Colab environments.  

After loading, the script prints the number of samples and the dataset’s column structure to verify successful import.

In [5]:
# ===============================================================
# 📘 Load the raw dataset
# ===============================================================
from datasets import load_dataset

raw = load_dataset("HuggingFaceH4/ultrafeedback_binarized", split=f"train_prefs[:{SUBSET}]")
print("Dataset loaded:", len(raw), "samples")
print("Columns:", raw.column_names)



README.md: 0.00B [00:00, ?B/s]

data/train_prefs-00000-of-00001.parquet:   0%|          | 0.00/226M [00:00<?, ?B/s]

data/test_prefs-00000-of-00001.parquet:   0%|          | 0.00/7.29M [00:00<?, ?B/s]

data/test_sft-00000-of-00001.parquet:   0%|          | 0.00/3.72M [00:00<?, ?B/s]

data/train_gen-00000-of-00001.parquet:   0%|          | 0.00/184M [00:00<?, ?B/s]

data/test_gen-00000-of-00001.parquet:   0%|          | 0.00/3.02M [00:00<?, ?B/s]

Generating train_prefs split:   0%|          | 0/61135 [00:00<?, ? examples/s]

Generating train_sft split:   0%|          | 0/61135 [00:00<?, ? examples/s]

Generating test_prefs split:   0%|          | 0/2000 [00:00<?, ? examples/s]

Generating test_sft split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating train_gen split:   0%|          | 0/61135 [00:00<?, ? examples/s]

Generating test_gen split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Dataset loaded: 1000 samples
Columns: ['prompt', 'prompt_id', 'chosen', 'rejected', 'messages', 'score_chosen', 'score_rejected']


## ✂️ Preprocessing & Truncation for DPO  

This step prepares the dataset for **Direct Preference Optimization (DPO)** by applying strict token length limits to all text fields.  
Since DPO training involves comparing *preferred* and *rejected* responses for the same prompt, maintaining consistent input lengths is essential for stable optimization.

### 🔹 Key Operations
- **Tokenizer Settings:**  
  - `model_max_length` ensures prompts never exceed the model’s maximum context window (`MAX_LEN`).  
  - `padding_side` is set to *right* for uniform padding, and `truncation_side` to *left* to preserve the most recent context.  

- **Helper Functions:**  
  - `_truncate_text()` – trims text to a fixed token limit.  
  - `to_chat_prompt()` – formats prompts using the model’s chat template and truncates to `MAX_LEN`.  
  - `trim_answer()` – restricts both chosen and rejected answers to `MAX_TARGET` tokens.  

- **Dataset Mapping:**  
  The `map_dpo()` function rebuilds the dataset with tokenized and truncated `prompt`, `chosen`, and `rejected` fields.  

### ✅ Sanity Check
Finally, the code verifies that no prompt or response exceeds its length cap, ensuring each sequence fits perfectly into the model’s input size during training.

This preprocessing guarantees consistent, memory-safe training behavior across all DPO batches.

In [6]:
# ---- Hard length guards for DPO ----
# Ensure tokenizer knows the caps
tokenizer.model_max_length = MAX_LEN          # e.g., 768
tokenizer.padding_side = "right"
tokenizer.truncation_side = "left"            # trim from the left for long prompts

def _truncate_text(txt: str, max_tokens: int) -> str:
    ids = tokenizer(
        txt,
        add_special_tokens=False,
        truncation=True,
        max_length=max_tokens,
        return_attention_mask=False,
        return_token_type_ids=False,
    )["input_ids"]
    return tokenizer.decode(ids, skip_special_tokens=False)

def to_chat_prompt(prompt_text: str) -> str:
    # Minimal system to keep prompt short
    messages = [
        {"role": "system", "content": "You are a concise assistant."},
        {"role": "user",   "content": str(prompt_text).strip()},
    ]
    s = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,   # policy will generate assistant
    )
    # Truncate the prompt to MAX_LEN tokens
    return _truncate_text(s, MAX_LEN)

def trim_answer(ans: str) -> str:
    # Truncate targets to MAX_TARGET tokens to keep sequence <= MAX_LEN + MAX_TARGET
    return _truncate_text(str(ans), MAX_TARGET)

# Rebuild the mapped dataset with strict truncation
def map_dpo(batch):
    prompts   = [to_chat_prompt(p) for p in batch["prompt"]]
    chosens   = [trim_answer(c)     for c in batch["chosen"]]
    rejecteds = [trim_answer(r)     for r in batch["rejected"]]
    return {"prompt": prompts, "chosen": chosens, "rejected": rejecteds}

dpo_ds = raw.map(
    map_dpo,
    batched=True,
    num_proc=2,
    remove_columns=raw.column_names,
)

# Quick sanity: verify no prompt exceeds MAX_LEN and no target exceeds MAX_TARGET
def _count_toks(s): return len(tokenizer(s, add_special_tokens=False)["input_ids"])
print("Sanity (first row):",
      _count_toks(dpo_ds[0]["prompt"]), _count_toks(dpo_ds[0]["chosen"]), _count_toks(dpo_ds[0]["rejected"]))


Map (num_proc=2):   0%|          | 0/1000 [00:00<?, ? examples/s]

Sanity (first row): 29 128 128


## 🧩 DPO Configuration & Trainer Wiring (the *correct* way)

This block builds a **Direct Preference Optimization** training loop with a
trainable **LoRA policy** and a **frozen 4-bit reference** model.

### 🔧 Why this setup works
- **`DPOConfig` (trainer args only):** Holds optimization/runtime knobs  
  — batch size, grad accumulation, `max_steps`, LR schedule, warmup, logging,
  precision (`fp16`/`bf16`), workers, and seed.  
  *Important:* Do **not** put `beta` or `max_target_length` here.
- **`DPOTrainer` (algorithm args):** Where DPO-specific pieces live:
  - `model` = the **LoRA policy** (trainable)
  - `ref_model` = the **frozen** 4-bit reference for KL anchoring
  - `beta=0.1` = strength of preference vs. KL regularization
  - `max_length` = max prompt length (input context)
  - `max_target_length` = cap on generated targets (chosen/rejected)
  - Column mapping: `prompt`, `chosen`, `rejected` to your preprocessed dataset

### 🚀 Fast-training choices
- **Small per-device batch + grad accumulation** keeps memory low while preserving effective batch size.
- **`save_strategy="no"`** skips checkpoints to maximize throughput (good for quick experiments).
- **Precision auto-switch (`bf16` if supported, else `fp16`)** for speed + memory savings.
- **`report_to=[]`** disables external loggers to reduce overhead (flip on later if needed).

### ✅ Outcome
Creates a ready-to-train DPO loop with clean separation between:
- **Runtime/optimization config** (in `DPOConfig`), and  
- **DPO algorithm parameters & data wiring** (in `DPOTrainer`).

In [7]:
# --- Correct DPO setup: config + trainer (no max_target_length in DPOConfig) ---
from trl import DPOConfig, DPOTrainer

dpo_args = DPOConfig(
    output_dir="smollm2_dpo_rl_fast",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    max_steps=MAX_STEPS,               # e.g., 200
    learning_rate=5e-6,
    lr_scheduler_type="linear",
    warmup_steps=25,
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    logging_steps=50,
    save_strategy="no",                # fastest: no checkpoints
    report_to=[],                      # avoid wandb
    dataloader_num_workers=2,
    seed=SEED,
    # NOTE: do NOT put max_target_length or beta here
)

trainer = DPOTrainer(
    model=policy,                      # LoRA policy (trainable)
    ref_model=reference,               # frozen 4-bit reference
    args=dpo_args,
    tokenizer=tokenizer,
    train_dataset=dpo_ds,
    beta=0.1,                          # <-- put beta here
    max_length=MAX_LEN,                # e.g., 768
    max_target_length=MAX_TARGET,      # e.g., 128  (pass to trainer, not config)
    prompt_column="prompt",
    chosen_column="chosen",
    rejected_column="rejected",
)

print("DPOTrainer ready (config fixed).")


Extracting prompt in train dataset (num_proc=6):   0%|          | 0/1000 [00:00<?, ? examples/s]

Applying chat template to train dataset (num_proc=6):   0%|          | 0/1000 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=6):   0%|          | 0/1000 [00:00<?, ? examples/s]

DPOTrainer ready (config fixed).


### 🚀 Training the Model with DPO

This cell executes the **Direct Preference Optimization (DPO)** training loop.

- We first clear any cached GPU memory to maximize available VRAM.
- `trainer.train()` handles the entire fine-tuning process using the:
  - **trainable LoRA policy model**
  - **frozen reference model**
  - **preprocessed DPO dataset** with `prompt`, `chosen`, and `rejected` pairs.
- The script also tracks how long training takes for better performance benchmarking.

> ⚙️ DPO encourages the model to generate responses closer to the “chosen” answers while diverging from the “rejected” ones, improving alignment without reinforcement rollouts.


## 🚀 Starting DPO Training  

This step launches the **Direct Preference Optimization (DPO)** training process for the LoRA-enhanced policy model.  
Before training begins, the script clears the GPU cache and triggers garbage collection to free up memory, ensuring maximum efficiency during the optimization loop.

### 🔹 Key Actions  
- **Memory cleanup:** Frees up unused CUDA memory to prevent allocation errors.  
- **Timer initialization:** Tracks the total duration of the training process.  
- **Trainer execution:** Calls `trainer.train()` to begin optimization based on the DPO configuration.  
- **Runtime logging:** Displays total elapsed time and basic training output metrics once training completes.  

This marks the main stage where the model learns from preference data by comparing “chosen” and “rejected” responses to optimize alignment and response quality.

In [8]:
import gc, time
# ---------- Train ----------
gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()
start = time.time()
train_out = trainer.train()
elapsed = time.time() - start
print("Train out:", train_out)
print(f"Elapsed: {elapsed/60:.1f} min")



The model is already on multiple devices. Skipping the move to device specified in `args`.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,000 | Num Epochs = 4 | Total steps = 200
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 8 x 1) = 16
 "-____-"     Trainable parameters = 4,884,480 of 139,399,488 (3.50% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,rewards / chosen,rewards / rejected,rewards / accuracies,rewards / margins,logps / chosen,logps / rejected,logits / chosen,logits / rejected,eval_logits / chosen,eval_logits / rejected,nll_loss
50,0.6932,0.001962,0.001927,0.48875,3.5e-05,-299.747101,-295.733429,7.339461,7.404482,0,0,0
100,0.6892,0.019915,0.01164,0.633838,0.008275,-302.061981,-299.768799,7.220024,7.277154,No Log,No Log,No Log
150,0.6846,0.037022,0.019629,0.728535,0.017393,-299.166382,-299.08313,7.377736,7.443317,No Log,No Log,No Log
200,0.681,0.048302,0.023581,0.724747,0.024721,-299.279541,-295.559143,7.356658,7.36257,No Log,No Log,No Log


Train out: TrainOutput(global_step=200, training_loss=0.6869774436950684, metrics={'train_runtime': 1175.3704, 'train_samples_per_second': 2.723, 'train_steps_per_second': 0.17, 'total_flos': 0.0, 'train_loss': 0.6869774436950684, 'epoch': 3.176})
Elapsed: 19.6 min


## 💾 Save Outputs & Quick Inference Sanity Checks

This step packages your training results and lets you sanity-check the model’s behavior.

### What happens here
- **Save LoRA adapters**: Persists the fine-tuned adapter weights to `smollm2_dpo_rl_fast/adapters` and the tokenizer files to `smollm2_dpo_rl_fast/tokenizer`, so you can reload or upload them later.
- **(Optional) Merge to fp16**: If `DO_MERGE=True`, the LoRA adapters are merged into the base model to produce a single fp16 checkpoint under `smollm2_dpo_rl_fast/merged`—useful for straightforward deployment/inference.
- **Lightweight chat helper**: A tiny function builds a short system+user prompt (via the tokenizer’s chat template) and runs **greedy decoding** for consistent outputs.
- **Smoke tests**: Runs two quick prompts (one explanatory, one coding) to verify the model generates sensible responses after DPO tuning.

### Why it’s useful
- Adapter saving keeps training artifacts small and upload-friendly.
- The optional merge creates a self-contained model for environments that don’t support LoRA.
- The smoke tests provide an immediate, low-cost sanity check before you invest time in full evaluation.

In [9]:
# ===============================================================
# 💾 Save adapters (fast path) and optionally merge to a single fp16 checkpoint
# ===============================================================
import os, torch
from peft import PeftModel
from transformers import AutoModelForCausalLM

# Where to store outputs
ADAPTER_DIR = "smollm2_dpo_rl_fast/adapters"
TOKEN_DIR   = "smollm2_dpo_rl_fast/tokenizer"
os.makedirs(ADAPTER_DIR, exist_ok=True)

# Save LoRA adapters and tokenizer config
trainer.model.save_pretrained(ADAPTER_DIR)
tokenizer.save_pretrained(TOKEN_DIR)
print(f"✅ Saved adapters to: {ADAPTER_DIR}")
print(f"✅ Saved tokenizer to: {TOKEN_DIR}")

# ---------------------------------------------------------------
# 🔁 Optional: merge LoRA into a single fp16 model (set DO_MERGE=True)
# ---------------------------------------------------------------
MERGED_DIR = None
if DO_MERGE:
    MERGED_DIR = "smollm2_dpo_rl_fast/merged"
    os.makedirs(MERGED_DIR, exist_ok=True)

    # Load the base model in fp16 on available device(s)
    base_fp16 = AutoModelForCausalLM.from_pretrained(
        POLICY_BASE,
        torch_dtype=torch.float16,
        device_map="auto",
    )
    # Attach adapters and merge into the base weights
    peft_model = PeftModel.from_pretrained(base_fp16, ADAPTER_DIR)
    merged = peft_model.merge_and_unload()

    # Persist merged weights + tokenizer
    merged.save_pretrained(MERGED_DIR, safe_serialization=True)
    tokenizer.save_pretrained(MERGED_DIR)
    print(f"✅ Merged model saved to: {MERGED_DIR}")

# ===============================================================
# 🗣️ Tiny chat helper for quick sanity checks
# ===============================================================
def chat(prompt: str, max_new_tokens: int = 128):
    # Pick the right model for inference: merged (if created) or LoRA policy
    model_for_infer = merged if (DO_MERGE and "merged" in locals()) else policy
    model_for_infer.eval()

    # Make sure tokenizer has a pad token
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Minimal system + user message
    messages = [
        {"role": "system", "content": "You are a helpful, concise assistant."},
        {"role": "user",   "content": str(prompt).strip()},
    ]

    # Tokenize with the chat template
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(device)

    # Build a simple attention mask (no padding expected for a single example)
    attention_mask = torch.ones_like(inputs)

    # Greedy decoding for reproducible outputs
    with torch.inference_mode():
        outputs = model_for_infer.generate(
            input_ids=inputs,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # greedy for consistency
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
            use_cache=True,
        )

    # Return only the newly generated continuation
    return tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)

# ---------------------------------------------------------------
# 🔎 Quick smoke tests
# ---------------------------------------------------------------
print("\n=== Inference (FAST MODE) ===")
print(chat("Explain the difference between a shallow copy and a deep copy in Python with a tiny example."))
print("-" * 80)
print(chat("Write a short Python function that checks if a string is a valid palindrome, ignoring non-alphanumerics."))


✅ Saved adapters to: smollm2_dpo_rl_fast/adapters
✅ Saved tokenizer to: smollm2_dpo_rl_fast/tokenizer

=== Inference (FAST MODE) ===
In Python, a shallow copy is a copy of an object that is created from an existing object, but not from an object that is created from a copy of an existing object. This means that if you create a shallow copy of an object from an existing object, you are essentially creating a copy of the object that was created from the original object, but not from the original object itself.

Here's a simple example:

```python
# Create a shallow copy of an object from an existing object
my_object = my_original_object

# Create a copy of an object from an existing object
my
--------------------------------------------------------------------------------
Here's a Python function that checks if a string is a valid palindrome:

```python
def is_palindrome(s):
    if not s:
        return False
    
    s = s.lower()
    return s == s[::-1]

# Example usage:
print(is_palin