<a target="_blank" href="https://colab.research.google.com/github/parambharat/wandb-addons/blob/prompts/trace-api/docs/prompts/examples/Trace_QuickStart.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

A quick start example that demonstrates how to use the `Trace` class, a high-level API to log LLM calls with the wandb prompts feature.

In [None]:
!pip install -qqq -U openai langchain wandb

In [None]:
import datetime

import wandb
from langchain.chains import LLMChain
from langchain.llms.fake import FakeListLLM
from langchain.prompts import PromptTemplate

from wandb_addons.prompts import Trace

In [None]:
PROJECT = "high_level_trace"

In [None]:
# trace langchain chains
run = wandb.init(project=PROJECT)
llm = FakeListLLM(responses=[f"Fake response: {i}" for i in range(100)])
prompt_template = "What is a good name for a company that makes {product}?"
prompt = PromptTemplate(
    input_variables=["product"],
    template=prompt_template,
)

chain = LLMChain(llm=llm, prompt=prompt)

for i in range(2):
    product = f"q: {i} - {datetime.datetime.now().timestamp()}"
    start_time_ms = datetime.datetime.now().timestamp() * 1000
    response = chain(product)
    end_time_ms = datetime.datetime.now().timestamp() * 1000
    trace = Trace(
        name=f"fake_chain_{i}",
        kind="chain",
        status_code="success",
        metadata=None,
        start_time_ms=start_time_ms,
        end_time_ms=end_time_ms,
        inputs={"prompt": prompt_template.format(product=product)},
        outputs={"response": response["text"]},
    )
    trace.log(name=f"trace_{i}")
run.finish()

In [None]:
# trace openai api calls
from getpass import getpass
import openai

openai.api_key = getpass("Please enter your openai api key")

In [None]:
run = wandb.init(project=PROJECT)
request_kwargs = dict(
    model="gpt-3.5-turbo",
    messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Who won the world series in 2020?"},
        {
            "role": "assistant",
            "content": "The Los Angeles Dodgers won the World Series in 2020.",
        },
        {"role": "user", "content": "Where was it played?"},
    ],
)

start_time_ms = datetime.datetime.now().timestamp() * 1000
response = openai.ChatCompletion.create(**request_kwargs)
end_time_ms = datetime.datetime.now().timestamp() * 1000

trace = Trace(
    name=f"openai_chat_completion",
    kind="llm",
    status_code="success",
    metadata={"model": "gpt-3.5-turbo"},
    start_time_ms=start_time_ms,
    end_time_ms=end_time_ms,
    inputs={"messages": request_kwargs["messages"]},
    outputs={"response": response.choices[0]["message"]["content"]},
)

trace.log(name=f"openai_trace")
run.finish()
display(run)

In [None]:
# use with promprtify
!pip install -qqq -U promptify

In [None]:
from promptify import OpenAI
from promptify import Prompter

run = wandb.init(project=PROJECT)

# NER example
sentence = "The patient is a 93-year-old female with a medical history of chronic right hip pain, osteoporosis, hypertension, depression, and chronic atrial fibrillation admitted for evaluation and management of severe nausea and vomiting and urinary tract infection"

model = OpenAI(openai.api_key)  # or `HubModel()` for Huggingface-based inference
nlp_prompter = Prompter(model)

start_time_ms = datetime.datetime.now().timestamp() * 1000
result = nlp_prompter.fit(
    "ner.jinja", domain="medical", text_input=sentence, labels=None
)
end_time_ms = datetime.datetime.now().timestamp() * 1000


trace = Trace(
    name=f"openai_chat_completion",
    kind="llm",
    status_code="success",
    metadata={k: v for k, v in result.items() if k != "text"},
    start_time_ms=start_time_ms,
    end_time_ms=end_time_ms,
    inputs={"sentence": sentence},
    outputs={"entities": result["text"]},
)
trace.log(name="promptify_ner")
run.finish()
display(run)

In [None]:
!pip install -qqq -U guidance

In [None]:
import guidance

run = wandb.init(project=PROJECT)
# define the model we will use
guidance.llm = guidance.llms.OpenAI("text-davinci-003", api_key=openai.api_key)

