# Week 5.2: Instruction Fine-Tuning with PyTorch

**Resource Required**: GPU with at least 16GB VRAM (24GB recommended)

## Objective:
* Understand the concept of instruction fine-tuning (IFT)
* Learn how to prepare data for instruction fine-tuning
* Implement a pure PyTorch training loop for fine-tuning a pre-trained language model
* Monitor training metrics and evaluate the fine-tuned model

💡 **NOTE**: In this notebook, we'll implement instruction fine-tuning using pure PyTorch. Later in the course, we'll demonstrate how to do the same with HuggingFace libraries, which can simplify the process significantly.

## Introduction: What is Instruction Fine-Tuning?

Instruction Fine-Tuning (IFT) is a critical technique that transforms a base language model into a model that can follow human instructions. This process is what enables models like ChatGPT, Claude, and Gemini to respond helpfully to user queries.

### Key concepts:

1. **Base Models vs. Instruction-Tuned Models**:
   - **Base models** are pre-trained on a large corpus of text, but only learn to predict the next token in a sequence
   - **Instruction-tuned models** are specifically trained to understand and follow instructions

2. **Instruction Tuning Process**:
   - Start with a pre-trained language model
   - Fine-tune using a dataset of instruction-response pairs
   - Train the model to generate helpful responses to instructions

3. **Benefits**:
   - Makes language models more helpful, honest, and harmless
   - Enables task-specific capabilities without massive parameter updates
   - Reduces the likelihood of generating harmful or nonsensical content

## Setup and Dependencies

Let's install the necessary packages for this notebook:

In [None]:
# Install required packages
!pip install transformers==4.40.1 torch==2.2.0 tqdm==4.66.1

Import the libraries we'll need:

In [None]:
import json
import random
import torch
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    default_data_collator,
    get_cosine_schedule_with_warmup,
    GenerationConfig
)
from tqdm.auto import tqdm
from types import SimpleNamespace
from pathlib import Path

Check if CUDA is available and set the device:

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set seed for reproducibility
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

## 1. Prepare the Instruction Dataset

We'll use the Alpaca dataset, which contains instructions, inputs, and outputs generated using GPT-4. This dataset is specifically designed for instruction tuning.

In [None]:
# The dataset is already downloaded at the following path
dataset_file = "data/alpaca_gpt4_data.json"

# Load the dataset
with open(dataset_file, "r") as f:
    alpaca = json.load(f)

In [None]:
# Inspect the dataset structure
print(f"Total examples: {len(alpaca)}")
print("\nExample data point:")
print(json.dumps(alpaca[0], indent=2))

### Split the Dataset into Train and Evaluation Sets

In [None]:
# Shuffle the dataset
random.seed(SEED)
random.shuffle(alpaca)

# Split into train and eval
train_dataset = alpaca[:-1000]  # Use all but the last 1000 examples for training
eval_dataset = alpaca[-1000:]   # Use the last 1000 examples for evaluation

print(f"Training examples: {len(train_dataset)}")
print(f"Evaluation examples: {len(eval_dataset)}")

### Format the Prompts for Instruction Fine-Tuning

We need to format our data in a specific way for instruction tuning. We'll create prompt templates for examples with and without additional input context.

In [None]:
def prompt_no_input(row):
    """Format a prompt for examples without additional input."""
    return ("Below is an instruction that describes a task. "
            "Write a response that appropriately completes the request.\n\n"
            "### Instruction:\n{instruction}\n\n### Response:\n").format_map(row)

def prompt_with_input(row):
    """Format a prompt for examples with additional input context."""
    return ("Below is an instruction that describes a task, paired with an input that provides further context. "
            "Write a response that appropriately completes the request.\n\n"
            "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n").format_map(row)

def create_alpaca_prompt(row):
    """Create a prompt based on whether the example has input or not."""
    return prompt_no_input(row) if row["input"] == "" else prompt_with_input(row)

In [None]:
# Example of a prompt without input
example_no_input = alpaca[0]  # This example doesn't have input
print("Example prompt without input:")
print(prompt_no_input(example_no_input))

# Find an example with input
example_with_input = None
for example in alpaca:
    if example["input"] != "":
        example_with_input = example
        break

if example_with_input:
    print("\nExample prompt with input:")
    print(prompt_with_input(example_with_input))

### Prepare the Dataset for Training

We need to:
1. Format all examples with our prompt templates
2. Append EOS token to all outputs
3. Combine prompts and outputs for training

