In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import datetime
import hashlib
import json 

import wandb
from wandb.sdk.data_types import trace_tree

import weave
from weave.monitoring import StreamTable

from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.callbacks.tracers import wandb as LCW

from uuid import uuid4

def wb_span_to_weave_spans(wb_span, trace_id=None, parent_id=None):
    attributes = {**wb_span.attributes}
    attributes['llm_span_kind'] = str(wb_span.span_kind)
    inputs = wb_span.results[0].inputs if wb_span.results is not None and len(wb_span.results) > 0 else None
    outputs = wb_span.results[0].outputs if wb_span.results is not None and len(wb_span.results) > 0 else None

    # print(wb_span)
    # TraceSpanDict(
    #     name=name,
    #     span_id=span_id,
    #     trace_id=trace_id,
    #     status_code=status_code,
    #     start_time_s=start_time_s,
    #     end_time_s=end_time_s,
    #     parent_id=parent_id,
    #     attributes=attributes,
    #     inputs=inputs,
    #     output=output,
    #     summary=summary,
    #     exception=exception,
    # )
    
    span = weave.stream_data_interfaces.TraceSpanDict(
        start_time_s = wb_span.start_time_ms / 1000.,
        end_time_s = wb_span.end_time_ms / 1000.,
        span_id = wb_span.span_id,
        name = wb_span.name,
        status_code = str(wb_span.status_code),
        trace_id = trace_id or str(uuid4()),
        parent_id = parent_id,
        attributes = attributes or {},
        inputs = inputs,
        output = outputs,
        exception = Exception(wb_span.status_message) if wb_span.status_message is not None else None,
        summary = {
            # // Hack for now - need to make this not required (both perf and summary)
            "latency": (wb_span.end_time_ms / 1000.) - (wb_span.start_time_ms / 1000.)
        },
    )
    spans = [span]
    for child in (wb_span.child_spans or []):
        spans += wb_span_to_weave_spans(child, span['trace_id'], span['span_id'])
    
    return spans

def _hash_id(s: str) -> str:
    return hashlib.md5(s.encode("utf-8")).hexdigest()[:16]
    
class WeaveTracer(BaseTracer):
    def __init__(self, stream_uri: str, **kwargs) -> None:
        super().__init__(**kwargs)
        self.run_processor = LCW.RunProcessor(wandb, trace_tree)
        self._st = StreamTable(stream_uri)

    def _persist_run(self, run: Run) -> None:
        root_span = self.run_processor.process_span(run)
        model_dict = self.run_processor.process_model(run)
        model_str = json.dumps(model_dict)
        root_span.attributes["model"] = {
            'id': _hash_id(model_str),
            'obj': model_str,
        }
        spans = wb_span_to_weave_spans(root_span)
        for span in spans:
            self._st.log(span)


In [None]:
from langchain.agents import load_tools
from langchain.agents import initialize_agent
from langchain.agents import AgentType
from langchain.llms import OpenAI

llm = OpenAI(temperature=0)
tools = load_tools(["llm-math"], llm=llm)
agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)

In [None]:
tracer = WeaveTracer("timssweeney/monitor_dev_test/stream_11")

In [None]:
# from langchain import PromptTemplate, LLMChain
# template = """Answer the following question: {question}."""

# prompt = PromptTemplate(template=template, input_variables=["question"])
# llm_chain = LLMChain(prompt=prompt, llm=llm)
# answer = llm_chain.run(
#     "Produce a JSON document containing a list of 10 questions to ask an AI assistant that spoecializes in math problems."+
#      " Each question must be a math problem. The format must be: `[{\"question\": QUESTION}, ...]`", callbacks=[tracer])

In [None]:
# answer_as_json = json.loads(answer)
# answer_as_json
answer_as_json = questions = [{'question': 'What is the sum of 5 and 7?'},
 {'question': 'What is the product of 4 and 6?'},
 {'question': 'What is the square root of 64?'},
 {'question': 'What is the area of a circle with a radius of 5?'},
 {'question': 'What is the slope of the line y = 3x + 2?'},
 {'question': 'What is the value of x in the equation 3x + 4 = 10?'},
 {'question': 'What is the value of y in the equation y = 2x + 1 when x = 3?'},
 {'question': 'What is the equation of the line that passes through the points (2, 4) and (3, 6)?'},
 {'question': 'What is the volume of a cube with a side length of 5?'},
 {'question': 'What is the value of x in the equation x2 + 3x - 4 = 0?'}]

In [None]:
questions = [row['question'] for row in answer_as_json]

In [None]:
for question in questions:
  try:
    answer = agent.run(question, callbacks=[tracer])
    print(answer)
  except Exception as e:
    print(e)
    pass