# define the few shot examples
examples = [
    {
        "input": "I wrote about shakespeare",
        "entities": [
            {"entity": "I", "time": "present"},
            {"entity": "Shakespeare", "time": "16th century"},
        ],
        "reasoning": "I can write about Shakespeare because he lived in the past with respect to me.",
        "answer": "No",
    },
    {
        "input": "Shakespeare wrote about me",
        "entities": [
            {"entity": "Shakespeare", "time": "16th century"},
            {"entity": "I", "time": "present"},
        ],
        "reasoning": "Shakespeare cannot have written about me, because he died before I was born",
        "answer": "Yes",
    },
]

# define the guidance program
structure_prompt = guidance(
    """Given a sentence tell me whether it contains an anachronism (i.e. whether it could have happened or not based on the time periods associated with the entities).
----

{{~! display the few-shot examples ~}}
{{~#each examples}}
Sentence: {{this.input}}
Entities and dates:{{#each this.entities}}
{{this.entity}}: {{this.time}}{{/each}}
Reasoning: {{this.reasoning}}
Anachronism: {{this.answer}}
---
{{~/each}}

{{~! place the real question at the end }}
Sentence: {{input}}
Entities and dates:
{{gen "entities"}}
Reasoning:{{gen "Reasoning"}}
Anachronism:{{#select "answer"}} Yes{{or}} No{{/select}}"""
)

start_time_ms = datetime.datetime.now().timestamp() * 1000
# execute the program
result = structure_prompt(examples=examples, input="The T-rex bit my dog")
end_time_ms = datetime.datetime.now().timestamp() * 1000
# trace guidance

trace = Trace(
    name=f"guidance_anachronism",
    kind="llm",
    status_code="success",
    metadata=None,
    start_time_ms=start_time_ms,
    end_time_ms=end_time_ms,
    inputs={"sentence": result.variables()["input"]},
    outputs={
        "entities": result.variables()["entities"],
        "answer": result.variables()["answer"],
    },
)
trace.log(name="guidance_anachronism")
run.finish()
display(run)

In [None]:
# example hierarchies usage in the Trace class
import time

root_trace = Trace(
    name="Parent Model",
    kind="LLM",
    status_code="SUCCESS",
    metadata={
        "attr_1": 1,
        "attr_2": 2,
    },
    start_time_ms=int(round(time.time() * 1000)),
    end_time_ms=int(round(time.time() * 1000)) + 1000,
    inputs={"user": "How old is google?"},
    outputs={"assistant": "25 years old"},
    model_dict={"_kind": "openai", "api_type": "azure"},
)

first_child = Trace(
    name="Child 1 Model",
    kind="LLM",
    status_code="ERROR",
    metadata={
        "child1_attr_1": 1,
        "child1_attr_2": 2,
    },
    start_time_ms=int(round(time.time() * 1000)) + 2000,
    end_time_ms=int(round(time.time() * 1000)) + 3000,
    inputs={"user": "How old is google?"},
    outputs={"assistant": "25 years old"},
    model_dict={"_kind": "openai", "api_type": "child1_azure"},
)

second_child = Trace(
    name="Child 2 Model",
    kind="LLM",
    status_code="SUCCESS",
    metadata={
        "child2_attr_1": 1,
        "child2_attr_2": 2,
    },
    start_time_ms=int(round(time.time() * 1000)) + 4000,
    end_time_ms=int(round(time.time() * 1000)) + 5000,
    inputs={"user": "How old is google?"},
    outputs={"assistant": "25 years old"},
    model_dict={"_kind": "openai", "api_type": "child2_azure"},
)

In [None]:
# simple heirarchy
run = wandb.init(project=PROJECT, job_type="simple_heirarchy")

root_trace.add_child(first_child)
first_child.add_child(second_child)

root_trace.log("root_trace")

wandb.finish()
display(run)

In [None]:
# nested heirarchy
run = wandb.init(project=PROJECT, job_type="nested_heirarchy")

root_trace.add_child(first_child)
first_child.add_child(second_child)
root_trace.add_child(second_child)

root_trace.log("root_trace")

wandb.finish()
display(run)

In [None]:
# all traces
run = wandb.init(project=PROJECT, job_type="all_traces")

root_trace.add_child(first_child)
first_child.add_child(second_child)

second_child.log("second_child")
first_child.log("first_child")
root_trace.log("root_trace")

wandb.finish()
display(run)