# Fine-Tuning a Small Language Model for Tool Selection using SFT

## Introduction

This document explains the Python script `main.py`, which demonstrates how to fine-tune a relatively small pre-trained language model (LLM) for a specific task: **tool selection**.

**Supervised Fine-Tuning (SFT)** is a crucial technique for adapting large, general-purpose language models (LLMs) to perform well on specific downstream tasks. It leverages a pre-trained model, which has already learned vast amounts of information about language structure and world knowledge during its initial, often unsupervised, training phase. The "supervised" aspect means we provide the model with labeled examples, typically consisting of an input prompt and the desired output or completion. Unlike the broad pre-training stage, SFT uses these explicit input-output pairs to guide the model's learning towards a specific behavior or response format. During SFT, the model's internal parameters (weights) are further adjusted by training on this smaller, task-specific dataset. This process fine-tunes the model's capabilities, making it more adept at the target task without needing to retrain it from scratch, which would be computationally prohibitive. The goal is to minimize the difference between the model's generated output and the provided target completion for each example in the fine-tuning dataset. This allows us to steer the model's behavior, improve its accuracy for certain types of questions, or teach it to follow specific instructions or output formats. In this tutorial, for instance, we use SFT to teach the model a very specific skill: mapping a user's query to the name of the most appropriate helper function (like a calculator, weather API, or reminder tool), using examples where the prompt is the user query and the completion is the correct tool name.


## 📺 Watch the Tutorial

Prefer a video walkthrough? Check out the accompanying tutorial on YouTube:

