# Automatic Prompt Engineering for classification

Given only (text -> label), this notebook generates and optimizes system and user prompts.

This is how classification is intended to be done.
- (system prompt, user prompt prefix + text + user prompt suffix) -Haiku-> bot response -extract-> label

The notebook will produce
- the system prompt
- the user prompt prefix
- the user prompt suffix

You can simply run this notebook with just
- an Anthropic API key and an OpenAI API key

If you want to change the classification task, you will need to
- provide a dataset (text -> label)

This is how prompt tuning is done
- Sample from the full dataset.
- Haiku takes in (system prompt, user prompt prefix + text + user prompt suffix) and produces model_response
- Extract the label from the model_response.
- Sample from the mistakes and the correct results.
- o1-mini summarizes the mistakes and update the prompts (model parameters).
- Repeat.

You will need to have these Python modules installed
- pandas
- scikit-learn
- anthropic
- openai

In [1]:
import os
import re
import random
import inspect
import textwrap
import collections
import itertools
import concurrent.futures
import pandas as pd
import html
from collections import defaultdict

from IPython.display import display, HTML
from sklearn.metrics import precision_score, recall_score

import anthropic
from openai import OpenAI



# Use your Anthropic API key and OpenAI API key here

In [2]:
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
anthropic_client = anthropic.Anthropic(api_key=anthropic_api_key)

openai_api_key = os.environ.get("OPENAI_API_KEY")
openai_client = OpenAI(api_key=openai_api_key)

In [3]:
print(anthropic_client.api_key[:16])

sk-ant-api03-Mbs


In [4]:
print(openai_client.api_key[:16])

sk-T0uvUYsHoQOag


In [5]:
NUM_PARALLEL_FORWARD_PASS_API_CALLS = 100  # see https://docs.anthropic.com/claude/reference/rate-limits
NUM_SAMPLES_FORWARD_PASS_FOR_EACH_LABEL = 100
NUM_SAMPLES_MISTAKE_GRADIENT_CALCULATION_FOR_EACH_LABEL = 20
NUM_SAMPLES_CORRECT_GRADIENT_CALCULATION_FOR_EACH_LABEL = 10
NUM_ITERATIONS = 5

# Define the dataset here

The outcome of this section should be the variable `dataset: list[tuple[str, str]]`

