In [40]:
!pip install datasets==2.20.0 fireworks-ai

In [132]:
from collections import Counter
import json
import requests
import os
import time

from datasets import load_dataset, load_from_disk
import fireworks.client as fc
from fireworks.client import Fireworks

In [3]:
# Update the line below to your account id
account_id = 'sdkramer10-5e98cb'

# Uncomment the line below and set the value to your account's API key. Alternatively, set the FIREWORKS_API_KEY environment variable.
# fc.api_key = '<API KEY>'

client = Fireworks()

In [133]:
# This is a datset containing human preferences from the chatbot arena. When a human types a message, they are sent responses from two
# different chatbots. The human then votes on which response they prefer. Throughout this course, I am going to fine-tune a model to predict
# human chatbot preferences. For this week's assignment, I will perform prompt engineering to get a baseline of how well llama3-8b-instruct
# performs at this task before performing fine-tuning
# For more details on the dataset: https://huggingface.co/datasets/lmsys/chatbot_arena_conversations
dataset = load_from_disk("./chatbot_arena_dataset/")['train'].shuffle(seed=42)

In [134]:
# For simplicity, I am only going to look at single-turn chats, where the user declared a winner after a single response from the bot.
examples = [example for example in dataset if example['turn'] == 1]

# The query the user sent to both bots should be exactly the same, so that we are fairly judging the responses. This should be always be
# the case for this dataset. This line just acts as a sanity check.
examples = [example for example in examples if example['conversation_a'][0]['content'] == example['conversation_b'][0]['content']]

# We take different examples for the train/validation/test sets
training_examples = examples[:2000]
validation_examples = examples[-1000:]
test_examples = examples[-2000:-1000]

In [16]:
sys_msg = f'''Choose the better chatbot response between model_a and model_b.

Your response MUST be ONLY the JSON object {{"winner": XXX}}. XXX can only equal "model_a", "model_b", "tie", or "tie (bothbad)".'''

# Converts the training examples to the format expected by Fireworks. See https://readme.fireworks.ai/docs/fine-tuning-models#conversation
def get_user_msg(example):
    user_query = example['conversation_a'][0]['content']
    model_a_response = example['conversation_a'][1]['content']
    model_b_response = example['conversation_b'][1]['content']
    user_msg = f"""user query: {user_query}

model_a response: {model_a_response}

model_b response: {model_b_response}"""
    return user_msg

def create_messages(example):
    user_msg = get_user_msg(example)
    asst_msg = json.dumps({"winner": example['winner']})

    return {"messages": [
        {"role": "system", "content": sys_msg}, 
        {"role": "user", "content": user_msg}, 
        {"role": "assistant", "content": asst_msg}
    ]}

def training_examples_to_json(examples):
    json_objs = list()
    for example in examples:  
        msg = create_messages(example)
        json_objs.append(msg)
    
    return json_objs

# Writes the data to a file so that it can be uploaded to Fireworks
def dataset_to_jsonl(examples, dataset_file_name):

    training_json = training_examples_to_json(training_examples)
    
    with open(dataset_file_name, 'w') as f:
        for obj in training_json:
            json.dump(obj, f)
            f.write('\n')

In [17]:
dataset_file_name = 'chatbot_arena_training_data.jsonl'
dataset_id = 'chatbot-arena-td-v1'
dataset_to_jsonl(training_examples, dataset_file_name)

In [18]:
# Follow instructions here to first install the firectil CLI - https://readme.fireworks.ai/docs/fine-tuning-models#installing-firectl
# Then run this command to upload the file to Fireworks
!firectl create dataset {dataset_id} {dataset_file_name}

4.27 MiB / 4.27 MiB [--------------------------------] 100.00% 1.94 MiB p/s 2.4s


In [41]:
# Creates a training job with the increased rank, learning rate, and epochs
# Uncomment out to run (prints my api key to stdout, so commenting it out for the demo).
!firectl create fine-tuning-job --settings-file chatbot_arena_training_v1.yaml --display-name {dataset_id} --dataset {dataset_id} 

In [47]:
model_v1_id = 'db6e8954716049e9aae361d59384dd7f'