In [None]:
# Generate prompts for all examples
train_prompts = [create_alpaca_prompt(row) for row in train_dataset]
eval_prompts = [create_alpaca_prompt(row) for row in eval_dataset]

# Helper function to add EOS token to outputs
def pad_eos(dataset):
    EOS_TOKEN = "</s>"  # End of sequence token for LLaMA models
    return [f"{row['output']}{EOS_TOKEN}" for row in dataset]

# Add EOS token to outputs
train_outputs = pad_eos(train_dataset)
eval_outputs = pad_eos(eval_dataset)

# Combine prompts and outputs for training
train_examples = [{"prompt": p, "output": o, "combined": p + o} 
                  for p, o in zip(train_prompts, train_outputs)]
eval_examples = [{"prompt": p, "output": o, "combined": p + o} 
                 for p, o in zip(eval_prompts, eval_outputs)]

In [None]:
# Show an example of the complete instruction-response pair
print("Complete instruction-response example:")
print(train_examples[0]["combined"])

## 2. Tokenization and Dataset Preparation

We'll now load the tokenizer and convert our text data into tokenized inputs for the model.

In [None]:
# Define the model to use - small enough to run on most GPUs
model_id = 'meta-llama/Llama-2-7b-hf'

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token  # Set padding token to EOS token

### Implement Data Packing

To train more efficiently, we'll pack multiple short examples into longer sequences. This increases training efficiency by reducing padding and allowing the model to learn from more examples per batch.

In [None]:
# Define maximum sequence length
max_sequence_len = 1024

def pack_examples(dataset, max_seq_len=max_sequence_len):
    """Pack multiple examples into fixed-length sequences."""
    # Tokenize all examples
    tokenized_inputs = tokenizer([ex["combined"] for ex in dataset])["input_ids"]
    
    # Concatenate all tokenized inputs
    all_token_ids = []
    for tokenized_input in tokenized_inputs:
        all_token_ids.extend(tokenized_input)
    
    print(f"Total number of tokens: {len(all_token_ids)}")
    
    # Pack tokens into fixed-length sequences
    packed_dataset = []
    for i in range(0, len(all_token_ids), max_seq_len+1):
        input_ids = all_token_ids[i : i + max_seq_len+1]
        if len(input_ids) == (max_seq_len+1):
            # Create input_ids and labels (shifted by 1 for next-token prediction)
            packed_dataset.append({
                "input_ids": input_ids[:-1], 
                "labels": input_ids[1:]
            })
    
    return packed_dataset

In [None]:
# Pack the datasets
train_ds_packed = pack_examples(train_examples)
eval_ds_packed = pack_examples(eval_examples)

print(f"Number of packed training sequences: {len(train_ds_packed)}")
print(f"Number of packed evaluation sequences: {len(eval_ds_packed)}")

### Create DataLoaders

DataLoaders provide batched data during training.

In [None]:
batch_size = 8  # Adjust based on your GPU memory

train_dataloader = DataLoader(
    train_ds_packed,
    batch_size=batch_size,
    collate_fn=default_data_collator,
    shuffle=True,
)

eval_dataloader = DataLoader(
    eval_ds_packed,
    batch_size=batch_size,
    collate_fn=default_data_collator,
    shuffle=False,
)

Let's inspect a batch to verify our data preparation:

In [None]:
# Examine a batch
sample_batch = next(iter(train_dataloader))
print(f"Batch keys: {sample_batch.keys()}")
print(f"Input shape: {sample_batch['input_ids'].shape}")
print(f"Labels shape: {sample_batch['labels'].shape}")

# Decode the first example in the batch to see what it looks like
print("\nSample input text (first 250 chars):")
print(tokenizer.decode(sample_batch["input_ids"][0])[:250])

## 3. Model Setup and Training Configuration

Now we'll load the pre-trained model and set up our training configuration.

In [None]:
# Define training configuration
config = SimpleNamespace(
    model_id='meta-llama/Llama-2-7b-hf',
    precision="bf16",         # bf16 is faster and uses less memory than fp32
    n_freeze=24,              # Number of layers to freeze (out of 32 for LLaMA-7B)
    learning_rate=2e-4,
    n_eval_samples=5,         # Number of samples to generate during evaluation
    max_seq_len=max_sequence_len,
    epochs=3,
    gradient_accumulation_steps=2,  # Simulate larger batch sizes
    batch_size=batch_size,
    gradient_checkpointing=True,    # Save memory at the cost of speed
    freeze_embeddings=True,         # Freeze the embedding layer
    seed=SEED,
)

