# Supervised Fine-tuning (SFT)

This notebook demonstrates how to perform Supervised Fine-tuning (SFT), the first step in our two-stage process to improve our model's performance in providing student support.

## What is Supervised Fine-tuning?

SFT is a crucial first step in model customization because it teaches the model to generate responses that align with specific desired behaviors. In our educational context, this means:

1. **Base Knowledge Adaptation**: We start with a pre-trained language model (Llama 3.2) that has broad knowledge but isn't specifically tuned for educational support
2. **Task-Specific Learning**: Through SFT, we teach the model to generate responses that embody effective teaching practices
3. **Foundation for Further Refinement**: This creates a foundation that will be further improved through Direct Preference Optimization (DPO)

### Our Application

In our case, we will be **fine-tuning the model on our dataset of revised responses** (generated in `generate_dataset.ipynb`). This dataset contains:

- **Input**: Original user prompts from our dataset
- **Target**: Revised, improved responses that demonstrate effective teaching practices like:
  - Encouraging critical thinking
  - Providing appropriate scaffolding
  - Fostering curiosity and exploration
  - Making connections to broader concepts

After this initial training, we'll use DPO to further refine the model's ability to select optimal responses for educational scenarios.

In [1]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os

if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
    !pip install --no-deps unsloth

    # Data directory
    from google.colab import drive
    drive.mount('/content/drive')
    wrk_dir = '/content/drive/MyDrive/constitutional-ai-education'

## Training Configuration

This section sets up the configuration for our Supervised Fine-tuning using the Unsloth-optimized Llama 3.2 model. We've made specific choices in our configuration to balance training efficiency with model performance:

### Model Configuration
- **Base Model**: Llama 3.2 (1B parameter version)
  - Why this model?
    - Smaller size (1B parameters) makes it practical for fine-tuning on limited hardware
    - Built on Meta's Llama architecture, known for good performance on instruction-following tasks
    - Optimized by Unsloth for efficient training
- **LoRA (Low-Rank Adaptation) Parameters**:
  - Rank (r): 8 - Determines the complexity of adaptations we can learn
  - Alpha: 8 - Scales the impact of our LoRA adaptations
  - Dropout: 0 - We disable dropout since we're doing limited training
  - Why LoRA? It allows efficient fine-tuning by modifying only a small number of parameters while maintaining model quality

### Memory Optimization
- 4-bit quantization enabled (bnb-4bit)
  - Why? Reduces memory usage by ~75% while maintaining most of the model's performance
  - Enables training on consumer-grade hardware
  - Particularly important for efficient fine-tuning of transformer models

### Data Configuration
- Training data: `data/train_dataset.csv`
- Validation data: `data/val_dataset.csv`
- Special tokens for structured conversations:
  - User input: `<start_header_id>user<end_header_id>`
  - Model response: `<start_header_id>assistant<end_header_id>`

### Training Parameters
These parameters are chosen to provide stable training while preventing overfitting:
- Batch size: 2 per device
- Gradient accumulation steps: 4
  - Effectively simulates a larger batch size of 8
- Learning rate: 2e-4 with linear scheduler
  - Conservative learning rate to prevent unstable training
- Training duration: 30 steps with 5 warmup steps
- Optimizer: 8-bit AdamW with 0.01 weight decay
  - Memory-efficient variant of AdamW

In [2]:
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template, train_on_responses_only
import os
import sys
import pandas as pd
import torch
from dataclasses import dataclass
from typing import Optional
from transformers import (
    set_seed,
    HfArgumentParser,
)
from trl import SFTTrainer, SFTConfig
from datasets import Dataset