In [6]:
tsv = """
What are some examples of sorting algorithms that require more conditional statements?	TRUE
Can you elaborate on how radix sort and counting sort functions work?	TRUE
What are some specialized sorting algorithms?	TRUE
Could you explain why sorting is equivalent to discovering the permutation of items?	TRUE
What are some real-world applications of sorting algorithms?	TRUE
Do you think some billionaires intentionally underreport their wealth?	TRUE
Could the media's focus on billionaires be detrimental in any way?	TRUE
Do you think that billionaire wealth lists provide any meaningful insights into the global economy?	TRUE
What are some activities that Japanese students could do to be more proactive in their learning?	TRUE
Is the trend of students becoming more "obedient and good kids" a positive or negative development for Japan?	FALSE
What are the implications of Japanese students becoming less likely to read for pleasure?	TRUE
What are some ways that teachers can help students transition from being students to graduates?	TRUE
How might this "good kid" syndrome affect Japanese society in the long term?	FALSE
Do you think DeSantis would have been a good president?	TRUE
Do you think the sympathy wave would have been enough to secure a win for DeSantis?	FALSE
What are some other factors that you think would have impacted the election outcome?	FALSE
Why do you think it's useful to think of dying as a process rather than an event?	TRUE
What other human cell types have a low metabolic rate?	FALSE
How long does it take for muscle and skin cells to die after circulation stops?	TRUE
What are some other factors that forensic scientists use to estimate time of death?	FALSE
What is the most surprising thing about how the human body decomposes?	TRUE
What are some examples of how apathy can affect a project?	TRUE
What is the trade-off between shipping updates quickly and releasing quality software?	TRUE
Is this apathy issue specific to software development, or is it a broader problem?	FALSE
What are the average costs associated with burying someone in a standard casket compared to an extra-large casket?	TRUE
Do cultural or religious beliefs influence the choice of casket size or burial practices?	TRUE
Are extra charges associated with cremation of obese individuals justified?	TRUE
Do you think the price of clothing for larger people is unfairly inflated?	TRUE
How does this perspective hold up when looking at specific Fighting-type Pokémon that are not typically heroic?	FALSE
Does the Fairy type fill the role of a "Light" type?	FALSE
What are some other positive qualities besides fighting and righteousness that the Fighting type embodies?	FALSE
Why do you think it's important to condemn violence in politics?	TRUE
What do you mean by "Trumpism would turn from a political movement to a religion"?	FALSE
Why do you hope the perpetrator was not a Democrat?	FALSE
How does obtaining guns easily make it easier to take a shot at a president?	TRUE
What are the most important factors when judging the success of a prime minister?	TRUE
What was it like living through the time of Harold Wilson?	TRUE
Did your dad ever explain why he acted that way?	FALSE
Do you think this experience changed your relationship with your father?	FALSE
How do you feel about your father now?	FALSE
Is it common for fathers to act this way?	FALSE
What's the most thoughtful thing your dad ever did for you?	TRUE
What are some benefits you received during your deployment?	TRUE
What were some of the challenges of living on a remote combat outpost in Afghanistan?	TRUE
How did you manage your money during your deployment?	TRUE
Do other cultures use stars and constellations to represent concepts like the afterlife?	FALSE
What are some names for Canis Minor in Sanskrit?	TRUE
Are there other examples of using celestial figures as a way to navigate or remember things?	FALSE
Is there any evidence that ancient Hindus actually used these constellations for navigating the path of the departed souls?	FALSE
Does the concept of Pitriloka and the path of the departed souls play a significant role in modern Hinduism?	TRUE
What are the implications of the current situation between the US and China?	FALSE
Why does China not need nuclear-fueled aircraft carriers?	TRUE
What is the strategy behind China's approach of building smaller ships?	TRUE
Did you think the employees at the first Lowe's thought you were going to spend a lot of money and just didn’t want to deal with it?	FALSE
Did you use financing for all $12,000?	FALSE
Why would someone choose to invest in real estate?	TRUE
Should your colleague factor in the interest on the loan into his calculations?	FALSE
What are some downsides to investing in real estate?	TRUE
What homework should people do before investing in real estate?	TRUE
Is your son's reaction common in boys his age?	FALSE
What could you do to help your son feel more comfortable expressing his emotions?	TRUE
What are the societal pressures that contribute to men being discouraged from expressing their emotions?	TRUE
What are your thoughts on the idea that men should "toughen up" and not cry?	TRUE
What does this story say about the nature of desire?	FALSE
What are some other stories of divine play (leela) in Hinduism?	FALSE
What are some other examples of the "divine play" in Hindu mythology?	FALSE
Are there other similar stories in other religions or mythologies that involve deities taking on human forms?	FALSE
Do you think Omaha Steaks could be considered a luxury brand?	TRUE
What are your thoughts on companies that incentivize large purchases with discounts and bundling?	TRUE
Why don't Omaha Steaks show photos of their raw steaks?	TRUE
What other information would you need to feel comfortable purchasing from Omaha Steaks?	FALSE
What is the difference between welfare and food stamps?	TRUE
What are some ways companies can create a more flexible work environment?	TRUE
What are some ways employees can communicate effectively with their teams about their schedule?	TRUE
How has the importance of being present in the office changed in recent years?	TRUE
Are there situations where being physically present in the office is still important?	TRUE
What are some companies that have a culture of flexible work hours?	TRUE
What are some examples of children who exhibit similar behavior?	FALSE
Is there any connection between E.'s behavior and his early childhood experiences?	FALSE
What are the long-term effects of "Defiance and Anger Management Disorder"?	FALSE
What are some of the challenges of dealing with children with "Defiance and Anger Management Disorder"?	FALSE
""".strip()

In [7]:
dataset = [row.split('\t') for row in tsv.split("\n")]
dataset = [(text, "0" if label == "TRUE" else "1") for text, label in dataset]

In [8]:
for text, target in dataset:
    # note that you need to cast the target into a string
    assert type(text) == str
    assert type(target) == str

In [9]:
# make sure the number of types of labels is small
collections.Counter(label for _, label in dataset)

Counter({'0': 48, '1': 33})

In [10]:
dataset[0]

('What are some examples of sorting algorithms that require more conditional statements?',
 '0')

# Model definition
If you want to use a different model, you can configure here this section.

Otherwise you can just run the rest of the code and see what prompts is o1 optimizing for you.

In [11]:
# usually o1 is good enough to produce working prompts from nothing
model_parameters = {
    "system_prompt": f"Classify into one of the following labels {set(label for _, label in dataset)}",
    "user_prompt_prefix": "Classify the following text <text>",
    "user_prompt_suffix": "</text>\nYour response should produce a brief reasoning and end with Label: <label>{{label}}</label>",
}

