In [None]:
import asyncio
import os
import random
from typing import Optional, Tuple
from uuid import UUID

from clickhouse_connect import get_client
from tensorzero import AsyncTensorZeroGateway, InferenceResponse
from tqdm.asyncio import tqdm_asyncio

In [None]:
CLICKHOUSE_URL = os.getenv("CLICKHOUSE_URL")

# Example: "http://localhost:8123/tensorzero" ("https://user:password@host:port/database")
assert CLICKHOUSE_URL is not None, "CLICKHOUSE_URL is not set"

In [None]:
tensorzero_client = AsyncTensorZeroGateway("http://localhost:3000", timeout=15.0)

In [None]:
with open("nounlist.txt", "r") as file:
    nouns = [line.strip() for line in file]
    random.shuffle(nouns)

print(f"There are {len(nouns)} nouns in the list of haiku topics.")

In [None]:
async def write_grade_haiku(
    topic: str,
    client: AsyncTensorZeroGateway,
) -> Optional[Tuple[str, bool, UUID]]:
    # Infer against the TensorZero gateway using the write_haiku function
    # This will naturally sample from the variants configured in `tensorzero.toml`
    try:
        haiku_result: InferenceResponse = await client.inference(
            function_name="write_haiku",
            input={
                "messages": [
                    {
                        "role": "user",
                        "content": [{"type": "text", "value": {"topic": topic}}],
                    }
                ]
            },
        )
    except Exception as e:
        print(e)
        return None

    # The LLM is instructed to conclude with the haiku, so we extract the last 3 lines
    haiku_text = haiku_result.content[0].text
    haiku_lines = haiku_text.strip().split("\n")
    haiku_text = "\n".join(haiku_lines[-3:])
    inference_id = haiku_result.inference_id

    # Judge the haiku using a separate TensorZero function, but we can use the same episode_id to associate these inferences
    try:
        judge_result: InferenceResponse = await client.inference(
            function_name="judge_haiku",
            input={
                "messages": [
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "value": {"topic": topic, "haiku": haiku_text},
                            }
                        ],
                    }
                ]
            },
            episode_id=haiku_result.episode_id,
        )
    except Exception as e:
        print(f"Error occurred: {type(e).__name__}: {e}")
        return None

    score = judge_result.output.parsed["score"]
    return haiku_text, score, inference_id

In [None]:
# Run a bunch of haiku writing and grading tasks concurrently
max_concurrent_requests = 50
num_requests = 500
semaphore = asyncio.Semaphore(max_concurrent_requests)


async def ratelimited_write_grade_haiku(noun, client):
    async with semaphore:
        return await write_grade_haiku(noun, client)


tasks = [
    ratelimited_write_grade_haiku(noun, tensorzero_client)
    for noun in nouns[:num_requests]
]
results = await tqdm_asyncio.gather(*tasks)

In [None]:
async def send_haiku_feedback(
    client: AsyncTensorZeroGateway, inference_id: UUID, score: bool
):
    await client.feedback(
        metric_name="haiku_score", inference_id=inference_id, value=score
    )

In [None]:
# Send feedback to the haiku grading function
results = [result for result in results if result is not None]
tasks = [
    send_haiku_feedback(tensorzero_client, result[2], result[1]) for result in results
]
await tqdm_asyncio.gather(*tasks);

In [None]:
clickhouse_client = get_client(dsn=CLICKHOUSE_URL)

In [None]:
# Query the inferences and feedback from the database and join them on the inference ID
df = clickhouse_client.query_df("""SELECT 
    i.variant_name, 
    i.input, 
    i.output, 
    b.value
FROM 
    ChatInference i
JOIN 
    BooleanMetricFeedback b ON i.id = b.target_id
WHERE 
    i.function_name = 'write_haiku'""")

In [None]:
# Print the average score for each variant
df.groupby("variant_name")["value"].mean()