# Automatic Prompt Engineering for classification

Given only (text -> label), this notebook generates and optimizes system and user prompts.

This is how classification is intended to be done.
- (system prompt, user prompt prefix + text + user prompt suffix) -Haiku-> bot response -extract-> label

The notebook will produce
- the system prompt
- the user prompt prefix
- the user prompt suffix

You can simply run this notebook with just
- an Anthropic API key and an OpenAI API key

If you want to change the classification task, you will need to
- provide a dataset (text -> label)

This is how prompt tuning is done
- Sample from the full dataset.
- Haiku takes in (system prompt, user prompt prefix + text + user prompt suffix) and produces model_response
- Extract the label from the model_response.
- Sample from the mistakes and the correct results.
- o1-mini summarizes the mistakes and update the prompts (model parameters).
- Repeat.

You will need to have these Python modules installed
- pandas
- scikit-learn
- anthropic
- openai

In [1]:
import os
import re
import random
import inspect
import textwrap
import collections
import itertools
import concurrent.futures
import pandas as pd
import html
from collections import defaultdict

from IPython.display import display, HTML
from sklearn.metrics import precision_score, recall_score

import anthropic
from openai import OpenAI



# Use your Anthropic API key here

In [2]:
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
# anthropic_api_key = "sk-ant-"
anthropic_client = anthropic.Anthropic(api_key=anthropic_api_key)

openai_api_key = os.environ.get("OPENAI_API_KEY")
openai_client = OpenAI(api_key=openai_api_key)

In [3]:
print(anthropic_client.api_key[:16])

sk-ant-api03-Mbs


In [4]:
print(openai_client.api_key[:16])

sk-T0uvUYsHoQOag


In [5]:
NUM_PARALLEL_FORWARD_PASS_API_CALLS = 100  # see https://docs.anthropic.com/claude/reference/rate-limits
NUM_SAMPLES_FORWARD_PASS_FOR_EACH_LABEL = 100
NUM_SAMPLES_MISTAKE_GRADIENT_CALCULATION_FOR_EACH_LABEL = 10
NUM_SAMPLES_CORRECT_GRADIENT_CALCULATION_FOR_EACH_LABEL = 5
NUM_ITERATIONS = 5

# Define the dataset here
You will need to edit this if your task is different.

In [6]:
# from https://www.kaggle.com/c/quora-insincere-questions-classification/data
df = pd.read_csv("qiqc_truncated.csv")
df["target"] = df["target"].astype(str)
df["target"].value_counts()

0    1000
1    1000
Name: target, dtype: int64

In [7]:
df = pd.concat([
    df[df["target"] == value].sample(min(count, 100), random_state=42)
    for value, count in df["target"].value_counts().iteritems()
], ignore_index=True).sample(frac=1, random_state=0)

df["target"].value_counts()

0    100
1    100
Name: target, dtype: int64

In [8]:
# you can also just define the dataset with code
dataset = list(zip(df["question_text"], df["target"]))

In [9]:
# make sure the number of types of labels is small
# prefer descriptive labels to avoid giving the model mental gymnastics
collections.Counter(label for _, label in dataset)

Counter({'0': 100, '1': 100})

In [10]:
dataset[0]  # should be tuple[string, label]

('Why does programming languages need to be conplicated and difficult to learn or even understand, why dont we make as simple as normal English?',
 '0')

# Define your task here
You will need to edit this if your task is different.

In [11]:
import re

def extract_from_model_response(text):
    pattern = r'<label>(.*?)</label>'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return None

In [12]:
# usually o1 is good enough to produce working prompts from nothing
model_parameters = {
    "system_prompt": "",
    "user_prompt_prefix": "",
    "user_prompt_suffix": "",
}

token_counts = defaultdict(int)
token_costs = defaultdict(float)

# Model definition
If you want to use a different model, you can configure here this section

In [13]:

