# Import the Zenbase Library

In [None]:
import sys
import subprocess

def install_package(package):
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])
    except subprocess.CalledProcessError as e:
        print(f"Failed to install {package}: {e}")
        raise

def install_packages(packages):
    for package in packages:
        install_package(package)

try:
    # Check if running in Google Colab
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    # Install the zenbase package if running in Google Colab
    # install_package('zenbase')
    # Install the zenbse package from a GitHub branch if running in Google Colab
    install_package('git+https://github.com/zenbase-ai/lib.git@main#egg=zenbase&subdirectory=py')

    # List of other packages to install in Google Colab
    additional_packages = [
        'python-dotenv',
        'arize-phoenix[evals]',
        'openai',
        'langchain',
        'langchain_openai'
    ]
    
    # Install additional packages
    install_packages(additional_packages)

# Now import the zenbase library
try:
    import zenbase
except ImportError as e:
    print("Failed to import zenbase: ", e)
    raise

# Configure the Environment

In [None]:
from pathlib import Path
from dotenv import load_dotenv

# import os
#
# os.environ["OPENAI_API_KEY"] = "..."

load_dotenv(Path("../../.env.test"), override=True)

In [None]:
import nest_asyncio

nest_asyncio.apply()

# Initial Setup


In [None]:
# initiate the phoenix app
import phoenix as px
px.launch_app()
# initiate the phoenix client
arize_phoenix = px.Client()

In [None]:
from openai import OpenAI
openai = OpenAI()

In [None]:
from zenbase.utils import ksuid
from zenbase.adaptors.arize import ZenArizeAdaptor
zen_arize_adaptor = ZenArizeAdaptor(arize_phoenix)

# setup datasets
import datasets
gsm8k_dataset = datasets.load_dataset("gsm8k", "main")
TESTSET_SIZE = 2
TRAINSET_SIZE = 5
VALIDATIONSET_SIZE = 2


def create_dataset_with_examples(zen_arize_adaptor: ZenArizeAdaptor, prefix: str, item_set: list) -> str:
    dataset_name = ksuid(prefix=prefix)

    inputs = [{"question": example["question"]} for example in item_set]
    expected_outputs = [{"answer": example["answer"]} for example in item_set]
    zen_arize_adaptor.add_examples_to_dataset(dataset_name, inputs, expected_outputs)
    return dataset_name

train_set = create_dataset_with_examples(
        zen_arize_adaptor,
        "GSM8K_train_set",
        list(gsm8k_dataset["train"].select(range(TRAINSET_SIZE))),
    )

validation_set = create_dataset_with_examples(
        zen_arize_adaptor,
        "GSM8K_validation_set",
        list(gsm8k_dataset["train"].select(range(TRAINSET_SIZE + 1, TRAINSET_SIZE + VALIDATIONSET_SIZE + 1))),
    )

test_set = create_dataset_with_examples(
        zen_arize_adaptor,
        "GSM8K_test_set",
        list(gsm8k_dataset["test"].select(range(TESTSET_SIZE))),
    )

# Now, you probably already have some LLM code.

It could use the OpenAI SDK, LangChain, or anything really. But it looks something like this:

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

def solver(inputs):
    messages = [
        (
            "system",
            """You are an expert math solver. Solve the given problem using the provided plan and operations.
        Return only the final numerical answer, without any additional text or explanation.""",
        ),
    ]

    messages.extend(
        [
            ("user", "Question: {question}"),
            ("user", "Plan: {plan}"),
            ("user", "Mathematical Operation: {operation}"),
            ("user", "Provide the final numerical answer:"),
        ]
    )

    chain = ChatPromptTemplate.from_messages(messages) | ChatOpenAI(model="gpt-3.5-turbo") | StrOutputParser()

    plan = planner_chain(inputs)
    operation = operation_finder({"plan": plan["plan"], "question": inputs["question"]})

    inputs_to_answer = {
        "question": inputs["question"],
        "plan": plan["plan"],
        "operation": operation["operation"],
    }
    answer = chain.invoke(inputs_to_answer)
    return {"answer": answer}

