In [None]:
import asyncio
import json
import os
from collections import Counter
from typing import Dict, List, Optional

import matplotlib.pyplot as plt
import neatplot
import numpy as np
import pandas as pd
import scipy.stats as stats
from clickhouse_connect import get_client
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tensorzero import AsyncTensorZeroGateway, InferenceResponse
from tqdm.asyncio import tqdm_asyncio

In [None]:
neatplot.set_style("notex")
CLICKHOUSE_URL = os.getenv("CLICKHOUSE_URL")

# Example: "http://localhost:8123/tensorzero" ("https://user:password@host:port/database")
assert CLICKHOUSE_URL is not None, "CLICKHOUSE_URL is not set"

In [None]:
tensorzero_client = AsyncTensorZeroGateway("http://localhost:3000", timeout=5)

Read the data from the CSV file provided


In [None]:
df = pd.read_csv("conllpp.csv")
df.head()
df.output = df.output.apply(json.loads)

In [None]:
train_df = df[df["split"] == 0]
val_df = df[df["split"] == 1]
test_df = df[df["split"] == 2]

# Shuffle the splits
train_df = train_df.sample(frac=1, random_state=42).reset_index(drop=True)
val_df = val_df.sample(frac=1, random_state=42).reset_index(drop=True)
test_df = test_df.sample(frac=1, random_state=42).reset_index(drop=True)

print(f"Train data shape: {train_df.shape}")
print(f"Validation data shape: {val_df.shape}")
print(f"Test data shape: {test_df.shape}")

The cell below defines the function that we'll actually use to extract entities from text.


In [None]:
# We retry the inference in case of a timeout and hide this in an inner function
@retry(stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10))
async def _get_entities(
    text: str,
    client: AsyncTensorZeroGateway,
    variant_name: Optional[str] = None,
    dryrun: bool = False,
) -> InferenceResponse:
    return await client.inference(
        function_name="extract_entities",
        input={"messages": [{"role": "user", "content": text}]},
        dryrun=dryrun,
        variant_name=variant_name,
    )


# Call this function to get the entities from the text
async def get_entities(
    text: str,
    client: AsyncTensorZeroGateway,
    variant_name: Optional[str] = None,
    dryrun: bool = False,
) -> Optional[InferenceResponse]:
    try:
        return await _get_entities(text, client, variant_name, dryrun)
    except Exception as e:
        print(f"Error: {e}")
        return None

In the next two code blocks we define two methods of evaluating the performance of an NER model: Exact Match and Jaccard Similarity.
We will use these metrics to evaluate the performance of each variant of our model.
Our Jaccard similarity metric gives partial credit and is more lenient than exact match.


In [None]:
def flatten_dict(d: Dict[str, List[str]]) -> List[str]:
    res = []
    for k, v in d.items():
        assert isinstance(v, list)
        for elt in v:
            res.append(f"__{k.upper()}__::{elt}")
    return res

In [None]:
# Exact match between the predicted and gold entities (the sharpest metric we use to evaluate NER)
def exact_match(predicted: Dict[str, List[str]], gold: Dict[str, List[str]]) -> bool:
    return set(flatten_dict(predicted)) == set(flatten_dict(gold))

In [None]:
# Jaccard similarity between the predicted and gold entities
# (a more lenient metric that gives partial credit for correct entities)
# NOTE: This is a different implementation from the original code by Predibase, so the metrics won't be directly comparable.
def jaccard_similarity(
    predicted: Dict[str, List[str]], gold: Dict[str, List[str]]
) -> float:
    target_entities = flatten_dict(gold)
    pred_entities = flatten_dict(predicted)
    target_count = Counter(target_entities)
    pred_count = Counter(pred_entities)
    num = 0
    den = 0
    all_keys = set(target_entities).union(set(pred_entities))
    for key in all_keys:
        num += min(target_count.get(key, 0), pred_count.get(key, 0))
        den += max(target_count.get(key, 0), pred_count.get(key, 0))
    if den == 0:
        return 1
    return num / den