# Calculate total training steps
config.total_train_steps = config.epochs * len(train_dataloader) // config.gradient_accumulation_steps
print(f"Total training steps: {config.total_train_steps}")

### Load the Pre-trained Model

In [None]:
# Load the model
model = AutoModelForCausalLM.from_pretrained(
    config.model_id,
    device_map="auto",  # Automatically determine best device mapping
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    use_cache=False,  # Disable KV cache for training
)

In [None]:
# Helper function to count parameters
def param_count(model):
    total_params = sum(p.numel() for p in model.parameters()) / 1_000_000
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
    print(f"Total params: {total_params:.2f}M, Trainable: {trainable_params:.2f}M")
    return total_params, trainable_params

# Count parameters before freezing
print("Parameter count before freezing:")
params, trainable_params = param_count(model)

### Freeze Parts of the Model

To reduce the computational and memory requirements, we'll freeze most of the model and only fine-tune the last few layers.

In [None]:
# Freeze all parameters first
for param in model.parameters():
    param.requires_grad = False
    
# Unfreeze the output layer (LM head)
for param in model.lm_head.parameters():
    param.requires_grad = True
    
# Unfreeze the last N transformer layers
for param in model.model.layers[config.n_freeze:].parameters():
    param.requires_grad = True

# Freeze embeddings to save memory
if config.freeze_embeddings:
    model.model.embed_tokens.weight.requires_grad_(False)

# Enable gradient checkpointing to save memory
if config.gradient_checkpointing:
    model.gradient_checkpointing_enable()

# Count parameters after freezing
print("Parameter count after freezing:")
params, trainable_params = param_count(model)

### Setup Optimizer and Learning Rate Scheduler

In [None]:
# Set up the optimizer
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=config.learning_rate, 
    betas=(0.9, 0.99), 
    eps=1e-5
)

# Set up the learning rate scheduler
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_training_steps=config.total_train_steps,
    num_warmup_steps=config.total_train_steps // 10,  # 10% of steps for warmup
)

In [None]:
# Define the loss function
def loss_fn(logits, labels):
    """Cross-entropy loss for next token prediction."""
    return torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1))

## 4. Model Evaluation Utilities

Let's create functions to evaluate our model during training.

In [None]:
# Setup generation configuration
gen_config = GenerationConfig.from_pretrained(config.model_id)
gen_config.max_new_tokens = 256
gen_config.temperature = 0.7
gen_config.top_p = 0.9
gen_config.do_sample = True

In [None]:
def generate_response(prompt, max_new_tokens=256):
    """Generate a response from the model for a given prompt."""
    # Tokenize the prompt
    tokenized_prompt = tokenizer(prompt, return_tensors='pt')['input_ids'].to(device)
    
    # Generate a response
    with torch.inference_mode():
        output = model.generate(
            tokenized_prompt, 
            max_new_tokens=max_new_tokens, 
            generation_config=gen_config
        )
    
    # Decode and return only the new tokens (not the prompt)
    return tokenizer.decode(output[0][len(tokenized_prompt[0]):], skip_special_tokens=True)

In [None]:
# Helper function to move batches to the device
def to_device(batch, device):
    """Move a batch of tensors to the specified device."""
    return {k: v.to(device) for k, v in batch.items()}

In [None]:
# Create a simple token-level accuracy metric
class TokenAccuracy:
    """Track token-level prediction accuracy."""
    def __init__(self):
        self.count = 0
        self.correct = 0.0
    
    def update(self, logits, labels):
        """Update accuracy with predictions from a batch."""
        predictions = logits.argmax(dim=-1).view(-1).cpu()
        labels = labels.view(-1).cpu()
        
        # Only consider non-padding tokens
        mask = labels != -100
        filtered_predictions = predictions[mask]
        filtered_labels = labels[mask]
        
        correct = (filtered_predictions == filtered_labels).sum()
        self.count += len(filtered_labels)
        self.correct += correct
        
        # Return batch accuracy
        return correct.item() / len(filtered_labels) if len(filtered_labels) > 0 else 0.0
    
    def compute(self):
        """Compute the overall accuracy."""
        return self.correct / self.count if self.count > 0 else 0.0

