# Fine tune GPT-4o-mini for classification


In [1]:
!pip install openai --quiet

In [2]:
SYSTEM_PROMPT = """You are a review classifier. Given a review, you need to assign it a number of stars. 1 is the lowest rating, and 5 is the highest rating. Assign the review to the correct number of stars. The amount of stars you can pick from is:\n - 1\n - 2\n - 3\n - 4\n - 5\n\n
Do not try to answer the question."""

In [3]:
# read train.csv and loop over the rows and take the question column, and the class_name column
# and create a list of objects that look like this:
# {
#         "messages": [
#             {"role": "system", "content":SYSTEM_PROMPT},
#             {"role": "user", "content": << column question>>},
#             {"role": "assistant",
#              "content": "<< column class_name >>"}
#         ]
# }
import pandas as pd
df_train = pd.read_csv('train.csv')
df_val = pd.read_csv('val.csv')

def convert_to_messages(df):
    data = []
    for i, row in df.iterrows():
        data.append({
            "messages": [
                {"role": "system", "content":SYSTEM_PROMPT},
                {"role": "user", "content": str(row['question'])},
                {"role": "assistant",
                 "content": str(row['class_name'])}
            ]
        })
    return data

data_train = convert_to_messages(df_train)
data_val = convert_to_messages(df_val)
    

In [4]:
import json

def to_openai_format(data, file_path):
    with open(
            file_path, "w", encoding="utf-8"
    ) as file:
        for record in data:
            json_line = json.dumps(record)
            file.write(json_line + "\n")

TRAIN_OPENAI_FILE = "train-openai.jsonl"
VAL_OPENAI_FILE = "val-openai.jsonl"
to_openai_format(data=data_train, file_path=TRAIN_OPENAI_FILE)
to_openai_format(data=data_val, file_path=VAL_OPENAI_FILE)


In [5]:

from collections import defaultdict
from pathlib import Path


def validate_dataset_format(dataset_path: Path):
    # todo - check if last message is assistant message
    print(f"Validating dataset format for {dataset_path}")
    # Format error checks
    format_errors = defaultdict(int)

    # Load the dataset
    with open(dataset_path, 'r', encoding='utf-8') as f:
        dataset = [json.loads(line) for line in f]

    for ex in dataset:
        if not isinstance(ex, dict):
            format_errors["data_type"] += 1
            continue

        messages = ex.get("messages", None)
        if not messages:
            format_errors["missing_messages_list"] += 1
            continue

        for message in messages:
            if "role" not in message or "content" not in message:
                format_errors["message_missing_key"] += 1

            if any(k not in ("role", "content", "name", "function_call", "weight") for k in message):
                format_errors["message_unrecognized_key"] += 1

            if message.get("role", None) not in ("system", "user", "assistant", "function"):
                format_errors["unrecognized_role"] += 1

            content = message.get("content", None)
            function_call = message.get("function_call", None)

            if (not content and not function_call) or not isinstance(content, str):
                format_errors["missing_content"] += 1

        if not any(message.get("role", None) == "assistant" for message in messages):
            format_errors["example_missing_assistant_message"] += 1

        assert messages[-1].get("role") == "assistant", f"{messages} \n Last message should be assistant message"

    if format_errors:
        print("Found errors:")
        for k, v in format_errors.items():
            print(f"{k}: {v}")
    else:
        print("No errors found")

validate_dataset_format(Path(TRAIN_OPENAI_FILE))
validate_dataset_format(Path(VAL_OPENAI_FILE))

Validating dataset format for train-openai.jsonl
No errors found
Validating dataset format for val-openai.jsonl
No errors found


In [18]:
import openai

client = openai.OpenAI()


In [23]:

def create_file(file_path):
    response = client.files.create(
        file=open(file_path, "rb"),
        purpose="fine-tune"
    )
    return response.id

train_id = create_file(file_path=TRAIN_OPENAI_FILE)
val_id = create_file(file_path=VAL_OPENAI_FILE)

response = client.fine_tuning.jobs.create(
    training_file=train_id,
    validation_file=val_id,
    model="gpt-4o-mini-2024-07-18"
)
response


FineTuningJob(id='ftjob-qSU0sUSPt4GQ55w5DYdGseg6', created_at=1726908084, error=Error(code=None, message=None, param=None), fine_tuned_model=None, finished_at=None, hyperparameters=Hyperparameters(n_epochs='auto', batch_size='auto', learning_rate_multiplier='auto'), model='gpt-4o-mini-2024-07-18', object='fine_tuning.job', organization_id='org-bBHdSNHiFHSbkdpwXvV7ClOQ', result_files=[], seed=1884865724, status='validating_files', trained_tokens=None, training_file='file-xx3hiLBWSsPiP2MvMtLCkUmn', validation_file='file-KnUIViROvxQM3t1urZsiRV4j', estimated_finish=None, integrations=[], user_provided_suffix=None)

In [60]:

MODEL_ID = "ft:gpt-4o-mini-2024-07-18:drift::A9qEalcA"

