<a href="https://colab.research.google.com/github/withpi/cookbook-withpi/blob/main/colabs/DSPy_Optimization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://withpi.ai"><img src="https://withpi.ai/logoFullBlack.svg" width="240"></a>

<a href="https://code.withpi.ai"><font size="4">Documentation</font></a>

<a href="https://play.withpi.ai"><font size="4">Technique Catalog</font></a>

# DSPy Optimization

This Colab is the companion to the DSPy Optimization Playground.  [DSPy](https://dspy.ai/) is a toolkit for optimizing an application's system prompt.  It establishes a baseline (by scoring a set of **Inputs** against a **Contract**), then experimenting with the application prompt to try and improve the contract scores over the inputs.

See [Key Concepts](https://code.withpi.ai/key-concepts) if you want more details about Contracts.

This Colab continues with the `Aesop AI` example and a test input set, but any will do.

## Install and initialize SDK

Connect to a regular CPU Python 3 runtime.  You won't need GPUs for this notebook.

You'll need a WITHPI_API_KEY from https://play.withpi.ai.  Add it to your notebook secrets (the key symbol) on the left.

Run the cell below to install packages and load the SDK

In [None]:
%%capture

%pip install withpi litellm httpx datasets jinja2 tqdm

# Import a bunch of useful libraries for later.
from concurrent.futures import ThreadPoolExecutor
import json
import os
from pathlib import Path
import re

import datasets
from google.colab import files, userdata
import httpx
import jinja2
from litellm import completion
from tqdm.notebook import tqdm
from withpi import PiClient
from withpi.types import Contract

# Load the notebook secret into the environment so the Pi Client can access it.
os.environ["WITHPI_API_KEY"] = userdata.get('WITHPI_API_KEY')

client = PiClient()

def print_contract(contract: Contract):
  """print_contract pretty-prints a contract"""
  for dimension in contract.dimensions:
    print(dimension.label)
    for sub_dimension in dimension.sub_dimensions:
      print(f"\t{sub_dimension.description}")

def generate(system: str, user: str, model: str) -> str:
  """generate passes the provided system and user prompts into the given model
  via LiteLLM"""
  messages = [
    {
      "content": system,
      "role": "system"
    },
    {
      "content": user,
      "role": "user"
    }
  ]
  return completion(model=model,
                    messages=messages).choices[0].message.content

class printer(str):
  """printer makes strings with embedded newlines print more nicely"""
  def __repr__(self):
    return self
def print_response(response: str):
  """print_response pretty-prints an LLM response, respecting newlines"""
  display(printer(response))

def print_scores(pi_scores):
  """print_scores pretty-prints a Pi Score response as a table."""
  for dimension_name, dimension_scores in pi_scores.dimension_scores.items():
    print(f"{dimension_name}: {dimension_scores.total_score}")
    for subdimension_name, subdimension_score in dimension_scores.subdimension_scores.items():
      print(f"\t{subdimension_name}: {subdimension_score}")
    print("\n")
  print("---------------------")
  print(f"Total score: {pi_scores.total_score}")

def save_file(filename: str, model: str):
  """save_file offers to download the model with the given filename"""
  Path(filename).write_text(model)
  files.download(filename)

def load_contract(url: str) -> Contract:
  """load_contract pulls a Contract JSON blob locally with validation."""
  resp = httpx.get(url)
  return Contract.model_validate_json(resp.content)

def load_and_split_dataset(url: str) -> datasets.DatasetDict:
  """load_and_split_dataset pulls in the Parquet file at url and does a 90/10 split"""
  return datasets.load_dataset('parquet', data_files=url, split="train").train_test_split(test_size=0.1)

def do_bulk_inference(dataset, system, model):
  """do_bulk_inference performs inference on the 'input' column of dataset, using
  the provided system prompt.  The model identified will be used via LiteLLM"""

  def do_generate(user, pbar):
    result = generate(system, user, model)
    pbar.update(1)
    return result

  futures = []
  pbar = tqdm(total=len(dataset))
  with ThreadPoolExecutor(max_workers=4) as executor:
    for row in dataset:
      futures.append(executor.submit(do_generate, row["input"], pbar))
  return [future.result() for future in futures]

def do_bulk_templated_inference(dataset, optimized, model):
  """do_bulk_templated_inference performs inference on the 'input' column of dataset,
  using the provided optimized prompt.  It should be a Jinja2 template as returned
  by DSPy"""
  prompt_template = jinja2.Template(optimized)
  result_extractor = re.compile(r".*\[\[ ## response ## \]\](.*)\[\[ ## completed ## \]\]", re.DOTALL)

  def do_generate(prompt: str, pbar) -> str:
    messages = json.loads(prompt_template.render(input=prompt))
    result = completion(model=model,
                        messages=messages).choices[0].message.content

    pbar.update(1)
    return result_extractor.match(result).group(1)

  futures = []
  pbar = tqdm(total=len(dataset))
  with ThreadPoolExecutor(max_workers=4) as executor:
    for row in dataset:
      futures.append(executor.submit(do_generate, row["input"], pbar))
  return [future.result() for future in futures]



# Load contract and Dataset

Load the `Aesop AI` example and example set from Pi Labs cookbooks, or edit below to load a different one.


In [None]:
aesop_contract = load_contract("https://raw.githubusercontent.com/withpi/cookbook-withpi/refs/heads/main/contracts/aesop_ai.json")
aesop = load_and_split_dataset("https://raw.githubusercontent.com/withpi/cookbook-withpi/refs/heads/main/datasets/aesop_ai_examples.parquet")
aesop


## Optimize your prompt

Kick off a prompt optimization run.  This will operate in the background.

In [None]:
prompt_optimization_status = client.prompt.optimize(
    contract=aesop_contract,
    dspy_optimization_type="COPRO",
    examples=[{"llm_input": row["input"], "llm_output": row["output"]} for row in aesop['train']],
    initial_system_instruction=aesop_contract.description,
    model_id="gpt-4o-mini",
    tuning_algorithm="DSPY",
)

## Check for completion

The following cell will connect to the tail of logs while optimization proceeds.  It will take order of **10 minutes**

In [None]:
import json

while True:
  optimized_response = client.prompt.get_status(job_id=prompt_optimization_status.job_id)
  if (optimized_response.state != 'QUEUED') and (optimized_response.state != 'RUNNING'):
    break

  with client.prompt.with_streaming_response.stream_messages(
      job_id=prompt_optimization_status.job_id, timeout=None) as response:
    for line in response.iter_lines():
          print(line)

optimized = json.dumps(optimized_response.optimized_prompt_messages, indent=2)
display(optimized)

## Save the new system prompt template

It's convenient to stash this template for use later.

In [None]:
save_file('aesop_ai_dspy_prompt.json.jinja', optimized)

## (if resuming) Load system prompt

If you don't want to wait, load the pre-optimized one.

In [None]:
optimized = httpx.get("https://raw.githubusercontent.com/withpi/cookbook-withpi/refs/heads/main/prompts/aesop_ai_dspy_prompt.json.jinja").text

In [None]:
print(optimized)

#Run inference with the test split

DSPy emits a Jinja2-style template, so inference requires some template substitution.  Let's compare performance on the holdout set.

In [None]:
os.environ["GEMINI_API_KEY"] = userdata.get('GOOGLE_API_KEY')

aesop_updated = aesop['test'].add_column('original_output', do_bulk_inference(aesop['test'], aesop_contract.description, "gemini/gemini-1.5-flash-8b")).add_column('optimized_output', do_bulk_templated_inference(aesop['test'], optimized, "gemini/gemini-1.5-flash-8b"))
aesop_updated

original_scores = [client.contracts.score(
    contract=aesop_contract,
    llm_input=row["input"],
    llm_output=row["original_output"],
).total_score for row in aesop_updated]

print("Original Scores")
print(original_scores)

optimized_scores = [client.contracts.score(
    contract=aesop_contract,
    llm_input=row["input"],
    llm_output=row["optimized_output"],
).total_score for row in aesop_updated]

print("Optimized Scores")
print(optimized_scores)

## Next Steps

Now you have an improved prompt on a small sample set.  You could deploy this now, but improving the training set or the contract will give you better performance.  Check out the rest of the playgrounds to proceed from here.