# Working with Fine-Tuning

# Part 1 - Getting Started

## Univeral Code Used for the Entire Notebook

Let's set up our libraries and client

In [41]:
# Import necessary libraries and modules

# OpenAI library for API interaction and event handling
from openai import OpenAI  

# JSON library for handling JSON data
import json

# TikToken library for token counting
import tiktoken

# NumPy library for numerical operations
import numpy as np

# defaultdict from the collections module for dictionary with default values
from collections import defaultdict

# Math library for mathematical operations
import math

import json
import random
from pathlib import Path


In [None]:
# Initialize the OpenAI client
client = OpenAI()  

## Training File Format Validation

### Data Loading

In [None]:
# Define the path to the dataset file
data_path = "./artifacts/marv_fine_tune.jsonl"

# Load the dataset
with open(data_path, 'r', encoding='utf-8') as file:
    dataset = [json.loads(line) for line in file]

# Print initial dataset statistics
print("Number of examples:", len(dataset))
print("First example:")

# Print messages from the first example in the dataset
for message in dataset[0]["messages"]:
    print(message)


### Format Validation

We can perform a variety of error checks to validate that each conversation in the dataset adheres to the format expected by the fine-tuning API. Errors are categorized based on their nature for easier debugging.

1. **Data Type Check**: Checks whether each entry in the dataset is a dictionary (dict). Error type: `data_type`.

2. **Presence of Message List**: Checks if a `messages` list is present in each entry. Error type: `missing_messages_list`.

3. **Message Keys Check**: Validates that each message in the `messages` list contains the keys `role` and `content`. Error type: `message_missing_key`.

4. **Unrecognized Keys in Messages**: Logs if a message has keys other than `role`, `content`, `weight`, `function_call`, and `name`. Error type: `message_unrecognized_key`.

5. **Role Validation**: Ensures the `role` is one of "system", "user", or "assistant". Error type: `unrecognized_role`.

6. **Content Validation**: Verifies that `content` has textual data and is a string. Error type: `missing_content`.

7. **Assistant Message Presence**: Checks that each conversation has at least one message from the assistant. Error type: `example_missing_assistant_message`.


In [None]:
# Dictionary to track format errors
format_errors = defaultdict(int)

# Iterate through each example in the dataset
for ex in dataset:
    # Check if the example is a dictionary
    if not isinstance(ex, dict):
        format_errors["data_type"] += 1
        continue
    
    # Retrieve the messages list from the example
    messages = ex.get("messages", None)
    if not messages:
        format_errors["missing_messages_list"] += 1
        continue
    
    # Check each message in the messages list
    for message in messages:
        # Check if required keys are present in the message
        if "role" not in message or "content" not in message:
            format_errors["message_missing_key"] += 1
        
        # Check for any unrecognized keys in the message
        if any(k not in ("role", "content", "name", "function_call", "weight") for k in message):
            format_errors["message_unrecognized_key"] += 1
        
        # Validate the role value in the message
        if message.get("role", None) not in ("system", "user", "assistant", "function"):
            format_errors["unrecognized_role"] += 1
        
        # Check content and function_call in the message
        content = message.get("content", None)
        function_call = message.get("function_call", None)
        if (not content and not function_call) or not isinstance(content, str):
            format_errors["missing_content"] += 1
    
    # Ensure at least one message from the assistant is present
    if not any(message.get("role", None) == "assistant" for message in messages):
        format_errors["example_missing_assistant_message"] += 1

# Print the results of the error checks
if format_errors:
    print("Found errors:")
    for key, value in format_errors.items():
        print(f"{key}: {value}")
else:
    print("No errors found")


### Token Counting Utilities

A few helpful utilities to be used in the rest of the notebook.

In [None]:

# Automatically get the encoding for a specific model
encoding = tiktoken.encoding_for_model("gpt-4o")