def compute_model_response(text, model_parameters):
    
    user_message = model_parameters["user_prompt_prefix"] + text + model_parameters["user_prompt_suffix"]
    
    message = anthropic_client.messages.create(
        model="claude-3-haiku-20240307",
        max_tokens=2000,
        temperature=0,
        messages=[{"role": "user", "content": [{"type": "text", "text": user_message}]}],
        system=model_parameters["system_prompt"],
        timeout=10
    )
    token_counts["haiku_input"] += message.usage.input_tokens
    token_counts["haiku_output"] += message.usage.output_tokens
    token_costs["haiku_input"] += message.usage.input_tokens * 0.25 * 1e-6
    token_costs["haiku_output"] += message.usage.output_tokens * 1.25 * 1e-6

    return message.content[0].text

# Optimization functions

In [14]:
def get_texts_and_labels(dataset):
    dataset = [data for data in dataset]
    random.shuffle(dataset)
    label_set = set(label for _,label in dataset)

    sampled_dataset = []
    for target_label in label_set:
        dataset_with_label = [(data, label) for data, label in dataset if label == target_label]
        sampled_dataset += dataset_with_label[:NUM_SAMPLES_FORWARD_PASS_FOR_EACH_LABEL]
    random.shuffle(sampled_dataset)

    return [data for data, _ in sampled_dataset], [label for _, label in sampled_dataset]

In [15]:
def forward_pass(texts, model_parameters):
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_PARALLEL_FORWARD_PASS_API_CALLS) as executor:
        model_response = executor.map(compute_model_response, texts, [model_parameters]*len(texts))

    model_response = list(model_response)

    with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_PARALLEL_FORWARD_PASS_API_CALLS) as executor:
        predicted_labels = executor.map(extract_from_model_response, model_response)

    predicted_labels = list(predicted_labels)

    return model_response, predicted_labels

In [16]:
def update_model_parameters(texts, model_responses, 
                            predicted_labels, actual_labels,
                            model_parameters, metrics):

    mistake_counts = collections.defaultdict(int)
    correct_counts = collections.defaultdict(int)

    user_message = textwrap.dedent(
        f"""
        You are given
        - a set of (text, model response, extracted label, expected label)
            - extracted label may be None if it is not found in the model response
        - the current set of prompts (which may be empty) for an LLM

        You will improve the prompts so that the LLM will predict the expected label.
        You might need to guess what each label means.

        The LLM input has the following parameters. Make use of all the parameters.
        - {list(model_parameters.keys())}
        
        This is how the parameters are used by the LLM
        {inspect.getsource(compute_model_response)}

        Please ensure that the prompt contains these following instructions
        - The label should have the same exact text as the actual labels
        - The information on what the set of labels are
        - A small number of example inputs and outputs
        - The LLM should only classify the text. The LLM should not respond to the text or decline classifying the text.
        - The LLM should provide a concise reasoning. The reasoning should happen before the label.
        - The response should always end with Label: <label>{{label}}</label>
            - Note that the label needs exactly the text in the expected label
            - Note that you need the html tags
        
        The current metrics is {str(metrics)}.
        Put more focus on the worst performing metric.
        """
    ) + "\n\n\n"
    
    for text, model_response, predicted_label, actual_label in zip(
        texts, model_responses, predicted_labels, actual_labels
    ):
        correctness_verdict = ""
        actual_labels_set = set(actual_labels)
        if predicted_label not in actual_labels_set:
            correctness_verdict = "This label could not be extracted or does not belong to one of the actual labels."
            if mistake_counts[None] > NUM_SAMPLES_MISTAKE_GRADIENT_CALCULATION_FOR_EACH_LABEL:
                continue
            mistake_counts[None] += 1
        elif predicted_label == actual_label:
            correctness_verdict = "This prediction is correct."
            if correct_counts[actual_label] > NUM_SAMPLES_CORRECT_GRADIENT_CALCULATION_FOR_EACH_LABEL:
                continue
            correct_counts[actual_label] += 1
        else:
            correctness_verdict = "This prediction is incorrect."
            if mistake_counts[actual_label] > NUM_SAMPLES_MISTAKE_GRADIENT_CALCULATION_FOR_EACH_LABEL:
                continue
            mistake_counts[actual_label] += 1
            
        user_message += textwrap.dedent(
            f"""
            <datapoint>
                {correctness_verdict}
            
                <text>{text}</text>

                <model_response>{model_response}</model_response>

                <predicted_label>{predicted_label}</predicted_label>

                <actual_label>{actual_label}</actual_label>
            
            </datapoint>
            """
        ) + "\n\n"
    
    user_message += "\n\n\nThis the current set of prompts\n"
    
    for model_parameter_key, current_model_parameter_value in model_parameters.items():
        user_message += textwrap.dedent(f"""
        <{model_parameter_key}>
        {current_model_parameter_value}
        </{model_parameter_key}>
        """) + "\n\n"

    user_message += textwrap.dedent(f"""
    Your reply (not the LLM you are tuning prompts for) should include the following
    
    A summary of the current performance
    
    The key mistakes observed with some examples
    
    What are the proposed changes the prompt
    
    The prompts in the following format
    """) + "\n\n"

    for model_parameter_key in model_parameters.keys():
        user_message += textwrap.dedent(f"""
        <{model_parameter_key}>
        the new {model_parameter_key} here
        </{model_parameter_key}>
        """) + "\n\n"        
    
    response = openai_client.chat.completions.create(
        model="o1-mini",
        messages=[{"role": "user", "content": user_message}]
    )

    token_counts["o1_input"] += response.usage.prompt_tokens
    token_counts["o1_output"] += response.usage.completion_tokens
    token_costs["o1_input"] += response.usage.prompt_tokens * 3 * 1e-6
    token_costs["o1_output"] += response.usage.completion_tokens * 15 * 1e-6

    parameters_update_response = response.choices[0].message.content

    for model_parameter_key in model_parameters.keys():
        groups = re.search(
            fr'<{model_parameter_key}>(.*?)</{model_parameter_key}>',
            parameters_update_response, re.DOTALL
        )
        model_parameters[model_parameter_key] = groups.group(1) if groups else ""
    
    return model_parameters, parameters_update_response