In [51]:
# Wait until the fine-tuning jobs have finished running.
# Uncomment out to run (prints my api key to stdout, so commenting it out for the demo).
!firectl get fine-tuning-job {model_v1_id}

In [53]:
# Deploy the first model to a serverless endpoint
!firectl deploy {model_v1_id}

In [135]:
!firectl get model {model_v1_id}

In [59]:
def get_results(examples, model_id):
    winners = list()
    
    for i, example in enumerate(examples):    
        user_msg = get_user_msg(example)

        response = client.chat.completions.create(
            model=model_id,
            messages=[
                {"role": "system", "content": sys_msg},
                {"role": "user", "content": user_msg},
            ],
            # setting temperature to 0 for this use case, so that responses are as deterministic as possible
            temperature=0, 
        )
        content = response.choices[0].message.content.replace("<|start_header_id|>", "")
    
        try:
            winner = json.loads(content.split('\n')[-1])["winner"]
            winners.append((i, winner))
        except:
            print(f"Failed to parse JSON for example {i}.")

    num_correct = sum([1 if winner[1] == examples[winner[0]]['winner'] else 0 for winner in winners])
    return winners, num_correct

In [60]:
# Determine how the fine-tuned model performs with the default fine-tuning params
model_name = f'accounts/{account_id}/models/{model_v1_id}'

validation_results, validation_num_correct = get_results(validation_examples, model_name)
print(f'Validation Set Correct: {validation_num_correct}')

Validation Set Correct: 587


In [136]:
counts = Counter([m[1] for m in validation_results])
print(f'Predicted Counts: {counts.most_common()}')

counts = Counter([v['winner'] for v in validation_examples])
print(f'Actual Counts: {counts.most_common()}')

Predicted Counts: [('model_a', 431), ('model_b', 423), ('tie (bothbad)', 97), ('tie', 48), ('data', 1)]
Actual Counts: [('model_b', 355), ('model_a', 351), ('tie (bothbad)', 186), ('tie', 108)]


In [63]:
counts = Counter([v['language'] for v in training_examples])
counts.most_common()

[('English', 1786),
 ('German', 46),
 ('Spanish', 33),
 ('French', 32),
 ('Russian', 18),
 ('Portuguese', 13),
 ('Italian', 11),
 ('Dutch', 6),
 ('Polish', 5),
 ('Chinese', 5),
 ('Danish', 4),
 ('unknown', 4),
 ('Japanese', 4),
 ('Finnish', 3),
 ('Indonesian', 3),
 ('Vietnamese', 3),
 ('Hungarian', 3),
 ('Korean', 3),
 ('Latin', 2),
 ('Malay', 2),
 ('Ukrainian', 2),
 ('Slovenian', 2),
 ('Galician', 2),
 ('Scots', 1),
 ('Malagasy', 1),
 ('Persian', 1),
 ('Greek', 1),
 ('Hebrew', 1),
 ('Swedish', 1),
 ('Interlingua', 1),
 ('Bangla', 1)]

In [88]:
def calculate_pct_correct(examples, results):
    num_correct = sum([1 if winner[1] == examples[winner[0]]['winner'] else 0 for winner in results])
    pct_correct = round(100 * num_correct / len(results), 2)
    print(f'Num Examples: {len(results)}')
    print(f'Num Correct: {num_correct}')
    print(f'Pct Correct: {pct_correct}')

In [89]:
print('English')
english_results = [res for res in validation_results if validation_examples[res[0]]['language'] == 'English']
calculate_pct_correct(validation_examples, english_results)

print('\nNon-English')
non_english_results = [res for res in validation_results if validation_examples[res[0]]['language'] != 'English']
calculate_pct_correct(validation_examples, non_english_results)

English
Num Examples: 897
Num Correct: 525
Pct Correct: 58.53

Non-English
Num Examples: 103
Num Correct: 62
Pct Correct: 60.19


In [67]:
for i, winner in validation_results:
    example = validation_examples[i]
    if example['winner'] == winner or example['winner'] != 'tie (bothbad)':
        continue

    print(f"Example {i}: Predicted Winner: {winner}, Actual Winner: {example['winner']}\n")
    user_message = example['conversation_a'][0]
    print(f"User: {user_message['content']}\n")
    print(f"Model A: {example['conversation_a'][1]['content']}\n")
    print(f"Model B: {example['conversation_b'][1]['content']}\n")
    print("=" * 120)

