## Example: NER + SFT + DPO

### Setup

In [None]:
import asyncio
import json
import os
import tempfile
import time
import warnings
from collections import Counter
from pathlib import Path
from pprint import pprint
from typing import Any, Dict, List, Optional

import altair as alt
import numpy as np
import openai
import pandas as pd
import toml
from clickhouse_connect import get_client
from IPython.display import clear_output
from minijinja import Environment
from tensorzero import AsyncTensorZeroGateway, InferenceResponse
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

####  IMPORTANT: Update the gateway URL below if you're not using the standard setup provided in this example

In [None]:
TENSORZERO_GATEWAY_URL = "http://localhost:3000"

### Load the Dataset

In [None]:
# Select only a subset of the dataset to speed things up
NUM_TRAIN_DATAPOINTS = 500
NUM_VAL_DATAPOINTS = 500

In [None]:
def load_dataset(path: str) -> (pd.DataFrame, pd.DataFrame):
    # Load the dataset
    df = pd.read_csv(path)
    df.output = df.output.apply(json.loads)

    # Split the dataset into train and validation sets
    train_df = df[df["split"] == 0]
    val_df = df[df["split"] == 1]

    # Shuffle the splits
    train_df = train_df.sample(frac=1, random_state=0).reset_index(drop=True)
    val_df = val_df.sample(frac=1, random_state=0).reset_index(drop=True)

    # Select only a subset of the dataset to speed things up
    train_df = train_df.iloc[:NUM_TRAIN_DATAPOINTS]
    val_df = val_df.iloc[:NUM_VAL_DATAPOINTS]

    return train_df, val_df

In [None]:
train_df, val_df = load_dataset("data/conllpp.csv")

print(f"Train data shape: {train_df.shape}")
print(f"Validation data shape: {val_df.shape}")

### Extract Entities
IMPORTANT: REduce the number of concurrent request if you're running into rate limits

In [None]:
MAX_CONCURRENT_REQUESTS = 10

In [None]:
tensorzero_client = await AsyncTensorZeroGateway.build_http(
    gateway_url=TENSORZERO_GATEWAY_URL, timeout=15
)
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)

In [None]:
async def get_entities(
    text: str,
    variant_name: Optional[str] = None,
    dryrun: bool = False,
) -> Optional[InferenceResponse]:
    # Use a semaphore to avoid rate limits
    async with semaphore:
        try:
            return await tensorzero_client.inference(
                function_name="extract_entities",
                input={"messages": [{"role": "user", "content": text}]},
                dryrun=dryrun,
                variant_name=variant_name,
            )
        except Exception as e:
            print(f"Error occurred: {type(e).__name__}: {e}")
            return None

In [None]:
# Run inference in parallel to speed things up
responses = await tqdm_asyncio.gather(
    *[get_entities(text) for text in train_df["input"]]
)

### Evaluate the Performance

In [None]:
def flatten_dict(d: Dict[str, List[str]]) -> List[str]:
    res = []
    for k, v in d.items():
        assert isinstance(v, list)
        for elt in v:
            res.append(f"__{k.upper()}__::{elt}")
    return res


# Exact match between the predicted and ground truth entities (the sharpest metric we use to evaluate NER)
def compute_exact_match(
    predicted: Dict[str, List[str]], ground_truth: Dict[str, List[str]]
) -> bool:
    return set(flatten_dict(predicted)) == set(flatten_dict(ground_truth))


# Jaccard similarity between the predicted and ground_truth entities
# (a more lenient metric that gives partial credit for correct entities)
# This is a different implementation from the original code by Predibase, so the metrics won't be directly comparable.
def compute_jaccard_similarity(
    predicted: Dict[str, List[str]], ground_truth: Dict[str, List[str]]
) -> float:
    target_entities = flatten_dict(ground_truth)
    pred_entities = flatten_dict(predicted)
    target_count = Counter(target_entities)
    pred_count = Counter(pred_entities)
    num = 0
    den = 0
    all_keys = set(target_entities).union(set(pred_entities))
    for key in all_keys:
        num += min(target_count.get(key, 0), pred_count.get(key, 0))
        den += max(target_count.get(key, 0), pred_count.get(key, 0))
    if den == 0:
        return 1
    return num / den