In [None]:
@torch.no_grad()
def evaluate_model(model, dataloader, device, num_samples=5):
    """Evaluate the model on the validation set and generate sample outputs."""
    model.eval()
    eval_accuracy = TokenAccuracy()
    total_loss = 0.0
    num_batches = 0
    
    # Compute loss and accuracy on validation set
    for batch in tqdm(dataloader, desc="Evaluating"):
        batch = to_device(batch, device)
        with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
            outputs = model(**batch)
            loss = loss_fn(outputs.logits, batch["labels"])
        
        total_loss += loss.item()
        eval_accuracy.update(outputs.logits, batch["labels"])
        num_batches += 1
    
    # Generate sample outputs
    samples = []
    for i in range(min(num_samples, len(eval_dataset))):
        prompt = eval_dataset[i]["prompt"]
        target = eval_dataset[i]["output"]
        generated = generate_response(prompt)
        samples.append({
            "prompt": prompt,
            "target": target,
            "generated": generated
        })
    
    # Return metrics and samples
    metrics = {
        "eval_loss": total_loss / num_batches,
        "eval_accuracy": eval_accuracy.compute()
    }
    
    # Set model back to training mode
    model.train()
    
    return metrics, samples

Let's test our model's initial generation capabilities before fine-tuning:

In [None]:
# Check model's initial generation (before fine-tuning)
test_prompt = eval_dataset[0]["prompt"]
print("Test prompt:")
print(test_prompt)

print("\nGenerated response (before fine-tuning):")
test_response = generate_response(test_prompt)
print(test_response)

## 5. Training Loop Implementation

Now we'll implement the PyTorch training loop for instruction fine-tuning.

In [None]:
def train_model(model, train_dataloader, eval_dataloader, optimizer, scheduler, config, device):
    """Train the model using a PyTorch training loop."""
    # Initialize metrics tracking
    accuracy_tracker = TokenAccuracy()
    best_eval_loss = float('inf')
    train_step = 0
    
    # Dictionary to store metrics
    metrics_history = {
        "train_loss": [],
        "train_accuracy": [],
        "eval_loss": [],
        "eval_accuracy": [],
        "learning_rate": []
    }
    
    # Training loop
    model.train()
    print("Starting training...")
    for epoch in range(config.epochs):
        # Training phase
        print(f"\nEpoch {epoch+1}/{config.epochs}")
        epoch_loss = 0.0
        num_batches = 0
        
        # Process batches
        for step, batch in enumerate(tqdm(train_dataloader, desc=f"Training epoch {epoch+1}")):
            # Move batch to device
            batch = to_device(batch, device)
            
            # Forward pass with mixed precision
            with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
                outputs = model(**batch)
                loss = loss_fn(outputs.logits, batch["labels"]) / config.gradient_accumulation_steps
            
            # Backward pass
            loss.backward()
            
            # Update weights after accumulating gradients
            if (step + 1) % config.gradient_accumulation_steps == 0:
                # Update parameters
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)
                
                # Update metrics
                batch_loss = loss.item() * config.gradient_accumulation_steps
                epoch_loss += batch_loss
                batch_accuracy = accuracy_tracker.update(outputs.logits, batch["labels"])
                
                # Print metrics
                if train_step % 50 == 0:  # Print every 50 steps
                    print(f"Step {train_step}: Loss = {batch_loss:.4f}, Accuracy = {batch_accuracy:.4f}, LR = {scheduler.get_last_lr()[0]:.8f}")
                
                # Store metrics
                metrics_history["train_loss"].append(batch_loss)
                metrics_history["train_accuracy"].append(batch_accuracy)
                metrics_history["learning_rate"].append(scheduler.get_last_lr()[0])
                
                train_step += 1
                num_batches += 1
        
        # Compute epoch metrics
        epoch_loss /= num_batches if num_batches > 0 else 1
        epoch_accuracy = accuracy_tracker.compute()
        print(f"Epoch {epoch+1} completed: Loss = {epoch_loss:.4f}, Accuracy = {epoch_accuracy:.4f}")
        
        # Evaluation phase
        print("\nRunning evaluation...")
        eval_metrics, samples = evaluate_model(model, eval_dataloader, device, config.n_eval_samples)
        
        # Store evaluation metrics
        metrics_history["eval_loss"].append(eval_metrics["eval_loss"])
        metrics_history["eval_accuracy"].append(eval_metrics["eval_accuracy"])
        
        # Print evaluation metrics
        print(f"Evaluation: Loss = {eval_metrics['eval_loss']:.4f}, Accuracy = {eval_metrics['eval_accuracy']:.4f}")
        
        # Print sample generations
        print("\nSample generations:")
        for i, sample in enumerate(samples):
            print(f"\nSample {i+1}:")
            print(f"Prompt: {sample['prompt'][:100]}...")
            print(f"Generated: {sample['generated'][:100]}...")
        
        # Save best model
        if eval_metrics["eval_loss"] < best_eval_loss:
            best_eval_loss = eval_metrics["eval_loss"]
            print(f"\nNew best model found! Eval loss: {best_eval_loss:.4f}")
            
            # We're not saving the model in this notebook to save space
            # But you can uncomment the following lines to save the model
            # model_save_path = Path(f"models/alpaca_ft_epoch_{epoch+1}")
            # model_save_path.mkdir(parents=True, exist_ok=True)
            # model.save_pretrained(model_save_path, safe_serialization=True)
            # tokenizer.save_pretrained(model_save_path)
    
    return model, metrics_history

