In [None]:
import logging
import sys

root = logging.getLogger()
root.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.addFilter(logging.Filter("trulens"))
handler.setFormatter(formatter)
root.addHandler(handler)

In [None]:
# Create snowpark session.
import json

from snowflake.snowpark import Session
from trulens.connectors.snowflake import SnowflakeConnector

snowflake_connection_parameters = {
    "account": "cortexsearch.qa6.us-west-2.aws",
    "user": "ADMIN",
    "database": "SALES_INTELLIGENCE",
    "schema": "AGENTS",
    "password": "...",
    "role": "ACCOUNTADMIN",
    "warehouse": "ADMIN_WH",
}
snowpark_session = Session.builder.configs(
    snowflake_connection_parameters
).create()

sf_connector = SnowflakeConnector(snowpark_session=snowpark_session)

In [None]:
APP_NAME = "SALES_INTELLIGENCE_AGENT"

In [None]:
all_df = snowpark_session.sql(
    """
    SELECT
        *
    FROM TABLE(SNOWFLAKE.LOCAL.GET_AI_OBSERVABILITY_EVENTS(
        'SNOWFLAKE_INTELLIGENCE', 
        'AGENTS', 
        ?, 
        'CORTEX AGENT'
    ))
""",
    params=[APP_NAME],
).to_pandas()

In [None]:
RECORD_IDS = ["c389923a-9ac1-4127-93f4-5f45ebae2fe3"]

In [None]:
from typing import Dict, List

import pandas as pd


def _retrieve_trace_for_record_id(
    snowpark_session,
    app_name=APP_NAME,
    thread_id=None,
    record_id=None,
):
    base_query = """
        SELECT
            *
        FROM TABLE(SNOWFLAKE.LOCAL.GET_AI_OBSERVABILITY_EVENTS(
            'SNOWFLAKE_INTELLIGENCE', 
            'AGENTS', 
            ?, 
            'CORTEX AGENT'
        ))
        WHERE RECORD_TYPE = 'SPAN'
    """

    where_clauses = []
    params = [app_name]

    if thread_id:
        where_clauses.append(
            'RECORD_ATTRIBUTES:"snow.ai.observability.agent.thread_id" = ?'
        )
        params.append(str(thread_id))

    if record_id:
        where_clauses.append(
            'RECORD_ATTRIBUTES:"ai.observability.record_id" = ?'
        )  # request id is actually record id in otel
        params.append(record_id)

    if where_clauses:
        query = base_query + " AND " + " AND ".join(where_clauses)
    else:
        query = base_query

    # Add ordering for trace assembly
    query += " ORDER BY START_TIMESTAMP ASC"

    return snowpark_session.sql(query, params=params).to_pandas()


def build_span_hierarchy_with_content(trace_df: pd.DataFrame):
    """Build a hierarchical view of the spans with concatenated record_attributes content"""
    spans = {}

    for idx, row in trace_df.iterrows():
        trace_info = json.loads(row["TRACE"])
        record_info = json.loads(row["RECORD"])
        record_attrs = json.loads(row["RECORD_ATTRIBUTES"])

        span_id = trace_info.get("span_id")
        parent_span_id = record_info.get("parent_span_id")
        span_name = record_info.get("name", "unknown")

        spans[span_id] = {
            "name": span_name,
            "parent_id": parent_span_id,
            "start_time": row["START_TIMESTAMP"],
            "record": record_info,
            "attributes": record_attrs,
            "full_content": record_attrs,  # Store all attributes for concatenation
        }

    # Build hierarchy and concatenate content
    root_spans = [
        span_id
        for span_id, span in spans.items()
        if span["parent_id"] is None or span["parent_id"] not in spans
    ]

    def _print_hierarchy_with_content(span_id, level=0):
        if span_id not in spans:
            return ""

        span = spans[span_id]
        indent = "  " * level

        # Format all record_attributes content
        content_lines = []
        for key, value in span["attributes"].items():
            if value is not None and str(value).strip():  # Skip empty values
                content_lines.append(f"{key}: {value}")

        content_str = "\n".join([f"{indent}  {line}" for line in content_lines])

        span_header = (
            f"{indent}{span['name']} ({span_id}) - {span['start_time']}"
        )
        full_span_content = (
            f"{span_header}\n{content_str}" if content_str else span_header
        )

        print(full_span_content)

        # Find and process children
        children = [
            sid for sid, s in spans.items() if s["parent_id"] == span_id
        ]
        for child_id in sorted(children, key=lambda x: spans[x]["start_time"]):
            _print_hierarchy_with_content(child_id, level + 1)

    print("\nSpan Hierarchy with Full Content:")
    print("=" * 80)
    for root_id in sorted(root_spans, key=lambda x: spans[x]["start_time"]):
        _print_hierarchy_with_content(root_id)
        print("-" * 80)

    return spans


def _get_concatenated_trace_content(trace_df):
    """Get all record_attributes content concatenated in chronological order"""
    all_content = []

    for idx, row in trace_df.iterrows():
        record_attrs = json.loads(row["RECORD_ATTRIBUTES"])
        record_info = json.loads(row["RECORD"])

        span_name = record_info.get("name", "unknown")
        timestamp = row["START_TIMESTAMP"]

        content_block = f"\n[{timestamp}] {span_name}:\n"

        for key, value in record_attrs.items():
            if value is not None and str(value).strip():
                content_block += f"  {key}: {value}\n"

        all_content.append(content_block)

    return "\n".join(all_content)