In [None]:
def evaluate_response(
    response: Optional[InferenceResponse], ground_truth_data: Dict[str, List[str]]
):
    predicted = response.output.parsed if response else None

    # `predicted` is None if the model failed to return a valid JSON that complies with the output schema
    valid_output = predicted is not None

    # Compute the other metrics
    exact_match = (
        compute_exact_match(predicted, ground_truth_data) if predicted else False
    )
    jaccard_similarity = (
        compute_jaccard_similarity(predicted, ground_truth_data) if predicted else 0
    )

    return valid_output, exact_match, jaccard_similarity

In [None]:
for response, ground_truth in tqdm(
    zip(responses, train_df["output"]), total=len(responses)
):
    # Don't send feedback if the request failed completely
    if response is None:
        continue

    # Evaluate the example
    valid_output, exact_match, jaccard_similarity = evaluate_response(
        response, ground_truth
    )

    # Send the metrics feedback to TensorZero
    await tensorzero_client.feedback(
        metric_name="valid_output",
        value=valid_output,
        inference_id=response.inference_id,
    )

    await tensorzero_client.feedback(
        metric_name="exact_match",
        value=exact_match,
        inference_id=response.inference_id,
    )

    await tensorzero_client.feedback(
        metric_name="jaccard_similarity",
        value=jaccard_similarity,
        inference_id=response.inference_id,
    )

    # Send the demonstration feedback to TensorZero
    await tensorzero_client.feedback(
        metric_name="demonstration",
        value=ground_truth,
        inference_id=response.inference_id,
    )

### Validation Set
IMPORTANT: Update the list blow when you create new variants in `tensorzero.toml`

In [None]:
# Include the variants in `tensorzero.toml` that we want to evaluate
VARIANTS_TO_EVALUATE = [
    "gpt_4o",
    "gpt_4o_mini",
    "gpt_4o_mini_sft_fine_tuned",
    "gpt_4o_mini_sft_dpo_fine_tuned",
]

In [None]:
scores = {}  # variant_name => (valid_output, exact_match, jaccard_similarity)

for variant_name in VARIANTS_TO_EVALUATE:
    # Run inference on the validation set
    responses = await tqdm_asyncio.gather(
        *[
            get_entities(
                text,
                variant_name=variant_name,  # pin to the specific variant we want to evaluate
                dryrun=True,  # don't store results to avoid leaking data
            )
            for text in val_df["input"]
        ],
        desc=f"Evaluating variant: {variant_name}",
    )

    # Evaluate the performance of the variant
    valid_output_scores = []
    exact_match_scores = []
    jaccard_similarity_scores = []

    for response, ground_truth in zip(responses, val_df["output"]):
        valid_output, exact_match, jaccard_similarity = evaluate_response(
            response, ground_truth
        )
        valid_output_scores.append(valid_output)
        exact_match_scores.append(exact_match)
        jaccard_similarity_scores.append(jaccard_similarity)

    scores[variant_name] = {
        "valid_output": valid_output_scores,
        "exact_match": exact_match_scores,
        "jaccard_similarity": jaccard_similarity_scores,
    }

    # Print the performance of the variant
    print(f"Valid Output: {sum(valid_output_scores) / len(valid_output_scores):.1%}")
    print(f"Exact Match: {sum(exact_match_scores) / len(exact_match_scores):.1%}")
    print(
        f"Jaccard Similarity (mean): {sum(jaccard_similarity_scores) / len(jaccard_similarity_scores):.1%}"
    )
    print()

### Plot Result

In [None]:
scores_df = []

for variant_name, variant_scores in scores.items():
    exact_match_score = sum(variant_scores["exact_match"]) / len(
        variant_scores["exact_match"]
    )
    scores_df.append(
        {
            "Variant": variant_name,
            "Metric": "exact_match",
            "Score": exact_match_score,
        }
    )

    jaccard_similarity_score = sum(variant_scores["jaccard_similarity"]) / len(
        variant_scores["jaccard_similarity"]
    )

    scores_df.append(
        {
            "Variant": variant_name,
            "Metric": "jaccard_similarity",
            "Score": jaccard_similarity_score,
        }
    )