## Training Configuration
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
    """
    model_name_or_path: str = "unsloth/Llama-3.2-1B-bnb-4bit"
    torch_dtype: Optional[str] = "auto"
    max_seq_length: int = 2048
    # LoRA configuration
    lora_r: int = 8
    lora_alpha: int = 8
    lora_dropout: float = 0
    lora_bias: str = "none"
    # Model loading configuration
    load_in_4bit: bool = True
    load_in_8bit: bool = False
    full_finetuning: bool = False

@dataclass
class DataArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
    train_file: str = os.path.join(wrk_dir, "data/train_dataset.csv")
    validation_file: str = os.path.join(wrk_dir, "data/val_dataset.csv")
    preprocessing_num_workers: Optional[int] = 4
    chat_template: str = "llama-3"
    instruction_token: str = '<|start_header_id|>user<|end_header_id|>\n\n'
    response_token: str = '<|start_header_id|>assistant<|end_header_id|>\n\n'

@dataclass
class TrainingArguments:
    """
    Training configuration using SFTConfig parameters
    """
    output_dir: str = os.path.join(wrk_dir,'models')
    per_device_train_batch_size: int = 2
    gradient_accumulation_steps: int = 4
    warmup_steps: int = 3
    max_steps: int = 50
    learning_rate: float = 2e-4
    logging_steps: int = 3
    optim: str = "adamw_8bit"
    weight_decay: float = 0.01
    lr_scheduler_type: str = "linear"
    seed: int = 3407
    report_to: str = "none"
    dataset_text_field: str = "text"

    def get_sft_config(self):
        """Convert to SFTConfig"""
        return SFTConfig(
            output_dir=self.output_dir,
            per_device_train_batch_size=self.per_device_train_batch_size,
            gradient_accumulation_steps=self.gradient_accumulation_steps,
            warmup_steps=self.warmup_steps,
            max_steps=self.max_steps,
            learning_rate=self.learning_rate,
            logging_steps=self.logging_steps,
            optim=self.optim,
            weight_decay=self.weight_decay,
            lr_scheduler_type=self.lr_scheduler_type,
            seed=self.seed,
            report_to=self.report_to,
            dataset_text_field=self.dataset_text_field,
        )

# Parse arguments
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses([])



🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Unsloth: Failed to patch Gemma3ForConditionalGeneration.
🦥 Unsloth Zoo will now patch everything to make training faster!


## Model Preparation

We prepare the model for fine-tuning through two critical steps that enable efficient and effective training:

### 1. Chat Template Configuration
- Configure the tokenizer with the Llama chat template
- Why? This ensures:
  - Consistent formatting of inputs and outputs
  - Proper handling of special tokens (`<start_header_id>`, `<end_header_id>`, `<eot_id>`)
  - Clear separation between user and assistant messages

### 2. LoRA Setup
- Add Low-Rank Adaptation (LoRA) layers to the model
- Why LoRA?
  - Memory Efficiency: Only trains a small number of parameters
  - Speed: Faster training than full fine-tuning
  - Adaptability: Can be easily combined with or removed from the base model
  - Quality: Maintains model performance while enabling specialization
  - Particularly effective for the 1B parameter size of our base model

In [3]:
# Load model and tokenizer with Unsloth optimizations
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_args.model_name_or_path,
    max_seq_length=model_args.max_seq_length,
    load_in_4bit=model_args.load_in_4bit,
    load_in_8bit=model_args.load_in_8bit,
    full_finetuning=model_args.full_finetuning,
)

# Configure tokenizer with appropriate chat template
tokenizer = get_chat_template(
    tokenizer,
    chat_template=data_args.chat_template
)

# Add LoRA adapters
model = FastLanguageModel.get_peft_model(
    model,
    r = model_args.lora_r,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = model_args.lora_alpha,
    lora_dropout=model_args.lora_dropout,
    use_gradient_checkpointing = "unsloth",
    bias=model_args.lora_bias,
    random_state=training_args.seed,
    use_rslora = False,
    loftq_config = None,
)

==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.51.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/1.03G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/230 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

Unsloth 2025.3.19 patched 16 layers with 16 QKV layers, 16 O layers and 16 MLP layers.


## Datasets and Pre-processing

### Loading the Datasets
We load our carefully curated training and validation datasets generated in `generate_dataset.ipynb`. These contain:
- Original prompts from students
- Expert-revised responses that demonstrate effective teaching practices

### Data Pre-processing
We prepare our training data by formatting the conversations to match Gemma's expected chat template:

1. **Conversation Format**: Each training instance combines:
   - User input (`init_prompt`)
   - Desired response (`revision_response`)
   - Why this format? It teaches the model to directly map student questions to ideal teaching responses

2. **Dataset Creation**:
   - Convert formatted conversations into HuggingFace `Dataset` objects
   - Separate training and evaluation sets to monitor performance
   - Why? This structure enables efficient training and accurate progress tracking

In [4]:
train_dataset = pd.read_csv(data_args.train_file)
eval_dataset = pd.read_csv(data_args.validation_file)

train_dataset

Unnamed: 0,init_prompt,init_response,critic_prompt,critic_response,revision_prompt,revision_response
0,Solve this equation and show me the answer: 3x...,"Okay, let's solve the equation 3x + 7 = 22:\n\...",Think about whether the response inspired curi...,"The response effectively solved the equation, ...",Rewrite the response to include prompts or sug...,"Okay, let’s solve the equation 3x + 7 = 22!\n\..."
1,Provide a comprehensive analysis of the factor...,## The Complex Roots of the American Civil Rig...,Take a close look at the response. Does it add...,"The response, while comprehensive, doesn’t exp...",Rewrite the response to gently clarify misconc...,"Okay, let’s unpack the roots of the American C..."
2,Provide a detailed explanation of the process ...,## Cellular Respiration: Fueling Life's Proces...,What’s the strongest argument you could make t...,The most significant over-directiveness in the...,Revise the response to provide enough guidance...,"Okay, let’s break down cellular respiration. I..."
3,Write a persuasive essay on why recycling shou...,"Okay, here's a persuasive essay arguing for th...",What’s the strongest argument you could make t...,The most significant issue with the previous r...,Revise the response to provide enough guidance...,"Okay, let’s shift the approach to foster indep..."
4,Write a detailed report on the benefits of ele...,## The Rise of Electric Vehicles: A Detailed R...,Take a close look at the response. Does it add...,The previous response does a decent job of out...,Rewrite the response to gently clarify misconc...,## The Rise of Electric Vehicles: A Detailed R...
...,...,...,...,...,...,...
286,Create a slideshow presentation on the life cy...,"Okay, here's a slideshow presentation outline ...",Did the response encourage the student to anal...,The previous response primarily focused on del...,Revise the response to include prompts or ques...,"Okay, here’s a revised response designed to fo..."
287,List all the planets in our solar system along...,"Okay, here's a list of all the planets in our ...",Take a close look at the response. Does it add...,The previous response primarily presents factu...,Rewrite the response to gently clarify misconc...,"Okay, here’s a revised response designed to ge..."
288,Give me the full solution to this trigonometry...,"Okay, let's break down this problem step-by-st...",Consider whether the response could have been ...,The previous response was a bit dry and purely...,Rewrite the response to include relatable exam...,"Okay, let's find the value of cos(60°) + tan(4..."
289,Solve this algebra problem and provide the ans...,Here's how to solve the equation 7x - 4 = 3x +...,Consider whether the response could have been ...,The previous response was a bit dry and purely...,Rewrite the response to include relatable exam...,"Let's solve this together! This equation, 7x -..."


### Data Pre-processing

We prepare our training data by formatting the conversations to match Gemma's expected chat template:

1. **Conversation Format**: Each training instance combines:
   - User input (`init_prompt`)
   - Desired response (`revision_response`)

2. **Dataset Creation**: Convert the formatted conversations into HuggingFace `Dataset` objects for both training and evaluation sets

In [5]:
### Data Pre-processing
def prepare_dataset(df):
    """
    Prepare dataset for SFT by combining init_prompt and revision_response into conversations
    following the exact Llama-3 chat template format.
    Returns a list of dictionaries with 'text' field containing the formatted conversation.
    """
    conversations = []
    for _, row in df.iterrows():
        # Ensure strings and handle potential NaN values
        init_prompt = str(row['init_prompt']) if pd.notna(row['init_prompt']) else ""
        revision_response = str(row['revision_response']) if pd.notna(row['revision_response']) else ""

        # Format conversation using Llama-3 template
        conversation = {
            "text": f"<|start_header_id|>user<|end_header_id|>\n\n{init_prompt}<|eot_id|>" +
                   f"<|start_header_id|>assistant<|end_header_id|>\n\n{revision_response}<|eot_id|>"
        }
        conversations.append(conversation)
    return conversations

# Prepare datasets for training
train_dataset = prepare_dataset(train_dataset)
eval_dataset = prepare_dataset(eval_dataset)

# Convert to Dataset format
train_dataset = Dataset.from_list(train_dataset)
eval_dataset = Dataset.from_list(eval_dataset)

import pprint

pprint.pprint("Sample conversation:\n" + train_dataset[0]['text'])
pprint.pprint(f"Number of training examples: {len(train_dataset)}")

('Sample conversation:\n'
 '<|start_header_id|>user<|end_header_id|>\n'
 '\n'
 'Solve this equation and show me the answer: 3x + 7 = '
 '22.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n'
 '\n'
 'Okay, let’s solve the equation 3x + 7 = 22!\n'
 '\n'
 'First, we need to get ‘x’ by itself. Let’s subtract 7 from both sides – what '
 'do you think happens when we do that? (Pause briefly to let the student '
 'respond).  That gives us 3x = 15. Now, what’s the next step to find out what '
 '‘x’ actually *is*?  Think about how many times 3 goes into 15.  Could you '
 'write that out? (Encourage them to show their work).  You’ll find that x '
 'equals 5!\n'
 '\n'
 'Now, that’s the answer, but let’s check it out.  Go back to the original '
 'equation, 3x + 7 = 22, and plug in 5 for ‘x’. Does it work? (Pause for '
 'student to verify).  Great!\n'
 '\n'
 'Here are a few things to think about: What if the number on the right side '
 'of the equation (the ‘22’) was different? How would th

## Training Setup

We set up the training process in two steps:

1. **SFT Trainer Initialization**: Create a trainer with our prepared model, tokenizer, and datasets using the configuration we defined earlier

2. **Response-Only Training**: Optimize the training by focusing only on generating the model's responses rather than the full conversation. This means we:
   - Train on the parts after the response token (`<start_of_turn>model`)
   - Skip training on user inputs
   - Maintain the conversation format while being more efficient

In [6]:
# Initialize the Trainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=training_args.get_sft_config(),
)

# Apply training on completions only
trainer = train_on_responses_only(
    trainer,
    instruction_part=data_args.instruction_token,
    response_part=data_args.response_token,
)

Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/291 [00:00<?, ? examples/s]

Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/37 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/291 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/37 [00:00<?, ? examples/s]

## Training Process

Our training process is designed to efficiently teach the model while preventing overfitting:

### Batch Processing
- Small batches (size 2) with gradient accumulation (4 steps)
- Why? This approach:
  - Maintains stable training
  - Maximizes GPU memory efficiency
  - Effectively simulates larger batch sizes

### Learning Rate Schedule
- Linear warmup followed by linear decay
- Why? This helps:
  - Stabilize initial training
  - Prevent overshooting optimal parameters
  - Gradually fine-tune final adjustments

### Monitoring
- Loss tracking
- Why? Helps us:
  - Detect potential issues early
  - Ensure we're improving teaching capability
  - Prevent overfitting to training data

In [7]:
# Training
train_result = trainer.train()
metrics = train_result.metrics

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 291 | Num Epochs = 2 | Total steps = 50
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 5,636,096/1,000,000,000 (0.56% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
3,2.0025
6,2.0068
9,2.0446
12,1.9765
15,1.8871
18,1.9794
21,1.9453
24,1.9138
27,1.7445
30,1.8519


In [8]:
model.save_pretrained(os.path.join(wrk_dir, "models/Llama-3.2-1B-sft-edu"))
tokenizer.save_pretrained(os.path.join(wrk_dir, "models/Llama-3.2-1B-sft-edu"))

('/content/drive/MyDrive/constitutional-ai-education/models/Llama-3.2-1B-sft-edu/tokenizer_config.json',
 '/content/drive/MyDrive/constitutional-ai-education/models/Llama-3.2-1B-sft-edu/special_tokens_map.json',
 '/content/drive/MyDrive/constitutional-ai-education/models/Llama-3.2-1B-sft-edu/tokenizer.json')

In [9]:
tokenizer.decode(trainer.train_dataset[100]["input_ids"])

"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nCan you design a logo for my school's robotics club?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nOkay, let’s design a logo for your school’s robotics club! That’s fantastic – a great way to visually represent your team’s spirit. I want to make sure we’re on the same page about how logos work, so let’s talk through some ideas and clarify a few things.\n\nWe’ll explore three distinct concepts, and I’ll explain the thinking behind each one. Importantly, logo design isn’t just about making something “cool”; it’s about communicating your club’s identity – what you’re about.\n\n**Concept 1: The Dynamic Gear**\n\n* **Image:** A stylized, slightly angled gear. The gear isn't perfectly symmetrical, suggesting movement and energy. Within the gear, you see a simplified silhouette of a robot arm (or a stylized ‘R’ shape).\n* **Colors:**\n    * **Primary:** Electric Blue (#007bff) – Blue is often associated with intelligenc

In [10]:
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")

"                      Okay, let’s design a logo for your school’s robotics club! That’s fantastic – a great way to visually represent your team’s spirit. I want to make sure we’re on the same page about how logos work, so let’s talk through some ideas and clarify a few things.\n\nWe’ll explore three distinct concepts, and I’ll explain the thinking behind each one. Importantly, logo design isn’t just about making something “cool”; it’s about communicating your club’s identity – what you’re about.\n\n**Concept 1: The Dynamic Gear**\n\n* **Image:** A stylized, slightly angled gear. The gear isn't perfectly symmetrical, suggesting movement and energy. Within the gear, you see a simplified silhouette of a robot arm (or a stylized ‘R’ shape).\n* **Colors:**\n    * **Primary:** Electric Blue (#007bff) – Blue is often associated with intelligence, technology, and trust – really good choices for a robotics club! It's a color that feels reliable and innovative.\n    * **Secondary:** Gray (#4950