<a href="https://colab.research.google.com/github/withpi/cookbook-withpi/blob/main/colabs/Input_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://withpi.ai"><img src="https://withpi.ai/logoFullBlack.svg" width="240"></a>

<a href="https://code.withpi.ai"><font size="4">Documentation</font></a>

<a href="https://play.withpi.ai"><font size="4">Technique Catalog</font></a>

# Input Generation

There is no Playground associated with this Colab, but it's coming soon!

Many techniques require input data to drives evaluation and training, but getting high-quality data can be painful and expensive.

Generating this data with AI support can give you a higher quality set with much lower effort.  And it can be done with the same Contract that drives other techniques in Pi!

We will walk through the same `Aesop AI` example, but you can load any contract here. Let's dig in!

## Install and initialize SDK

Connect to a regular CPU Python 3 runtime.  You won't need GPUs for this notebook.

You'll need a WITHPI_API_KEY from https://play.withpi.ai.  Add it to your notebook secrets (the key symbol) on the left.

Run the cell below to install packages and load the SDK

In [4]:
%%capture

import os
from google.colab import files, userdata

# Load the notebook secret into the environment so the Pi Client can access it.
os.environ["WITHPI_API_KEY"] = userdata.get('WITHPI_API_KEY')

%pip install withpi litellm httpx datasets jinja2 tqdm

# Import a bunch of useful libraries for later.
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
import time
import json
from pathlib import Path
import re

import datasets
import httpx
import litellm
import jinja2
from tqdm.notebook import tqdm
from withpi import PiClient
from withpi.types import Contract

from rich.console import Console
from rich.table import Table
from rich.live import Live

console = Console()

client = PiClient()

def print_contract(contract: Contract):
  """print_contract pretty-prints a contract"""
  for dimension in contract.dimensions:
    print(dimension.label)
    for sub_dimension in dimension.sub_dimensions:
      print(f"\t{sub_dimension.description}")

def generate(system: str, user: str, model: str) -> str:
  """generate passes the provided system and user prompts into the given model
  via LiteLLM"""
  messages = [
    {
      "content": system,
      "role": "system"
    },
    {
      "content": user,
      "role": "user"
    }
  ]
  return litellm.completion(model=model,
                            messages=messages).choices[0].message.content

class printer(str):
  """printer makes strings with embedded newlines print more nicely"""
  def __repr__(self):
    return self
def print_response(response: str):
  """print_response pretty-prints an LLM response, respecting newlines"""
  display(printer(response))

def print_scores(pi_scores):
  """print_scores pretty-prints a Pi Score response as a table."""
  for dimension_name, dimension_scores in pi_scores.dimension_scores.items():
    print(f"{dimension_name}: {dimension_scores.total_score}")
    for subdimension_name, subdimension_score in dimension_scores.subdimension_scores.items():
      print(f"\t{subdimension_name}: {subdimension_score}")
    print("\n")
  print("---------------------")
  print(f"Total score: {pi_scores.total_score}")

def save_file(filename: str, model: str):
  """save_file offers to download the model with the given filename"""
  Path(filename).write_text(model)
  files.download(filename)

def load_contract(url: str) -> Contract:
  """load_contract pulls a Contract JSON blob locally with validation."""
  resp = httpx.get(url)
  return Contract.model_validate_json(resp.content)

def load_and_split_dataset(url: str) -> datasets.DatasetDict:
  """load_and_split_dataset pulls in the Parquet file at url and does a 90/10 split"""
  return datasets.load_dataset('parquet', data_files=url, split="train").train_test_split(test_size=0.1)

def do_bulk_inference(dataset, system, model):
  """do_bulk_inference performs inference on the 'input' column of dataset, using
  the provided system prompt.  The model identified will be used via LiteLLM"""

  def do_generate(user, pbar):
    result = generate(system, user, model)
    pbar.update(1)
    return result

  futures = []
  pbar = tqdm(total=len(dataset))
  with ThreadPoolExecutor(max_workers=4) as executor:
    for row in dataset:
      futures.append(executor.submit(do_generate, row["input"], pbar))
  return [future.result() for future in futures]

def do_bulk_templated_inference(dataset, optimized, model):
  """do_bulk_templated_inference performs inference on the 'input' column of dataset,
  using the provided optimized prompt.  It should be a Jinja2 template as returned
  by DSPy"""
  prompt_template = jinja2.Template(optimized)
  result_extractor = re.compile(r".*\[\[ ## response ## \]\](.*)\[\[ ## completed ## \]\]", re.DOTALL)

  def do_generate(prompt: str, pbar) -> str:
    messages = json.loads(prompt_template.render(input=prompt))
    result = litellm.completion(model=model,
                                messages=messages).choices[0].message.content

    pbar.update(1)
    return result_extractor.match(result).group(1)

  futures = []
  pbar = tqdm(total=len(dataset))
  with ThreadPoolExecutor(max_workers=4) as executor:
    for row in dataset:
      futures.append(executor.submit(do_generate, row["input"], pbar))
  return [future.result() for future in futures]

