In [None]:
CONFIG_PATH = "../../examples/haiku_hidden_preferences/config/tensorzero.toml"

FUNCTION_NAME = "write_haiku"
BASE_VARIANT_NAME = "initial_prompt_gpt4o_mini"


METRIC_NAME = "haiku_score"
MAX_SAMPLES = 1000
FLOAT_METRIC_THRESHOLD = 0.5

In [None]:
import json
import os
from pathlib import Path
from typing import Dict, Optional
from urllib.parse import urljoin

import dspy
import toml
from clickhouse_driver import Client

Load the TensorZero configuration file.


In [None]:
config_path = Path(CONFIG_PATH)

assert config_path.exists(), f"{CONFIG_PATH} does not exist"
assert config_path.is_file(), f"{CONFIG_PATH} is not a file"

with config_path.open("r") as f:
    config = toml.load(f)

Retrieve the function configuration for the function we are optimizing.


In [None]:
function_config = config["functions"][FUNCTION_NAME]

In [None]:
def is_flat_schema(schema: dict):
    """Check if a JSON schema (given as a dict) is flat."""
    if not isinstance(schema, dict):
        return False

    if "type" not in schema or schema["type"] != "object":
        return False

    if "properties" not in schema:
        return True

    for prop in schema["properties"].values():
        if prop.get("type") in ["object", "array"]:
            return False

    return True

In [None]:
def preprocess_json_schema(schema: dict) -> Dict[str, str]:
    """
    Preprocess a flat JSON schema to create a mapping of field names to their types.

    Args:
    user_schema (dict): A flat JSON schema.

    Returns:
    Dict[str, str]: A dictionary mapping field names to their types (number, string, bool, or integer).
    """
    assert is_flat_schema(schema), f"JSON schema is not flat: {schema}"
    result = {}
    properties = schema.get("properties", {})

    for field_name, field_info in properties.items():
        field_type = field_info.get("type", "")
        if field_type == "number":
            result[field_name] = "number"
        elif field_type == "string":
            result[field_name] = "string"
        elif field_type == "boolean":
            result[field_name] = "bool"
        elif field_type == "integer":
            result[field_name] = "integer"

    return result

In [None]:
def jsonschema_type_to_python_type(field_type: str) -> str:
    if field_type == "number":
        return "float"
    elif field_type == "string":
        return "str"
    elif field_type == "boolean":
        return "bool"
    elif field_type == "integer":
        return "int"
    else:
        raise ValueError(f"Unsupported field type: {field_type}")

In [None]:
def function_config_to_dspy_signature(function_name: str, function_config: dict):
    assert (
        "system_schema" not in function_config
    ), "System schema not supported by DSPy recipe"
    assert (
        "assistants_schema" not in function_config
    ), "Assistant schema not supported by DSPy recipe"
    assert "user_schema" in function_config, "User schema not found in function config"
    user_schema_path = config_path.parent / function_config["user_schema"]
    with user_schema_path.open("r") as f:
        user_schema = preprocess_json_schema(json.load(f))
    output_schema_path = function_config.get("output_schema", None)
    if output_schema_path:
        output_schema_path = config_path.parent / output_schema_path
        with output_schema_path.open("r") as f:
            output_schema = preprocess_json_schema(json.load(f))
    else:
        output_schema = None
    input_signature = ""
    for field_name, field_type in user_schema.items():
        input_signature += f"{field_name}:{jsonschema_type_to_python_type(field_type)},"
    input_signature = input_signature[:-1]
    ## we don't need to trim the trailing comma because DSPy will handle it correctly (ignore it)
    if output_schema:
        output_signature = ""
        for field_name, field_type in output_schema.items():
            output_signature += (
                f"{field_name}:{jsonschema_type_to_python_type(field_type)},"
            )
        # we don't need to trim the trailing comma because DSPy will handle it correctly (ignore it)
    else:
        output_signature = "output"
    string_signature = f"{input_signature} -> {output_signature}"
    return dspy.make_signature(string_signature, signature_name=function_name)

In [None]:
function_signature = function_config_to_dspy_signature(FUNCTION_NAME, function_config)

Retrieve the database name.


In [None]:
if "clickhouse" in config and "database" in config["clickhouse"]:
    database_name = config["clickhouse"]["database"]
else:
    database_name = "tensorzero"

Initialize the ClickHouse client.


In [None]:
assert (
    "CLICKHOUSE_NATIVE_URL" in os.environ
), "CLICKHOUSE_NATIVE_URL environment variable not set"

url_with_database = urljoin(os.environ["CLICKHOUSE_NATIVE_URL"], database_name)

clickhouse_client = Client.from_url(url_with_database)

Retrieve the metric configuration.


In [None]:
assert "metrics" in config, "No `[metrics]` section found in config"
assert (
    METRIC_NAME in config["metrics"]
), f"No metric named `{METRIC_NAME}` found in config"

metric = config["metrics"][METRIC_NAME]

metric

Determine the ClickHouse table name for the metric.


In [None]:
feedback_table_name = {
    "float": "FloatMetricFeedback",
    "boolean": "BooleanMetricFeedback",
}.get(metric["type"])

if feedback_table_name is None:
    raise ValueError(f"Unsupported metric type: {metric['type']}")

In [None]:
assert "type" in metric, "Metric is missing the `type` field"
assert "optimize" in metric, "Metric is missing the `optimize` field"

threshold = FLOAT_METRIC_THRESHOLD if metric["type"] == "float" else 0.5
comparison_operator = ">=" if metric["optimize"] == "max" else "<="

query = f"""
SELECT 
    i.variant_name, 
    i.input, 
    i.output, 
    f.value
FROM 
    tensorzero.Inference i
JOIN 
    tensorzero.BooleanMetricFeedback f ON i.id = f.target_id
WHERE 
    i.function_name = %(function_name)s
    AND f.value {comparison_operator} %(threshold)s
LIMIT %(max_samples)s
"""

params = {
    "database_name": database_name,
    "feedback_table_name": feedback_table_name,
    "function_name": FUNCTION_NAME,
    "comparison_operator": comparison_operator,
    "threshold": threshold,
    "max_samples": MAX_SAMPLES,
}

df = clickhouse_client.query_dataframe(query, params)

df.head()