In [29]:
def is_in_question_category(examples, sys_msg, model_id):
    responses = list()
    
    for i, example in enumerate(examples):    
        user_query = example['conversation_a'][0]['content']
    
        response = client.chat.completions.create(
            model=model_id,
            messages=[
                {"role": "system", "content": sys_msg},
                {"role": "user", "content": user_query},
            ],
            # setting temperature to 0 for this use case, so that responses are as deterministic as possible
            temperature=0, 
        )
        content = response.choices[0].message.content
        content = content.strip().strip('"\'').lower()
        is_in_category = (content == "true")
        responses.append(is_in_category)
    
    return responses


math_sys_msg = f'''Determine whether the user is asking a math question.

Your response MUST be ONLY the value "true" or "false" and NOTHING ELSE!'''

In [49]:
math_questions = is_in_question_category(validation_examples, math_sys_msg, 'accounts/fireworks/models/llama-v3-70b-instruct')

In [92]:
print('Math')
math_results = [res for res in validation_results if math_questions[res[0]]]
calculate_pct_correct(validation_examples, math_results)

print('\nNot Math')
non_math_coding_results = [res for res in validation_results if not math_questions[res[0]]]
calculate_pct_correct(validation_examples, non_math_coding_results)

Math
Num Examples: 174
Num Correct: 99
Pct Correct: 56.9

Not Math
Num Examples: 826
Num Correct: 488
Pct Correct: 59.08


In [None]:
# Optiona: undeploy the model (shouldn't cost anything extra to leave deployed).
# !firectl undeploy {model_v1_id}

In [30]:
orig_num_training_examples = 2000
addtl_math_questions = is_in_question_category(
    examples[orig_num_training_examples:orig_num_training_examples+2000],
    math_sys_msg, 
    'accounts/fireworks/models/llama-v3-70b-instruct'
)

In [101]:
addtl_math_questions_v2 = is_in_question_category(
    examples[2000:6000],
    math_sys_msg, 
    'accounts/fireworks/models/llama-v3-70b-instruct'
)

In [31]:
training_examples = examples[:orig_num_training_examples]
for i, is_math in enumerate(addtl_math_questions):
    if is_math:
        training_examples.append(examples[orig_num_training_examples+i])

In [107]:
dataset_file_name = 'chatbot_arena_training_data_more_math.jsonl'
dataset_id = 'chatbot-arena-td-more-math-v2'
dataset_to_jsonl(training_examples, dataset_file_name)

In [108]:
!firectl create dataset {dataset_id} {dataset_file_name}

5.20 MiB / 5.20 MiB [--------------------------------] 100.00% 1.94 MiB p/s 2.9s


In [126]:
!firectl create fine-tuning-job --settings-file chatbot_arena_training_v1.yaml --display-name {dataset_id} --dataset {dataset_id} 

In [110]:
model_v2_id = '39b93365bd66462792d432b8064915c4'

In [121]:
!firectl get fine-tuning-job {model_v2_id}

In [120]:
!firectl deploy {model_v2_id}

In [127]:
!firectl get model {model_v2_id}

In [123]:
model_name = f'accounts/{account_id}/models/{model_v2_id}'
validation_results_v2, validation_num_correct_v2 = get_results(validation_examples, model_name)
print(f'Validation Set Correct: {validation_num_correct_v2}')

Validation Set Correct: 597


In [125]:
print('Math')
math_results = [res for res in validation_results_v2 if math_questions[res[0]]]
calculate_pct_correct(validation_examples, math_results)

print('\nNot Math')
non_math_coding_results = [res for res in validation_results_v2 if not math_questions[res[0]]]
calculate_pct_correct(validation_examples, non_math_coding_results)

Math
Num Examples: 174
Num Correct: 100
Pct Correct: 57.47

Not Math
Num Examples: 826
Num Correct: 497
Pct Correct: 60.17


In [97]:
# Optiona: undeploy the model (shouldn't cost anything extra to leave deployed).
# !firectl undeploy {model_v2_id}