def get_agent_traces(snowpark_session, record_ids: List[str]) -> Dict[str, str]:
    record_id_to_trace = {}
    for record_id in record_ids:
        trace_df = _retrieve_trace_for_record_id(
            snowpark_session, record_id=record_id
        )
        record_id_to_trace[record_id] = _get_concatenated_trace_content(
            trace_df
        )
    return record_id_to_trace


record_id_to_trace = get_agent_traces(snowpark_session, RECORD_IDS)

# # Usage:
# trace_df = _retrieve_trace_for_record_id(snowpark_session, record_id='ae105c3c-1735-4696-8a71-333f34eb01d0')

# # Hierarchical view with all content
# span_hierarchy = build_span_hierarchy_with_content(trace_df)

# # Linear concatenated view
# assembled_trace = _get_concatenated_trace_content(trace_df)
# print("\nConcatenated Trace Content:")
# print("=" * 80)
# print(assembled_trace) # THIS IS THE TRACE TO BE USED IN THE EVALUATION PROMPT

In [None]:
record_id_to_trace["c389923a-9ac1-4127-93f4-5f45ebae2fe3"]

In [None]:
import uuid

from trulens.apps.app import TruApp
from trulens.core.run import RunConfig

APP_VERSION = "V1"

tru_app = TruApp(
    app=None,  # No app object needed for virtual runs
    app_name=APP_NAME,
    app_version=APP_VERSION,
    connector=sf_connector,
)

# Create run config with dataset specification
run_name = f"DANIEL_GT_TEST_RUN_{uuid.uuid4()}"
run_config = RunConfig(
    run_name=run_name,
    dataset_name="AGENT_TRACES",
    source_type="TABLE",
    dataset_spec={
        # "record_root.input": "AGENT_TRACE",
        # "input_id": "RECORD_ID",
    },
)

# virtual_run = tru_app.add_run(run_config=run_config)
virtual_run = tru_app.get_run(run_name=run_name)

print(f"Created virtual run: {run_name}")

In [None]:
# Start the virtual run - this will create OTEL spans from existing data
# virtual_run.start(virtual=True)

In [None]:
tru_snowflake_connection_parameters = {
    "account": "SNOWHOUSE",
    "user": "dhuang",
    "database": "SNOWFLAKE_INTELLIGENCE",
    "schema": "AGENTS",
    "authenticator": "externalbrowser",
}
tru_snowpark_session = Session.builder.configs(
    tru_snowflake_connection_parameters
).create()

In [None]:
from trulens.core.feedback.custom_metric import MetricConfig
from trulens.core.feedback.selector import Selector
from trulens.otel.semconv.trace import SpanAttributes
from trulens.providers.cortex import Cortex

trace_eval_provider = Cortex(
    model_engine="claude-4-sonnet", snowpark_session=tru_snowpark_session
)

f_logical_consistency = MetricConfig(
    metric_implementation=trace_eval_provider.logical_consistency_with_cot_reasons,
    metric_name="Logical Consistency",
    selectors={
        "trace": Selector(  # Parameter name in the function
            span_type=SpanAttributes.SpanType.RECORD_ROOT,
            span_attribute=SpanAttributes.RECORD_ROOT.INPUT,
        ),
    },
)

f_execution_efficiency = MetricConfig(
    metric_implementation=trace_eval_provider.execution_efficiency_with_cot_reasons,
    metric_name="Execution Efficiency",
    selectors={
        "trace": Selector(  # Parameter name in the function
            span_type=SpanAttributes.SpanType.RECORD_ROOT,
            span_attribute=SpanAttributes.RECORD_ROOT.INPUT,
        ),
    },
)


f_tool_calling = MetricConfig(
    metric_implementation=trace_eval_provider.tool_calling_with_cot_reasons,
    metric_name="Tool Calling",
    selectors={
        "trace": Selector(
            span_type=SpanAttributes.SpanType.RECORD_ROOT,
            span_attribute=SpanAttributes.RECORD_ROOT.INPUT,
        ),
    },
)

f_tool_selection = MetricConfig(
    metric_implementation=trace_eval_provider.tool_selection_with_cot_reasons,
    metric_name="Tool Selection",
    selectors={
        "trace": Selector(
            span_type=SpanAttributes.SpanType.RECORD_ROOT,
            span_attribute=SpanAttributes.RECORD_ROOT.INPUT,
        ),
    },
)

metrics_to_compute = [
    # f_execution_efficiency,
    # f_logical_consistency,
    # f_tool_calling,
    f_tool_selection,
]

In [None]:
import time

while virtual_run.get_status() != "INVOCATION_COMPLETED":
    time.sleep(3)

virtual_run.compute_metrics(metrics_to_compute)

In [None]:
virtual_run.get_status()

In [None]:
gpa_evals_df = virtual_run.get_record_details()
# gpa_evals_df = virtual_run.get_records()

In [None]:
gpa_evals_df.to_json("build_demo_gpa_evals.json")