token_counts = defaultdict(int)
token_costs = defaultdict(float)

In [12]:
# the code for this function is being used as o1 input
def compute_model_response(text, model_parameters):    
    user_message = model_parameters["user_prompt_prefix"] + text + model_parameters["user_prompt_suffix"]
    
    message = anthropic_client.messages.create(
        model="claude-3-haiku-20240307",
        max_tokens=2000,
        temperature=0,
        messages=[
            {"role": "user", "content": user_message},
            {"role": "assistant", "content": "Reasoning:"},
        ],
        system=model_parameters["system_prompt"],
        timeout=10
    )
    token_counts["haiku_input"] += message.usage.input_tokens
    token_counts["haiku_output"] += message.usage.output_tokens
    token_costs["haiku_input"] += message.usage.input_tokens * 0.25 * 1e-6
    token_costs["haiku_output"] += message.usage.output_tokens * 1.25 * 1e-6

    return message.content[0].text

In [13]:
import re

def extract_from_model_response(text):
    pattern = r'<label>(.*?)</label>'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return None

# Optimization functions

In [14]:
def get_texts_and_labels(dataset):
    dataset = [data for data in dataset]
    random.shuffle(dataset)
    label_set = set(label for _,label in dataset)

    sampled_dataset = []
    for target_label in label_set:
        dataset_with_label = [(data, label) for data, label in dataset if label == target_label]
        sampled_dataset += dataset_with_label[:NUM_SAMPLES_FORWARD_PASS_FOR_EACH_LABEL]
    random.shuffle(sampled_dataset)

    return [data for data, _ in sampled_dataset], [label for _, label in sampled_dataset]

In [15]:
def forward_pass(texts, model_parameters):

    with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_PARALLEL_FORWARD_PASS_API_CALLS) as executor:
        model_response = executor.map(compute_model_response, texts, [model_parameters]*len(texts))

    model_response = list(model_response)

    with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_PARALLEL_FORWARD_PASS_API_CALLS) as executor:
        predicted_labels = executor.map(extract_from_model_response, model_response)

    predicted_labels = list(predicted_labels)

    return model_response, predicted_labels

