In [None]:
NUM_PREDICTIONS = 500
MAX_CONCURRENT_REQUESTS = 50

In [None]:
import asyncio
import json
import os
from typing import Dict, List, Optional

import pandas as pd
from clickhouse_driver import Client
from tensorzero import AsyncTensorZeroGateway, InferenceResponse
from tqdm.asyncio import tqdm_asyncio

In [None]:
tensorzero_client = AsyncTensorZeroGateway("http://localhost:3000")

In [None]:
df = pd.read_csv("conllpp.csv")
df.head()

In [None]:
train_df = df[df["split"] == 0]
# Shuffle the training data
train_df = train_df.sample(frac=1, random_state=42).reset_index(drop=True)

val_df = df[df["split"] == 1]
test_df = df[df["split"] == 2]

print(f"Train data shape: {train_df.shape}")
print(f"Validation data shape: {val_df.shape}")
print(f"Test data shape: {test_df.shape}")

In [None]:
train_df["output"].iloc[0]

In [None]:
async def get_entities(
    text: str, client: AsyncTensorZeroGateway
) -> Optional[InferenceResponse]:
    try:
        response: InferenceResponse = await client.inference(
            function_name="extract_entities",
            input={"messages": [{"role": "user", "content": text}]},
        )
    except Exception as e:
        print(f"Error: {e}")
        return None
    return response

In [None]:
def exact_match(predicted: Dict[str, List[str]], gold: Dict[str, List[str]]) -> bool:
    if predicted.keys() != gold.keys():
        return False
    for key, value in gold.items():
        if set(item.lower() for item in predicted[key]) != set(
            item.lower() for item in value
        ):
            return False
    return True

In [None]:
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)


async def make_inference(text: str, client: AsyncTensorZeroGateway):
    async with semaphore:
        return await get_entities(text, client)


responses = await tqdm_asyncio.gather(
    *[
        make_inference(text, tensorzero_client)
        for text in train_df["input"][:NUM_PREDICTIONS]
    ]
)

In [None]:
async def evaluate_send_feedback(
    response: InferenceResponse, gold_data: Dict[str, List[str]]
):
    predicted = response.output.parsed
    matched = exact_match(predicted, gold_data) if predicted else False
    await tensorzero_client.feedback(
        metric_name="exact_match",
        value=matched,
        inference_id=response.inference_id,
    )

In [None]:
await asyncio.gather(
    *[
        evaluate_send_feedback(response, json.loads(gold))
        for response, gold in zip(responses, train_df["output"][:NUM_PREDICTIONS])
    ]
);

In [None]:
clickhouse_client = Client.from_url(os.getenv("CLICKHOUSE_NATIVE_URL"))

In [None]:
# Query the inferences and feedback from the database and join them on the inference ID
df = clickhouse_client.query_dataframe("""SELECT 
    i.variant_name, 
    i.input, 
    i.output, 
    b.value
FROM 
    Inference i
JOIN 
    BooleanMetricFeedback b ON i.id = b.target_id
WHERE 
    i.function_name = 'extract_entities'""")
df.head()

In [None]:
# Print the average score for each variant
df.groupby("variant_name")["value"].mean()