In [None]:
from dotenv import load_dotenv

In [None]:
_ = load_dotenv("../.env")

In [None]:
import os
from collections import defaultdict

from langfuse import Langfuse

In [None]:
langfuse = Langfuse(
    secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
    public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
    host=os.getenv("LANGFUSE_HOST"),
)

In [None]:
traces = langfuse.fetch_traces(user_id="srm", tags=["kba-test", "2025-05-20"])

In [None]:
traces.meta

In [None]:
def extract_tool_call_details(trace):
    tool_call_details = defaultdict(dict)

    for message in trace.output["messages"]:
        if isinstance(message["content"], list) and message["type"] == "ai":
            for chunk in message["content"]:
                if chunk["type"] == "tool_use":
                    tool_call_details[chunk["name"]]["id"] = chunk["id"]
                    tool_call_details[chunk["name"]]["input"] = chunk["input"]

    for message in trace.output["messages"]:
        if message["type"] == "tool":
            for tool in tool_call_details:
                if message["tool_call_id"] == tool_call_details[tool]["id"]:
                    tool_call_details[tool]["output"] = message["content"]

    return tool_call_details

def score_tool_calls(trace):
    tool_call_details = extract_tool_call_details(trace)
    tools_called = set(tool_call_details.keys())
    tools_not_called = set(trace.metadata["tools"].split(",")) - tools_called
    if len(tools_not_called):
        return 0
    else:
        return 1

def score_location_tool(trace):
    tool_call_details = extract_tool_call_details(trace)
    if "location-tool" in tool_call_details:
        score_input = tool_call_details["location-tool"]["input"] == trace.metadata["location_tool_input"]
        score_output = eval(tool_call_details["location-tool"]["output"])[0] == trace.metadata["location_tool_output"][0]
        if score_input and score_output:
            return True
    return False

In [None]:
for trace in traces.data:
    tool_call_details = extract_tool_call_details(trace)
    for k in tool_call_details:
        print(k)
    print("===")

In [None]:
for trace in traces.data:
    langfuse.score(
        trace_id=trace.id,
        name="score_tool_calls",
        value=score_tool_calls(trace)
    )

    langfuse.score(
        trace_id=trace.id,
        name="score_location_tool",
        value=score_location_tool(trace)
    )