In [None]:
import asyncio
from typing import List

import numpy as np
from tensorzero import AsyncTensorZeroGateway
from tqdm.asyncio import trange
from utils import BeerQA, confidence_interval, execute_tool_call

In [None]:
beerqa = BeerQA("data/beerqa_dev_v1.0.json")

In [None]:
tensorzero_semaphore = asyncio.Semaphore(10)
wikipedia_semaphore = asyncio.Semaphore(10)

In [None]:
async def solve_beerqa(
    client: AsyncTensorZeroGateway,
    question: str,
    variant_name: str = "baseline",
    query_budget: int = 3,
) -> str:
    """
    Solve a BeerQA question using Wikipedia.
    """
    messages = [{"role": "user", "content": question}]
    for queries_remaining in range(query_budget, 0, -1):
        async with tensorzero_semaphore:
            response = await client.inference(
                function_name="beerqa_solver",
                input={
                    "system": {"queries_remaining": queries_remaining},
                    "messages": messages,
                },
                variant_name=variant_name,
            )
        messages.append({"role": "user", "content": response.content})
        for block in response.content:
            if block.type == "tool_call":
                if block.name == "submit_answer":
                    return block.arguments["answer"]
                result = await execute_tool_call(block)
                messages.append({"role": "user", "content": [result]})
    async with tensorzero_semaphore:
        response = await client.inference(
            function_name="final_answer",
            input={"system": {"question": question}, "messages": messages},
            variant_name=variant_name,
        )
    return response.output.parsed["answer"]

In [None]:
async def grade_answer(
    client: AsyncTensorZeroGateway,
    question: str,
    gt_answer: List[str],
    submitted_answer: str,
) -> float:
    async with tensorzero_semaphore:
        response = await client.inference(
            function_name="grade_answer",
            input={
                "messages": [
                    {
                        "role": "user",
                        "content": {
                            "question": question,
                            "gt_answer": gt_answer,
                            "submitted_answer": submitted_answer,
                        },
                    }
                ]
            },
        )
    return response.output.parsed["score"]

In [None]:
async def solve_grade_question(
    client: AsyncTensorZeroGateway,
    question: str,
    gt_answer: List[str],
    variant_name="baseline",
    query_budget: int = 3,
) -> float:
    submitted_answer = await solve_beerqa(
        client, question, variant_name=variant_name, query_budget=query_budget
    )
    score = await grade_answer(client, question, gt_answer, submitted_answer)
    return score

In [None]:
num_questions = 100
scores = []

async with AsyncTensorZeroGateway("http://localhost:3000") as client:
    tasks = []
    for i in range(num_questions):
        question = beerqa.get_question(i)
        gt_answer = beerqa.get_answers(i)
        tasks.append(
            solve_grade_question(
                client, question, gt_answer, variant_name="baseline", query_budget=5
            )
        )

    progress_bar = trange(num_questions, desc="Solving questions")
    for task in asyncio.as_completed(tasks):
        score = await task
        scores.append(score)
        current = len(scores)
        ci_lower, ci_upper = confidence_interval(scores)
        progress_bar.update(1)
        progress_bar.set_postfix(
            {
                "Average Score": f"{np.mean(scores):.2f} CI: ({ci_lower:.2f}, {ci_upper:.2f})"
            },
            refresh=True,
        )
    progress_bar.close()

print(f"Average score: {np.mean(scores)}")

In [None]:
beerqa.get_answers(0)