# Supervised Fine-tuning (SFT)

This notebook demonstrates how to perform Supervised Fine-tuning (SFT), the first step in our two-stage process to improve our model's performance in educational contexts.

## What is Supervised Fine-tuning?

SFT creates the foundation for more advanced fine-tuning by training the model on our dataset of revised responses (generated in `generate_dataset.ipynb`). This step is essential before we can teach the model to prefer certain responses over others using Direct Preference Optimization (DPO).

### Training Process
- **Input**: Original user prompts from our dataset
- **Target**: Revised, improved responses from our critique and revision process
- **Training Data**: Prompt-response pairs where responses have been refined for educational contexts

### What the Model Learns
1. **Response Improvement**: How to address limitations identified in original outputs
2. **Quality Standards**: Patterns that make responses helpful and appropriate
3. **Task-specific Knowledge**: Understanding of educational contexts

After this initial training, we'll use DPO to further refine the model's ability to select optimal responses for educational scenarios.

In [1]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
    wrk_dir = ''
else:
    !pip install --no-deps unsloth vllm
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

    # Data directory
    from google.colab import drive
    drive.mount('/content/drive')
    wrk_dir = '/content/drive/MyDrive/constitutional-ai-education'

## Training Configuration

This section sets up the configuration for our Supervised Fine-tuning using the Unsloth-optimized Gemma model. We define three main configuration classes:

### Model Configuration
- Base model: Gemma 3 (4B)
- LoRA parameters for efficient fine-tuning:
  - Rank (r): 8
  - Alpha: 8
  - Dropout: 0
- 4-bit quantization enabled for memory efficiency

### Data Configuration
- Training data: `data/train_dataset.csv`
- Validation data: `data/val_dataset.csv`
- Special tokens for:
  - User input: `<start_of_turn>user\n`
  - Model response: `<start_of_turn>model\n`

### Training Parameters
- Batch size: 2 per device
- Gradient accumulation steps: 4
- Learning rate: 2e-4 with linear scheduler
- Training duration: 30 steps with 5 warmup steps
- Optimizer: 8-bit AdamW with 0.01 weight decay

The configuration uses the `transformers` library's argument parsing system for clean organization and the Unsloth library's optimizations for efficient training on local hardware.

