In [1]:
from pathlib import Path
from dotenv import load_dotenv

# import os
#
# os.environ["OPENAI_API_KEY"] = "..."
# os.environ["LANGFUSE_HOST"] = "..."
# os.environ["LANGFUSE_PUBLIC_KEY"] = "..."
# os.environ["LANGFUSE_SECRET_KEY"] = "..."

load_dotenv(Path("../.env.test"), override=True)

True

In [2]:
import nest_asyncio

nest_asyncio.apply()

In [3]:
from openai import OpenAI
from langfuse import Langfuse

langfuse = Langfuse()
langfuse.auth_check()

openai = OpenAI()

In [6]:
from langfuse.decorators import observe

@observe()
def langchain_chain(inputs):
    from langchain_openai import ChatOpenAI
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_core.output_parsers import StrOutputParser

    messages = [
        (
            "system",
            "You are an expert math solver. Your answer must be just the number with no separators, and nothing else. Follow the format of the examples.",
        ),
        ("user", "{question}")
    ]

    chain = (
        ChatPromptTemplate.from_messages(messages)
        | ChatOpenAI(model="gpt-3.5-turbo")
        | StrOutputParser()
    )

    answer = chain.invoke(inputs)
    return answer

def score_answer(answer: str, expected_output: dict):
    """The first argument is the return value from the `langchain_chain` function above."""
    score = int(answer.split("#### ")[-1] == expected_output["answer"].split("#### ")[-1])
    langfuse.score(
        name="correctness",
        value=score,
        trace_id=langfuse.get_trace_id(),
    )
    return {"score": score}

In [7]:
evalset = langfuse.get_dataset("gsm8k-evalset")

scores = []
for item in evalset.items:
   answer = langchain_chain(item.input)
   eval = score_answer(answer, item.expected_output)
   scores.append(eval["score"])

print("Average score", sum(scores) / len(scores))

Average score 0.4


In [12]:
from zenbase.types import LMRequest, LMDemo, deflm

@deflm
@observe()
def zen_chain(request: LMRequest):
    from langchain_openai import ChatOpenAI
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_core.output_parsers import StrOutputParser

    messages = [
        (
            "system",
            "You are an expert math solver. Your answer must be just the number with no separators, and nothing else. Follow the format of the examples.",
        )
    ]

    for demo in request.zenbase.task_demos:
        messages += [
            ("user", demo.inputs["question"]),
            ("assistant", demo.outputs["answer"]),
        ]

    messages.append(("user", "{question}"))

    chain = (
        ChatPromptTemplate.from_messages(messages)
        | ChatOpenAI(model="gpt-3.5-turbo")
        | StrOutputParser()
    )

    answer = chain.invoke(request.inputs)
    return answer

def score_answer(answer: str, demo: LMDemo, langfuse: Langfuse):
    """The first argument is the return value from the `zen_chain` function above."""
    score = int(answer.split("#### ")[-1] == demo.outputs["answer"].split("#### ")[-1])
    langfuse.score(
        name="correctness",
        value=score,
        trace_id=langfuse.get_trace_id(),
    )
    return {"score": score}

In [13]:
from zenbase.optim.metric.labeled_few_shot import LabeledFewShot
from zenbase.helpers.langfuse import ZenLangfuse

optimizer = LabeledFewShot(
    demoset=ZenLangfuse.dataset_demos(langfuse.get_dataset("gsm8k-demoset")),
    shots=3,
)

best_fn, candidate_results = optimizer.perform(
    zen_chain,
    evaluator=ZenLangfuse.metric_evaluator(
        evalset=evalset,
        evaluate=score_answer,
        langfuse=langfuse,
    ),
    samples=4,
    concurrency=1,
    rounds=1,
)

In [14]:
output = best_fn({"question": "What is 2+2?"})
output

'4'

In [15]:
# You can even run your function asynchronously in a coroutine
%autoawait

await best_fn.coroutine({
  "question": "What is 2+2?"
})

IPython autoawait is `on`, and set to use `asyncio`


'4'

In [16]:
# You can also save the zenbase params for re-use
import pickle

pickled_zenbase = pickle.dumps(best_fn.zenbase)
zen_chain.zenbase = pickle.loads(pickled_zenbase)

zen_chain({"question": "What is 2 + 2?"}) # uses the best few-shot demos

'4'