# Display functions

In [17]:
def calculate_metrics(predicted_labels, actual_labels):
    metrics = {}
    actual_labels_set = set(actual_labels)
    for label in sorted(actual_labels_set):
        metrics[f"{label}_precision"] = precision_score(
            [actual_label == label for actual_label in actual_labels],
            [predicted_label == label for predicted_label in predicted_labels],
            zero_division = 0,
        )
        metrics[f"{label}_recall"] = recall_score(
            [actual_label == label for actual_label in actual_labels],
            [predicted_label == label for predicted_label in predicted_labels],
        )
    metrics["missing"] = sum(
        [predicted_label not in actual_labels_set for predicted_label in predicted_labels]
    ) / len(predicted_labels)
    return metrics

In [18]:
def save_and_display_prompt_history(model_parameters_history, parameters_update_response_history, metrics_history):

    iteration_data_all = []

    for model_parameters, parameters_update_response, metrics in itertools.zip_longest(
        model_parameters_history, parameters_update_response_history, metrics_history, fillvalue={}
    ):
        iteration_data = {}
        for k,v in model_parameters.items():
            iteration_data[k] = v
        for k,v in metrics.items():
            iteration_data[k] = f"{v:.3f}"
        if parameters_update_response:
            iteration_data["parameters_update_response"] = parameters_update_response
        iteration_data_all.append(iteration_data)

    df = pd.DataFrame(iteration_data_all).fillna("")

    html_prefix = '''
    <meta charset="UTF-8">
    <style>
    table {
        border-collapse: collapse;
    }
    td, th {
        border: 1px solid black;
        padding: 5px;
        vertical-align: top;
    }
    td {
        white-space: pre-wrap;
        font-family: monospace;
    }
    </style>
    '''
    
    os.makedirs("html_output", exist_ok=True)
    prompt_info_file_name = "html_output/prompt-history-classification.html"
    with open(prompt_info_file_name, 'w') as f:
        f.write(
            html_prefix + df.replace(
                {r'\n': '__NEWLINE__'}, regex=True
            ).applymap(str).applymap(html.escape).replace(
                {'__NEWLINE__': '<br>'}, regex=True
            ).style.set_table_styles(
                [
                    dict(selector="tr:nth-child(even)", props=[("background-color", "#f2f2f2")]),
                    dict(selector="tr:nth-child(odd)", props=[("background-color", "white")]),
                ]
            ).render(
                index=False, escape=False
            )
        )

    link = f'<a href="{prompt_info_file_name}" target="_blank">{prompt_info_file_name}</a>'
    display(HTML(link))
        
    