[Fine-Tuning a Small LLM for Tool Selection (SFT)](https://youtu.be/Ain269vmeZg)

## Core Concepts

* **Language Model (LM)** – learns statistical patterns of language and can generate text.
* **Causal LM** – predicts the next token given everything before it (e.g. GPT-style models).
* **Fine-Tuning / SFT** – continues training a pre-trained model on a smaller, task-specific corpus under *supervision* (we know the desired answer for every prompt).
* **Tokenizer** – maps text ↔ integer IDs; every LM has its own tokenizer.
* **Hugging Face `transformers` / `trl`** – high-level libraries that spare us from boiler-plate training code.
* **Special Token** – a sentinel string (here `<my_tool_selection>`) added to the vocabulary to mark where the model should output the tool name.

## Script Breakdown

## 🚀 Environment Setup

Run the two code blocks below **once** at the very start of your notebook:
Install (or upgrade) all required libraries.

In [None]:
# ⇩ 1 | install / upgrade deps (comment-out after first run)
!pip install -q torch datasets transformers trl accelerate --upgrade

### 1  Imports

```python
import random           # utilities for reproducible shuffling / sampling
import torch            # underlying deep-learning framework (PyTorch)
from datasets import Dataset  # HF library for fast, memory-mapped datasets
from transformers import (
    AutoModelForCausalLM,   # generic loader for any causal LM
    AutoTokenizer,          # matching tokenizer loader
    pipeline,               # convenience wrapper for inference
)
from trl import SFTConfig, SFTTrainer  # higher-level SFT helpers
from gen_dataset import generate_raw_examples  # custom data generator
```
No configurable parameters here, but remember that **matching model + tokenizer IDs** are mandatory.

### 2  Data Generation & Preparation

```python
raw_examples = generate_raw_examples(10000)  # ➜ List[Tuple[str, str]]
special = "<my_tool_selection>"             # sentinel delimiter
```
* **`generate_raw_examples(n)`** (custom): returns **n** `(query, tool)` pairs. In this tutorial the function produces *synthetic* queries.  
  *Parameter* | *Type* | *Meaning*  
  `n` | `int` | how many examples to create.

```python
dataset = Dataset.from_list([
    {
        "prompt":     f"User: {q} {special}\nAssistant:",
        "completion": f" {tool}"
    }
    for q, tool in raw_examples
])
```
* **`Dataset.from_list(list_of_dicts)`** builds a HF Dataset.  Each dict must contain all the columns that the trainer will later reference (`prompt`, `completion`).

```python
splits = dataset.train_test_split(test_size=0.20, seed=42)
train_ds, eval_ds = splits["train"], splits["test"]
```
* **`test_size`** – fraction (or absolute count) reserved for validation.  
* **`seed`** – ensures deterministic shuffling so future reruns get the exact same split.

### 3  Model & Tokenizer Loading

```python
model_name = "HuggingFaceTB/SmolLM2-135M"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto"   # ← let HF dispatch layers across all GPUs / CPU
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token  # safe default
```
* **`device_map="auto"`** – Hugging Face will inspect available hardware (multiple GPUs / CPU RAM) and **split the weights** to fit in memory automatically. Good for big models on limited VRAM.
* **`pad_token`** – token used to right-pad sequences inside a batch so they all have the same length. Some older tokenizers miss it, so we reuse `eos_token`.

### 4  Adding the Special Delimiter

```python
tokenizer.add_special_tokens({"additional_special_tokens": [special]})
model.resize_token_embeddings(len(tokenizer))
```
* **`add_special_tokens`** – extends the vocabulary and returns how many were added.  The new **ID** is accessible via `tokenizer.convert_tokens_to_ids(special)`.
* **`resize_token_embeddings(new_size)`** – resizes the model’s embedding / output matrices so the extra token gets its own learnable vector.  Must be called **after** adding tokens.

### 5  `SFTConfig` – every hyper-parameter explained

```python
sft_args = SFTConfig(
    output_dir="./tool_choice_sft",      # where checkpoints / logs go
    num_train_epochs=2,                   # scan full dataset twice
    per_device_train_batch_size=4,        # effective batch = 4 × gradient_accumulation_steps
    gradient_accumulation_steps=1,        # accumulate ⧸N steps before weight update
    learning_rate=5e-5,                  # AdamW step size
    warmup_ratio=0.10,                   # first 10 % of total steps = LR warm-up
    logging_steps=5,                     # log train loss every 5 optimisation steps
    eval_strategy="steps",              # run evaluation every *N* steps (not epochs)
    eval_steps=20,                       # evaluate on `eval_ds` every 20 steps
    save_steps=50,                       # save checkpoint every 50 steps
    report_to=["none"],                 # disable WandB / TensorBoard
)
```
**Additional notes**
* `gradient_accumulation_steps` lets you simulate a larger batch without extra VRAM: gradients are accumulated locally and the optimiser runs only every *k* mini-batches.
* `learning_rate` pairs with the *AdamW* optimiser under the hood (the default for `SFTTrainer`).
* The **total number of optimiser steps** = `(train_examples / batch) × epochs / gradient_accumulation_steps` – warm-up ratio is applied over that count.

### 6  `SFTTrainer` initialisation parameters

```python
trainer = SFTTrainer(
    model=model,               # the LM we just loaded / resized
    args=sft_args,             # all hyper-params
    train_dataset=train_ds,    # 80 % split
    eval_dataset=eval_ds,      # 20 % split (never back-prop through)
    tokenizer=tokenizer,       # for smart batching & padding
)
```
* **`model`** – *must* be a `PreTrainedModel` subclass that supports generation.
* **`tokenizer`** – ensures the trainer uses the right pad / special tokens.
* **`train_dataset` / `eval_dataset`** – any PyTorch-compatible dataset; the trainer wraps them in `DataLoader`s with automatic collation.

### 7  Running the Training Loop

```python
trainer.train()  # heavy lifting happens here
```
Under the hood `SFTTrainer` performs:
1. Epoch → batch iteration, tokenisation & padding.
2. Forward pass → compute loss **only** on `completion` tokens (via an internal label-mask).
3. Back-prop & AdamW update respecting `gradient_accumulation_steps`.
4. Callbacks for logging, evaluation, checkpointing at the configured step intervals.

### 8  Saving artefacts

```python
model.save_pretrained("./tool_choice_final")
tokenizer.save_pretrained("./tool_choice_final")
```
* Saves both **weights** (`pytorch_model.bin`) and the **config** (`config.json`) so the model can be re-loaded with a single `from_pretrained()` call anywhere.

### 9  Inference `pipeline` parameters

```python
gen = pipeline(
    "text-generation",      # task selector – chooses correct pipeline class
    model="./tool_choice_final",  # path or HF-Hub ID
    tokenizer="./tool_choice_final",
    max_new_tokens=5,         # safety cap – tool names are short
    eos_token_id=tokenizer.eos_token_id,  # stop when model emits EOS
    do_sample=False,          # deterministic greedy decoding
)
```
| Argument | Effect |
| -------- | ------ |
| `model` / `tokenizer` | Can be local path or remote repo ID; they *must match each other* or token IDs will diverge. |
| `max_new_tokens` | Hard upper-bound on generated length; prevents runaway text. |
| `eos_token_id` | Allows the pipeline to stop early if EOS appears before hitting the length cap. |
| `do_sample` | `False` ➜ greedy decoding.  `True` would enable nucleus / temperature sampling for stochastic outputs (not desired for routers). |

### 🔍 Why *synthetic* data, and its trade-offs

Synthetic data is quick, cheap, and lets us demonstrate technique without exposing real user logs. But be mindful of its limitations:

* **Distribution shift.** If the phrasing of real queries differs from our synthetic ones, performance may drop in production.
* **Bias reinforcement.** If our generator under-represents certain language styles, the fine-tuned model will likewise under-perform for those users.
* **Evaluation realism.** Always test on *real* hold-out data before shipping.

A common compromise is **mixed-source datasets**: seed the model with real anonymised queries (after consent & PII scrubbing) and pad with synthetic ones for coverage.


### 📊 Evaluation beyond token accuracy

We logged `mean_token_accuracy`, but for a 6-class decision task we can use richer metrics:

| Metric | Why it matters |
| ------ | -------------- |
| **Exact-match accuracy** | “Did we pick the right tool?”—simple, interpretable. |
| **Confusion matrix** | Reveals systematic mix-ups (e.g., `search_web` vs `translate_text`). |
| **Macro-F1** | Balances precision & recall per class, helpful if class distribution is skewed. |
| **Calibration error** | Tells us if the softmax probabilities are reliable enough to gate fallback rules (“only call the tool if p > 0.8”). |

After training, run a small script to compute and visualise these—mistakes jump out immediately.


### 🏋️‍♀️ Classification head *vs.* next-token generation

Classification head is an alternative to next-token generation: attaching a **classification head** on top of the language model. Let’s do a quick side-by-side:

| Aspect | Next-token generation | Classification head |
| ------ | -------------------- | ------------------- |
| **Architecture change** | None – reuse the LM exactly as is. | Add a small feed-forward layer mapping hidden state → 6 logits. |
| **Training objective** | Cross-entropy over *tokens* (predict the tool name character-by-character). | Cross-entropy over *classes* (predict one of six tools). |
| **Speed** | Slower at inference (needs multiple decoding steps). | Single forward pass – faster. |
| **Confidence scores** | Harder to interpret (need to read log-probs of whole string). | Softmax directly gives per-tool probability. |
| **Flexibility** | Can generalise to unseen tools if vocabulary covers them. | Fixed to the predefined class set. |

For small closed-set routers, the classification head is usually the pragmatic choice. We used generation here to stay closer to core SFT mechanics and keep the script model-agnostic—but feel free to fork the repo and try the head-based variant!