First, we'll run inference using TensorZero on the training set to collect data for training future variants.
We will evaluate the predictions on the training set and send the feedback to TensorZero as well.


In [None]:
# Feel free to change these to run inference on more or fewer examples
# or to respect your provider's rate limits
NUM_TRAIN_PREDICTIONS = 100
MAX_CONCURRENT_REQUESTS = 20

semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)

In [None]:
async def make_inference(text: str, client: AsyncTensorZeroGateway):
    async with semaphore:
        return await get_entities(text, client)


# We run inference in parallel to speed things up
responses = await tqdm_asyncio.gather(
    *[
        make_inference(text, tensorzero_client)
        for text in train_df["input"][:NUM_TRAIN_PREDICTIONS]
    ]
)

In [None]:
def evaluate(response: Optional[InferenceResponse], gold_data: Dict[str, List[str]]):
    predicted = response.output.parsed if response else None
    valid_json = predicted is not None
    matched = exact_match(predicted, gold_data) if predicted else False
    jaccard = jaccard_similarity(predicted, gold_data) if predicted else 0
    return valid_json, matched, jaccard


async def evaluate_send_feedback(
    response: Optional[InferenceResponse], gold_data: Dict[str, List[str]]
):
    valid_json, matched, jaccard = evaluate(response, gold_data)
    async with semaphore:
        feedback_tasks = [
            tensorzero_client.feedback(
                metric_name="valid_json",
                value=valid_json,
                inference_id=response.inference_id,
            ),
            tensorzero_client.feedback(
                metric_name="exact_match",
                value=matched,
                inference_id=response.inference_id,
            ),
            tensorzero_client.feedback(
                metric_name="jaccard_similarity",
                value=jaccard,
                inference_id=response.inference_id,
            ),
        ]

        if not matched:
            # Send the demonstration to TensorZero as a serialized JSON string
            feedback_tasks.append(
                tensorzero_client.feedback(
                    metric_name="demonstration",
                    value=gold_data,
                    inference_id=response.inference_id,
                )
            )

        await asyncio.gather(*feedback_tasks)

Evaluate the predictions on the training set and send the feedback to TensorZero


In [None]:
await tqdm_asyncio.gather(
    *[
        evaluate_send_feedback(response, gold)
        for response, gold in zip(responses, train_df["output"][:NUM_TRAIN_PREDICTIONS])
        if response is not None
    ]
);

Now that we've collected data on the training set, we can query the database to see how well each variant performed.
You should see the performance of the GPT-4o mini variant for each metric.
First, we'll check the exact match metric for each variant.


In [None]:
clickhouse_client = get_client(dsn=CLICKHOUSE_URL)

In [None]:
metric_name = "exact_match"

In [None]:
# Query the inferences and feedback from the database and join them on the inference ID
df = clickhouse_client.query_df(
    """SELECT 
    i.variant_name, 
    i.input, 
    i.output, 
    b.value
FROM 
    JsonInference i
JOIN 
    BooleanMetricFeedback b ON i.id = b.target_id
WHERE 
    i.function_name = 'extract_entities'
    AND b.metric_name = %(metric_name)s""",
    {"metric_name": metric_name},
)

df.head()

In [None]:
# Print the average score for each variant
df.groupby("variant_name")["value"].mean()

Next, we'll check the jaccard similarity metric for each variant.


In [None]:
# Query the inferences and feedback from the database and join them on the inference ID
df = clickhouse_client.query_df(
    """SELECT 
    i.variant_name, 
    i.input, 
    i.output, 
    f.value
FROM 
    JsonInference i
JOIN 
    FloatMetricFeedback f ON i.id = f.target_id
WHERE 
    i.function_name = 'extract_entities'
    AND f.metric_name = 'jaccard_similarity'""",
)

df.head()

In [None]:
# Print the average score for each variant
df.groupby("variant_name")["value"].mean()