scores_df = pd.DataFrame(scores_df)

In [None]:
chart = (
    alt.Chart(scores_df)
    .encode(
        x=alt.X("Score:Q", axis=alt.Axis(format="%"), scale=alt.Scale(domain=[0, 1])),
        y="Variant:N",
        yOffset="Metric:N",
        color="Metric:N",
        text=alt.Text("Score:Q", format=".1%"),
    )
    .properties(title="Metrics by Variant")
)

chart = chart.mark_bar() + chart.mark_text(align="left", dx=2)

chart

In [None]:
CONFIG_PATH = "config/tensorzero.toml"
FUNCTION_NAME = "extract_entities"
TEMPLATE_VARIANT_NAME = "gpt_4o_mini"
MODEL_NAME = "gpt-4o-2024-08-06"
VAL_FRACTION = 0.2
MAX_SAMPLES = 500

assert "OPENAI_API_KEY" in os.environ
assert "TENSORZERO_CLICKHOUSE_URL" in os.environ

### STEP 1: SFT

### STEP 2: DPO

In [None]:
config_path = Path(CONFIG_PATH)

assert config_path.exists(), f"{CONFIG_PATH} does not exist"
assert config_path.is_file(), f"{CONFIG_PATH} is not a file"

with config_path.open("r") as f:
    config = toml.load(f)

Ensure that the function and variant being fine-tuned are present in the provided config.

In [None]:
assert "functions" in config, "No `[functions]` section found in config"
assert "variants" in config["functions"][FUNCTION_NAME], (
    f"No variants section found for function `{FUNCTION_NAME}`"
)
assert TEMPLATE_VARIANT_NAME in config["functions"][FUNCTION_NAME]["variants"], (
    f"No variant named `{TEMPLATE_VARIANT_NAME}` found in function `{FUNCTION_NAME}`"
)

Retrieve the configuration for the variant with the templates we will use for fine-tuning.

In [None]:
function_type = config["functions"][FUNCTION_NAME]["type"]
variant = config["functions"][FUNCTION_NAME]["variants"][TEMPLATE_VARIANT_NAME]

In [None]:
templates = {}

if "assistant_template" in variant:
    assistant_template_path = config_path.parent / variant["assistant_template"]
    with assistant_template_path.open("r") as f:
        templates["assistant"] = f.read()

if "system_template" in variant:
    system_template_path = config_path.parent / variant["system_template"]
    with system_template_path.open("r") as f:
        templates["system"] = f.read()

if "user_template" in variant:
    user_template_path = config_path.parent / variant["user_template"]
    with user_template_path.open("r") as f:
        templates["user"] = f.read()

env = Environment(templates=templates)

Initialize the ClickHouse client.

In [None]:
assert "TENSORZERO_CLICKHOUSE_URL" in os.environ, (
    "TENSORZERO_CLICKHOUSE_URL environment variable not set"
)

clickhouse_client = get_client(dsn=os.environ["TENSORZERO_CLICKHOUSE_URL"])

Determine the ClickHouse table name for the function.

In [None]:
inference_table_name = {"json": "JsonInference"}.get(function_type)

if inference_table_name is None:
    raise ValueError(f"Unsupported function type: {function_type}")

Query ClickHouse for inference, feedback, and metric.

In [None]:
# ---------------------------
# Query ClickHouse for data
# ---------------------------
query = """
SELECT
    i.variant_name AS variant,
    i.episode_id AS episode_id,
    i.input AS input,
    i.output AS non_preferred_output,
    d.value AS preferred_output
FROM
    JsonInference AS i
INNER JOIN DemonstrationFeedback AS d ON i.id = d.inference_id
WHERE
    (i.function_name = %(function_name)s)
LIMIT %(max_samples)s
"""

params = {"function_name": FUNCTION_NAME, "max_samples": MAX_SAMPLES}
df = clickhouse_client.query_df(query, params)
df.head()
print(df)

render message