def planner_chain(inputs):
    messages = [
        (
            "system",
            """You are an expert math solver. Create a step-by-step plan to solve the given problem.
        Be clear and concise in your steps.""",
        ),
        ("user", "Problem: {question}\n\nProvide a step-by-step plan to solve this problem:"),
    ]

    chain = ChatPromptTemplate.from_messages(messages) | ChatOpenAI(model="gpt-3.5-turbo") | StrOutputParser()
    plan = chain.invoke(inputs)
    return {"plan": plan}

def operation_finder(inputs):
    messages = [
        (
            "system",
            """You are an expert math solver. Identify the overall mathematical operation needed to solve the
             problem
        based on the given plan. Use simple operations like addition, subtraction, multiplication, and division.""",
        ),
        ("user", "Question: {question}"),
        ("user", "Plan: {plan}"),
        ("user", "Identify the primary mathematical operation needed:"),
    ]

    chain = ChatPromptTemplate.from_messages(messages) | ChatOpenAI(model="gpt-3.5-turbo") | StrOutputParser()
    operation = chain.invoke(inputs)
    return {"operation": operation}


In [None]:
solver({"question": "What is 2 + 2?"})

## And let's say you have an eval function like this

In [None]:
def score_answer(output: str, expected: dict):
    """The first argument is the return value from the `langchain_chain` function above."""
    score = int(output["answer"] == expected["answer"].split("#### ")[-1])
    return score


## Then you're probably evaluating like this

In [None]:
from phoenix.experiments import run_experiment

experiment = run_experiment(
                arize_phoenix.get_dataset(name=test_set),
                solver,
                experiment_name="Experiment-Name",
                evaluators=[score_answer],
            )

 # Now, how can we optimize this score of 0.6?

## First, initialize the Zenbase ZenbaseTracer and import the Langfuse helper

In [None]:
from zenbase.core.managers import ZenbaseTracer
zenbase_tracer = ZenbaseTracer()

from zenbase.adaptors.arize import ZenArizeAdaptor
zen_arize_adaptor = ZenArizeAdaptor(arize_phoenix)

## Hook up Zenbase to your functions

1. Use the `zenbase_tracer` decorator.
2. Change function inputs to request
3. Use request's `zenbase.task_demos` to get the few-shot examples for the task and add them however you would like into your prompt.
4. If you need to use just a few examples, you can use `request.zenbase.task_demos[:2]` to get the first two examples.

In [None]:
from zenbase.types import LMRequest

@zenbase_tracer  # it is 1
def solver(request: LMRequest):  # it is 2
    messages = [
        (
            "system",
            """You are an expert math solver. Solve the given problem using the provided plan and operations.
        Return only the final numerical answer, without any additional text or explanation.""",
        ),
    ]

    for demo in request.zenbase.task_demos:  # it is 3
        demo_input = demo.inputs["question"]
        demo_output = demo.outputs["answer"]

        messages += [
            ("user", f"Example Question: {demo_input}"),
            ("assistant", f"Example Answer: {demo_output}"),
        ]  # it is 4

    messages.extend(
        [
            ("user", "Question: {question}"),
            ("user", "Plan: {plan}"),
            ("user", "Mathematical Operation: {operation}"),
            ("user", "Provide the final numerical answer:"),
        ]
    )

    chain = ChatPromptTemplate.from_messages(messages) | ChatOpenAI(model="gpt-3.5-turbo") | StrOutputParser()

    plan = planner_chain(request.inputs)
    operation = operation_finder({"plan": plan["plan"], "question": request.inputs["question"]})

    inputs_to_answer = {
        "question": request.inputs["question"],
        "plan": plan["plan"],
        "operation": operation["operation"],
    }
    answer = chain.invoke(inputs_to_answer)
    return {"answer": answer}