def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
    """
    Calculate the number of tokens in a list of messages.
    
    Args:
        messages (list): List of message dictionaries.
        tokens_per_message (int): Base tokens per message.
        tokens_per_name (int): Additional tokens for the 'name' field.

    Returns:
        int: Total number of tokens.
    """
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.items():
            try:
                num_tokens += len(encoding.encode(str(value)))
            except Exception as e:
                print(f"Error encoding key: {key}, value: {value}, type: {type(value)}")
                print(f"Error message: {str(e)}")
                raise
            if key == "name":
                num_tokens += tokens_per_name
    num_tokens += 3  # Adding 3 tokens for end of sequence
    return num_tokens

def num_assistant_tokens_from_messages(messages):
    """
    Calculate the number of tokens in assistant messages.
    
    Args:
        messages (list): List of message dictionaries.

    Returns:
        int: Total number of tokens in assistant messages.
    """
    num_tokens = 0
    for message in messages:
        if message["role"] == "assistant":
            num_tokens += len(encoding.encode(str(message["content"])))  # Convert content to string
    return num_tokens

# Function to print the distribution of values
def print_distribution(values, name):
    """
    Print the distribution statistics of a list of values.
    
    Args:
        values (list): List of numerical values.
        name (str): Description of the values.
    """
    print(f"\n#### Distribution of {name}:")
    print(f"min / max: {min(values)} / {max(values)}")
    print(f"mean / median: {np.mean(values)} / {np.median(values)}")
    print(f"p5 / p95: {np.quantile(values, 0.05)} / {np.quantile(values, 0.95)}")


### Data Warnings and Token Counts

With some lightweight analysis we can identify potential issues in the dataset, like missing messages, and provide statistical insights into message and token counts.

1. **Missing System/User Messages**: Counts the number of conversations missing a "system" or "user" message. Such messages are critical for defining the assistant's behavior and initiating the conversation.

2. **Number of Messages Per Example**: Summarizes the distribution of the number of messages in each conversation, providing insight into dialogue complexity.

3. **Total Tokens Per Example**: Calculates and summarizes the distribution of the total number of tokens in each conversation. Important for understanding fine-tuning costs.

4. **Tokens in Assistant's Messages**: Calculates the number of tokens in the assistant's messages per conversation and summarizes this distribution. Useful for understanding the assistant's verbosity.

5. **Token Limit Warnings**: Checks if any examples exceed the maximum token limit (16,385 tokens), as such examples will be truncated during fine-tuning, potentially resulting in data loss.


In [None]:
# Warnings and tokens counts
n_missing_system = 0
n_missing_user = 0
n_messages = []
convo_lens = []
assistant_message_lens = []
for i, ex in enumerate(dataset):
    messages = ex["messages"]
    if not any(message["role"] == "system" for message in messages):
        n_missing_system += 1
    if not any(message["role"] == "user" for message in messages):
        n_missing_user += 1
    n_messages.append(len(messages))
    try:
        convo_lens.append(num_tokens_from_messages(messages))
        assistant_message_lens.append(num_assistant_tokens_from_messages(messages))
    except Exception as e:
        print(f"Error processing example {i}:")
        print(f"Messages: {messages}")
        print(f"Error: {str(e)}")
        raise

print("Num examples missing system message:", n_missing_system)
print("Num examples missing user message:", n_missing_user)
print_distribution(n_messages, "num_messages_per_example")
print_distribution(convo_lens, "num_total_tokens_per_example")
print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
n_too_long = sum(l > 16385 for l in convo_lens)
print(f"\n{n_too_long} examples may be over the 16,385 token limit, they will be truncated during fine-tuning")

### Cost Estimation

Finally, we estimate the total number of tokens that will be used for fine-tuning, which allows us to approximate the cost. It is worth noting that the duration of the fine-tuning jobs will also increase with the token count.

In [None]:
# Constants
MAX_TOKENS_PER_EXAMPLE = 16385
TARGET_EPOCHS = 3
MIN_TARGET_EXAMPLES = 100
MAX_TARGET_EXAMPLES = 25000
MIN_DEFAULT_EPOCHS = 1
MAX_DEFAULT_EPOCHS = 25

def calculate_epochs(n_train_examples):
    """Calculate the number of epochs based on the number of training examples."""
    if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
        return min(MAX_DEFAULT_EPOCHS, math.ceil(MIN_TARGET_EXAMPLES / n_train_examples))
    elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
        return max(MIN_DEFAULT_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)
    return TARGET_EPOCHS

