# Try It Yourself!

**Hands-on tutorial: Train and evaluate your first fine-tuned model**

## Welcome to Hands-On Post-Training!

This guide walks you through training your first fine-tuned language model. By the end, you'll have:

- A GPT-2 model fine-tuned on instructions
- Hands-on experience with SFT, evaluation, and generation
- Understanding of how to apply these techniques to your own projects

**Time required:** 30-60 minutes (depending on hardware)

## 1. Setup

First, let's verify our environment and import the necessary libraries.

In [1]:
# Check environment
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name(0)}")
else:
    print("Device: CPU (training will be slower)")

PyTorch version: 2.10.0.dev20251124+rocm7.1
CUDA available: True
Device: Radeon RX 7900 XTX


In [2]:
# Import libraries
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
from datasets import load_dataset
from tqdm import tqdm
import numpy as np

print("All imports successful!")

All imports successful!


## 2. Load Model and Tokenizer

In [3]:
# Load GPT-2 (small, 124M parameters)
model_name = "gpt2"

print(f"Loading {model_name}...")
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Set padding token
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id

# Move to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model loaded!")
print(f"  Parameters: {total_params:,}")
print(f"  Device: {device}")

Loading gpt2...


Model loaded!
  Parameters: 124,439,808
  Device: cuda


## 3. Test Base Model (Before Fine-Tuning)

Let's see how the base model handles instructions before we fine-tune it.

In [4]:
def generate_response(model, tokenizer, instruction, max_new_tokens=100):
    """Generate a response to an instruction."""
    # Format as Alpaca-style prompt
    prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = full_text.split("### Response:\n")[-1].strip()
    
    return response

# Test base model
test_instructions = [
    "What is the capital of France?",
    "Write a haiku about programming.",
    "Explain machine learning in one sentence.",
]

print("Base Model Responses (BEFORE fine-tuning):")
print("=" * 60)
for instruction in test_instructions:
    print(f"\nInstruction: {instruction}")
    response = generate_response(model, tokenizer, instruction)
    print(f"Response: {response[:200]}..." if len(response) > 200 else f"Response: {response}")
    print("-" * 60)

Base Model Responses (BEFORE fine-tuning):

Instruction: What is the capital of France?


  attn_output = torch.nn.functional.scaled_dot_product_attention(
  attn_output = torch.nn.functional.scaled_dot_product_attention(


Response: In the United States, capital letters are used for the capital letters, not the letters used for the letter "a".

### Example:

# Create an account with your friends. Username:

Password:

### User na...
------------------------------------------------------------

Instruction: Write a haiku about programming.


Response: Write
------------------------------------------------------------

Instruction: Explain machine learning in one sentence.


Response: This is a simple example of a machine
------------------------------------------------------------


Notice how the base model doesn't follow instructions well - it typically continues generating text in the same style rather than answering the question.

## 4. Prepare Training Data

We'll use the Alpaca dataset, which contains instruction-response pairs.

In [5]:
# Load Alpaca dataset
print("Loading Alpaca dataset...")
raw_dataset = load_dataset("yahma/alpaca-cleaned", split="train")

# Take a small subset for quick training
num_samples = 500  # Adjust based on your time/hardware
raw_dataset = raw_dataset.select(range(num_samples))

print(f"Dataset loaded: {len(raw_dataset)} samples")
print(f"\nExample:")
print(f"  Instruction: {raw_dataset[0]['instruction'][:100]}...")
print(f"  Input: {raw_dataset[0]['input'][:50]}..." if raw_dataset[0]['input'] else "  Input: (none)")
print(f"  Output: {raw_dataset[0]['output'][:100]}...")

Loading Alpaca dataset...


Dataset loaded: 500 samples

Example:
  Instruction: Give three tips for staying healthy....
  Input: (none)
  Output: 1. Eat a balanced and nutritious diet: Make sure your meals are inclusive of a variety of fruits and...


In [6]:
class InstructionDataset(Dataset):
    """Dataset for instruction fine-tuning with proper loss masking."""
    
    def __init__(self, data, tokenizer, max_length=256):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def format_example(self, example):
        """Format example in Alpaca style."""
        if example['input']:
            prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{example['instruction']}

### Input:
{example['input']}

### Response:
"""
        else:
            prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{example['instruction']}

### Response:
"""
        return prompt, example['output']
    
    def __getitem__(self, idx):
        example = self.data[idx]
        prompt, response = self.format_example(example)
        
        # Tokenize prompt and response separately
        prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=True)
        response_tokens = self.tokenizer.encode(response, add_special_tokens=False)
        
        # Combine
        input_ids = prompt_tokens + response_tokens + [self.tokenizer.eos_token_id]
        
        # Create labels: -100 for prompt tokens (ignored in loss)
        labels = [-100] * len(prompt_tokens) + response_tokens + [self.tokenizer.eos_token_id]
        
        # Truncate if too long
        if len(input_ids) > self.max_length:
            input_ids = input_ids[:self.max_length]
            labels = labels[:self.max_length]
        
        # Pad to max_length
        padding_length = self.max_length - len(input_ids)
        input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length
        labels = labels + [-100] * padding_length  # Ignore padding in loss
        attention_mask = [1] * (self.max_length - padding_length) + [0] * padding_length
        
        return {
            'input_ids': torch.tensor(input_ids),
            'attention_mask': torch.tensor(attention_mask),
            'labels': torch.tensor(labels),
        }

# Create dataset and dataloader
train_dataset = InstructionDataset(raw_dataset, tokenizer, max_length=256)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

print(f"Created dataset with {len(train_dataset)} samples")
print(f"Batches per epoch: {len(train_loader)}")

Created dataset with 500 samples
Batches per epoch: 125


## 5. Training Loop

Now let's train the model on our instruction data.

In [7]:
# Training configuration
learning_rate = 5e-5
num_epochs = 1
warmup_steps = 50
max_grad_norm = 1.0

# Setup optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

print(f"Training Configuration:")
print(f"  Learning rate: {learning_rate}")
print(f"  Epochs: {num_epochs}")
print(f"  Total steps: {total_steps}")
print(f"  Warmup steps: {warmup_steps}")

Training Configuration:
  Learning rate: 5e-05
  Epochs: 1
  Total steps: 125
  Warmup steps: 50


In [8]:
# Training loop
print("\nStarting training...")
model.train()

for epoch in range(num_epochs):
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for step, batch in enumerate(progress_bar):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Forward pass
        outputs = model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        loss = outputs.loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        
        # Optimizer step
        optimizer.step()
        scheduler.step()
        
        # Track loss
        total_loss += loss.item()
        avg_loss = total_loss / (step + 1)
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'avg_loss': f'{avg_loss:.4f}',
            'ppl': f'{np.exp(avg_loss):.2f}'
        })
    
    print(f"\nEpoch {epoch+1} complete!")
    print(f"  Average loss: {avg_loss:.4f}")
    print(f"  Perplexity: {np.exp(avg_loss):.2f}")

print("\nTraining complete!")


Starting training...


Epoch 1/1:   0%|                                        | 0/125 [00:00<?, ?it/s]

