In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [9]:
import datetime
import hashlib
import json 

import wandb
from wandb.sdk.data_types import trace_tree

import weave
from weave.monitoring import StreamTable

from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.callbacks.tracers import wandb as LCW


def wb_span_to_weave_spans(wb_span, trace_id=None, parent_id=None):
    attributes = {**wb_span.attributes}
    attributes['llm_span_kind'] = str(wb_span.span_kind)
    inputs = wb_span.results[0].inputs if wb_span.results is not None and len(wb_span.results) > 0 else None
    outputs = wb_span.results[0].outputs if wb_span.results is not None and len(wb_span.results) > 0 else None
    
    span = weave.monitoring.monitor.TraceSpanStruct(
        start_time = datetime.datetime.fromtimestamp(wb_span.start_time_ms / 1000.),
        end_time = datetime.datetime.fromtimestamp(wb_span.end_time_ms / 1000.),
        id = wb_span.span_id,
        name = wb_span.name,
        status_code = str(wb_span.status_code),
        trace_id = trace_id,
        parent_id = parent_id,
        attributes = attributes,
        inputs = inputs,
        output = outputs,
        exception = Exception(wb_span.status_message) if wb_span.status_message is not None else None,
        summary = None,
    )
    spans = [span]
    for child in (wb_span.child_spans or []):
        spans += wb_span_to_weave_spans(child, span.trace_id, span.id)
    
    return spans

def _hash_id(s: str) -> str:
    return hashlib.md5(s.encode("utf-8")).hexdigest()[:16]
    
class WeaveTracer(BaseTracer):
    def __init__(self, stream_uri: str, **kwargs) -> None:
        super().__init__(**kwargs)
        self.run_processor = LCW.RunProcessor(wandb, trace_tree)
        self._st = StreamTable(stream_uri)

    def _persist_run(self, run: Run) -> None:
        root_span = self.run_processor.process_span(run)
        model_dict = self.run_processor.process_model(run)
        model_str = json.dumps(model_dict)
        root_span.attributes["_model_dict"] = {
            'dumps': model_str,
            'hash': _hash_id(model_str)
        }
        spans = wb_span_to_weave_spans(root_span)
        for span in spans:
            self._st.log(span.asdict())


In [10]:
from langchain.agents import load_tools
from langchain.agents import initialize_agent
from langchain.agents import AgentType
from langchain.llms import OpenAI

llm = OpenAI(temperature=0)
tools = load_tools(["llm-math"], llm=llm)
agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)

In [11]:
tracer = WeaveTracer("timssweeney/monitor_dev_test/stream_1")

questions = [
    "Find the square root of 5.4.",
    "What is 3 divided by 7.34 raised to the power of pi?",
    "What is the sin of 0.47 radians, divided by the cube root of 27?"
]

for question in questions:
  try:
    answer = agent.run(question, callbacks=[tracer])
    print(answer)
  except Exception as e:
    print(e)
    pass

2.3
0.005720801417544866
0.43737990984599