In [None]:
def render_message(message: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
    role = message["role"]
    assert role in ["user", "assistant"], f"Invalid role: {role}"
    content: List[Dict[str, Any]] = []
    tool_calls: List[Dict[str, Any]] = []
    rendered_messages: List[Dict[str, Any]] = []

    for content_block in message["content"]:
        if content_block["type"] == "text":
            parsed_content = content_block["value"]
            if not isinstance(parsed_content, str):
                parsed_content = env.render_template(role, **parsed_content)
            content.append({"type": "text", "text": parsed_content})
        elif content_block["type"] == "raw_text":
            content.append({"type": "text", "text": content_block["value"]})
        elif content_block["type"] == "thought":
            content.append(
                {"type": "text", "text": f"<think>{content_block['text']}</think>"}
            )
        elif content_block["type"] == "tool_call" and role == "assistant":
            tool_calls.append(
                {
                    "function": {
                        "arguments": json.dumps(content_block["arguments"]),
                        "name": content_block["name"],
                    },
                    "id": content_block["id"],
                    "type": "function",
                }
            )
        elif content_block["type"] == "tool_result" and role == "user":
            # Tool results get priority so that they follow the tool call in the conversation.
            # Any other "user" content will be appended in another message below.
            rendered_messages.append(
                {
                    "role": "tool",
                    "tool_call_id": content_block["id"],
                    "content": content_block["result"],
                }
            )
        else:
            warnings.warn(
                f"We do not support content block type: {content_block['type']}, dropping example.",
                UserWarning,
            )
            return None

    if content or tool_calls:
        role_message: Dict[str, Any] = {"role": role}
        if content:
            role_message["content"] = content
        if tool_calls:
            role_message["tool_calls"] = tool_calls
        rendered_messages.append(role_message)

    return rendered_messages


def render_output(
    output: List[Dict[str, Any]],
) -> Dict[str, Any]:
    """
    Parses the assistant message from an observation using the provided function configuration.
    """
    content: List[Dict[str, Any]] = []
    tool_calls: List[Dict[str, Any]] = []

    if function_type == "json":
        return {"role": "assistant", "content": output["raw"]}
    elif function_type == "chat":
        for content_block in output:
            if content_block["type"] == "text":
                content.append({"type": "text", "text": content_block["text"]})
            elif content_block["type"] == "thought":
                content.append(
                    {"type": "text", "text": f"<think>{content_block['text']}</think>"}
                )
            elif content_block["type"] == "tool_call":
                tool_calls.append(
                    {
                        "function": {
                            "arguments": json.dumps(content_block["arguments"]),
                            "name": content_block["name"],
                        },
                        "id": content_block["id"],
                        "type": "function",
                    }
                )
            else:
                warnings.warn(
                    f"We do not support content block type: {content_block['type']}, dropping example.",
                    UserWarning,
                )
                return None
    else:
        raise ValueError(f"Unsupported function type: {function_type}")

    # Once we finish collecting all blocks, create one assistant message.
    output_message: Dict[str, Any] = {"role": "assistant"}
    if content:
        output_message["content"] = content
    if tool_calls:
        output_message["tool_calls"] = tool_calls

    return output_message


def sample_to_openai_messages(sample) -> List[Dict[str, str]]:
    function_input = json.loads(sample["input"])

    result = {
        "input": {"messages": [], "tools": [], "parallel_tool_calls": True},
        "preferred_output": [],
        "non_preferred_output": [],
    }

    # Add the system message to the rendered messages
    # If there is data passed in or a system template there must be a system message
    system = function_input.get("system", {})
    if len(system) > 0 or system_template_path:
        if system_template_path:
            system_message = env.render_template("system", **system)
            result["input"]["messages"].append(
                {"role": "system", "content": system_message}
            )
        else:
            result["input"]["messages"].append(
                {"role": "system", "content": system_message}
            )

    # Add the input messages to the rendered messages
    for message in function_input["messages"]:
        rendered_message = render_message(message)
        if rendered_message is None:
            # `render_message` will return None if the message contains an unknown or unsupported content block.
            # The entire example is dropped if this is the case.
            return None
        result["input"]["messages"].extend(render_message(message))

    # Add the demonstration (preferred output)
    preferred_output = json.loads(sample["preferred_output"])
    rendered_preferred_output = render_output(preferred_output)
    if rendered_preferred_output is None:
        # `render_output` will return None if the output contains an unknown or unsupported content block.
        # The entire example is dropped if this is the case.
        return None
    result["preferred_output"].append(rendered_preferred_output)

    # Add the inference output (non-preferred output)
    non_preferred_output = json.loads(sample["non_preferred_output"])
    rendered_non_preferred_output = render_output(non_preferred_output)
    if rendered_non_preferred_output is None:
        # `render_output` will return None if the output contains an unknown or unsupported content block.
        # The entire example is dropped if this is the case.
        return None
    result["non_preferred_output"].append(rendered_non_preferred_output)

    return result


df["openai_messages"] = df.apply(sample_to_openai_messages, axis=1)

# Drop null rows
df = df[df["openai_messages"].notna()]

df.head()

Split data into training and validation sets for fine-tuning

In [None]:
# Get unique episode_ids
unique_episode_ids = df["episode_id"].unique()

# Shuffle the unique episode_ids
np.random.seed(42)
np.random.shuffle(unique_episode_ids)

# Calculate the split index for episode_ids
split_index = int(len(unique_episode_ids) * (1 - VAL_FRACTION))

# Split the episode_ids into training and validation sets
train_episode_ids = unique_episode_ids[:split_index]
val_episode_ids = unique_episode_ids[split_index:]

# Create training and validation DataFrames based on episode_ids
train_df = df[df["episode_id"].isin(train_episode_ids)]
val_df = df[df["episode_id"].isin(val_episode_ids)]

print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")
print(f"Actual validation fraction: {len(val_df) / len(df):.2f}")

Upload the preared datasets to OpenAI

In [None]:
def upload_dataset_to_openai(df, openai_client) -> str:
    with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
        for item in df["openai_messages"]:
            json.dump(item, f)
            f.write("\n")
        f.flush()

        print(f"File persisted on path [{f.name}]")

        with open(f.name, "rb") as file:
            file_object = openai_client.files.create(file=file, purpose="fine-tune")

        return file_object.id


openai_client = openai.OpenAI()

dpo_fine_tuning_object_id = upload_dataset_to_openai(train_df, openai_client)
val_file_object_id = upload_dataset_to_openai(val_df, openai_client)

Launch the fine-tuning job and wait for it to complete.

In [None]:
fine_tuning_job = openai_client.fine_tuning.jobs.create(
    training_file=dpo_fine_tuning_object_id,
    validation_file=val_file_object_id,
    model=MODEL_NAME,
    method={
        "type": "dpo",
        "dpo": {
            "hyperparameters": {"beta": 0.2},
        },
    },
)

while True:
    clear_output(wait=True)

    try:
        job_status = openai_client.fine_tuning.jobs.retrieve(fine_tuning_job.id)
        pprint(job_status.to_dict())
        if job_status.status in ("succeeded", "failed", "cancelled"):
            break
    except Exception as e:
        print(f"Error: {e}")

    time.sleep(10)

print(f"The fine-tuning job has compeleted with result {job_status.status}")

TODO: Adding the fine-tuned model to the config file

In [None]:
fine_tuned_model = job_status.fine_tuned_model
model_config = {
    "models": {
        fine_tuned_model: {
            "routing": ["openai"],
            "providers": {"openai": {"type": "openai", "model_name": fine_tuned_model}},
        }
    }
}

print(toml.dumps(model_config))

TODO: Adding a new variant to your function to use the fine-tuned model

In [None]:
variant_config = {
    "type": "chat_completion",
    "model": fine_tuned_model,
}

system_template = variant.get("system_template")
if system_template:
    variant_config["system_template"] = system_template

user_template = variant.get("user_template")
if user_template:
    variant_config["user_template"] = user_template

assistant_template = variant.get("assistant_template")
if assistant_template:
    variant_config["assistant_template"] = assistant_template

full_variant_config = {
    "functions": {FUNCTION_NAME: {"variants": {fine_tuned_model: variant_config}}}
}

print(toml.dumps(full_variant_config))