# Example: Optimizing LLMs to Satisfy a Judge with Hidden Preferences


## Setup

In [None]:
import asyncio
import random

import altair as alt
import pandas as pd
from tensorzero import AsyncTensorZeroGateway
from tqdm.asyncio import tqdm_asyncio

> **IMPORTANT:** Update the gateway URL below if you're not using the standard setup provided in this example

In [None]:
TENSORZERO_GATEWAY_URL = "http://localhost:3000"

## Load the Dataset

In [None]:
NUM_TRAIN_DATAPOINTS = 500
NUM_VAL_DATAPOINTS = 500

In [None]:
random.seed(0)  # Set seed for reproducibility


with open("data/nounlist.txt", "r") as file:
    topics = [line.strip() for line in file]
    random.shuffle(topics)

print(f"There are {len(topics)} topics in the list of haiku topics.")

train_topics = topics[:NUM_TRAIN_DATAPOINTS]
val_topics = topics[NUM_TRAIN_DATAPOINTS : NUM_TRAIN_DATAPOINTS + NUM_VAL_DATAPOINTS]

print(
    f"Using {len(train_topics)} topics for training and {len(val_topics)} topics for validation."
)

## Inference: Write and Judge Haikus

> **IMPORTANT:** Reduce the number of concurrent requests if you're running into rate limits

In [None]:
MAX_CONCURRENT_REQUESTS = 10

In [None]:
tensorzero_client = AsyncTensorZeroGateway(TENSORZERO_GATEWAY_URL, timeout=15.0)

In [None]:
async def write_judge_haiku(topic, variant_name):
    # Generate a haiku about the given topic
    try:
        write_result = await tensorzero_client.inference(
            function_name="write_haiku",
            variant_name=variant_name,  # only used during validation
            input={
                "messages": [
                    {
                        "role": "user",
                        "content": [{"type": "text", "value": {"topic": topic}}],
                    }
                ]
            },
        )
    except Exception as e:
        print(f"Error occurred: {type(e).__name__}: {e}")
        return None

    # The LLM is instructed to conclude with the haiku, so we extract the last 3 lines
    # In a real application, you'll want more sophisticated validation and parsing logic
    haiku_text = "\n".join(write_result.content[0].text.strip().split("\n")[-3:])

    # Judge the haiku using a separate TensorZero function
    # We use the same episode_id to associate these inferences
    try:
        judge_result = await tensorzero_client.inference(
            function_name="judge_haiku",
            input={
                "messages": [
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "value": {"topic": topic, "haiku": haiku_text},
                            }
                        ],
                    }
                ]
            },
            episode_id=write_result.episode_id,
        )

        score = judge_result.output.parsed["score"]
    except Exception as e:
        print(f"Error occurred: {type(e).__name__}: {e}")
        return None

    return (write_result.inference_id, score)

In [None]:
# Run inference in parallel to speed things up
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)


async def ratelimited_write_judge_haiku(topic, variant_name=None):
    async with semaphore:
        return await write_judge_haiku(topic, variant_name=variant_name)


results = await tqdm_asyncio.gather(
    *[ratelimited_write_judge_haiku(topic) for topic in train_topics]
)

## Send Feedback

In [None]:
async def send_haiku_feedback(inference_id, score):
    async with semaphore:
        await tensorzero_client.feedback(
            metric_name="haiku_score", inference_id=inference_id, value=score
        )

In [None]:
await tqdm_asyncio.gather(
    *[send_haiku_feedback(*result) for result in results if result is not None]
);

## Validation Set

> **IMPORTANT:** Update the list below when you create new variants in `tensorzero.toml`

In [None]:
# Include the variants in `tensorzero.toml` that we want to evaluate
VARIANTS_TO_EVALUATE = [
    "gpt_4o_mini",
    # "gpt_4o_mini_fine_tuned",
]

In [None]:
scores = {}  # variant_name => score


for variant_name in VARIANTS_TO_EVALUATE:
    # Run inference on the validation set
    val_results = await tqdm_asyncio.gather(
        *[
            ratelimited_write_judge_haiku(
                topic,
                variant_name=variant_name,  # pin to the specific variant we want to evaluate
            )
            for topic in val_topics
        ],
        desc=f"Evaluating variant: {variant_name}",
    )

    # Compute the average score for the variant
    scores[variant_name] = sum(
        result[1] for result in val_results if result is not None
    ) / len(val_results)

## Plot Results

In [None]:
# Build a dataframe for plotting
scores_df = []

for variant_name, variant_score in scores.items():
    scores_df.append(
        {
            "Variant": variant_name,
            "Metric": "haiku_score",
            "Score": variant_score,
        }
    )

scores_df = pd.DataFrame(scores_df)

In [None]:
# Build the chart
chart = (
    alt.Chart(scores_df)
    .encode(
        x=alt.X("Score:Q", axis=alt.Axis(format="%"), scale=alt.Scale(domain=[0, 1])),
        y="Variant:N",
        color="Metric:N",
        text=alt.Text("Score:Q", format=".1%"),
    )
    .properties(title="Score by Variant")
)

chart = chart.mark_bar() + chart.mark_text(align="left", dx=2)

chart