In [16]:
def update_model_parameters(
    texts, model_responses, predicted_labels, correct_labels,
    model_parameters, metrics,
    parameters_update_response_history, metrics_history,
):
    conversation_history = []
    assert len(parameters_update_response_history) == len(metrics_history[:-1])
    for parameters_update_response_past, metrics_past in zip(
        parameters_update_response_history, metrics_history[:-1]
    ):
        # the latest metrics will be given in the last message
        conversation_history.append(
            {
                "role": "user",
                "content": str(metrics_past),  
            }
        )
        conversation_history.append(
            {
                "role": "assistant",
                "content": parameters_update_response_past,
            },
        )
    
    mistake_counts = collections.defaultdict(int)
    correct_counts = collections.defaultdict(int)

    user_message = textwrap.dedent(
        f"""
        You are given
        - a set of (text, model response, predicted label, correct label)
            - predicted label may be None if it is not found in the model response
        - the current set of prompts (which may be empty) for an LLM

        You will improve the prompts so that the LLM will predict the correct label.
        Please spend some time to think what each label means, based on the examples.

        The LLM input has the following parameters.
        - {list(model_parameters.keys())}
        
        This is how the parameters are used by the LLM
        {inspect.getsource(compute_model_response)}

        Please ensure that the prompts contains these following instructions
        - The label should have the same exact text as the actual labels
        - The information on what the set of labels are
            - Some information on what is considered invalid reasons for a label
        - A small number of example inputs and outputs
        - The LLM should only classify the text. The LLM should not respond to the text or decline classifying the text.
        - The LLM should provide a concise reasoning. The reasoning should happen before the label.
        - The response should always end with Label: <label>{{label}}</label>
            - Note that the label needs exactly the text in the correct label
            - Note that you need the html tags
        
        The current metrics is {str(metrics)}.
        If the metrics is especially bad (i.e. correctness is less than random), you likely need to rethink the interpretation of the labels.
        """
    ) + "\n\n\n"
    
    for text, model_response, predicted_label, correct_label in zip(
        texts, model_responses, predicted_labels, correct_labels
    ):
        correctness_verdict = ""
        correct_labels_set = set(correct_labels)
        if predicted_label == correct_label:
            correctness_verdict = "This prediction is correct."
            if correct_counts[correct_label] > NUM_SAMPLES_CORRECT_GRADIENT_CALCULATION_FOR_EACH_LABEL:
                continue
            correct_counts[correct_label] += 1
        elif predicted_label not in correct_labels_set:
            correctness_verdict = "This predicted label could not be extracted or does not belong to one of the actual labels."
            if mistake_counts[None] > NUM_SAMPLES_MISTAKE_GRADIENT_CALCULATION_FOR_EACH_LABEL:
                continue
            mistake_counts[None] += 1
        else:
            correctness_verdict = "This prediction is incorrect."
            if mistake_counts[correct_label] > NUM_SAMPLES_MISTAKE_GRADIENT_CALCULATION_FOR_EACH_LABEL:
                continue
            mistake_counts[correct_label] += 1
            
        user_message += textwrap.dedent(
            f"""
            <datapoint>
                {correctness_verdict}
            
                <text>{text}</text>

                <model_response>{model_response}</model_response>

                <predicted_label>{predicted_label}</predicted_label>

                <correct_label>{correct_label}</correct_label>
            
            </datapoint>
            """
        ) + "\n\n"
    
    user_message += "\n\n\nThis the current set of prompts\n"
    
    for model_parameter_key, current_model_parameter_value in model_parameters.items():
        user_message += textwrap.dedent(f"""
        <{model_parameter_key}>
        {current_model_parameter_value}
        </{model_parameter_key}>
        """) + "\n\n"

    user_message += textwrap.dedent(f"""
    Your reply (not the LLM you are tuning prompts for) should include the following
    
    Your informed interpretation of the labels
        
    The prompt parameters in the following format within the xml tags
    (please make sure each prompt parameter has some meaningful content)    
    """) + "\n\n"

    for model_parameter_key in model_parameters.keys():
        user_message += textwrap.dedent(f"""
        <{model_parameter_key}>
        the new {model_parameter_key} here
        </{model_parameter_key}>
        """) + "\n\n"        
        
    user_message += "\n\nPlease spend time to check that your improved prompts will actually fix the mistakes."
        
    conversation_history.append({"role": "user", "content": user_message})
    
    response = openai_client.chat.completions.create(
        model="o1-preview",
        messages=conversation_history
    )
    
    token_counts["o1_input"] += response.usage.prompt_tokens
    token_counts["o1_output"] += response.usage.completion_tokens
    token_costs["o1_input"] += response.usage.prompt_tokens * 3 * 1e-6
    token_costs["o1_output"] += response.usage.completion_tokens * 15 * 1e-6

    parameters_update_response = response.choices[0].message.content

    for model_parameter_key in model_parameters.keys():
        groups = re.search(
            fr'<{model_parameter_key}>(.*?)</{model_parameter_key}>',
            parameters_update_response, re.DOTALL
        )
        model_parameters[model_parameter_key] = groups.group(1) if groups else ""
    
    return model_parameters, parameters_update_response

# Display functions

In [17]:
def calculate_metrics(predicted_labels, correct_labels):
    metrics = {}
    correct_labels_set = set(correct_labels)
    for label in sorted(correct_labels_set):
        metrics[f"{label}_recall"] = recall_score(
            [correct_label == label for correct_label in correct_labels],
            [predicted_label == label for predicted_label in predicted_labels],
        )
    metrics["accuracy"] = sum(
        predicted_label == correct_label for predicted_label, correct_label in zip(predicted_labels, correct_labels)
    )
    metrics["missing"] = sum(
        [predicted_label not in correct_labels_set for predicted_label in predicted_labels]
    ) / len(predicted_labels)
    return metrics

In [18]:
def save_and_display_prompt_history(model_parameters_history, parameters_update_response_history, metrics_history):

    iteration_data_all = []

    for model_parameters, parameters_update_response, metrics in itertools.zip_longest(
        model_parameters_history, parameters_update_response_history, metrics_history, fillvalue={}
    ):
        iteration_data = {}
        for k,v in model_parameters.items():
            iteration_data[k] = v
        for k,v in metrics.items():
            iteration_data[k] = f"{v:.3f}"
        if parameters_update_response:
            iteration_data["parameters_update_response"] = parameters_update_response
        iteration_data_all.append(iteration_data)

    df = pd.DataFrame(iteration_data_all).fillna("")

    html_prefix = '''
    <meta charset="UTF-8">
    <style>
    table {
        border-collapse: collapse;
    }
    td, th {
        border: 1px solid black;
        padding: 5px;
        vertical-align: top;
    }
    td {
        white-space: pre-wrap;
        font-family: monospace;
    }
    </style>
    '''
    
    os.makedirs("html_output", exist_ok=True)
    prompt_info_file_name = "html_output/prompt-history-classification.html"
    with open(prompt_info_file_name, 'w') as f:
        f.write(
            html_prefix + df.replace(
                {r'\n': '__NEWLINE__'}, regex=True
            ).applymap(str).applymap(html.escape).replace(
                {'__NEWLINE__': '<br>'}, regex=True
            ).style.set_table_styles(
                [
                    dict(selector="tr:nth-child(even)", props=[("background-color", "#f2f2f2")]),
                    dict(selector="tr:nth-child(odd)", props=[("background-color", "white")]),
                ]
            ).render(
                index=False, escape=False
            )
        )

    link = f'<a href="{prompt_info_file_name}" target="_blank">{prompt_info_file_name}</a>'
    display(HTML(link))
        
    