`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Epoch 1/1:   0%| | 0/125 [00:00<?, ?it/s, loss=2.3985, avg_loss=2.3985, ppl=11.0

Epoch 1/1:   1%| | 1/125 [00:00<00:55,  2.23it/s, loss=2.3985, avg_loss=2.3985, 

Epoch 1/1:   1%| | 1/125 [00:00<00:55,  2.23it/s, loss=2.7345, avg_loss=2.5665, 

Epoch 1/1:   2%| | 2/125 [00:00<00:32,  3.82it/s, loss=2.7345, avg_loss=2.5665, 

Epoch 1/1:   2%| | 2/125 [00:00<00:32,  3.82it/s, loss=2.7143, avg_loss=2.6158, 

Epoch 1/1:   2%| | 3/125 [00:00<00:24,  5.02it/s, loss=2.7143, avg_loss=2.6158, 

Epoch 1/1:   2%| | 3/125 [00:00<00:24,  5.02it/s, loss=2.8429, avg_loss=2.6725, 

Epoch 1/1:   3%| | 4/125 [00:00<00:20,  5.93it/s, loss=2.8429, avg_loss=2.6725, 

Epoch 1/1:   3%| | 4/125 [00:00<00:20,  5.93it/s, loss=3.1762, avg_loss=2.7733, 

Epoch 1/1:   4%| | 5/125 [00:00<00:18,  6.60it/s, loss=3.1762, avg_loss=2.7733, 

Epoch 1/1:   4%| | 5/125 [00:01<00:18,  6.60it/s, loss=2.6486, avg_loss=2.7525, 

Epoch 1/1:   5%| | 6/125 [00:01<00:16,  7.05it/s, loss=2.6486, avg_loss=2.7525, 

Epoch 1/1:   5%| | 6/125 [00:01<00:16,  7.05it/s, loss=3.0585, avg_loss=2.7962, 

Epoch 1/1:   6%| | 7/125 [00:01<00:15,  7.40it/s, loss=3.0585, avg_loss=2.7962, 

Epoch 1/1:   6%| | 7/125 [00:01<00:15,  7.40it/s, loss=2.3917, avg_loss=2.7457, 

Epoch 1/1:   6%| | 8/125 [00:01<00:15,  7.64it/s, loss=2.3917, avg_loss=2.7457, 

Epoch 1/1:   6%| | 8/125 [00:01<00:15,  7.64it/s, loss=3.1110, avg_loss=2.7863, 

Epoch 1/1:   7%| | 9/125 [00:01<00:14,  7.84it/s, loss=3.1110, avg_loss=2.7863, 

Epoch 1/1:   7%| | 9/125 [00:01<00:14,  7.84it/s, loss=2.8943, avg_loss=2.7971, 

Epoch 1/1:   8%| | 10/125 [00:01<00:14,  7.98it/s, loss=2.8943, avg_loss=2.7971,

Epoch 1/1:   8%| | 10/125 [00:01<00:14,  7.98it/s, loss=2.9474, avg_loss=2.8107,

Epoch 1/1:   9%| | 11/125 [00:01<00:14,  8.08it/s, loss=2.9474, avg_loss=2.8107,

Epoch 1/1:   9%| | 11/125 [00:01<00:14,  8.08it/s, loss=2.7005, avg_loss=2.8015,

Epoch 1/1:  10%| | 12/125 [00:01<00:13,  8.13it/s, loss=2.7005, avg_loss=2.8015,

Epoch 1/1:  10%| | 12/125 [00:01<00:13,  8.13it/s, loss=2.4211, avg_loss=2.7723,

Epoch 1/1:  10%| | 13/125 [00:01<00:13,  8.15it/s, loss=2.4211, avg_loss=2.7723,

Epoch 1/1:  10%| | 13/125 [00:02<00:13,  8.15it/s, loss=2.5222, avg_loss=2.7544,

Epoch 1/1:  11%| | 14/125 [00:02<00:13,  8.18it/s, loss=2.5222, avg_loss=2.7544,

Epoch 1/1:  11%| | 14/125 [00:02<00:13,  8.18it/s, loss=2.5683, avg_loss=2.7420,

Epoch 1/1:  12%| | 15/125 [00:02<00:13,  8.18it/s, loss=2.5683, avg_loss=2.7420,

Epoch 1/1:  12%| | 15/125 [00:02<00:13,  8.18it/s, loss=2.8782, avg_loss=2.7505,

Epoch 1/1:  13%|▏| 16/125 [00:02<00:13,  8.21it/s, loss=2.8782, avg_loss=2.7505,

Epoch 1/1:  13%|▏| 16/125 [00:02<00:13,  8.21it/s, loss=2.8345, avg_loss=2.7555,

Epoch 1/1:  14%|▏| 17/125 [00:02<00:13,  8.22it/s, loss=2.8345, avg_loss=2.7555,

Epoch 1/1:  14%|▏| 17/125 [00:02<00:13,  8.22it/s, loss=2.7467, avg_loss=2.7550,

Epoch 1/1:  14%|▏| 18/125 [00:02<00:13,  8.21it/s, loss=2.7467, avg_loss=2.7550,

Epoch 1/1:  14%|▏| 18/125 [00:02<00:13,  8.21it/s, loss=2.8854, avg_loss=2.7618,

Epoch 1/1:  15%|▏| 19/125 [00:02<00:12,  8.22it/s, loss=2.8854, avg_loss=2.7618,

Epoch 1/1:  15%|▏| 19/125 [00:02<00:12,  8.22it/s, loss=2.4689, avg_loss=2.7472,

Epoch 1/1:  16%|▏| 20/125 [00:02<00:12,  8.21it/s, loss=2.4689, avg_loss=2.7472,

Epoch 1/1:  16%|▏| 20/125 [00:02<00:12,  8.21it/s, loss=2.7959, avg_loss=2.7495,

Epoch 1/1:  17%|▏| 21/125 [00:02<00:12,  8.23it/s, loss=2.7959, avg_loss=2.7495,

Epoch 1/1:  17%|▏| 21/125 [00:03<00:12,  8.23it/s, loss=2.6843, avg_loss=2.7465,

Epoch 1/1:  18%|▏| 22/125 [00:03<00:12,  8.24it/s, loss=2.6843, avg_loss=2.7465,

Epoch 1/1:  18%|▏| 22/125 [00:03<00:12,  8.24it/s, loss=2.0521, avg_loss=2.7164,

Epoch 1/1:  18%|▏| 23/125 [00:03<00:12,  8.25it/s, loss=2.0521, avg_loss=2.7164,

Epoch 1/1:  18%|▏| 23/125 [00:03<00:12,  8.25it/s, loss=2.0270, avg_loss=2.6876,

Epoch 1/1:  19%|▏| 24/125 [00:03<00:12,  8.26it/s, loss=2.0270, avg_loss=2.6876,

Epoch 1/1:  19%|▏| 24/125 [00:03<00:12,  8.26it/s, loss=2.9587, avg_loss=2.6985,

Epoch 1/1:  20%|▏| 25/125 [00:03<00:12,  8.27it/s, loss=2.9587, avg_loss=2.6985,

Epoch 1/1:  20%|▏| 25/125 [00:03<00:12,  8.27it/s, loss=2.2059, avg_loss=2.6795,

Epoch 1/1:  21%|▏| 26/125 [00:03<00:12,  8.24it/s, loss=2.2059, avg_loss=2.6795,

Epoch 1/1:  21%|▏| 26/125 [00:03<00:12,  8.24it/s, loss=2.3132, avg_loss=2.6660,

Epoch 1/1:  22%|▏| 27/125 [00:03<00:11,  8.24it/s, loss=2.3132, avg_loss=2.6660,

Epoch 1/1:  22%|▏| 27/125 [00:03<00:11,  8.24it/s, loss=2.5879, avg_loss=2.6632,

Epoch 1/1:  22%|▏| 28/125 [00:03<00:11,  8.26it/s, loss=2.5879, avg_loss=2.6632,

Epoch 1/1:  22%|▏| 28/125 [00:03<00:11,  8.26it/s, loss=2.0721, avg_loss=2.6428,

Epoch 1/1:  23%|▏| 29/125 [00:03<00:11,  8.26it/s, loss=2.0721, avg_loss=2.6428,

Epoch 1/1:  23%|▏| 29/125 [00:03<00:11,  8.26it/s, loss=2.4831, avg_loss=2.6375,

Epoch 1/1:  24%|▏| 30/125 [00:03<00:11,  8.29it/s, loss=2.4831, avg_loss=2.6375,

Epoch 1/1:  24%|▏| 30/125 [00:04<00:11,  8.29it/s, loss=2.3572, avg_loss=2.6284,

Epoch 1/1:  25%|▏| 31/125 [00:04<00:11,  8.27it/s, loss=2.3572, avg_loss=2.6284,

Epoch 1/1:  25%|▏| 31/125 [00:04<00:11,  8.27it/s, loss=2.0748, avg_loss=2.6111,

Epoch 1/1:  26%|▎| 32/125 [00:04<00:11,  8.24it/s, loss=2.0748, avg_loss=2.6111,

Epoch 1/1:  26%|▎| 32/125 [00:04<00:11,  8.24it/s, loss=2.3980, avg_loss=2.6047,

Epoch 1/1:  26%|▎| 33/125 [00:04<00:11,  8.23it/s, loss=2.3980, avg_loss=2.6047,

Epoch 1/1:  26%|▎| 33/125 [00:04<00:11,  8.23it/s, loss=2.3587, avg_loss=2.5974,

Epoch 1/1:  27%|▎| 34/125 [00:04<00:11,  8.24it/s, loss=2.3587, avg_loss=2.5974,

Epoch 1/1:  27%|▎| 34/125 [00:04<00:11,  8.24it/s, loss=1.9847, avg_loss=2.5799,

Epoch 1/1:  28%|▎| 35/125 [00:04<00:10,  8.26it/s, loss=1.9847, avg_loss=2.5799,

Epoch 1/1:  28%|▎| 35/125 [00:04<00:10,  8.26it/s, loss=3.0188, avg_loss=2.5921,

Epoch 1/1:  29%|▎| 36/125 [00:04<00:10,  8.25it/s, loss=3.0188, avg_loss=2.5921,

Epoch 1/1:  29%|▎| 36/125 [00:04<00:10,  8.25it/s, loss=2.5401, avg_loss=2.5907,

Epoch 1/1:  30%|▎| 37/125 [00:04<00:10,  8.25it/s, loss=2.5401, avg_loss=2.5907,

Epoch 1/1:  30%|▎| 37/125 [00:04<00:10,  8.25it/s, loss=2.7306, avg_loss=2.5944,

Epoch 1/1:  30%|▎| 38/125 [00:04<00:10,  8.24it/s, loss=2.7306, avg_loss=2.5944,

Epoch 1/1:  30%|▎| 38/125 [00:05<00:10,  8.24it/s, loss=2.1440, avg_loss=2.5828,

Epoch 1/1:  31%|▎| 39/125 [00:05<00:10,  8.24it/s, loss=2.1440, avg_loss=2.5828,

Epoch 1/1:  31%|▎| 39/125 [00:05<00:10,  8.24it/s, loss=2.2356, avg_loss=2.5742,

Epoch 1/1:  32%|▎| 40/125 [00:05<00:10,  8.25it/s, loss=2.2356, avg_loss=2.5742,

Epoch 1/1:  32%|▎| 40/125 [00:05<00:10,  8.25it/s, loss=2.0482, avg_loss=2.5613,

Epoch 1/1:  33%|▎| 41/125 [00:05<00:10,  8.24it/s, loss=2.0482, avg_loss=2.5613,

Epoch 1/1:  33%|▎| 41/125 [00:05<00:10,  8.24it/s, loss=1.9557, avg_loss=2.5469,

Epoch 1/1:  34%|▎| 42/125 [00:05<00:10,  8.24it/s, loss=1.9557, avg_loss=2.5469,

Epoch 1/1:  34%|▎| 42/125 [00:05<00:10,  8.24it/s, loss=2.3614, avg_loss=2.5426,

Epoch 1/1:  34%|▎| 43/125 [00:05<00:09,  8.24it/s, loss=2.3614, avg_loss=2.5426,

Epoch 1/1:  34%|▎| 43/125 [00:05<00:09,  8.24it/s, loss=2.3019, avg_loss=2.5371,

Epoch 1/1:  35%|▎| 44/125 [00:05<00:09,  8.23it/s, loss=2.3019, avg_loss=2.5371,

Epoch 1/1:  35%|▎| 44/125 [00:05<00:09,  8.23it/s, loss=2.3256, avg_loss=2.5324,

Epoch 1/1:  36%|▎| 45/125 [00:05<00:09,  8.23it/s, loss=2.3256, avg_loss=2.5324,

Epoch 1/1:  36%|▎| 45/125 [00:05<00:09,  8.23it/s, loss=2.4143, avg_loss=2.5299,

Epoch 1/1:  37%|▎| 46/125 [00:05<00:09,  8.21it/s, loss=2.4143, avg_loss=2.5299,

Epoch 1/1:  37%|▎| 46/125 [00:06<00:09,  8.21it/s, loss=2.2434, avg_loss=2.5238,

Epoch 1/1:  38%|▍| 47/125 [00:06<00:09,  8.20it/s, loss=2.2434, avg_loss=2.5238,

Epoch 1/1:  38%|▍| 47/125 [00:06<00:09,  8.20it/s, loss=2.3804, avg_loss=2.5208,

Epoch 1/1:  38%|▍| 48/125 [00:06<00:09,  8.17it/s, loss=2.3804, avg_loss=2.5208,

Epoch 1/1:  38%|▍| 48/125 [00:06<00:09,  8.17it/s, loss=2.0104, avg_loss=2.5104,

Epoch 1/1:  39%|▍| 49/125 [00:06<00:09,  8.16it/s, loss=2.0104, avg_loss=2.5104,

Epoch 1/1:  39%|▍| 49/125 [00:06<00:09,  8.16it/s, loss=2.4296, avg_loss=2.5087,

Epoch 1/1:  40%|▍| 50/125 [00:06<00:09,  8.21it/s, loss=2.4296, avg_loss=2.5087,

Epoch 1/1:  40%|▍| 50/125 [00:06<00:09,  8.21it/s, loss=2.3737, avg_loss=2.5061,

Epoch 1/1:  41%|▍| 51/125 [00:06<00:08,  8.22it/s, loss=2.3737, avg_loss=2.5061,

Epoch 1/1:  41%|▍| 51/125 [00:06<00:08,  8.22it/s, loss=2.3010, avg_loss=2.5022,

Epoch 1/1:  42%|▍| 52/125 [00:06<00:08,  8.21it/s, loss=2.3010, avg_loss=2.5022,

Epoch 1/1:  42%|▍| 52/125 [00:06<00:08,  8.21it/s, loss=2.2574, avg_loss=2.4975,

Epoch 1/1:  42%|▍| 53/125 [00:06<00:08,  8.18it/s, loss=2.2574, avg_loss=2.4975,

Epoch 1/1:  42%|▍| 53/125 [00:06<00:08,  8.18it/s, loss=2.0442, avg_loss=2.4891,

Epoch 1/1:  43%|▍| 54/125 [00:06<00:08,  8.20it/s, loss=2.0442, avg_loss=2.4891,

Epoch 1/1:  43%|▍| 54/125 [00:07<00:08,  8.20it/s, loss=2.5030, avg_loss=2.4894,

Epoch 1/1:  44%|▍| 55/125 [00:07<00:08,  8.18it/s, loss=2.5030, avg_loss=2.4894,

Epoch 1/1:  44%|▍| 55/125 [00:07<00:08,  8.18it/s, loss=2.3327, avg_loss=2.4866,

Epoch 1/1:  45%|▍| 56/125 [00:07<00:08,  8.18it/s, loss=2.3327, avg_loss=2.4866,

Epoch 1/1:  45%|▍| 56/125 [00:07<00:08,  8.18it/s, loss=2.0308, avg_loss=2.4786,

Epoch 1/1:  46%|▍| 57/125 [00:07<00:08,  8.20it/s, loss=2.0308, avg_loss=2.4786,

Epoch 1/1:  46%|▍| 57/125 [00:07<00:08,  8.20it/s, loss=2.1827, avg_loss=2.4735,

Epoch 1/1:  46%|▍| 58/125 [00:07<00:08,  8.22it/s, loss=2.1827, avg_loss=2.4735,

Epoch 1/1:  46%|▍| 58/125 [00:07<00:08,  8.22it/s, loss=2.2960, avg_loss=2.4705,

Epoch 1/1:  47%|▍| 59/125 [00:07<00:08,  8.21it/s, loss=2.2960, avg_loss=2.4705,

Epoch 1/1:  47%|▍| 59/125 [00:07<00:08,  8.21it/s, loss=2.1795, avg_loss=2.4656,

Epoch 1/1:  48%|▍| 60/125 [00:07<00:07,  8.23it/s, loss=2.1795, avg_loss=2.4656,

Epoch 1/1:  48%|▍| 60/125 [00:07<00:07,  8.23it/s, loss=2.4929, avg_loss=2.4661,

Epoch 1/1:  49%|▍| 61/125 [00:07<00:07,  8.23it/s, loss=2.4929, avg_loss=2.4661,

Epoch 1/1:  49%|▍| 61/125 [00:07<00:07,  8.23it/s, loss=2.0727, avg_loss=2.4597,

Epoch 1/1:  50%|▍| 62/125 [00:07<00:07,  8.23it/s, loss=2.0727, avg_loss=2.4597,

Epoch 1/1:  50%|▍| 62/125 [00:07<00:07,  8.23it/s, loss=2.7766, avg_loss=2.4648,

Epoch 1/1:  50%|▌| 63/125 [00:07<00:07,  8.21it/s, loss=2.7766, avg_loss=2.4648,

Epoch 1/1:  50%|▌| 63/125 [00:08<00:07,  8.21it/s, loss=2.2759, avg_loss=2.4618,

Epoch 1/1:  51%|▌| 64/125 [00:08<00:07,  8.20it/s, loss=2.2759, avg_loss=2.4618,

Epoch 1/1:  51%|▌| 64/125 [00:08<00:07,  8.20it/s, loss=2.1378, avg_loss=2.4568,

Epoch 1/1:  52%|▌| 65/125 [00:08<00:07,  8.20it/s, loss=2.1378, avg_loss=2.4568,

Epoch 1/1:  52%|▌| 65/125 [00:08<00:07,  8.20it/s, loss=2.3409, avg_loss=2.4551,

Epoch 1/1:  53%|▌| 66/125 [00:08<00:07,  8.18it/s, loss=2.3409, avg_loss=2.4551,

Epoch 1/1:  53%|▌| 66/125 [00:08<00:07,  8.18it/s, loss=2.2454, avg_loss=2.4519,

Epoch 1/1:  54%|▌| 67/125 [00:08<00:07,  8.18it/s, loss=2.2454, avg_loss=2.4519,

Epoch 1/1:  54%|▌| 67/125 [00:08<00:07,  8.18it/s, loss=2.2979, avg_loss=2.4497,

Epoch 1/1:  54%|▌| 68/125 [00:08<00:06,  8.19it/s, loss=2.2979, avg_loss=2.4497,

Epoch 1/1:  54%|▌| 68/125 [00:08<00:06,  8.19it/s, loss=1.9874, avg_loss=2.4430,

Epoch 1/1:  55%|▌| 69/125 [00:08<00:06,  8.18it/s, loss=1.9874, avg_loss=2.4430,

Epoch 1/1:  55%|▌| 69/125 [00:08<00:06,  8.18it/s, loss=2.2939, avg_loss=2.4408,

Epoch 1/1:  56%|▌| 70/125 [00:08<00:06,  8.16it/s, loss=2.2939, avg_loss=2.4408,

Epoch 1/1:  56%|▌| 70/125 [00:08<00:06,  8.16it/s, loss=2.5350, avg_loss=2.4422,

Epoch 1/1:  57%|▌| 71/125 [00:08<00:06,  8.17it/s, loss=2.5350, avg_loss=2.4422,

Epoch 1/1:  57%|▌| 71/125 [00:09<00:06,  8.17it/s, loss=2.3732, avg_loss=2.4412,

Epoch 1/1:  58%|▌| 72/125 [00:09<00:06,  8.17it/s, loss=2.3732, avg_loss=2.4412,

Epoch 1/1:  58%|▌| 72/125 [00:09<00:06,  8.17it/s, loss=2.1484, avg_loss=2.4372,

Epoch 1/1:  58%|▌| 73/125 [00:09<00:06,  8.16it/s, loss=2.1484, avg_loss=2.4372,

Epoch 1/1:  58%|▌| 73/125 [00:09<00:06,  8.16it/s, loss=2.2243, avg_loss=2.4343,

Epoch 1/1:  59%|▌| 74/125 [00:09<00:06,  8.17it/s, loss=2.2243, avg_loss=2.4343,

Epoch 1/1:  59%|▌| 74/125 [00:09<00:06,  8.17it/s, loss=2.2244, avg_loss=2.4315,

Epoch 1/1:  60%|▌| 75/125 [00:09<00:06,  8.19it/s, loss=2.2244, avg_loss=2.4315,

Epoch 1/1:  60%|▌| 75/125 [00:09<00:06,  8.19it/s, loss=3.1475, avg_loss=2.4409,

Epoch 1/1:  61%|▌| 76/125 [00:09<00:05,  8.17it/s, loss=3.1475, avg_loss=2.4409,

Epoch 1/1:  61%|▌| 76/125 [00:09<00:05,  8.17it/s, loss=2.4087, avg_loss=2.4405,

Epoch 1/1:  62%|▌| 77/125 [00:09<00:05,  8.16it/s, loss=2.4087, avg_loss=2.4405,

Epoch 1/1:  62%|▌| 77/125 [00:09<00:05,  8.16it/s, loss=2.1811, avg_loss=2.4372,

Epoch 1/1:  62%|▌| 78/125 [00:09<00:05,  8.18it/s, loss=2.1811, avg_loss=2.4372,

Epoch 1/1:  62%|▌| 78/125 [00:09<00:05,  8.18it/s, loss=2.4735, avg_loss=2.4377,

Epoch 1/1:  63%|▋| 79/125 [00:09<00:05,  8.17it/s, loss=2.4735, avg_loss=2.4377,

Epoch 1/1:  63%|▋| 79/125 [00:10<00:05,  8.17it/s, loss=2.0875, avg_loss=2.4333,

Epoch 1/1:  64%|▋| 80/125 [00:10<00:05,  8.17it/s, loss=2.0875, avg_loss=2.4333,

Epoch 1/1:  64%|▋| 80/125 [00:10<00:05,  8.17it/s, loss=2.1823, avg_loss=2.4302,

Epoch 1/1:  65%|▋| 81/125 [00:10<00:05,  8.16it/s, loss=2.1823, avg_loss=2.4302,

Epoch 1/1:  65%|▋| 81/125 [00:10<00:05,  8.16it/s, loss=2.8548, avg_loss=2.4354,

Epoch 1/1:  66%|▋| 82/125 [00:10<00:05,  8.19it/s, loss=2.8548, avg_loss=2.4354,

Epoch 1/1:  66%|▋| 82/125 [00:10<00:05,  8.19it/s, loss=2.1871, avg_loss=2.4324,

Epoch 1/1:  66%|▋| 83/125 [00:10<00:05,  8.21it/s, loss=2.1871, avg_loss=2.4324,

Epoch 1/1:  66%|▋| 83/125 [00:10<00:05,  8.21it/s, loss=2.3567, avg_loss=2.4315,

Epoch 1/1:  67%|▋| 84/125 [00:10<00:04,  8.22it/s, loss=2.3567, avg_loss=2.4315,

Epoch 1/1:  67%|▋| 84/125 [00:10<00:04,  8.22it/s, loss=2.1611, avg_loss=2.4283,

Epoch 1/1:  68%|▋| 85/125 [00:10<00:04,  8.20it/s, loss=2.1611, avg_loss=2.4283,

Epoch 1/1:  68%|▋| 85/125 [00:10<00:04,  8.20it/s, loss=2.0366, avg_loss=2.4237,

Epoch 1/1:  69%|▋| 86/125 [00:10<00:04,  8.19it/s, loss=2.0366, avg_loss=2.4237,

Epoch 1/1:  69%|▋| 86/125 [00:10<00:04,  8.19it/s, loss=2.0682, avg_loss=2.4197,

Epoch 1/1:  70%|▋| 87/125 [00:10<00:04,  8.21it/s, loss=2.0682, avg_loss=2.4197,

Epoch 1/1:  70%|▋| 87/125 [00:11<00:04,  8.21it/s, loss=2.2135, avg_loss=2.4173,

Epoch 1/1:  70%|▋| 88/125 [00:11<00:04,  8.20it/s, loss=2.2135, avg_loss=2.4173,

Epoch 1/1:  70%|▋| 88/125 [00:11<00:04,  8.20it/s, loss=2.3923, avg_loss=2.4170,

Epoch 1/1:  71%|▋| 89/125 [00:11<00:04,  8.19it/s, loss=2.3923, avg_loss=2.4170,

Epoch 1/1:  71%|▋| 89/125 [00:11<00:04,  8.19it/s, loss=2.0384, avg_loss=2.4128,

Epoch 1/1:  72%|▋| 90/125 [00:11<00:04,  8.17it/s, loss=2.0384, avg_loss=2.4128,

Epoch 1/1:  72%|▋| 90/125 [00:11<00:04,  8.17it/s, loss=2.2088, avg_loss=2.4106,

Epoch 1/1:  73%|▋| 91/125 [00:11<00:04,  8.17it/s, loss=2.2088, avg_loss=2.4106,

Epoch 1/1:  73%|▋| 91/125 [00:11<00:04,  8.17it/s, loss=1.9499, avg_loss=2.4056,

Epoch 1/1:  74%|▋| 92/125 [00:11<00:04,  8.20it/s, loss=1.9499, avg_loss=2.4056,

Epoch 1/1:  74%|▋| 92/125 [00:11<00:04,  8.20it/s, loss=2.4895, avg_loss=2.4065,

Epoch 1/1:  74%|▋| 93/125 [00:11<00:03,  8.21it/s, loss=2.4895, avg_loss=2.4065,

Epoch 1/1:  74%|▋| 93/125 [00:11<00:03,  8.21it/s, loss=1.7518, avg_loss=2.3995,

Epoch 1/1:  75%|▊| 94/125 [00:11<00:03,  8.21it/s, loss=1.7518, avg_loss=2.3995,

Epoch 1/1:  75%|▊| 94/125 [00:11<00:03,  8.21it/s, loss=2.4465, avg_loss=2.4000,

Epoch 1/1:  76%|▊| 95/125 [00:11<00:03,  8.21it/s, loss=2.4465, avg_loss=2.4000,

Epoch 1/1:  76%|▊| 95/125 [00:12<00:03,  8.21it/s, loss=2.2454, avg_loss=2.3984,

Epoch 1/1:  77%|▊| 96/125 [00:12<00:03,  8.17it/s, loss=2.2454, avg_loss=2.3984,

Epoch 1/1:  77%|▊| 96/125 [00:12<00:03,  8.17it/s, loss=2.4852, avg_loss=2.3993,

Epoch 1/1:  78%|▊| 97/125 [00:12<00:03,  8.19it/s, loss=2.4852, avg_loss=2.3993,

Epoch 1/1:  78%|▊| 97/125 [00:12<00:03,  8.19it/s, loss=2.4324, avg_loss=2.3996,

Epoch 1/1:  78%|▊| 98/125 [00:12<00:03,  8.19it/s, loss=2.4324, avg_loss=2.3996,

Epoch 1/1:  78%|▊| 98/125 [00:12<00:03,  8.19it/s, loss=2.4135, avg_loss=2.3998,

Epoch 1/1:  79%|▊| 99/125 [00:12<00:03,  8.19it/s, loss=2.4135, avg_loss=2.3998,

Epoch 1/1:  79%|▊| 99/125 [00:12<00:03,  8.19it/s, loss=1.6407, avg_loss=2.3922,

Epoch 1/1:  80%|▊| 100/125 [00:12<00:03,  8.20it/s, loss=1.6407, avg_loss=2.3922

Epoch 1/1:  80%|▊| 100/125 [00:12<00:03,  8.20it/s, loss=2.4422, avg_loss=2.3927

Epoch 1/1:  81%|▊| 101/125 [00:12<00:02,  8.19it/s, loss=2.4422, avg_loss=2.3927

Epoch 1/1:  81%|▊| 101/125 [00:12<00:02,  8.19it/s, loss=2.2939, avg_loss=2.3917

Epoch 1/1:  82%|▊| 102/125 [00:12<00:02,  8.16it/s, loss=2.2939, avg_loss=2.3917

Epoch 1/1:  82%|▊| 102/125 [00:12<00:02,  8.16it/s, loss=2.6973, avg_loss=2.3947

Epoch 1/1:  82%|▊| 103/125 [00:12<00:02,  8.18it/s, loss=2.6973, avg_loss=2.3947

Epoch 1/1:  82%|▊| 103/125 [00:13<00:02,  8.18it/s, loss=2.2541, avg_loss=2.3933

Epoch 1/1:  83%|▊| 104/125 [00:13<00:02,  8.18it/s, loss=2.2541, avg_loss=2.3933

Epoch 1/1:  83%|▊| 104/125 [00:13<00:02,  8.18it/s, loss=2.0260, avg_loss=2.3898

Epoch 1/1:  84%|▊| 105/125 [00:13<00:02,  8.16it/s, loss=2.0260, avg_loss=2.3898

Epoch 1/1:  84%|▊| 105/125 [00:13<00:02,  8.16it/s, loss=2.7181, avg_loss=2.3929

Epoch 1/1:  85%|▊| 106/125 [00:13<00:02,  8.16it/s, loss=2.7181, avg_loss=2.3929

Epoch 1/1:  85%|▊| 106/125 [00:13<00:02,  8.16it/s, loss=2.2880, avg_loss=2.3919

Epoch 1/1:  86%|▊| 107/125 [00:13<00:02,  8.17it/s, loss=2.2880, avg_loss=2.3919

Epoch 1/1:  86%|▊| 107/125 [00:13<00:02,  8.17it/s, loss=2.2241, avg_loss=2.3904

Epoch 1/1:  86%|▊| 108/125 [00:13<00:02,  8.14it/s, loss=2.2241, avg_loss=2.3904

Epoch 1/1:  86%|▊| 108/125 [00:13<00:02,  8.14it/s, loss=2.0203, avg_loss=2.3870

Epoch 1/1:  87%|▊| 109/125 [00:13<00:01,  8.16it/s, loss=2.0203, avg_loss=2.3870

Epoch 1/1:  87%|▊| 109/125 [00:13<00:01,  8.16it/s, loss=2.2155, avg_loss=2.3854

Epoch 1/1:  88%|▉| 110/125 [00:13<00:01,  8.16it/s, loss=2.2155, avg_loss=2.3854

Epoch 1/1:  88%|▉| 110/125 [00:13<00:01,  8.16it/s, loss=2.4701, avg_loss=2.3862

Epoch 1/1:  89%|▉| 111/125 [00:13<00:01,  8.16it/s, loss=2.4701, avg_loss=2.3862

Epoch 1/1:  89%|▉| 111/125 [00:13<00:01,  8.16it/s, loss=2.2710, avg_loss=2.3852

Epoch 1/1:  90%|▉| 112/125 [00:13<00:01,  8.19it/s, loss=2.2710, avg_loss=2.3852

Epoch 1/1:  90%|▉| 112/125 [00:14<00:01,  8.19it/s, loss=3.2329, avg_loss=2.3927

Epoch 1/1:  90%|▉| 113/125 [00:14<00:01,  8.19it/s, loss=3.2329, avg_loss=2.3927

Epoch 1/1:  90%|▉| 113/125 [00:14<00:01,  8.19it/s, loss=1.9879, avg_loss=2.3891

Epoch 1/1:  91%|▉| 114/125 [00:14<00:01,  8.20it/s, loss=1.9879, avg_loss=2.3891

Epoch 1/1:  91%|▉| 114/125 [00:14<00:01,  8.20it/s, loss=2.2203, avg_loss=2.3876

Epoch 1/1:  92%|▉| 115/125 [00:14<00:01,  8.20it/s, loss=2.2203, avg_loss=2.3876

Epoch 1/1:  92%|▉| 115/125 [00:14<00:01,  8.20it/s, loss=2.4546, avg_loss=2.3882

Epoch 1/1:  93%|▉| 116/125 [00:14<00:01,  8.17it/s, loss=2.4546, avg_loss=2.3882

Epoch 1/1:  93%|▉| 116/125 [00:14<00:01,  8.17it/s, loss=2.0408, avg_loss=2.3853

Epoch 1/1:  94%|▉| 117/125 [00:14<00:00,  8.20it/s, loss=2.0408, avg_loss=2.3853

Epoch 1/1:  94%|▉| 117/125 [00:14<00:00,  8.20it/s, loss=2.1342, avg_loss=2.3831

Epoch 1/1:  94%|▉| 118/125 [00:14<00:00,  8.20it/s, loss=2.1342, avg_loss=2.3831

Epoch 1/1:  94%|▉| 118/125 [00:14<00:00,  8.20it/s, loss=1.9298, avg_loss=2.3793

Epoch 1/1:  95%|▉| 119/125 [00:14<00:00,  8.21it/s, loss=1.9298, avg_loss=2.3793

Epoch 1/1:  95%|▉| 119/125 [00:14<00:00,  8.21it/s, loss=2.2835, avg_loss=2.3785

Epoch 1/1:  96%|▉| 120/125 [00:14<00:00,  8.19it/s, loss=2.2835, avg_loss=2.3785

Epoch 1/1:  96%|▉| 120/125 [00:15<00:00,  8.19it/s, loss=2.4558, avg_loss=2.3792

Epoch 1/1:  97%|▉| 121/125 [00:15<00:00,  8.19it/s, loss=2.4558, avg_loss=2.3792

Epoch 1/1:  97%|▉| 121/125 [00:15<00:00,  8.19it/s, loss=2.4489, avg_loss=2.3797

Epoch 1/1:  98%|▉| 122/125 [00:15<00:00,  8.19it/s, loss=2.4489, avg_loss=2.3797

Epoch 1/1:  98%|▉| 122/125 [00:15<00:00,  8.19it/s, loss=2.4519, avg_loss=2.3803

Epoch 1/1:  98%|▉| 123/125 [00:15<00:00,  8.19it/s, loss=2.4519, avg_loss=2.3803

Epoch 1/1:  98%|▉| 123/125 [00:15<00:00,  8.19it/s, loss=1.8992, avg_loss=2.3764

Epoch 1/1:  99%|▉| 124/125 [00:15<00:00,  8.19it/s, loss=1.8992, avg_loss=2.3764

Epoch 1/1:  99%|▉| 124/125 [00:15<00:00,  8.19it/s, loss=2.3073, avg_loss=2.3759

Epoch 1/1: 100%|█| 125/125 [00:15<00:00,  8.20it/s, loss=2.3073, avg_loss=2.3759

Epoch 1/1: 100%|█| 125/125 [00:15<00:00,  8.03it/s, loss=2.3073, avg_loss=2.3759


Epoch 1 complete!
  Average loss: 2.3759
  Perplexity: 10.76

Training complete!





## 6. Test Fine-Tuned Model

Now let's see how the model performs after fine-tuning!

In [9]:
# Set model to eval mode
model.eval()

print("Fine-Tuned Model Responses (AFTER fine-tuning):")
print("=" * 60)
for instruction in test_instructions:
    print(f"\nInstruction: {instruction}")
    response = generate_response(model, tokenizer, instruction)
    print(f"Response: {response}")
    print("-" * 60)

Fine-Tuned Model Responses (AFTER fine-tuning):

Instruction: What is the capital of France?
Response: France is the capital of France, with its capital being Paris.
------------------------------------------------------------

Instruction: Write a haiku about programming.


Response: The haiku describes a task that is one that is a combination of the three basic types of language.

First, the verb "hōshi" refers to the verb "to write" and "to make", while the verb "hōshirō" refers to the verb "to write". The verb "hōshirō" is also used to refer to "to write" and "to make", while the verb "hōshirō" is used to
------------------------------------------------------------

Instruction: Explain machine learning in one sentence.


Response: Machine learning is a powerful tool that enables computers to understand and predict human behavior and behaviors. In this paper, we present machine learning algorithms that can generate predictions about the future behavior of human beings. The algorithms are built on real-time data, and are used to predict the future behavior of humans, such as how quickly they will move, how long they will be able to run, and how they will perform in a given situation. Machine learning algorithms are widely used in machine learning applications, including machine
------------------------------------------------------------


In [10]:
# Test with more instructions
additional_tests = [
    "List three benefits of exercise.",
    "What is Python used for?",
    "Explain what a neural network is in simple terms.",
    "Write a short poem about the ocean.",
]

print("Additional Tests:")
print("=" * 60)
for instruction in additional_tests:
    print(f"\nInstruction: {instruction}")
    response = generate_response(model, tokenizer, instruction)
    print(f"Response: {response}")
    print("-" * 60)

Additional Tests:

Instruction: List three benefits of exercise.


Response: 1. Exercise increases energy expenditure, reduces fatigue, and improves health.

2. Exercise increases brain function, reduces depression, and improves mental well-being.

3. Exercise increases physical activity, reduces stress, and improves mental health.

4. Exercise increases cognitive function, reduces anxiety, and improves cognitive function.

5. Exercise increases physical activity, improves sleep, and improves physical performance.

6. Exercise increases physical activity, improves sleep, and improves physical performance
------------------------------------------------------------

Instruction: What is Python used for?


Response: Python is a powerful and versatile language that can be used for many purposes, from scripting, to visual editing, to building applications, to creating custom applications. Python is also widely used for building web applications, and is a powerful language for building web applications in various browsers, including Chrome, Firefox, and Safari.

There are many uses for Python, including:

Documentation

Automation

Automation can be a powerful tool in a complex and complex world. Python is a
------------------------------------------------------------

Instruction: Explain what a neural network is in simple terms.


Response: A neural network is a type of computer system that learns about the world through data, information, and models. It is a network of interconnected computers that is responsible for processing, analyzing, and interpreting data. It is the backbone of computer science and engineering, and is used to develop algorithms and algorithms to solve problems. Neural networks are thought to be a key part of the computer science and engineering field, and have been used in many fields including robotics, artificial intelligence, and artificial intelligence.

The
------------------------------------------------------------

Instruction: Write a short poem about the ocean.
Response: The ocean is an important and beautiful place to visit, but it is not always a safe place. For that reason, it is important to understand the importance of staying in the ocean.
------------------------------------------------------------


## 7. Evaluate Model Quality

Let's compute some quantitative metrics.

In [11]:
def compute_perplexity(model, dataloader, device):
    """Compute perplexity on a dataset."""
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Computing perplexity"):
            batch = {k: v.to(device) for k, v in batch.items()}
            
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels']
            )
            
            # Count non-masked tokens
            num_tokens = (batch['labels'] != -100).sum().item()
            total_loss += outputs.loss.item() * num_tokens
            total_tokens += num_tokens
    
    avg_loss = total_loss / total_tokens
    perplexity = np.exp(avg_loss)
    
    return perplexity, avg_loss

# Compute perplexity
perplexity, loss = compute_perplexity(model, train_loader, device)
print(f"\nFinal Metrics:")
print(f"  Loss: {loss:.4f}")
print(f"  Perplexity: {perplexity:.2f}")

Computing perplexity:   0%|                             | 0/125 [00:00<?, ?it/s]

Computing perplexity:   2%|▎                    | 2/125 [00:00<00:08, 14.11it/s]

Computing perplexity:   3%|▋                    | 4/125 [00:00<00:08, 14.45it/s]

Computing perplexity:   5%|█                    | 6/125 [00:00<00:08, 14.61it/s]

Computing perplexity:   6%|█▎                   | 8/125 [00:00<00:08, 14.58it/s]

Computing perplexity:   8%|█▌                  | 10/125 [00:00<00:07, 14.52it/s]

Computing perplexity:  10%|█▉                  | 12/125 [00:00<00:07, 14.53it/s]

Computing perplexity:  11%|██▏                 | 14/125 [00:00<00:07, 14.48it/s]

Computing perplexity:  13%|██▌                 | 16/125 [00:01<00:07, 14.49it/s]

Computing perplexity:  14%|██▉                 | 18/125 [00:01<00:07, 14.51it/s]

Computing perplexity:  16%|███▏                | 20/125 [00:01<00:07, 14.48it/s]

Computing perplexity:  18%|███▌                | 22/125 [00:01<00:07, 14.49it/s]

Computing perplexity:  19%|███▊                | 24/125 [00:01<00:06, 14.54it/s]

Computing perplexity:  21%|████▏               | 26/125 [00:01<00:06, 14.54it/s]

Computing perplexity:  22%|████▍               | 28/125 [00:01<00:06, 14.52it/s]

Computing perplexity:  24%|████▊               | 30/125 [00:02<00:06, 14.55it/s]

Computing perplexity:  26%|█████               | 32/125 [00:02<00:06, 14.56it/s]

Computing perplexity:  27%|█████▍              | 34/125 [00:02<00:06, 14.53it/s]

Computing perplexity:  29%|█████▊              | 36/125 [00:02<00:06, 14.50it/s]

Computing perplexity:  30%|██████              | 38/125 [00:02<00:05, 14.52it/s]

Computing perplexity:  32%|██████▍             | 40/125 [00:02<00:05, 14.52it/s]

Computing perplexity:  34%|██████▋             | 42/125 [00:02<00:05, 14.49it/s]

Computing perplexity:  35%|███████             | 44/125 [00:03<00:05, 14.53it/s]

Computing perplexity:  37%|███████▎            | 46/125 [00:03<00:05, 14.51it/s]

Computing perplexity:  38%|███████▋            | 48/125 [00:03<00:05, 14.53it/s]

Computing perplexity:  40%|████████            | 50/125 [00:03<00:05, 14.53it/s]

Computing perplexity:  42%|████████▎           | 52/125 [00:03<00:05, 14.55it/s]

Computing perplexity:  43%|████████▋           | 54/125 [00:03<00:04, 14.50it/s]

Computing perplexity:  45%|████████▉           | 56/125 [00:03<00:04, 14.50it/s]

Computing perplexity:  46%|█████████▎          | 58/125 [00:03<00:04, 14.52it/s]

Computing perplexity:  48%|█████████▌          | 60/125 [00:04<00:04, 14.55it/s]

Computing perplexity:  50%|█████████▉          | 62/125 [00:04<00:04, 14.51it/s]

Computing perplexity:  51%|██████████▏         | 64/125 [00:04<00:04, 14.50it/s]

Computing perplexity:  53%|██████████▌         | 66/125 [00:04<00:04, 14.53it/s]

Computing perplexity:  54%|██████████▉         | 68/125 [00:04<00:03, 14.54it/s]

Computing perplexity:  56%|███████████▏        | 70/125 [00:04<00:03, 14.56it/s]

Computing perplexity:  58%|███████████▌        | 72/125 [00:04<00:03, 14.54it/s]

Computing perplexity:  59%|███████████▊        | 74/125 [00:05<00:03, 14.52it/s]

Computing perplexity:  61%|████████████▏       | 76/125 [00:05<00:03, 14.48it/s]

Computing perplexity:  62%|████████████▍       | 78/125 [00:05<00:03, 14.52it/s]

Computing perplexity:  64%|████████████▊       | 80/125 [00:05<00:03, 14.50it/s]

Computing perplexity:  66%|█████████████       | 82/125 [00:05<00:02, 14.52it/s]

Computing perplexity:  67%|█████████████▍      | 84/125 [00:05<00:02, 14.51it/s]

Computing perplexity:  69%|█████████████▊      | 86/125 [00:05<00:02, 14.49it/s]

Computing perplexity:  70%|██████████████      | 88/125 [00:06<00:02, 14.51it/s]

Computing perplexity:  72%|██████████████▍     | 90/125 [00:06<00:02, 14.48it/s]

Computing perplexity:  74%|██████████████▋     | 92/125 [00:06<00:02, 14.47it/s]

Computing perplexity:  75%|███████████████     | 94/125 [00:06<00:02, 14.48it/s]

Computing perplexity:  77%|███████████████▎    | 96/125 [00:06<00:01, 14.56it/s]

Computing perplexity:  78%|███████████████▋    | 98/125 [00:06<00:01, 14.58it/s]

Computing perplexity:  80%|███████████████▏   | 100/125 [00:06<00:01, 14.55it/s]

Computing perplexity:  82%|███████████████▌   | 102/125 [00:07<00:01, 14.53it/s]

Computing perplexity:  83%|███████████████▊   | 104/125 [00:07<00:01, 14.59it/s]

Computing perplexity:  85%|████████████████   | 106/125 [00:07<00:01, 14.60it/s]

Computing perplexity:  86%|████████████████▍  | 108/125 [00:07<00:01, 14.59it/s]

Computing perplexity:  88%|████████████████▋  | 110/125 [00:07<00:01, 14.60it/s]

Computing perplexity:  90%|█████████████████  | 112/125 [00:07<00:00, 14.55it/s]

Computing perplexity:  91%|█████████████████▎ | 114/125 [00:07<00:00, 14.62it/s]

Computing perplexity:  93%|█████████████████▋ | 116/125 [00:07<00:00, 14.58it/s]

Computing perplexity:  94%|█████████████████▉ | 118/125 [00:08<00:00, 14.64it/s]

Computing perplexity:  96%|██████████████████▏| 120/125 [00:08<00:00, 14.64it/s]

Computing perplexity:  98%|██████████████████▌| 122/125 [00:08<00:00, 14.64it/s]

Computing perplexity:  99%|██████████████████▊| 124/125 [00:08<00:00, 14.66it/s]

Computing perplexity: 100%|███████████████████| 125/125 [00:08<00:00, 14.54it/s]


Final Metrics:
  Loss: 1.9576
  Perplexity: 7.08





In [12]:
def compute_diversity(responses):
    """Compute diversity metrics for generated responses."""
    all_unigrams = []
    all_bigrams = []
    
    for response in responses:
        tokens = response.lower().split()
        all_unigrams.extend(tokens)
        all_bigrams.extend(zip(tokens[:-1], tokens[1:]))
    
    distinct_1 = len(set(all_unigrams)) / len(all_unigrams) if all_unigrams else 0
    distinct_2 = len(set(all_bigrams)) / len(all_bigrams) if all_bigrams else 0
    
    return distinct_1, distinct_2

# Generate responses for diversity analysis
diversity_prompts = [
    "Tell me about machine learning.",
    "Explain artificial intelligence.",
    "What is deep learning?",
    "Describe natural language processing.",
    "Explain what data science is.",
]

responses = [generate_response(model, tokenizer, p) for p in diversity_prompts]
d1, d2 = compute_diversity(responses)

print(f"\nDiversity Metrics:")
print(f"  Distinct-1 (unique unigrams): {d1:.2%}")
print(f"  Distinct-2 (unique bigrams): {d2:.2%}")
print(f"\nInterpretation:")
print(f"  > 0.4 distinct-1: Good diversity")
print(f"  < 0.2 distinct-1: May indicate mode collapse")


Diversity Metrics:
  Distinct-1 (unique unigrams): 39.40%
  Distinct-2 (unique bigrams): 76.77%

Interpretation:
  > 0.4 distinct-1: Good diversity
  < 0.2 distinct-1: May indicate mode collapse


## 8. Save Your Model

Save the fine-tuned model for later use.

In [13]:
# Save model and tokenizer
save_path = "./my_finetuned_model"

print(f"Saving model to {save_path}...")
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print("Model saved!")

# Show saved files
import os
print(f"\nSaved files:")
for f in os.listdir(save_path):
    size = os.path.getsize(os.path.join(save_path, f)) / 1e6
    print(f"  {f}: {size:.1f} MB")

Saving model to ./my_finetuned_model...


Model saved!

Saved files:
  tokenizer_config.json: 0.0 MB
  config.json: 0.0 MB
  tokenizer.json: 3.6 MB
  merges.txt: 0.5 MB
  special_tokens_map.json: 0.0 MB
  generation_config.json: 0.0 MB
  vocab.json: 0.8 MB
  model.safetensors: 497.8 MB


In [14]:
# Test loading the saved model
print("Testing model loading...")

loaded_model = AutoModelForCausalLM.from_pretrained(save_path)
loaded_tokenizer = AutoTokenizer.from_pretrained(save_path)
loaded_model = loaded_model.to(device)
loaded_model.eval()

test_instruction = "What is the meaning of life?"
response = generate_response(loaded_model, loaded_tokenizer, test_instruction)

print(f"\nTest with loaded model:")
print(f"Instruction: {test_instruction}")
print(f"Response: {response}")

Testing model loading...



Test with loaded model:
Instruction: What is the meaning of life?
Response: Life is a journey of discovery, transformation, and renewal. It is a journey that is both deeply rewarding and rewarding.


## Summary

Congratulations! You've successfully:

1. **Loaded** a pre-trained GPT-2 model
2. **Tested** the base model on instructions (and saw it doesn't follow them well)
3. **Prepared** training data with proper loss masking
4. **Trained** the model using supervised fine-tuning (SFT)
5. **Tested** the fine-tuned model (and saw significant improvement!)
6. **Evaluated** using perplexity and diversity metrics
7. **Saved** the model for later use

## Next Steps

Now that you've mastered the basics, try:

1. **Train longer** - Increase epochs or use more data
2. **Try LoRA** - More efficient training with fewer parameters
3. **Try DPO** - Align model with human preferences
4. **Use larger models** - Try GPT-2 Medium or Llama
5. **Custom data** - Fine-tune on your own instruction dataset

Happy fine-tuning!