## 6. Run Training

Now let's run the training loop to fine-tune our model on the Alpaca dataset.

In [None]:
# Start the training process
model, metrics_history = train_model(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    optimizer=optimizer,
    scheduler=scheduler,
    config=config,
    device=device
)

## 7. Evaluate the Fine-tuned Model

Let's see how our fine-tuned model performs on some test examples.

In [None]:
# Test the fine-tuned model on a few examples
print("Testing fine-tuned model:\n")
for i in range(5):
    # Get a random example from the evaluation set
    example = random.choice(eval_dataset)
    prompt = example["prompt"]
    target = example["output"]
    
    # Generate a response
    generated = generate_response(prompt)
    
    # Print the results
    print(f"Example {i+1}:")
    print(f"Prompt: {prompt}")
    print(f"\nTarget response: {target}")
    print(f"\nGenerated response: {generated}")
    print("\n" + "-"*80 + "\n")

## 8. Plot Training Metrics

Let's visualize the training progress.

In [None]:
# Try to import matplotlib for plotting
try:
    import matplotlib.pyplot as plt
    
    # Plot training loss
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(metrics_history["train_loss"])
    plt.title("Training Loss")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    
    # Plot training and evaluation accuracy
    plt.subplot(1, 2, 2)
    plt.plot(metrics_history["train_accuracy"], label="Train")
    
    # Add evaluation accuracy points
    eval_x = [i * (len(metrics_history["train_accuracy"]) // len(metrics_history["eval_accuracy"])) for i in range(len(metrics_history["eval_accuracy"]))]
    plt.plot(eval_x, metrics_history["eval_accuracy"], 'o-', label="Eval")
    
    plt.title("Accuracy")
    plt.xlabel("Step")
    plt.ylabel("Accuracy")
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Plot learning rate schedule
    plt.figure(figsize=(10, 4))
    plt.plot(metrics_history["learning_rate"])
    plt.title("Learning Rate Schedule")
    plt.xlabel("Step")
    plt.ylabel("Learning Rate")
    plt.show()
    
except ImportError:
    print("Matplotlib not installed. Skipping plots.")
    print("\nFinal metrics:")
    print(f"Training loss: {metrics_history['train_loss'][-1]:.4f}")
    print(f"Training accuracy: {metrics_history['train_accuracy'][-1]:.4f}")
    print(f"Evaluation loss: {metrics_history['eval_loss'][-1]:.4f}")
    print(f"Evaluation accuracy: {metrics_history['eval_accuracy'][-1]:.4f}")

## 9. Conclusion and Next Steps

In this notebook, we've implemented instruction fine-tuning using pure PyTorch loops. We've covered:

1. **Data Preparation**: Formatting instruction-response pairs for training
2. **Efficient Training**: Using techniques like gradient accumulation, mixed precision, and parameter freezing
3. **Model Evaluation**: Tracking metrics and sample generations during training

### Key Insights:

- Instruction fine-tuning can dramatically improve a model's ability to follow directions and generate helpful responses
- Training only a subset of model parameters (LoRA or partial fine-tuning) can significantly reduce computational requirements
- Data formatting and prompt engineering are crucial for effective instruction tuning

### Next Steps:

In the next notebook, we'll explore how to achieve similar results using the HuggingFace Trainer API, which simplifies the process and adds features like distributed training.

You could improve this implementation by:
- Using gradient clipping to prevent exploding gradients
- Implementing early stopping
- Using techniques like LoRA (Low-Rank Adaptation) for more efficient fine-tuning
- Exploring different instruction formats or datasets