In [None]:
ENTITY = 'timssweeney'
PROJECT =  'weave'
STREAM = 'custom_llm_monitoring_example'

In [None]:
import uuid
import time
import typing
from datetime import datetime 

from weave.stream_data_interfaces import LLMCompletionDict, _LLMCompletionInputs, _LLMCompletionOutput, _LLMCompletionSummary, _LLMCompletionMessage, _LLMCompletionChoice
import tiktoken

def count_chat_completion_tokens(
    model_name: typing.Optional[str] = None,
    prompt_input_messages: typing.List[_LLMCompletionMessage] = [],
    completion_choices: typing.List[_LLMCompletionChoice] = []
) -> dict:
        
    summary = {}
    if not model_name:
        encoding = tiktoken.get_encoding("cl100k_base")
    else:
        encoding = tiktoken.encoding_for_model(model_name)

    prompt_tokens = (encoding.encode(m["content"]) for m in prompt_input_messages)
    summary["prompt_tokens"] = sum(len(c) for c in prompt_tokens)

    completion_tokens = (
        encoding.encode(c["message"]["content"]) for c in completion_choices
    )
    summary["completion_tokens"] = sum(len(c) for c in completion_tokens)
    summary["total_tokens"] = summary["prompt_tokens"] + summary["completion_tokens"]
    return summary

def create_llm_record(

        # The end-format is that of `_LLMCompletionInputs`, but we can add helper processing to convert from other formats
        inputs: typing.Union[_LLMCompletionInputs, typing.List[_LLMCompletionMessage], typing.List[str], str] = None,

        # the end-format is that of `_LLMCompletionOutput`, but we can add helper processing to convert from other formats
        output: typing.Union[_LLMCompletionOutput, str] = None,

        # Required for cost analysis
        model_name: str = "",

        # Optional fields
        span_id: typing.Optional[str] = None,
        name: typing.Optional[str] = None,
        trace_id: typing.Optional[str] = None,
        status_code: typing.Optional[str] = None,
        start_time_s: typing.Optional[float] = None,
        end_time_s: typing.Optional[float] = None,
        parent_id: typing.Optional[str] = None,

        # Must contain dictionaries, lists, and primitives (ie json serializable)
        attributes: typing.Optional[typing.Dict[str, typing.Any]] = None,

        summary: typing.Optional[_LLMCompletionSummary] = None,

        exception: typing.Optional[str] = None
):
    span_id = span_id or str(uuid.uuid4())
    name = name or "llm_completion"
    trace_id = trace_id or span_id
    status_code = status_code or "UNSET"
    start_time_s = start_time_s or time.time()
    end_time_s = end_time_s or (start_time_s + 1)
    latency_s = end_time_s - start_time_s
    # parent_id can be None
    attributes = attributes or {}
    

    # Input handling
    if isinstance(inputs, dict):
        # Assume correct format
        inputs = inputs
    elif isinstance(inputs, list):
        messages = []
        for item in inputs:
            if isinstance(item, str):
                messages.append(_LLMCompletionMessage(content=item))
            elif isinstance(item, dict):
                # Assume correct format
                messages.append(item)
            else:
                raise ValueError(f"Invalid type for item in inputs: {type(item)}")
        inputs = _LLMCompletionInputs(messages=messages)
    elif isinstance(inputs, str):
        inputs = _LLMCompletionInputs(messages=[_LLMCompletionMessage(content=inputs)])
    else:
        raise ValueError(f"Invalid type for inputs: {type(inputs)}")
    

    # Output handling
    if isinstance(output, dict):
        # Assume correct format
        output = output
    elif isinstance(output, str):
        output = _LLMCompletionOutput(model=model_name, choices=[_LLMCompletionChoice(message=_LLMCompletionMessage(content=output))])
    else:
        raise ValueError(f"Invalid type for output: {type(output)}")

    # Sort of odd, but we need at least one key for now
    summary = summary or {}
    summary = {
        **(summary or {}),
        "latency_s": latency_s,
        **count_chat_completion_tokens(
            model_name,
            inputs['messages'],
            output['choices'])
    }
    

    # exception can be None

    assert status_code in ["SUCCESS", "ERROR", "UNSET"]

    return LLMCompletionDict(
        span_id = span_id,
        name = name,
        trace_id = trace_id,
        status_code = status_code,
        start_time_s = start_time_s,
        end_time_s = end_time_s,
        parent_id = parent_id,
        attributes = attributes,
        inputs = inputs,
        output = output,
        summary = summary,
        exception = exception,
        # Manually set timestamp - else it will be set to the time of the function call
        timestamp = datetime.fromtimestamp(start_time_s)
    )

In [None]:
from weave.legacy.weave.monitoring import StreamTable
st = StreamTable(f"{ENTITY}/{PROJECT}/{STREAM}")

In [None]:
# Base Case
record = create_llm_record(
    inputs="hello",
    output="world"
)
st.log(record)