## Tutorial: Optimizing a Prompt

![TextGrad](https://github.com/vinid/data/blob/master/logo_full.png?raw=true)

An autograd engine -- for textual gradients!

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/zou-group/TextGrad/blob/main/examples/notebooks/Prompt-Optimization.ipynb)
[![GitHub license](https://img.shields.io/badge/License-MIT-blue.svg)](https://lbesson.mit-license.org/)
[![Arxiv](https://img.shields.io/badge/arXiv-2406.07496-B31B1B.svg)](https://arxiv.org/abs/2406.07496)
[![Documentation Status](https://readthedocs.org/projects/textgrad/badge/?version=latest)](https://textgrad.readthedocs.io/en/latest/?badge=latest)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/textgrad)](https://pypi.org/project/textgrad/)
[![PyPI](https://img.shields.io/pypi/v/textgrad)](https://pypi.org/project/textgrad/)

**Objectives:**

* In this tutorial, we will run prompt optimization.

**Requirements:**

* You need to have an OpenAI API key to run this tutorial. This should be set as an environment variable as OPENAI_API_KEY.


In [7]:
!pip install textgrad # you might need to restart the notebook after installing textgrad

import argparse
import concurrent
from dotenv import load_dotenv
from tqdm import tqdm
import textgrad as tg
from textgrad.tasks import load_task
import numpy as np
import random
load_dotenv(override=True)


True

Let's first define some support functions

In [8]:
def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)

In [9]:
def eval_sample(item, eval_fn, model):
    """
    This function allows us to evaluate if an answer to a question in the prompt is a good answer.

    """
    x, y = item
    x = tg.Variable(x, requires_grad=False, role_description="query to the language model")
    y = tg.Variable(y, requires_grad=False, role_description="correct answer for the query")
    response = model(x)
    try:
        eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))
        return int(eval_output_variable.value)
    except:
        eval_output_variable = eval_fn([x, y, response])
        eval_output_parsed = eval_fn.parse_output(eval_output_variable)
        return int(eval_output_parsed)

In [10]:
def eval_dataset(test_set, eval_fn, model, max_samples: int=None):
    if max_samples is None:
        max_samples = len(test_set)
    accuracy_list = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
        futures = []
        for _, sample in enumerate(test_set):
            
            future = executor.submit(eval_sample, sample, eval_fn, model)
            futures.append(future)
            if len(futures) >= max_samples:
                break

        tqdm_loader = tqdm(as_completed(futures), total=len(futures), position=0)
        for future in tqdm_loader:
            acc_item = future.result()
            accuracy_list.append(acc_item)
            tqdm_loader.set_description(f"Accuracy: {np.mean(accuracy_list)}")
    return accuracy_list 

In [11]:
def run_validation_revert(system_prompt: tg.Variable, results, model, eval_fn, val_set):
    val_performance = np.mean(eval_dataset(val_set, eval_fn, model))
    previous_performance = np.mean(results["validation_acc"][-1])
    print("val_performance: ", val_performance)
    print("previous_performance: ", previous_performance)
    previous_prompt = results["prompt"][-1]
    
    if val_performance < previous_performance:
        print(f"rejected prompt: {system_prompt.value}")
        system_prompt.set_value(previous_prompt)
        val_performance = previous_performance

    results["validation_acc"].append(val_performance)

In [12]:
set_seed(12)
llm_api_eval = tg.get_engine(engine_name="gpt-4o")
llm_api_test = tg.get_engine(engine_name="gpt-3.5-turbo-0125")
tg.set_backward_engine(llm_api_eval, override=True)

# Load the data and the evaluation function
train_set, val_set, test_set, eval_fn = load_task("BBH_object_counting", evaluation_api=llm_api_eval)
print("Train/Val/Test Set Lengths: ", len(train_set), len(val_set), len(test_set))
STARTING_SYSTEM_PROMPT = train_set.get_task_description()


Train/Val/Test Set Lengths:  83 83 84


This is the system prompt we are going to start from:

In [None]:
print(STARTING_SYSTEM_PROMPT)


In [14]:
train_loader = tg.tasks.DataLoader(train_set, batch_size=3, shuffle=True)


# Testing the 0-shot performance of the evaluation engine
system_prompt = tg.Variable(STARTING_SYSTEM_PROMPT, 
                            requires_grad=True, 
                            role_description="system prompt to the language model")
model_evaluation = tg.BlackboxLLM(llm_api_eval, system_prompt)

system_prompt = tg.Variable(STARTING_SYSTEM_PROMPT, 
                            requires_grad=True,
                            role_description="structured system prompt to a somewhat capable language model that specifies the behavior and strategies for the QA task")
model = tg.BlackboxLLM(llm_api_test, system_prompt)

optimizer = tg.TextualGradientDescent(engine=llm_api_eval, parameters=[system_prompt])

results = {"test_acc": [], "prompt": [], "validation_acc": []}
results["test_acc"].append(eval_dataset(test_set, eval_fn, model))
results["validation_acc"].append(eval_dataset(val_set, eval_fn, model))
results["prompt"].append(system_prompt.get_value())


Accuracy: 0.7142857142857143: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1438.83it/s]
Accuracy: 0.6867469879518072: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:00<00:00, 1318.42it/s]


In [15]:
for epoch in range(3):
    for steps, (batch_x, batch_y) in enumerate((pbar := tqdm(train_loader, position=0))):
        pbar.set_description(f"Training step {steps}. Epoch {epoch}")
        optimizer.zero_grad()
        losses = []
        for (x, y) in zip(batch_x, batch_y):
            x = tg.Variable(x, requires_grad=False, role_description="query to the language model")
            y = tg.Variable(y, requires_grad=False, role_description="correct answer for the query")
            response = model(x)
            try:
                eval_output_variable = eval_fn(inputs=dict(prediction=response, ground_truth_answer=y))
            except:
                eval_output_variable = eval_fn([x, y, response])
            losses.append(eval_output_variable)
        total_loss = tg.sum(losses)
        total_loss.backward()
        optimizer.step()
        
        run_validation_revert(system_prompt, results, model, eval_fn, val_set)
        
        print("sys prompt: ", system_prompt)
        test_acc = eval_dataset(test_set, eval_fn, model)
        results["test_acc"].append(test_acc)
        results["prompt"].append(system_prompt.get_value())
        if steps == 3:
            break

Accuracy: 0.5542168674698795: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:21<00:00,  3.94it/s]


val_performance:  0.5542168674698795
previous_performance:  0.6867469879518072
rejected prompt: You will answer a reasoning question. Provide the final numerical answer directly. Your response should be a single numerical value. Do not include any explanation or additional text, only the numerical answer.
sys prompt:  You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value.


Accuracy: 0.7142857142857143: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1201.24it/s]
Accuracy: 0.8313253012048193: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:43<00:00,  1.91it/s]


