In [16]:
NUM_PREDICTIONS = 1000
MAX_CONCURRENT_REQUESTS = 100

In [17]:
import asyncio
import json
import os
from typing import Dict, List, Optional

import pandas as pd
from clickhouse_driver import Client
from tensorzero import AsyncTensorZeroGateway, InferenceResponse
from tqdm.asyncio import tqdm_asyncio

In [18]:
tensorzero_client = AsyncTensorZeroGateway("http://localhost:3000")

In [19]:
df = pd.read_csv("conllpp.csv")
df.head()

Unnamed: 0,raw_id,raw_split,split,input,output
0,0,train,0,EU rejects German call to boycott British lamb .,"{""person"": [], ""organization"": [""EU""], ""locati..."
1,1,train,0,Peter Blackburn,"{""person"": [""Peter Blackburn""], ""organization""..."
2,2,train,0,BRUSSELS 1996-08-22,"{""person"": [], ""organization"": [], ""location"":..."
3,3,train,0,The European Commission said on Thursday it di...,"{""person"": [], ""organization"": [""European Comm..."
4,4,train,0,Germany 's representative to the European Unio...,"{""person"": [""Werner Zwingmann""], ""organization..."


In [20]:
train_df = df[df["split"] == 0]
# Shuffle the training data
train_df = train_df.sample(frac=1, random_state=42).reset_index(drop=True)

val_df = df[df["split"] == 1]
test_df = df[df["split"] == 2]

print(f"Train data shape: {train_df.shape}")
print(f"Validation data shape: {val_df.shape}")
print(f"Test data shape: {test_df.shape}")

Train data shape: (14041, 5)
Validation data shape: (3250, 5)
Test data shape: (3453, 5)


In [21]:
async def get_entities(
    text: str, client: AsyncTensorZeroGateway
) -> Optional[InferenceResponse]:
    try:
        response: InferenceResponse = await client.inference(
            function_name="extract_entities",
            input={"messages": [{"role": "user", "content": text}]},
        )
    except Exception as e:
        print(f"Error: {e}")
        return None
    return response

In [22]:
def exact_match(predicted: Dict[str, List[str]], gold: Dict[str, List[str]]) -> bool:
    if predicted.keys() != gold.keys():
        return False
    for key, value in gold.items():
        if set(item.lower() for item in predicted[key]) != set(
            item.lower() for item in value
        ):
            return False
    return True

In [23]:
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)


async def make_inference(text: str, client: AsyncTensorZeroGateway):
    async with semaphore:
        return await get_entities(text, client)


responses = await tqdm_asyncio.gather(
    *[
        make_inference(text, tensorzero_client)
        for text in train_df["input"][:NUM_PREDICTIONS]
    ]
)

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:13<00:00, 73.31it/s]


In [24]:
async def evaluate_send_feedback(
    response: InferenceResponse, gold_data: Dict[str, List[str]]
):
    predicted = response.output.parsed
    valid_json = predicted is not None
    await tensorzero_client.feedback(
        metric_name="valid_json",
        value=valid_json,
        inference_id=response.inference_id,
    )
    matched = exact_match(predicted, gold_data) if predicted else False
    await tensorzero_client.feedback(
        metric_name="exact_match",
        value=matched,
        inference_id=response.inference_id,
    )

In [25]:
await tqdm_asyncio.gather(
    *[
        evaluate_send_feedback(response, json.loads(gold))
        for response, gold in zip(responses, train_df["output"][:NUM_PREDICTIONS])
    ]
);

100%|██████████| 1000/1000 [00:18<00:00, 54.08it/s]


In [26]:
clickhouse_client = Client.from_url(os.getenv("CLICKHOUSE_NATIVE_URL"))

In [36]:
metric_name = "exact_match"

In [37]:
# Query the inferences and feedback from the database and join them on the inference ID
df = clickhouse_client.query_dataframe("""SELECT 
    i.variant_name, 
    i.input, 
    i.output, 
    b.value
FROM 
    Inference i
JOIN 
    BooleanMetricFeedback b ON i.id = b.target_id
WHERE 
    i.function_name = 'extract_entities'
    AND b.metric_name = %(metric_name)s""",
    {"metric_name": metric_name})
df.head()

Unnamed: 0,variant_name,input,output,value
0,gpt4o_initial_prompt,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[],\""organization\"":[],\""l...",True
1,gpt4o_initial_prompt,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[\""Richard\"",\""Edberg\""],\...",True
2,gpt4o_initial_prompt,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[],\""organization\"":[],\""l...",True
3,claude_sonnet_initial_prompt,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[],\""organization\"":[],\""l...",False
4,claude_sonnet_initial_prompt,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[],\""organization\"":[\""NAT...",False


In [38]:
# Print the average score for each variant
df.groupby("variant_name")["value"].mean()

variant_name
claude_sonnet_initial_prompt    0.160000
gpt4o_initial_prompt            0.494737
llama_405b_initial_prompt       0.343220
llama_8b_initial_prompt         0.195572
mistral_large_initial_prompt    0.236842
Name: value, dtype: float64