@zenbase_tracer  # it is 1
def planner_chain(request: LMRequest):  # it is 2
    messages = [
        (
            "system",
            """You are an expert math solver. Create a step-by-step plan to solve the given problem.
        Be clear and concise in your steps.""",
        ),
        ("user", "Problem: {question}\n\nProvide a step-by-step plan to solve this problem:"),
    ]

    if request.zenbase.task_demos:  # it is 3
        for demo in request.zenbase.task_demos[:2]:  # it is 4
            messages += [
                ("user", demo.inputs["question"]),
                ("assistant", demo.outputs["plan"]),
            ]

    chain = ChatPromptTemplate.from_messages(messages) | ChatOpenAI(model="gpt-3.5-turbo") | StrOutputParser()
    plan = chain.invoke(request.inputs)
    return {"plan": plan}

@zenbase_tracer  # it is 1
def operation_finder(request: LMRequest):  # it is 2
    messages = [
        (
            "system",
            """You are an expert math solver. Identify the overall mathematical operation needed to solve the
             problem
        based on the given plan. Use simple operations like addition, subtraction, multiplication, and division.""",
        ),
        ("user", "Question: {question}"),
        ("user", "Plan: {plan}"),
        ("user", "Identify the primary mathematical operation needed:"),
    ]

    if request.zenbase.task_demos:  # it is 3
        for demo in request.zenbase.task_demos[:2]:  # it is 4
            messages += [
                ("user", demo.inputs["question"]),
                ("user", demo.inputs["plan"]),
                ("assistant", demo.outputs["operation"]),
            ]

    chain = ChatPromptTemplate.from_messages(messages) | ChatOpenAI(model="gpt-3.5-turbo") | StrOutputParser()
    operation = chain.invoke(request.inputs)
    return {"operation": operation}


In [None]:
return_langchain = solver({"question": "What is 2 + 2?"})

## Now we can optimize!

### Set up your optimizer:

In [None]:
from zenbase.optim.metric.bootstrap_few_shot import BootstrapFewShot

SHOTS = 2
SAMPLES = 2

evaluator_kwargs = dict(
    dataset=arize_phoenix.get_dataset(name=test_set), evaluators=[score_answer]

)

bootstrap_few_shot = BootstrapFewShot(
    shots=SHOTS,
    training_set=train_set,
    test_set=test_set,
    validation_set=validation_set,
    evaluator_kwargs=evaluator_kwargs,
    zen_adaptor=zen_arize_adaptor,
)


### Do the optimization

In [None]:
# Run the optimization
best_fn, candidates = bootstrap_few_shot.perform(
    solver,
    samples=SAMPLES,
    rounds=1,
    trace_manager=zenbase_tracer,
)

### Use your optimized function


In [None]:
zenbase_tracer.flush()
best_fn({"question": "What is 2+2?"})

### Introspect function traces

In [None]:
function_traces = [v for k, v in zenbase_tracer.all_traces.items()][0]["optimized"]


### Check the optimized parameters for planner_chain

In [None]:
from pprint import pprint

pprint(function_traces["planner_chain"]["args"]["request"].zenbase.task_demos)


### Check the optimized parameters for operation_finder chain

In [None]:
from pprint import pprint

pprint(function_traces["operation_finder"]["args"]["request"].zenbase.task_demos)


### Check the optimized parameters for solver

In [None]:
from pprint import pprint

pprint(function_traces["solver"]["args"]["request"].zenbase.task_demos)


## How to save the function and load it later

### Save the optimized function args to a file

In [None]:
bootstrap_few_shot.save_optimizer_args("bootstrap_few_shot_args.zenbase")

### Load the optimized function args with the function

In [None]:
bootstrap_few_shot.save_optimizer_args("bootstrap_few_shot_args.zenbase")

optimized_function = bootstrap_few_shot.load_optimizer_and_function("bootstrap_few_shot_args.zenbase", solver, zenbase_tracer)

### Use the loaded function and make sure it loaded the demos.


In [None]:
zenbase_tracer.flush()
optimized_function({"question": "If I have 30% of shares, and Mo has 24.5% of shares, how many of our 10M shares are unassigned?"})
function_traces = [v for k, v in zenbase_tracer.all_traces.items()][0]["optimized"]
from pprint import pprint

pprint(function_traces["solver"]["args"]["request"].zenbase.task_demos)
pprint(function_traces["planner_chain"]["args"]["request"].zenbase.task_demos)
pprint(function_traces["operation_finder"]["args"]["request"].zenbase.task_demos)