def generate_table(
    job_id: str, training_data: dict, is_done: bool, additional_columns: dict[str, str]
):
    """Generate a training progress table dynamically."""
    table = Table(title=f"Training Status for {job_id}")

    # Define columns
    table.add_column("Step", justify="right", style="cyan")
    table.add_column("Epoch", justify="right", style="cyan")
    table.add_column("Learning Rate", justify="right", style="cyan")
    table.add_column("Train Loss", justify="right", style="magenta")
    table.add_column("Eval Loss", justify="right", style="green")
    for header in additional_columns.keys():
        table.add_column(header, justify="right", style="black")

    def format_num(num: float | None, digits: int = 4) -> str:
        if num is None:
            return "X"
        return format(num, f".{digits}f")

    for step, data in training_data.items():
        additional_columns_data = [
            format_num(data.get(column_name, None))
            for column_name in additional_columns.values()
        ]
        table.add_row(
            str(step),
            format_num(data.get("epoch", None)),
            format_num(data.get("learning_rate", None), digits=10),
            format_num(data.get("loss", None)),
            format_num(data.get("eval_loss", None)),
            *additional_columns_data,
        )

    if not is_done:
        table.add_row("...", "", "", "", "", "")

    return table


def stream_response(job_id: str, method, additional_columns: dict[str, str]={}):
    """stream_response streams messages from the provided method

    method should be a Pi client object with `retrieve` and `stream_messages`
    endpoints.  This is primarily for convenience."""

    training_data = defaultdict(dict)
    is_log_console = False

    while True:
        response = method.retrieve(job_id=job_id)
        if (response.state != "QUEUED") and (response.state != "RUNNING"):
            if response.state == "DONE" and not is_log_console:
                for line in response.detailed_status:
                    try:
                        data_dict = json.loads(line)
                        training_data[data_dict["step"]].update(data_dict)
                    except Exception:
                        pass
                console.print(
                    generate_table(
                        job_id,
                        training_data,
                        is_done=True,
                        additional_columns=additional_columns,
                    )
                )
            return response

        with method.with_streaming_response.stream_messages(
            job_id=job_id, timeout=None
        ) as response:
            with Live(auto_refresh=True, console=console, refresh_per_second=4) as live:
                is_done = False
                for line in response.iter_lines():
                    if line == "DONE":
                        is_done = True
                    try:
                        data_dict = json.loads(line)
                        training_data[data_dict["step"]].update(data_dict)
                    except Exception:
                        pass
                    live.update(
                        generate_table(
                            job_id,
                            training_data,
                            is_done,
                            additional_columns=additional_columns,
                        )
                    )
                    is_log_console = True


# Load contract

Load the `Aesop AI` example from Pi Labs cookbooks, or edit below to load a different one.


In [5]:
aesop_contract = load_contract("https://raw.githubusercontent.com/withpi/cookbook-withpi/refs/heads/main/contracts/aesop_ai.json")

## Generate an Input Set

Given this structured description, let's build a Dataset containing a bunch of plausible moral lessons that could be used to exercise the contract.  This will take about 30 seconds to generate.

In [18]:
data_generation_status = client.data.inputs.generate_from_seeds.generate(
    application_description=aesop_contract.description,
    num_inputs_to_generate=10,
    seeds=[],
    batch_size=3,
    num_shots=3,
)

## Stream the messages as the inputs are generated

The messages provide detail about what is being done to generate the data.

In [None]:
while True:
    job_status = client.data.inputs.generate_from_seeds.retrieve(data_generation_status.job_id)
    if job_status.state in ["ERROR", "DONE"]:
        break

    # Steam Messages
    with client.data.inputs.with_streaming_response.generate_from_seeds.stream_messages(data_generation_status.job_id) as response:
        for line in response.iter_lines():
            print(line)

One can also stream the inputs instead as shown in the cell below

In [None]:
with client.data.inputs.with_streaming_response.generate_from_seeds.stream_data(data_generation_status.job_id) as response:
    for line in response.iter_lines():
        print(f"Generated Input: {line}")

## Take a look at the returned data

Take a look at the returned inputs

In [None]:
# Print all the data now that the job is complete.
if job_status.state not in ["ERROR", "DONE"]:
  print("Please wait for the job to finish and then run this cell again...")
else:
    if job_status.state == "DONE":
        print("Printing all the generated inputs below...")
        assert job_status.data is not None
        for input in job_status.data:
            print(input)
    else:
        print("Job ended in error")

# Augment with responses

Now run inference with your favorite LLM to generate responses to the default prompt.  We will optimize this later.

The below cell uses LiteLLM.  You can get a Gemini key from the left side for free to try it out.  You may need a small amount of money in your account to run to completion.

In [None]:
os.environ["GEMINI_API_KEY"] = userdata.get('GOOGLE_API_KEY')

input_set = datasets.Dataset.from_dict({'input': job_status.data})
input_set = input_set.add_column(
    'output', do_bulk_inference(input_set, aesop_contract.description, "gemini/gemini-1.5-flash-8b"))
input_set

## Save the set

We will come back to this in a future colab, so it's useful to capture.  Store it as a Parquet table, which you can download.

Alternatively, upload to Hugging Face.

In [None]:
from google.colab import files

filename = "aesop_ai_examples.parquet"
input_set.to_parquet(filename)
files.download(filename)

## Next Steps

This input set can drive many other techniques in Pi.  You can adjust the above methods to add seeds and steer the AI in different ways.