val_performance:  0.8313253012048193
previous_performance:  0.6867469879518072
sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of the question and focus on counting the items accurately. List each item you count and then verify the total number. Avoid adding any extra information or context that is not directly related to the total count. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.


Accuracy: 0.8214285714285714: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:43<00:00,  1.94it/s]
Accuracy: 0.5542168674698795: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:46<00:00,  1.80it/s]


val_performance:  0.5542168674698795
previous_performance:  0.8313253012048193
rejected prompt: You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of the question and focus on counting the items accurately. 

1. Restate the question to ensure clarity.
2. List each item you count on a new line using bullet points for clarity. If there are multiple items of the same type, list them together and indicate the quantity.
3. Ensure the numbering and items in the list are consistent and free from typographical errors.
4. After listing all items, count the total number of items and provide the final count in the specified format.
5. Verify the total number by cross-checking with the list.
6. If there are potential errors or ambiguities in the list, acknowledge them and request additional details if necessary.

The last line of your response should be of the following 

Accuracy: 0.8214285714285714: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1638.96it/s]
Accuracy: 0.6867469879518072: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:18<00:00,  4.43it/s]


val_performance:  0.6867469879518072
previous_performance:  0.8313253012048193
rejected prompt: You will answer a reasoning question that involves counting items in a list. Directly compute the total number of items and provide the final count. Ensure you understand the context of the question and focus on counting the items accurately. Avoid adding any extra information or context that is not directly related to the total count. Ensure your response follows the format 'Answer: $VALUE' with no additional text. If the input contains unexpected characters or is malformed, correct the input and provide a coherent response. Ensure the final line contains only the answer in the required format.
sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of the question and focus on counting the items accurately. List each item you count and then verify the tot

Accuracy: 0.8214285714285714: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1043.66it/s]
Training step 3. Epoch 0: : 3it [07:29, 149.70s/it]
Accuracy: 0.4819277108433735: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:34<00:00,  2.44it/s]


val_performance:  0.4819277108433735
previous_performance:  0.8313253012048193
rejected prompt: Answer a reasoning question that involves counting items in a list. Focus on counting only the items relevant to the question. Directly provide the total count without listing each item. Ensure the final line contains only the answer in the format 'Answer: $VALUE' where VALUE is a numerical value. Double-check your count for accuracy before providing the final number. If you encounter any ambiguous items or quantities, make a note of them and proceed with the calculation based on the clear items.
sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of the question and focus on counting the items accurately. List each item you count and then verify the total number. Avoid adding any extra information or context that is not directly related to the total co

Accuracy: 0.8214285714285714: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1660.49it/s]
Accuracy: 0.3132530120481928: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:16<00:00,  5.03it/s]


val_performance:  0.3132530120481928
previous_performance:  0.8313253012048193
rejected prompt: You will answer a reasoning question that involves counting items in a list. Ensure you understand the context of the question, such as identifying specific categories of items (e.g., vegetables) before counting. Count the items and provide the total number. Be cautious of items that might be similar or easily confused, and ensure you are counting the correct items based on the context. If the query is ambiguous or could be interpreted in multiple ways, provide a brief explanation of your reasoning or ask for clarification. Ensure the response is a single numerical value without any additional text or formatting.
sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of the question and focus on counting the items accurately. List each item you count and t

