In [1]:
import json

from datasets import load_dataset
from fireworks.client import Fireworks
from pydantic import BaseModel, Field

In [2]:
# Make sure you have the FIREWORKS_API_KEY set to your account's key!
client = Fireworks()

In [3]:
# This is a datset containing human preferences from the chatbot arena. When a human types a message, they are sent responses from two
# different chatbots. The human then votes on which response they prefer. Throughout this course, I am going to fine-tune a model to predict
# human chatbot preferences. For this week's assignment, I will perform prompt engineering to get a baseline of how well llama3-8b-instruct
# performs at this task before performing fine-tuning
# For more details on the dataset: https://huggingface.co/datasets/lmsys/chatbot_arena_conversations
dataset = load_dataset("lmsys/chatbot_arena_conversations")['train']

In [4]:
# For simplicity, I am only going to look at single-turn chats, where the user declared a winner after a single response from the bot.
examples = [example for example in dataset if example['turn'] == 1]

# The query the user sent to both bots should be exactly the same, so that we are fairly judging the responses. This should be always be
# the case for this dataset. This line just acts as a sanity check.
examples = [example for example in examples if example['conversation_a'][0]['content'] == example['conversation_b'][0]['content']]

# I am using the first 1000 responses to get an estimate of how well the baseline model performs. Using the entire dataset will take longer
# and cost more, and is not needed.
# Note that you DO NOT need 1,000 examples for your dataset for this assignment (but you can if you want). Just use at least 10 examples 
# so that you can get an idea of how the base model performs.
training_examples = examples[:1000]

# To improve the prompt, I am going to use few-shot learning. We want the examples used in few-shot learning to be different than the
# examples we used to evaluate the model quality
few_shot_learning_examples = examples[1000:1003]

In [9]:
# My initial prompt is very simple, and does not include any prompt engineering techniques
initial_sys_prompt = f'''Choose the better chatbot response between model_a and model_b.

Your response MUST be ONLY the JSON object {{"winner": XXX}}. XXX can only equal "model_a", "model_b", or "tie".'''

In [11]:
winners = list()

for i, example in enumerate(training_examples):
    if i % 100 == 0:
        print(f"Finished generating {i} examples")
    
    user_query = example['conversation_a'][0]['content']
    model_a_response = example['conversation_a'][1]['content']
    model_b_response = example['conversation_b'][1]['content']

    user_prompt = f"""user query: {user_query}
    
model_a response: {model_a_response}

model_b response: {model_b_response}"""

    response = client.chat.completions.create(
        model="accounts/fireworks/models/llama-v3-8b-instruct",
        messages=[
            {
                "role": "system",
                "content": initial_sys_prompt
            },
            {
                "role": "user",
                "content": user_prompt
            }
        ],
        # setting temperature to 0 for this use case, so that responses are as deterministic as possible
        temperature=0, 
    )
    content = response.choices[0].message.content

    try:
        winner = json.loads(content.split('\n')[-1])["winner"]
        winners.append((i, winner))
    except json.JSONDecodeError as jde:
        print(f"Failed to parse JSON for example {i}.")

Finished generating 0 examples
Finished generating 100 examples
Finished generating 200 examples
Finished generating 300 examples
Finished generating 400 examples
Finished generating 500 examples
Finished generating 600 examples
Finished generating 700 examples
Finished generating 800 examples
Failed to parse JSON for example 801.
Failed to parse JSON for example 804.
Finished generating 900 examples


In [30]:
num_correct = sum([training_examples[winner[0]]['winner'] == winner[1] for winner in winners])
print(f"Initial Prompt Accuracy: {100 * num_correct / len(winners)}%")

Initial Prompt Accuracy: 58.41683366733467%


In [31]:
# I first perform prompt engineering improving clarity and specificity of the system prompt.
improved_sys_prompt = f'''You are evaluating chatbot responses.

In the conversation below, a human sent a query to two different chatbots. The human then voted for whether they preferred the response from model_a or model_b. Your task is to predict which response the human preferred.

Choose the better chatbot response between model_a and model_b.

Your response MUST be ONLY the JSON object {{"winner": XXX}}. XXX can only equal "model_a", "model_b", or "tie".'''

In [42]:
# I will also use the "few-shot learning" prompt engineering method. This method involves including a few examples in the prompt, so that
# the model can start to get an idea of how it should respond.
def create_example_message(example):

    example_user_msg = f"""user query: {example['conversation_a'][0]['content']}
        
model_a response: {example['conversation_a'][1]['content']}

model_b response: {example['conversation_b'][1]['content']}"""
    
    example_asst_msg = json.dumps({"winner": example['winner']})

    return (example_user_msg, example_asst_msg)

example_0 = create_example_message(few_shot_learning_examples[0])
example_1 = create_example_message(few_shot_learning_examples[1])
example_2 = create_example_message(few_shot_learning_examples[2])

In [46]:
winners = list()

for i, example in enumerate(training_examples):
    if i % 100 == 0:
        print(f"Finished generating {i} examples")
    
    user_query = example['conversation_a'][0]['content']
    model_a_response = example['conversation_a'][1]['content']
    model_b_response = example['conversation_b'][1]['content']

    user_prompt = f"""user query: {user_query}
    
model_a response: {model_a_response}

model_b response: {model_b_response}"""

    response = client.chat.completions.create(
        model="accounts/fireworks/models/llama-v3-8b-instruct",
        messages=[
            # few-shot learning
            {"role": "system", "content": improved_sys_prompt},
            {"role": "user", "content": example_0[0]},
            {"role": "assistant", "content": example_0[1]},
            {"role": "user", "content": example_1[0]},
            {"role": "assistant", "content": example_1[1]},
            {"role": "user", "content": example_2[0]},
            {"role": "assistant", "content": example_2[1]},
            {"role": "user", "content": user_prompt}
        ],
        # setting temperature to 0 for this use case, so that responses are as deterministic as possible
        temperature=0, 
    )
    content = response.choices[0].message.content

    try:
        winner = json.loads(content.split('\n')[-1])["winner"]
        winners.append((i, winner))
    except:
        print(f"Failed to parse JSON for example {i}.")

Finished generating 0 examples
Finished generating 100 examples
Failed to parse JSON for example 129.
Failed to parse JSON for example 137.
Failed to parse JSON for example 198.
Finished generating 200 examples
Failed to parse JSON for example 260.
Finished generating 300 examples
Finished generating 400 examples
Failed to parse JSON for example 465.
Failed to parse JSON for example 467.
Finished generating 500 examples
Failed to parse JSON for example 504.
Finished generating 600 examples
Failed to parse JSON for example 648.
Finished generating 700 examples
Finished generating 800 examples
Failed to parse JSON for example 844.
Finished generating 900 examples
Failed to parse JSON for example 922.
Failed to parse JSON for example 956.


In [47]:
# I was not able to improve the accuracy through prompt engineering. This shows that prompt engineering is unsufficient for this task,
# and I will need to perform fine-tuning to get better results. Next week, I will fine-tune my first model for this task to get
# better results
num_correct = sum([training_examples[winner[0]]['winner'] == winner[1] for winner in winners])
print(f"Improved Prompt Accuracy: {100 * num_correct / len(winners)}%")

Improved Prompt Accuracy: 52.881698685540954%