def save_and_display_current_iteration(iteration_idx, texts, model_response, predicted_labels, correct_labels):
    
    df = pd.DataFrame({
        "text": texts,
        "model_response": model_response,
        "predicted_label": predicted_labels,
        "correct_label": correct_labels,
    })
    
    def highlight_diff(row):
        if row['predicted_label'] == row['correct_label']:
            return ['background-color: #90EE90'] * len(row)  # green
        return ['background-color: #FFB6C1'] * len(row)  # red

    html_prefix = '''
    <meta charset="UTF-8">
    <style>
    table {
        border-collapse: collapse;
    }
    td, th {
        border: 1px solid black;
        padding: 5px;
        vertical-align: top;
    }
    td {
        white-space: pre-wrap;
        font-family: monospace;
    }
    </style>
    '''
    
    os.makedirs("html_output", exist_ok=True)
    iteration_info_file_name = f"html_output/iteration-classification-{iteration_idx:03}.html"
    with open(iteration_info_file_name, 'w') as f:
        f.write(
            html_prefix + df.replace(
                {r'\n': '__NEWLINE__'}, regex=True
            ).applymap(str).applymap(html.escape).replace(
                {'__NEWLINE__': '<br>'}, regex=True
            ).style.apply(highlight_diff, axis=1).render(
                index=False, escape=False
            )
        )
    
    link = f'<a href="{iteration_info_file_name}" target="_blank">{iteration_info_file_name}</a>'
    display(HTML(link))   
    
    os.makedirs("html_output", exist_ok=True)
    iteration_info_file_name = f"html_output/iteration-classification-{iteration_idx:03}-diff.html"
    with open(iteration_info_file_name, 'w') as f:
        f.write(
            html_prefix + df[df["predicted_label"] != df["correct_label"]].sort_values("correct_label").replace(
                {r'\n': '__NEWLINE__'}, regex=True
            ).applymap(str).applymap(html.escape).replace(
                {r'__NEWLINE__': '<br>'}, regex=True
            ).style.set_table_styles(
                [
                    dict(selector="tr:nth-child(even)", props=[("background-color", "#f2f2f2")]),
                    dict(selector="tr:nth-child(odd)", props=[("background-color", "white")]),                    
                ]
            ).render(
                index=False, escape=False
            )
        )

    link = f'<a href="{iteration_info_file_name}" target="_blank">{iteration_info_file_name}</a>'
    display(HTML(link))

# Execution

In [19]:
model_parameters_history = [{k:v for k,v in model_parameters.items()}]
parameters_update_response_history = []
metrics_history = []

for iteration_idx in range(1, NUM_ITERATIONS+1):
    samples, correct_labels = get_texts_and_labels(dataset)

    model_responses, predicted_labels = forward_pass(samples, model_parameters)
    metrics = calculate_metrics(predicted_labels, correct_labels)
    metrics_history.append(metrics)

    if iteration_idx != NUM_ITERATIONS:  # don't update parameters for the last iteration
        model_parameters, parameters_update_response = update_model_parameters(
            samples, model_responses,
            predicted_labels, correct_labels,
            model_parameters, metrics,
            parameters_update_response_history, metrics_history,
        )
        parameters_update_response_history.append(parameters_update_response)
        model_parameters_history.append({k:v for k,v in model_parameters.items()})

    save_and_display_prompt_history(model_parameters_history, parameters_update_response_history, metrics_history)
    save_and_display_current_iteration(iteration_idx, samples, model_responses, predicted_labels, correct_labels)

# Cost tracking

In [20]:
token_counts

defaultdict(int,
            {'haiku_input': 240596,
             'haiku_output': 20808,
             'o1_input': 37424,
             'o1_output': 31341})

In [21]:
print(f"Cost: ${sum(token_costs.values()):.2f}")

Cost: $0.67
