In [1]:
import tabulate
from pydantic import BaseModel


def print_quick_table(rows:list[dict | BaseModel]):
    final_rows = []
    for row in rows:
        if isinstance(row, BaseModel):
            final_rows.append(row.model_dump())
        else:
            final_rows.append(row)
    print(tabulate.tabulate(final_rows, headers="keys", tablefmt="grid"))

In [12]:
from weave.trace.refs import ObjectRef, TableRef
from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer
from weave.wandb_interface import wandb_api

entity = "timssweeney"
project = "on_demand_eval_demo_4"
project_id = f"{entity}/{project}"

wandb_api.init()
wandb_context = wandb_api.get_wandb_api_context()

server_client = RemoteHTTPTraceServer("http://127.0.01:6345")
server_client.set_auth(("", wandb_context.api_key))

In [13]:
# Step 1.a: Create Dataset
#
# These low level calls are facilitated by the dataset uploader & edit UIs

table = server_client.table_create({
    "table": {
        "project_id": project_id,
        "rows": [
            {"input": "United States"},
            {"input": "Canada"},
            {"input": "Mexico"},
        ],
    },
})
table_digest = table.digest
table_ref = TableRef(
    entity=entity,
    project=project,
    _digest=table_digest,
).uri()

dataset = server_client.obj_create({
    "obj": {
        "project_id": project_id,
        "object_id": "country_abbrev_dataset",
        "val": {"_type": "Dataset", "_class_name": "Dataset", "_bases": ["Object", "BaseModel"], "rows": table_ref},
    },
})
dataset_digest = dataset.digest
dataset_ref = ObjectRef(
    entity=entity,
    project=project,
    name="country_abbrev_dataset",
    _digest=dataset_digest,
).uri()


In [None]:
datasets = server_client.objs_query({
    "project_id": project_id,
    "filter": {
        "base_object_classes": ["Dataset"]
    }
})

print_quick_table(datasets.objs)

In [15]:
# Step 2.a: Create Scorer
#
# These low level calls are facilitated by the scorer builder UI

scorer = server_client.obj_create({
    "obj": {
        "project_id": project_id,
        "object_id": "abbreviation_correctness_scorer",
        "builtin_object_class": "LLMJudgeScorer",
        "val": {
            "model": "gpt-4o-mini",
            "system_prompt": "Determine if the output is the correct abbreviation for the given country.",
            "response_format": {
            "type": "json_schema",
            "json_schema": {
                "name": "response",
                "schema": {
                    "type": "object",
                    "properties": {
                        "passed": {"type": "boolean"},
                        "reason": {"type": "string"},
                    },
                },
            },
        },
        },
    },
})

scorer_digest = scorer.digest
scorer_ref = ObjectRef(
    entity=entity,
    project=project,
    name="correctness_scorer",
    _digest=scorer_digest,
).uri()


In [None]:
scorers = server_client.objs_query({
    "project_id": project_id,
    "filter": {
        "base_object_classes": ["Scorer"]
    }
})

print_quick_table(scorers.objs)

In [17]:
# Step 3.a: Create Evaluation
#
# These low level calls are facilitated by the evaluation builder UI

evaluation = server_client.obj_create({
    "obj": {
        "project_id": project_id,
        "object_id": "country_abbrev_evaluation",
        "val": {"_type": "Evaluation", "_class_name": "Evaluation", "_bases": ["Object", "BaseModel"],
                "dataset": dataset_ref,
                "scorers": [scorer_ref],
                },
    },
})
evaluation_digest = evaluation.digest
evaluation_ref = ObjectRef(
    entity=entity,
    project=project,
    name="country_abbrev_evaluation",
    _digest=evaluation_digest,
).uri()

In [None]:
evaluations = server_client.objs_query({
    "project_id": project_id,
    "filter": {
        "base_object_classes": ["Evaluation"]
    }
})

print_quick_table(evaluations.objs)

In [19]:
# Step 4.a: Create a model
#
# These low level calls are facilitated by the model builder UI


model = server_client.obj_create({
    "obj": {
        "project_id": project_id,
        "object_id": "correctness_model",
        "builtin_object_class": "LiteLLMCompletionModel",
        "val": {
            "model": "gpt-4o-mini",
            "messages_template": [
                {
                    "role": "system",
                    "content": "Determine the abbreviation for the given country.",
                },
                {"role": "user", "content": "{input}"},
            ]
        },
    },
})

model_digest = model.digest
model_ref = ObjectRef(
    entity=entity,
    project=project,
    name="correctness_model",
    _digest=model_digest,
).uri()

In [None]:
# Step 5.a: Run the evaluation:

results = server_client.evaluate_stream({
    "project_id": project_id,
    "evaluation_ref": evaluation_ref,
    "model_ref": model_ref,
})

async for result in await results:
    print(result)

In [None]:
evaluate_calls = server_client.calls_query_stream({
    "project_id": project_id,
    "filter": {
        "op_names": [f"weave:///{project_id}/op/Evaluation.evaluate:*"],
        "input_refs": [
            evaluation_ref,
        ],
        "trace_roots_only": True,
    }
})

print_quick_table(list(evaluate_calls))

In [None]:
import weave
import os
os.environ["WF_TRACE_SERVER_URL"] = "http://127.0.01:6345"
correctness_model = weave.ref("weave:///timssweeney/on_demand_eval_demo_3/object/correctness_model:3Y1Z6Q6okGNtFunX9OYYqxggIO3MxtFyXkzTPmYFjxA").get()
correctness_model.predict(input="United States")