At this point, you have accumulated a dataset of training "demonstrations" for each variant.
You should use the TensorZero recipe (at `recipes/supervised_fine_tuning/demonstrations/openai/`) to fine-tune a custom GPT-4o mini model on these demonstrations in order to improve the performance of the model on the test set.
After you do so, paste the config output by the notebook into your `tensorzero.toml`, give it a nonzero weight, and restart the gateway to begin testing the new variant and model!


In [None]:
variants_to_evaluate = df["variant_name"].unique()

Now, we'll evaluate each variant on the test set. In order to be sure that we're not leaking any data, we'll use the `dryrun` flag to make sure the test set is not leaked here.

We will also "pin" the `variant_name` for each inference request to ensure that we're evaluating the same variant across all requests for a fair trial.


In [None]:
NUM_TEST_PREDICTIONS = 500
test_set = test_df.iloc[:NUM_TEST_PREDICTIONS]


async def make_inference(
    text: str, client: AsyncTensorZeroGateway, variant_name: Optional[str] = None
):
    async with semaphore:
        # We use dryrun=True to make sure the test set is not leaked here
        return await get_entities(text, client, variant_name=variant_name, dryrun=True)

In [None]:
variant_responses = []
for variant_name in variants_to_evaluate:
    variant_task = tqdm_asyncio.gather(
        *[
            make_inference(text, tensorzero_client, variant_name=variant_name)
            for text in test_set["input"]
        ],
        desc=f"Evaluating variant: {variant_name}",
    )
    variant_result = await variant_task
    variant_responses.append(variant_result)

Finally, we'll evaluate the performance of each variant on the test set.


In [None]:
variant_data = {}
for variant_name, responses in zip(variants_to_evaluate, variant_responses):
    jaccards = []
    well_formed_jsons = []
    exact_matches = []
    print(f"Evaluating variant: {variant_name}")
    print(f"Number of responses: {len(responses)}")
    for response, gold in zip(responses, test_set["output"]):
        valid_json, matched, jaccard = evaluate(response, gold)
        jaccards.append(jaccard)
        well_formed_jsons.append(valid_json)
        exact_matches.append(matched)
    variant_data[variant_name] = {
        "jaccard_similarity": jaccards,
        "well_formed_json": well_formed_jsons,
        "exact_match": exact_matches,
    }
    print(
        f"Average Well-formed JSON: {sum(well_formed_jsons) / len(well_formed_jsons):.1%}"
    )
    print(f"Average Jaccard Similarity: {sum(jaccards) / len(jaccards):.1%}")
    print(f"Average Exact Match: {sum(exact_matches) / len(exact_matches):.1%}")
    print()

In [None]:
def plot_metrics_with_ci(variant_data):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

    for i, (metric, ax) in enumerate(
        [("exact_match", ax1), ("jaccard_similarity", ax2)]
    ):
        means = []
        cis = []

        for variant, data in variant_data.items():
            values = data[metric]
            n = len(values)
            mean = np.mean(values)
            means.append(mean)

            if metric == "exact_match":
                # Binomial test for exact matches
                ci_low, ci_high = stats.binomtest(int(sum(values)), n).proportion_ci()
            else:
                # Normal approximation for Jaccard similarity
                se = stats.sem(values)
                ci_low, ci_high = stats.t.interval(0.95, n - 1, loc=mean, scale=se)

            cis.append((mean - ci_low, ci_high - mean))

        x = range(len(variant_data))
        ax.bar(x, means, yerr=list(zip(*cis)), capsize=5, alpha=0.8)
        ax.set_xticks(x)
        ax.set_xticklabels(variant_data.keys(), rotation=45, ha="right")
        ax.set_title(f"Average {metric.replace('_', ' ').title()}")
        ax.set_ylim(0, 1)

        for i, v in enumerate(means):
            ax.text(i, v / 2, f"{v:.2f}", ha="center", va="bottom")

    plt.tight_layout()
    plt.show()


plot_metrics_with_ci(variant_data)