Accuracy: 0.8214285714285714: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1674.21it/s]
Accuracy: 0.891566265060241: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [01:00<00:00,  1.36it/s]


val_performance:  0.891566265060241
previous_performance:  0.8313253012048193
sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of each item and its relevance to the total count. Use bullet points or numbering with periods for listing items. Maintain consistent naming conventions for similar items (e.g., Bed 1, Bed 2). If there is any ambiguity, provide reasoning for your choice or ask a clarifying question. After listing the items, verify the total count to ensure accuracy. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.


Accuracy: 0.8571428571428571: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [01:00<00:00,  1.38it/s]
Accuracy: 0.8313253012048193: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:43<00:00,  1.89it/s]


val_performance:  0.8313253012048193
previous_performance:  0.891566265060241
rejected prompt: You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of each item and its relevance to the total count. Use bullet points with a hyphen (-) for each item. Maintain consistent naming conventions for similar items (e.g., Bed 1, Bed 2). If there is any ambiguity, provide reasoning for your choice or ask a clarifying question. After listing the items, state the total number of items in the format: 'The total number of items listed is X.' Ensure the final line contains only the answer in the required format: 'Answer: $VALUE' where VALUE is a numerical value. Double-check your arithmetic to ensure the sum of the counts is correct.
sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of yo

Accuracy: 0.8571428571428571: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1432.86it/s]
Training step 3. Epoch 1: : 3it [08:16, 165.36s/it]
Accuracy: 0.6144578313253012: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:18<00:00,  4.50it/s]


val_performance:  0.6144578313253012
previous_performance:  0.891566265060241
rejected prompt: You will answer a reasoning question that involves counting items in a list. Provide only the final count of items without listing intermediate steps. Ensure you understand the context of each item and its relevance to the total count. If there is any ambiguity, ask a clarifying question. Present the numerical value clearly and directly, without any surrounding text or context. Ensure the final answer is a single numerical value without any additional text. Do not repeat the final answer; provide it only once in the specified format. Avoid using ellipses, parentheses, or any other punctuation that could complicate the response. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the required format.
sys prompt:  You will answer a reasoning question that involves counting items in

Accuracy: 0.8571428571428571: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1608.54it/s]
Accuracy: 0.5301204819277109: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:16<00:00,  5.10it/s]


val_performance:  0.5301204819277109
previous_performance:  0.891566265060241
rejected prompt: Count the items and provide the total number as a single integer. Do not include any additional text, explanations, or lists in your response. Ensure the answer is a numerical value only.
sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of each item and its relevance to the total count. Use bullet points or numbering with periods for listing items. Maintain consistent naming conventions for similar items (e.g., Bed 1, Bed 2). If there is any ambiguity, provide reasoning for your choice or ask a clarifying question. After listing the items, verify the total count to ensure accuracy. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value. Ensure the final line contains only the answer in the r

Accuracy: 0.8571428571428571: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1517.28it/s]
Accuracy: 0.8433734939759037: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [01:20<00:00,  1.03it/s]


val_performance:  0.8433734939759037
previous_performance:  0.891566265060241
rejected prompt: You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of each item and its relevance to the total count. Use bullet points or numbering with periods for listing items. Maintain consistent naming conventions for similar items (e.g., Bed 1, Bed 2). Ensure that the numbering is sequential and does not skip any numbers. Avoid using ellipses or incomplete names; list each item explicitly. If there is any ambiguity, provide reasoning for your choice or ask a clarifying question. Ensure that each item listed is distinct and not a duplicate. After listing the items, count them explicitly and verify that the total matches the number of distinct items listed. Explicitly show all mathematical operations in your reasoning. For example, if you are adding quantities, write out the f

Accuracy: 0.8571428571428571: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1705.01it/s]
Accuracy: 0.7710843373493976: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 83/83 [00:27<00:00,  3.07it/s]


val_performance:  0.7710843373493976
previous_performance:  0.891566265060241
rejected prompt: You will answer a reasoning question that involves counting items in a list. Provide a concise answer without detailed explanations. List each item clearly and sequentially without using bullet points or numbering. Use consistent naming conventions for similar items (e.g., Bed 1, Bed 2). If there is any ambiguity, make a reasonable assumption and provide the answer directly. Ensure the final line contains only the answer in the following format: 'Answer: $VALUE' where VALUE is a numerical value.
sys prompt:  You will answer a reasoning question that involves counting items in a list. Think step by step, but provide a concise summary of your reasoning. Ensure you understand the context of each item and its relevance to the total count. Use bullet points or numbering with periods for listing items. Maintain consistent naming conventions for similar items (e.g., Bed 1, Bed 2). If there is any am

Accuracy: 0.8571428571428571: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 1062.32it/s]
Training step 3. Epoch 2: : 3it [07:08, 142.76s/it]
