<a href="https://colab.research.google.com/github/withpi/cookbook-withpi/blob/main/colabs/Cascading.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://withpi.ai"><img src="https://withpi.ai/logoFullBlack.svg" width="240"></a>

<a href="https://code.withpi.ai"><font size="4">Documentation</font></a>

<a href="https://play.withpi.ai"><font size="4">Technique Catalog</font></a>

# Cascading

This Colab is the companion to the Cascading playground.

Let's say you have two models, an expensive, slow, but high quality one, and a fast, cheap, but relatively lower quality one.

You can improve both cost and performance by using the faster model for simple tasks and the slower model for complex ones.  The trick is to know which input is which.  Contracts help you do that.

## Install and initialize SDK

Connect to a regular CPU Python 3 runtime.  You won't need GPUs for this notebook.

You'll need a WITHPI_API_KEY from https://play.withpi.ai.  Add it to your notebook secrets (the key symbol) on the left.

Run the cell below to install packages and load the SDK

In [None]:
%%capture

%pip install withpi litellm httpx datasets jinja2 tqdm

# Import a bunch of useful libraries for later.
from concurrent.futures import ThreadPoolExecutor
import json
import os
from pathlib import Path
import re

import datasets
from google.colab import files, userdata
import httpx
import jinja2
from litellm import completion
from tqdm.notebook import tqdm
from withpi import PiClient
from withpi.types import Contract

# Load the notebook secret into the environment so the Pi Client can access it.
os.environ["WITHPI_API_KEY"] = userdata.get('WITHPI_API_KEY')

client = PiClient()

def print_contract(contract: Contract):
  """print_contract pretty-prints a contract"""
  for dimension in contract.dimensions:
    print(dimension.label)
    for sub_dimension in dimension.sub_dimensions:
      print(f"\t{sub_dimension.description}")

def generate(system: str, user: str, model: str) -> str:
  """generate passes the provided system and user prompts into the given model
  via LiteLLM"""
  messages = [
    {
      "content": system,
      "role": "system"
    },
    {
      "content": user,
      "role": "user"
    }
  ]
  return completion(model=model,
                    messages=messages).choices[0].message.content

class printer(str):
  """printer makes strings with embedded newlines print more nicely"""
  def __repr__(self):
    return self
def print_response(response: str):
  """print_response pretty-prints an LLM response, respecting newlines"""
  display(printer(response))

def print_scores(pi_scores):
  """print_scores pretty-prints a Pi Score response as a table."""
  for dimension_name, dimension_scores in pi_scores.dimension_scores.items():
    print(f"{dimension_name}: {dimension_scores.total_score}")
    for subdimension_name, subdimension_score in dimension_scores.subdimension_scores.items():
      print(f"\t{subdimension_name}: {subdimension_score}")
    print("\n")
  print("---------------------")
  print(f"Total score: {pi_scores.total_score}")

def save_file(filename: str, model: str):
  """save_file offers to download the model with the given filename"""
  Path(filename).write_text(model)
  files.download(filename)

def load_contract(url: str) -> Contract:
  """load_contract pulls a Contract JSON blob locally with validation."""
  resp = httpx.get(url)
  return Contract.model_validate_json(resp.content)

def load_and_split_dataset(url: str) -> datasets.DatasetDict:
  """load_and_split_dataset pulls in the Parquet file at url and does a 90/10 split"""
  return datasets.load_dataset('parquet', data_files=url, split="train").train_test_split(test_size=0.1)

def do_bulk_inference(dataset, system, model):
  """do_bulk_inference performs inference on the 'input' column of dataset, using
  the provided system prompt.  The model identified will be used via LiteLLM"""

  def do_generate(user, pbar):
    result = generate(system, user, model)
    pbar.update(1)
    return result

  futures = []
  pbar = tqdm(total=len(dataset))
  with ThreadPoolExecutor(max_workers=4) as executor:
    for row in dataset:
      futures.append(executor.submit(do_generate, row["input"], pbar))
  return [future.result() for future in futures]

def do_bulk_templated_inference(dataset, optimized, model):
  """do_bulk_templated_inference performs inference on the 'input' column of dataset,
  using the provided optimized prompt.  It should be a Jinja2 template as returned
  by DSPy"""
  prompt_template = jinja2.Template(optimized)
  result_extractor = re.compile(r".*\[\[ ## response ## \]\](.*)\[\[ ## completed ## \]\]", re.DOTALL)

  def do_generate(prompt: str, pbar) -> str:
    messages = json.loads(prompt_template.render(input=prompt))
    result = completion(model=model,
                        messages=messages).choices[0].message.content

    pbar.update(1)
    return result_extractor.match(result).group(1)

  futures = []
  pbar = tqdm(total=len(dataset))
  with ThreadPoolExecutor(max_workers=4) as executor:
    for row in dataset:
      futures.append(executor.submit(do_generate, row["input"], pbar))
  return [future.result() for future in futures]

def stream_response(job_id: str, method):
  """stream_response streams messages from the provided method

  method should be a Pi client object with `retrieve` and `stream_messages`
  endpoints.  This is primarily for convenience."""

  while True:
    response = method.retrieve(job_id=job_id)
    if (response.state != 'QUEUED') and (response.state != 'RUNNING'):
      return response

    with method.with_streaming_response.stream_messages(
        job_id=job_id, timeout=None) as response:
      for line in response.iter_lines():
        print(line)


# Load a contract and data

We'll keep using a pre-built contract with sample inputs, but feel free to bring your own


In [None]:
aesop_contract = load_contract("https://raw.githubusercontent.com/withpi/cookbook-withpi/refs/heads/main/contracts/aesop_ai.json")
aesop = load_and_split_dataset("https://raw.githubusercontent.com/withpi/cookbook-withpi/refs/heads/main/datasets/aesop_ai_examples.parquet")
aesop

## Evaluate the Contract on different models.

Let's try generating responses from a "big" model and a "small" one to compare scores.

Adjust to pick a different model and supply your own key with docs at https://docs.litellm.ai/docs/.

You can import a Google Gemini key from AI Studio on the left pane, which populates a GOOGLE_API_KEY secret essentially for free.

In [None]:
os.environ["GEMINI_API_KEY"] = userdata.get('GOOGLE_API_KEY')

aesop_updated = aesop['test'].add_column(
    'big_model_output', do_bulk_inference(aesop['test'], aesop_contract.description, "gemini/gemini-2.0-flash")
).add_column(
    'small_model_output', do_bulk_inference(aesop['test'], aesop_contract.description, "gemini/gemini-1.5-flash-8b"))

big_scores = [client.contracts.score(
    contract=aesop_contract,
    llm_input=row["input"],
    llm_output=row["big_model_output"],
).total_score for row in aesop_updated]


small_scores = [client.contracts.score(
    contract=aesop_contract,
    llm_input=row["input"],
    llm_output=row["small_model_output"],
).total_score for row in aesop_updated]


print("Big model          | Small Model")
for big_score, small_score in zip(big_scores, small_scores):
  print(f"{big_score} | {small_score}")

## Next Steps

Now you can imagine unconditionally applying a cheap model, scoring it, then if it scores below a threshold falling back to the big model.

An improvement to this technique is to train a classifier that can predict whether a particular prompt is likely to require a more powerful model.  Then start with the big one if warranted. That is a followup Colab.