In [1]:
import weave
from weave.trace.refs import ObjectRef
from weave.trace_server.trace_server_interface import (
    EvaluateModelReq,
    EvaluateModelRes,
    ObjCreateReq,
    TableCreateReq,
)

In [2]:
import os

os.environ["WF_TRACE_SERVER_URL"] = "http://localhost:6345"
os.environ["WANDB_BASE_URL"] = "https://api.wandb.test"


weave.init("wandb/eval_test_project_2")


@weave.op()
def simple_test():
    return 1


simple_test()

[36m[1mweave[0m: retry_attempt
[36m[1mweave[0m: retry_attempt
[36m[1mweave[0m: retry_failed
[36m[1mweave[0m: weave version 0.51.55-dev0 has been retired!  Please upgrade.
[36m[1mweave[0m: Logged in as Weights & Biases user: timssweeney.
[36m[1mweave[0m: View Weave data at https://app.wandb.test/wandb/eval_test_project_2/weave


1

[36m[1mweave[0m: 🍩 https://app.wandb.test/wandb/eval_test_project_2/r/call/0197f16c-92da-7a20-814f-113e2b49c596


In [3]:
from weave.trace.context.weave_client_context import require_weave_client


def create_model() -> str:
    """Create a test model and return its reference URI."""
    client = require_weave_client()
    project_id = f"{client.entity}/{client.project}"
    model_object_id = "test_model_for_eval"
    llm_model_val = {
        "llm_model_id": "gpt-4o-mini",
        "default_params": {
            "messages_template": [
                {
                    "role": "system",
                    "content": "You are a helpful assistant",
                },
            ],
            "response_format": "text",
        },
    }
    model_create_res = client.server.obj_create(
        ObjCreateReq.model_validate(
            {
                "obj": {
                    "project_id": project_id,
                    "object_id": model_object_id,
                    "val": llm_model_val,
                    "builtin_object_class": "LLMStructuredCompletionModel",
                }
            }
        )
    )
    return ObjectRef(
        entity=client.entity,
        project=client.project,
        name=model_object_id,
        _digest=model_create_res.digest,
    ).uri()

In [4]:
model_ref_uri = create_model()
print(model_ref_uri)

weave:///wandb/eval_test_project_2/object/test_model_for_eval:opsX9CNP3gzaPq6jU2aeXb3R35h5S9zEY2BT4X7hGXk


In [5]:
def create_dataset() -> str:
    """Create a test dataset and return its reference URI."""
    client = require_weave_client()
    project_id = f"{client.entity}/{client.project}"
    dataset_table_val = [
        {"user_input": "How are you?", "expected": "I'm doing well, thank you!"},
        {"user_input": "What's 2+2?", "expected": "4"},
        {
            "user_input": "Tell me a joke",
            "expected": "Why did the chicken cross the road?",
        },
    ]
    dataset_table_res = client.server.table_create(
        TableCreateReq.model_validate(
            {
                "table": {
                    "project_id": project_id,
                    "rows": dataset_table_val,
                }
            }
        )
    )
    dataset_object_id = "test_eval_dataset"
    dataset_val = {
        "_type": "Dataset",
        "_class_name": "Dataset",
        "_bases": ["Dataset", "Object", "BaseModel"],
        "rows": f"weave:///{project_id}/table/{dataset_table_res.digest}",
    }
    dataset_create_res = client.server.obj_create(
        ObjCreateReq.model_validate(
            {
                "obj": {
                    "project_id": project_id,
                    "object_id": dataset_object_id,
                    "val": dataset_val,
                }
            }
        )
    )
    return ObjectRef(
        entity=client.entity,
        project=client.project,
        name=dataset_object_id,
        _digest=dataset_create_res.digest,
    ).uri()

In [6]:
dataset_ref_uri = create_dataset()
print(dataset_ref_uri)

weave:///wandb/eval_test_project_2/object/test_eval_dataset:X7H8wj2rDp5q93XJbKqXqX0rYugBW5T0MxjwU0i8U3E


In [7]:
def create_scorer() -> str:
    """Create a test scorer and return its reference URI."""
    client = require_weave_client()
    project_id = f"{client.entity}/{client.project}"

    # First create the model for the scorer
    scorer_model_object_id = "test_eval_scorer_model"
    scorer_model_val = {
        "llm_model_id": "gpt-4o-mini",
        "default_params": {
            "messages_template": [
                {
                    "role": "system",
                    "content": "You are an expert judge. Compare the model output to the expected output and return a score from 0 to 1. Please return the score in a JSON object with the key 'score'.",
                },
            ],
            "response_format": "json_object",
        },
    }
    scorer_model_create_res = client.server.obj_create(
        ObjCreateReq.model_validate(
            {
                "obj": {
                    "project_id": project_id,
                    "object_id": scorer_model_object_id,
                    "val": scorer_model_val,
                    "builtin_object_class": "LLMStructuredCompletionModel",
                }
            }
        )
    )
    scorer_model_ref = ObjectRef(
        entity=client.entity,
        project=client.project,
        name=scorer_model_object_id,
        _digest=scorer_model_create_res.digest,
    ).uri()

    # Then create the scorer
    scorer_object_id = "test_eval_llm_judge_scorer"
    scorer_val = {
        "_type": "LLMAsAJudgeScorer",
        "_class_name": "LLMAsAJudgeScorer",
        "_bases": ["LLMAsAJudgeScorer", "Scorer", "Object", "BaseModel"],
        "model": scorer_model_ref,
        "scoring_prompt": "User input: {user_input}\nModel output: {output}\nExpected output: {expected}\n\nScore the similarity (0-1).",
    }
    scorer_create_res = client.server.obj_create(
        ObjCreateReq.model_validate(
            {
                "obj": {
                    "project_id": project_id,
                    "object_id": scorer_object_id,
                    "val": scorer_val,
                }
            }
        )
    )
    return ObjectRef(
        entity=client.entity,
        project=client.project,
        name=scorer_object_id,
        _digest=scorer_create_res.digest,
    ).uri()

In [8]:
scorer_ref_uri = create_scorer()
print(scorer_ref_uri)

weave:///wandb/eval_test_project_2/object/test_eval_llm_judge_scorer:n1f572k0CDURbHAzYyhS0TYWB3YFUYL2PMOSUeKVXKs


In [9]:
def create_evaluation(dataset_ref: str, scorer_ref: str) -> str:
    """Create a test evaluation and return its reference URI."""
    client = require_weave_client()
    project_id = f"{client.entity}/{client.project}"
    evaluation_object_id = "test_evaluation"
    evaluation_val = {
        "_type": "Evaluation",
        "_class_name": "Evaluation",
        "_bases": ["Evaluation", "Object", "BaseModel"],
        "dataset": dataset_ref,
        "scorers": [scorer_ref],
        # Note: You might need to add more fields depending on the Evaluation class structure
    }
    evaluation_create_res = client.server.obj_create(
        ObjCreateReq.model_validate(
            {
                "obj": {
                    "project_id": project_id,
                    "object_id": evaluation_object_id,
                    "val": evaluation_val,
                }
            }
        )
    )
    return ObjectRef(
        entity=client.entity,
        project=client.project,
        name=evaluation_object_id,
        _digest=evaluation_create_res.digest,
    ).uri()

In [10]:
evaluation_ref_uri = create_evaluation(dataset_ref_uri, scorer_ref_uri)
print(evaluation_ref_uri)

weave:///wandb/eval_test_project_2/object/test_evaluation:gObJe2APNUgjgX2JTvrHDPOmaVuYUhzCXys6MYCas90


In [11]:
def evaluate_model_harness(
    evaluation_ref: str,
    model_ref: str,
) -> EvaluateModelRes:
    """Run an evaluation on a model."""
    client = require_weave_client()
    project_id = f"{client.entity}/{client.project}"
    eval_res = client.server.evaluate_model(
        EvaluateModelReq.model_validate(
            {
                "project_id": project_id,
                "evaluation_ref": evaluation_ref,
                "model_ref": model_ref,
            }
        )
    )
    return eval_res

In [15]:
res = evaluate_model_harness(evaluation_ref_uri, model_ref_uri)
print(res)

call_id='0197f173-3f50-7b58-814e-7b4850829f53'
