# SFT Training Loop

**Complete implementation with best practices**

## Training Pipeline Overview

```
1. Load pre-trained model and tokenizer
2. Prepare dataset with instruction formatting
3. Create data loader with proper collation
4. Setup optimizer and learning rate scheduler
5. Training loop with loss masking
6. Evaluation and checkpointing
7. Save fine-tuned model
```

In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm import tqdm
import numpy as np

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Dataset Class

In [2]:
ALPACA_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
{response}"""

class SFTDataset(Dataset):
    """Dataset for supervised fine-tuning."""
    
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Format with template
        formatted = ALPACA_TEMPLATE.format(
            instruction=item['instruction'],
            response=item['output']
        )
        
        # Find response start position (before tokenization)
        prompt = ALPACA_TEMPLATE.format(
            instruction=item['instruction'],
            response=''
        )
        
        # Tokenize
        full_tokens = self.tokenizer(
            formatted,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        prompt_tokens = self.tokenizer(
            prompt,
            max_length=self.max_length,
            truncation=True,
            return_tensors='pt'
        )
        
        response_start = prompt_tokens['input_ids'].shape[1]
        
        # Create labels with masking
        labels = full_tokens['input_ids'].clone().squeeze(0)
        labels[:response_start] = -100  # Mask prompt tokens
        
        # Also mask padding
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        return {
            'input_ids': full_tokens['input_ids'].squeeze(0),
            'attention_mask': full_tokens['attention_mask'].squeeze(0),
            'labels': labels
        }

## Training Configuration

In [3]:
from dataclasses import dataclass

@dataclass
class SFTConfig:
    """Configuration for SFT training."""
    model_name: str = "gpt2"
    max_length: int = 512
    batch_size: int = 4
    learning_rate: float = 2e-4
    num_epochs: int = 3
    warmup_steps: int = 100
    gradient_accumulation_steps: int = 4
    max_grad_norm: float = 1.0
    logging_steps: int = 10
    eval_steps: int = 100
    save_steps: int = 500
    output_dir: str = "./sft_output"

config = SFTConfig()
print("Training configuration:")
for k, v in vars(config).items():
    print(f"  {k}: {v}")

Training configuration:
  model_name: gpt2
  max_length: 512
  batch_size: 4
  learning_rate: 0.0002
  num_epochs: 3
  warmup_steps: 100
  gradient_accumulation_steps: 4
  max_grad_norm: 1.0
  logging_steps: 10
  eval_steps: 100
  save_steps: 500
  output_dir: ./sft_output


## The Training Loop

In [4]:
def train_sft(model, tokenizer, train_dataset, eval_dataset, config):
    """Complete SFT training loop."""
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=0
    )
    
    eval_loader = DataLoader(
        eval_dataset,
        batch_size=config.batch_size,
        shuffle=False
    )
    
    # Setup optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=0.01
    )
    
    # Learning rate scheduler
    total_steps = len(train_loader) * config.num_epochs // config.gradient_accumulation_steps
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=total_steps
    )
    
    model.train()
    global_step = 0
    best_eval_loss = float('inf')
    
    for epoch in range(config.num_epochs):
        epoch_loss = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
        
        for step, batch in enumerate(progress_bar):
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Forward pass
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels']
            )
            
            loss = outputs.loss / config.gradient_accumulation_steps
            
            # Backward pass
            loss.backward()
            
            epoch_loss += loss.item() * config.gradient_accumulation_steps
            
            # Update weights every gradient_accumulation_steps
            if (step + 1) % config.gradient_accumulation_steps == 0:
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(),
                    config.max_grad_norm
                )
                
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                
                global_step += 1
                
                # Logging
                if global_step % config.logging_steps == 0:
                    avg_loss = epoch_loss / (step + 1)
                    progress_bar.set_postfix({
                        'loss': f'{avg_loss:.4f}',
                        'lr': f'{scheduler.get_last_lr()[0]:.2e}'
                    })
        
        # End of epoch evaluation
        eval_loss = evaluate(model, eval_loader, device)
        print(f"\nEpoch {epoch+1} - Train Loss: {epoch_loss/len(train_loader):.4f}, Eval Loss: {eval_loss:.4f}")
        
        # Save best model
        if eval_loss < best_eval_loss:
            best_eval_loss = eval_loss
            model.save_pretrained(f"{config.output_dir}/best")
            tokenizer.save_pretrained(f"{config.output_dir}/best")
            print(f"Saved best model with eval loss: {eval_loss:.4f}")
    
    return model


def evaluate(model, eval_loader, device):
    """Evaluate model on validation set."""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in eval_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels']
            )
            total_loss += outputs.loss.item()
    
    model.train()
    return total_loss / len(eval_loader)

## Running Training

In [5]:
# Load model and tokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Set pad token
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id

model.to(device)

# Load dataset (using a small subset for demonstration)
raw_data = load_dataset("yahma/alpaca-cleaned", split="train")
raw_data = raw_data.select(range(1000))  # Small subset for demo

# Split into train/eval
train_size = int(0.9 * len(raw_data))
train_data = raw_data.select(range(train_size))
eval_data = raw_data.select(range(train_size, len(raw_data)))

# Create datasets
train_dataset = SFTDataset(train_data, tokenizer, max_length=256)
eval_dataset = SFTDataset(eval_data, tokenizer, max_length=256)

print(f"Train samples: {len(train_dataset)}")
print(f"Eval samples: {len(eval_dataset)}")

Train samples: 900
Eval samples: 100


In [6]:
# Train! (uncomment to run)
# config.num_epochs = 1  # Quick test
# model = train_sft(model, tokenizer, train_dataset, eval_dataset, config)

## Generating with the Fine-Tuned Model

In [7]:
def generate_response(model, tokenizer, instruction, max_new_tokens=100):
    """Generate a response for an instruction."""
    prompt = ALPACA_TEMPLATE.format(instruction=instruction, response='')
    
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.pad_token_id
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract just the response part
    response = response.split("### Response:")[-1].strip()
    
    return response

# Test generation
test_instruction = "Explain what machine learning is in simple terms."
print(f"Instruction: {test_instruction}")
print(f"Response: {generate_response(model, tokenizer, test_instruction)}")

Instruction: Explain what machine learning is in simple terms.


  attn_output = torch.nn.functional.scaled_dot_product_attention(
  attn_output = torch.nn.functional.scaled_dot_product_attention(


Response: How to answer the question.

### Method:

How to answer the question.

### Message:

A list of commands.

### Instruction:

Explain what machine learning is


## Next Steps

Now that we have a complete SFT training loop, let's learn about LoRA for efficient fine-tuning with fewer parameters.