# Function to make a request to your fine-tuned model
def call_finetuned_model(prompt):
    try:
        # Call the completion endpoint with your fine-tuned model
        response = client.chat.completions.create(
            model=MODEL_ID,  # Use the fine-tuned model ID
            messages=[    
                {"role":"system", "content": SYSTEM_PROMPT},
                {"role":"user", "content": prompt}
            ],
            logprobs=True,
            top_logprobs=4,
            temperature=0
            
        )
        # Extract and return the generated text
        # return response.choices[0].message.content.strip()
        return response.choices[0]
    
    except Exception as e:
        return f"Error: {str(e)}"

In [99]:
prompt = """Oh. My. God. \n\nCan this place really be real? The food here is absolutely INCREDIBLE! So delicious with EXTRA LARGE portions...like, share with a friend. Seriously. We ordered Andy's Sage Fried Chicken Benedict and the Chicken and Waffles and both dishes were to die for. Bacon strips placed inside of waffles?! How can you not be ok with that?? Bes thing ever. No complaints just super full bellies and an awkward waddle out the door. We had a 45 minute wait to be seated for a party of 5 but we just headed to O'Shea's downstairs to pass the time. Went by pretty quick and definitely worth the wait. The place is open 24 hours which is also a cool touch as it makes for great drunk and/or hungover food. Or just good food in general. Oh and don't forget to try the BLT Bloody Mary!!"""

response=call_finetuned_model(prompt)
response.message.content.strip(), response.logprobs.content[0].logprob

('5', -3.1281633e-07)

In [61]:
# for each row in the test.csv file, take the question column and call the function call_finetuned_model
# with the question as input and store the output in a new column called predicted_class_name
df_test = pd.read_csv('test.csv')
df_test['prediction'] = df_test['question'].apply(call_finetuned_model)
df_test

Unnamed: 0,class_name,question,id,prediction
0,5,Oh. My. God. \n\nCan this place really be real...,427563,"Choice(finish_reason='stop', index=0, logprobs..."
1,3,The Golden Dragon has an east to find location...,353823,"Choice(finish_reason='stop', index=0, logprobs..."
2,1,It was my first time boarding my dogs of 6yrs....,306553,"Choice(finish_reason='stop', index=0, logprobs..."
3,4,"Since my daughters favorite food is sushi, thi...",349159,"Choice(finish_reason='stop', index=0, logprobs..."
4,2,I came in here on a Saturday morning a little ...,438984,"Choice(finish_reason='stop', index=0, logprobs..."
...,...,...,...,...
520,5,Just reviewing this club is distracting me eno...,365676,"Choice(finish_reason='stop', index=0, logprobs..."
521,2,i like their chicken biryani. That's the only ...,303420,"Choice(finish_reason='stop', index=0, logprobs..."
522,5,"If you want authentic Greek food, this place i...",439190,"Choice(finish_reason='stop', index=0, logprobs..."
523,5,"I love, love, love how the California casino/h...",120345,"Choice(finish_reason='stop', index=0, logprobs..."


In [100]:
df_test['predicted_class_name'] = df_test['prediction'].apply(lambda x: x.message.content.strip())
df_test['predicted_class_logprob'] = df_test['prediction'].apply(lambda x: x.logprobs.content[0].logprob)
df_test

Unnamed: 0,class_name,question,id,prediction,predicted_class_name,predicted_class_logprob
0,5,Oh. My. God. \n\nCan this place really be real...,427563,"Choice(finish_reason='stop', index=0, logprobs...",5,-4.320200e-07
1,3,The Golden Dragon has an east to find location...,353823,"Choice(finish_reason='stop', index=0, logprobs...",3,-2.682780e-02
2,1,It was my first time boarding my dogs of 6yrs....,306553,"Choice(finish_reason='stop', index=0, logprobs...",1,-7.609112e-04
3,4,"Since my daughters favorite food is sushi, thi...",349159,"Choice(finish_reason='stop', index=0, logprobs...",4,-5.231245e-01
4,2,I came in here on a Saturday morning a little ...,438984,"Choice(finish_reason='stop', index=0, logprobs...",1,-4.277308e-02
...,...,...,...,...,...,...
520,5,Just reviewing this club is distracting me eno...,365676,"Choice(finish_reason='stop', index=0, logprobs...",5,-3.888926e-06
521,2,i like their chicken biryani. That's the only ...,303420,"Choice(finish_reason='stop', index=0, logprobs...",2,-2.287880e-03
522,5,"If you want authentic Greek food, this place i...",439190,"Choice(finish_reason='stop', index=0, logprobs...",5,-3.650519e-06
523,5,"I love, love, love how the California casino/h...",120345,"Choice(finish_reason='stop', index=0, logprobs...",5,-3.362966e-04


In [101]:
df_test["predicted_class_name"] = pd.to_numeric(df_test["predicted_class_name"], errors='coerce') 
df_test = df_test.dropna(subset=['predicted_class_name']) 
df_test['predicted_class_name'] = df_test['predicted_class_name'].astype(int) 


In [102]:
df_test["predicted_class_name"].unique()

array([5, 3, 1, 4, 2])

In [107]:
from datasets import load_metric

metric = load_metric("accuracy")
accuracy = metric.compute(predictions=df_test['predicted_class_name'], references=df_test['class_name'])

print(f"Accuracy: {accuracy['accuracy']}")

Accuracy: 0.6780952380952381
