In [None]:
from typing import Any

from tensorzero import AsyncTensorZeroGateway, ToolCall

from utils import BeerQA, get_wikipedia_full_text, get_wikipedia_summary

In [None]:
beerqa = BeerQA("data/beerqa_dev_v1.0.json")

In [None]:
async def execute_tool_call(tool_call: ToolCall) -> Any:
    title = tool_call.arguments["article_title"]
    if tool_call.name == "get_summary":
        return await get_wikipedia_summary(title)
    elif tool_call.name == "get_full_text":
        return await get_wikipedia_full_text(title)
    else:
        raise ValueError(f"Unknown tool call: {tool_call.name}")

In [None]:
async def solve_beerqa(client: AsyncTensorZeroGateway, question: str, query_budget: int = 3) -> str:
    """
    Solve a BeerQA question using Wikipedia.
    """
    messages = [{"role": "user", "content": question}]
    for queries_remaining in range(query_budget, 0, -1):
        response = await client.inference(function_name="beerqa_solver", input={
            "system": {"queries_remaining": queries_remaining},
            "messages": messages
        })
        messages.append({"role": "user", "content": response.content})
        for block in response.content:
            if block.type == "tool_call":
                if block.name == "submit_answer":
                    return block.arguments["answer"]
                result = await execute_tool_call(block)
                print(result)

In [None]:
question = beerqa.get_question(0)
async with AsyncTensorZeroGateway("http://localhost:3000") as client:
    answer = await solve_beerqa(client, question)