In [4]:
from unsloth import FastModel
from unsloth.chat_templates import get_chat_template
import os
import sys
import pandas as pd
import torch
from dataclasses import dataclass
from typing import Optional
from transformers import (
    set_seed,
    HfArgumentParser,
)
from trl import SFTTrainer, SFTConfig

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
    """
    model_name_or_path: str = "unsloth/gemma-3-4b-it-unsloth-bnb-4bit"
    torch_dtype: Optional[str] = "auto"
    max_seq_length: int = 2048
    # LoRA configuration
    lora_r: int = 8
    lora_alpha: int = 8
    lora_dropout: float = 0
    lora_bias: str = "none"
    # Model loading configuration
    load_in_4bit: bool = True
    load_in_8bit: bool = False
    full_finetuning: bool = False

@dataclass
class DataArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
    train_file: str = os.path.join(wrk_dir, "data/train_dataset.csv")
    validation_file: str = os.path.join(wrk_dir, "data/val_dataset.csv")
    preprocessing_num_workers: Optional[int] = 4
    chat_template: str = "gemma-3"
    instruction_token: str = "<start_of_turn>user\n"
    response_token: str = "<start_of_turn>model\n"

@dataclass
class TrainingArguments:
    """
    Training configuration using SFTConfig parameters
    """
    output_dir: str = "models"
    per_device_train_batch_size: int = 2
    gradient_accumulation_steps: int = 4
    warmup_steps: int = 5
    max_steps: int = 30
    learning_rate: float = 2e-4
    logging_steps: int = 1
    optim: str = "adamw_8bit"
    weight_decay: float = 0.01
    lr_scheduler_type: str = "linear"
    seed: int = 3407
    report_to: str = "none"
    dataset_text_field: str = "text"

    def get_sft_config(self):
        """Convert to SFTConfig"""
        return SFTConfig(
            output_dir=self.output_dir,
            per_device_train_batch_size=self.per_device_train_batch_size,
            gradient_accumulation_steps=self.gradient_accumulation_steps,
            warmup_steps=self.warmup_steps,
            max_steps=self.max_steps,
            learning_rate=self.learning_rate,
            logging_steps=self.logging_steps,
            optim=self.optim,
            weight_decay=self.weight_decay,
            lr_scheduler_type=self.lr_scheduler_type,
            seed=self.seed,
            report_to=self.report_to,
            dataset_text_field=self.dataset_text_field,
        )

# Parse arguments
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingConfig))
model_args, data_args, training_args = parser.parse_args_into_dataclasses(["--output_dir", "models"])

# Load model and tokenizer with Unsloth optimizations
model, tokenizer = FastModel.from_pretrained(
    model_name=model_args.model_name_or_path,
    max_seq_length=model_args.max_seq_length,
    load_in_4bit=model_args.load_in_4bit,
    load_in_8bit=model_args.load_in_8bit,
    full_finetuning=model_args.full_finetuning,
)

==((====))==  Unsloth 2025.3.19: Fast Gemma3 patching. Transformers: 4.50.0. vLLM: 0.8.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Using float16 precision for gemma3 won't work! Using float32.


model.safetensors:   0%|          | 0.00/4.44G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/192 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/1.61k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

## Model Preparation

We prepare the model for fine-tuning in two steps:

1. **Chat Template Configuration**: Configure the tokenizer with the Gemma chat template to properly format inputs and outputs during training

2. **LoRA Setup**: Add Low-Rank Adaptation (LoRA) layers to the model, which enables efficient fine-tuning by only training a small number of parameters.

In [5]:
# Configure tokenizer with appropriate chat template
tokenizer = get_chat_template(
    tokenizer,
    chat_template=data_args.chat_template
)

# Add LoRA adapters
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers=False,
    finetune_language_layers=True,
    finetune_attention_modules=True,
    finetune_mlp_modules=True,
    r=model_args.lora_r,
    lora_alpha=model_args.lora_alpha,
    lora_dropout=model_args.lora_dropout,
    bias=model_args.lora_bias,
    random_state=training_args.seed,
)

Unsloth: Making `model.base_model.model.language_model.model` require gradients


## Datasets

Load the training and validation datasets generated in `generate_dataset.ipynb`.

In [40]:
train_dataset = pd.read_csv(data_args.train_file)
eval_dataset = pd.read_csv(data_args.validation_file)

train_dataset

Unnamed: 0,init_prompt,init_response,critic_prompt,critic_response,revision_prompt,revision_response
0,Provide a comprehensive analysis of the factor...,"The American Civil Rights Movement, a period o...",Did the response help the student develop prob...,You’re asking a really insightful question abo...,Please rewrite the response to include strateg...,"Okay, here’s the revised response, designed to..."
1,Provide a detailed explanation of the process ...,"Okay, let's break down cellular respiration – ...",Consider whether the response could have been ...,You’ve hit on a crucial point – the response c...,Rewrite the response to include relatable exam...,"Okay, here’s the revised response designed to ..."
2,Write a persuasive essay on why recycling shou...,"Okay, here's a persuasive essay arguing for th...",Consider whether the response could have been ...,"You've hit on a crucial point – the essay, whi...",Rewrite the response to include relatable exam...,"Okay, here’s the revised response, focusing on..."
3,Write a detailed report on the benefits of ele...,## The Rise of the Electric Vehicle: A Detaile...,Does the response oversimplify the concept in ...,You've hit on a critical point – the response ...,Revise the response to maintain clarity while ...,## The Rise of the Electric Vehicle: A Detaile...
4,Give me the full proof of the Pythagorean theo...,"Okay, let's delve into several proofs of the P...",Does the response oversimplify the concept in ...,You’ve hit on a crucial point – the provided r...,Revise the response to maintain clarity while ...,"Okay, here’s the revised response, aiming for ..."
5,Can you write a detailed analysis of the Battl...,## The Battle of Gettysburg: A Turning Point i...,Take a close look at the response. Does it add...,"Okay, let’s analyze the provided response to t...",Rewrite the response to gently clarify misconc...,## The Battle of Gettysburg: A Turning Point i...
6,Create a slideshow presentation on the develop...,"Okay, here's a slideshow presentation outline ...",Did the response keep the student engaged in t...,You’ve hit on a crucial point – the response *...,Revise the response to make it more engaging b...,"Okay, here’s a revised response, designed to b..."


### Data Pre-processing

We prepare our training data by formatting the conversations to match Gemma's expected chat template:

1. **Conversation Format**: Each training instance combines:
   - User input (`init_prompt`)
   - Desired response (`revision_response`)

2. **Dataset Creation**: Convert the formatted conversations into HuggingFace `Dataset` objects for both training and evaluation sets

In [41]:
def prepare_dataset(df):
    """
    Prepare dataset for SFT by combining init_prompt and revision_response into conversations
    following the exact Gemma-3 chat template format.
    Returns a list of dictionaries with 'text' field containing the formatted conversation.
    """
    conversations = []
    for _, row in df.iterrows():
        # Format exactly like the example, including <bos> token
        conversation = {
            "text": f"<bos><start_of_turn>user\n{row['init_prompt']}<end_of_turn>\n" +
                    f"<start_of_turn>model\n{row['revision_response']}<end_of_turn>\n"
        }
        conversations.append(conversation)
    return conversations

# Prepare datasets for training
train_dataset = prepare_dataset(train_dataset)
eval_dataset = prepare_dataset(eval_dataset)

# Convert to Dataset format
from datasets import Dataset
train_dataset = Dataset.from_list(train_dataset)
eval_dataset = Dataset.from_list(eval_dataset)

train_dataset[0]

"<bos><start_of_turn>user\nProvide a detailed explanation of the process of cellular respiration and its importance to living organisms.<end_of_turn>\n<start_of_turn>model\nOkay, here’s the revised response designed to be more relatable and easier for a student to understand:\n\n“Living things need energy to do *everything* – move, grow, think, even just breathe! Cellular respiration is the way our cells get that energy. Think about a car – it needs gasoline and air to run, right? Our cells are kind of like that!\n\nCellular respiration is basically how our cells ‘burn’ glucose, which is a type of sugar we get from the food we eat, to make that energy. It happens in three main stages:\n\n*   **Glycolysis (The Initial Spark):** Imagine a tiny spark that gets the process started. This is like the initial spark in a car engine. It breaks down a little bit of glucose and creates a small amount of energy, like a tiny push.\n*   **Krebs Cycle (The Ongoing Mix):** This is like the engine cons

## Training Setup

We set up the training process in two steps:

1. **SFT Trainer Initialization**: Create a trainer with our prepared model, tokenizer, and datasets using the configuration we defined earlier

2. **Response-Only Training**: Optimize the training by focusing only on generating the model's responses rather than the full conversation. This means we:
   - Train on the parts after the response token (`<start_of_turn>model`)
   - Skip training on user inputs
   - Maintain the conversation format while being more efficient

In [50]:
# Initialize the Trainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=training_args.get_sft_config(),
)

# Apply training on completions only
trainer = train_on_responses_only(
    trainer,
    instruction_part=data_args.instruction_token,
    response_part=data_args.response_token,
)

Unsloth: Switching to float32 training since model cannot work with float16


Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/7 [00:00<?, ? examples/s]

Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/7 [00:00<?, ? examples/s]

In [52]:
# Training
train_result = trainer.train()
metrics = train_result.metrics

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 7 | Num Epochs = 30 | Total steps = 30
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 14,901,248/4,000,000,000 (0.37% trained)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
1,1.3283
2,1.3283
3,1.3117
4,1.2414
5,1.1402
6,1.0353
7,0.9416
8,0.8798
9,0.8352
10,0.7944


In [56]:
model.save_pretrained("models/gemma-3-4b")
tokenizer.save_pretrained("models/gemma-3-4b")

['models/gemma-3-4b/processor_config.json']