def calculate_billing_tokens(convo_lens):
    """Calculate the number of billing tokens in the dataset."""
    return sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)

def print_dataset_statistics(n_train_examples, convo_lens):
    """Print the dataset statistics and billing information."""
    n_epochs = calculate_epochs(n_train_examples)
    n_billing_tokens = calculate_billing_tokens(convo_lens)
    
    print(f"Dataset Statistics:")
    print(f"- Number of training examples: {n_train_examples}")
    print(f"- Approximate billable tokens: {n_billing_tokens}")
    print(f"- Default number of epochs: {n_epochs}")
    print(f"- Estimated total billable tokens: {n_epochs * n_billing_tokens}")

# Print the dataset statistics
n_train_examples = len(dataset)
print_dataset_statistics(n_train_examples, convo_lens)

## Train / Test Split

In [47]:
# Train / Test Split Functions for JSONL Files
def split_jsonl_file(file_path, train_ratio=0.8):
    # Read the input file
    file_path = Path(file_path)
    with file_path.open('r', encoding='utf-8') as f:
        data = [json.loads(line) for line in f]
    
    # Shuffle the data
    random.shuffle(data)
    
    # Calculate split index
    split_index = int(len(data) * train_ratio)
    
    # Split the data
    train_data = data[:split_index]
    test_data = data[split_index:]
    
    # Prepare output file paths
    train_file = file_path.with_name(f"{file_path.stem}_train{file_path.suffix}")
    test_file = file_path.with_name(f"{file_path.stem}_test{file_path.suffix}")
    
    # Write train data
    with train_file.open('w', encoding='utf-8') as f:
        for item in train_data:
            json.dump(item, f)
            f.write('\n')
    
    # Write test data
    with test_file.open('w', encoding='utf-8') as f:
        for item in test_data:
            json.dump(item, f)
            f.write('\n')
    
    print(f"Train data saved to: {train_file}")
    print(f"Test data saved to: {test_file}")
    print(f"Train set size: {len(train_data)}")
    print(f"Test set size: {len(test_data)}")

# Usage
file_path = "./artifacts/marv_fine_tune.jsonl"
split_jsonl_file(file_path)

Train data saved to: artifacts\marv_fine_tune_train.jsonl
Test data saved to: artifacts\marv_fine_tune_test.jsonl
Train set size: 80
Test set size: 20


## Creating a Fine-Tuning Job

### Upload File

In [44]:
# Upload the training data to the OpenAI API
fine_tune_file = client.files.create(
            file=open("./artifacts/marv_fine_tune.jsonl", "rb"),
            purpose="fine-tune"
            )

### Create a Simple Fine-Tuning Job

In [46]:
# Create a fine-tuning job using the uploaded training data
client.fine_tuning.jobs.create(
    training_file=fine_tune_file.id, 
    model="gpt-4o-mini-2024-07-18"
)

FineTuningJob(id='ftjob-OHgkMsoSksrEt0cV04mef81l', created_at=1722466982, error=Error(code=None, message=None, param=None), fine_tuned_model=None, finished_at=None, hyperparameters=Hyperparameters(n_epochs='auto', batch_size='auto', learning_rate_multiplier='auto'), model='gpt-4o-mini-2024-07-18', object='fine_tuning.job', organization_id='org-SQH2HT1IvRszon9pdYwV1yvQ', result_files=[], seed=146873275, status='validating_files', trained_tokens=None, training_file='file-UMFTrHO0qLBbLoHUJpqr5jYq', validation_file=None, estimated_finish=None, integrations=[], user_provided_suffix=None)

### Create a Fine-Tune Job with All Default Parameters

In [50]:
# Create a fine-tuning job using the uploaded training data
my_ft_job = client.fine_tuning.jobs.create(
        model="gpt-4o-mini-2024-07-18",
        training_file=fine_tune_file.id, 
        validation_file=None,
        hyperparameters={
            "batch_size": "auto",
            "learning_rate_multiplier": "auto",
            "n_epochs": "auto",
        },
        suffix=None,
        integrations=None,
        seed=None,
    )

## Exploring Fine-Tuning Jobs

### Listing Jobs