def save_and_display_current_iteration(iteration_idx, texts, model_response, predicted_labels, actual_labels):
    
    df = pd.DataFrame({
        "text": texts,
        "model_response": model_response,
        "predicted_label": predicted_labels,
        "actual_label": actual_labels,
    })
    
    def highlight_diff(row):
        if row['predicted_label'] == row['actual_label']:
            return ['background-color: #90EE90'] * len(row)  # green
        return ['background-color: #FFB6C1'] * len(row)  # red

    html_prefix = '''
    <meta charset="UTF-8">
    <style>
    table {
        border-collapse: collapse;
    }
    td, th {
        border: 1px solid black;
        padding: 5px;
        vertical-align: top;
    }
    td {
        white-space: pre-wrap;
        font-family: monospace;
    }
    </style>
    '''    
    
    os.makedirs("html_output", exist_ok=True)
    iteration_info_file_name = f"html_output/iteration-classification-{iteration_idx:03}.html"
    with open(iteration_info_file_name, 'w') as f:
        f.write(
            html_prefix + df.replace(
                {r'\n': '__NEWLINE__'}, regex=True
            ).applymap(str).applymap(html.escape).replace(
                {'__NEWLINE__': '<br>'}, regex=True
            ).style.apply(highlight_diff, axis=1).render(
                index=False, escape=False
            )
        )
    
    link = f'<a href="{iteration_info_file_name}" target="_blank">{iteration_info_file_name}</a>'
    display(HTML(link))   
    
    os.makedirs("html_output", exist_ok=True)
    iteration_info_file_name = f"html_output/iteration-classification-{iteration_idx:03}-diff.html"
    with open(iteration_info_file_name, 'w') as f:
        f.write(
            html_prefix + df[df["predicted_label"] != df["actual_label"]].sort_values("actual_label").replace(
                {r'\n': '__NEWLINE__'}, regex=True
            ).applymap(str).applymap(html.escape).replace(
                {r'__NEWLINE__': '<br>'}, regex=True
            ).style.set_table_styles(
                [
                    dict(selector="tr:nth-child(even)", props=[("background-color", "#f2f2f2")]),
                    dict(selector="tr:nth-child(odd)", props=[("background-color", "white")]),                    
                ]
            ).render(
                index=False, escape=False
            )
        )

    link = f'<a href="{iteration_info_file_name}" target="_blank">{iteration_info_file_name}</a>'
    display(HTML(link))

# Execution

In [19]:
model_parameters_history = [{k:v for k,v in model_parameters.items()}]
parameters_update_response_history = []
metrics_history = []

for iteration_idx in range(1, NUM_ITERATIONS+1):
    samples, actual_labels = get_texts_and_labels(dataset)

    model_responses, predicted_labels = forward_pass(samples, model_parameters)
    metrics = calculate_metrics(predicted_labels, actual_labels)
    metrics_history.append(metrics)

    if iteration_idx != NUM_ITERATIONS:  # don't update parameters for the last iteration
        model_parameters, parameters_update_response = update_model_parameters(
            samples, model_responses, predicted_labels, actual_labels, model_parameters, metrics)
        parameters_update_response_history.append(parameters_update_response)
        model_parameters_history.append({k:v for k,v in model_parameters.items()})

    save_and_display_prompt_history(model_parameters_history, parameters_update_response_history, metrics_history)
    save_and_display_current_iteration(iteration_idx, samples, model_responses, predicted_labels, actual_labels)

# Cost tracking

In [20]:
token_counts

defaultdict(int,
            {'haiku_input': 462669,
             'haiku_output': 76423,
             'o1_input': 22164,
             'o1_output': 9344})

In [21]:
print(f"Cost: ${sum(token_costs.values()):.2f}")

Cost: $0.42
