<a href="https://colab.research.google.com/github/withpi/cookbook-withpi/blob/main/colabs/SFT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://withpi.ai"><img src="https://withpi.ai/logoFullBlack.svg" width="240"></a>

<a href="https://code.withpi.ai"><font size="4">Documentation</font></a>

<a href="https://play.withpi.ai"><font size="4">Technique Catalog</font></a>

# Supervised Fine-tuning (SFT) with Standard Gradient Descent

This is the companion to the SFT playground

Description: Train models to more deeply learn patterns from your data.

## Install and initialize SDK

Connect to a regular CPU Python 3 runtime.  You won't need GPUs for this notebook.

You'll need a WITHPI_API_KEY from https://play.withpi.ai.  Add it to your notebook secrets (the key symbol) on the left.

Run the cell below to install packages and load the SDK

In [None]:
%%capture

import os
from google.colab import files, userdata

# Load the notebook secret into the environment so the Pi Client can access it.
os.environ["WITHPI_API_KEY"] = userdata.get('WITHPI_API_KEY')

%pip install withpi litellm httpx datasets jinja2 tqdm

# Import a bunch of useful libraries for later.
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
import time
import json
from pathlib import Path
import re

import datasets
import httpx
import litellm
import jinja2
from tqdm.notebook import tqdm
from withpi import PiClient
from withpi.types import Contract

from rich.console import Console
from rich.table import Table
from rich.live import Live

console = Console()

client = PiClient()

def print_contract(contract: Contract):
  """print_contract pretty-prints a contract"""
  for dimension in contract.dimensions:
    print(dimension.label)
    for sub_dimension in dimension.sub_dimensions:
      print(f"\t{sub_dimension.description}")

def generate(system: str, user: str, model: str) -> str:
  """generate passes the provided system and user prompts into the given model
  via LiteLLM"""
  messages = [
    {
      "content": system,
      "role": "system"
    },
    {
      "content": user,
      "role": "user"
    }
  ]
  return litellm.completion(model=model,
                            messages=messages).choices[0].message.content

class printer(str):
  """printer makes strings with embedded newlines print more nicely"""
  def __repr__(self):
    return self
def print_response(response: str):
  """print_response pretty-prints an LLM response, respecting newlines"""
  display(printer(response))

def print_scores(pi_scores):
  """print_scores pretty-prints a Pi Score response as a table."""
  for dimension_name, dimension_scores in pi_scores.dimension_scores.items():
    print(f"{dimension_name}: {dimension_scores.total_score}")
    for subdimension_name, subdimension_score in dimension_scores.subdimension_scores.items():
      print(f"\t{subdimension_name}: {subdimension_score}")
    print("\n")
  print("---------------------")
  print(f"Total score: {pi_scores.total_score}")

def save_file(filename: str, model: str):
  """save_file offers to download the model with the given filename"""
  Path(filename).write_text(model)
  files.download(filename)

def load_contract(url: str) -> Contract:
  """load_contract pulls a Contract JSON blob locally with validation."""
  resp = httpx.get(url)
  return Contract.model_validate_json(resp.content)

def load_and_split_dataset(url: str) -> datasets.DatasetDict:
  """load_and_split_dataset pulls in the Parquet file at url and does a 90/10 split"""
  return datasets.load_dataset('parquet', data_files=url, split="train").train_test_split(test_size=0.1)

def do_bulk_inference(dataset, system, model):
  """do_bulk_inference performs inference on the 'input' column of dataset, using
  the provided system prompt.  The model identified will be used via LiteLLM"""

  def do_generate(user, pbar):
    result = generate(system, user, model)
    pbar.update(1)
    return result

  futures = []
  pbar = tqdm(total=len(dataset))
  with ThreadPoolExecutor(max_workers=4) as executor:
    for row in dataset:
      futures.append(executor.submit(do_generate, row["input"], pbar))
  return [future.result() for future in futures]

def do_bulk_templated_inference(dataset, optimized, model):
  """do_bulk_templated_inference performs inference on the 'input' column of dataset,
  using the provided optimized prompt.  It should be a Jinja2 template as returned
  by DSPy"""
  prompt_template = jinja2.Template(optimized)
  result_extractor = re.compile(r".*\[\[ ## response ## \]\](.*)\[\[ ## completed ## \]\]", re.DOTALL)

  def do_generate(prompt: str, pbar) -> str:
    messages = json.loads(prompt_template.render(input=prompt))
    result = litellm.completion(model=model,
                                messages=messages).choices[0].message.content

    pbar.update(1)
    return result_extractor.match(result).group(1)

  futures = []
  pbar = tqdm(total=len(dataset))
  with ThreadPoolExecutor(max_workers=4) as executor:
    for row in dataset:
      futures.append(executor.submit(do_generate, row["input"], pbar))
  return [future.result() for future in futures]

def generate_table(
    job_id: str, training_data: dict, is_done: bool, additional_columns: dict[str, str]
):
    """Generate a training progress table dynamically."""
    table = Table(title=f"Training Status for {job_id}")

    # Define columns
    table.add_column("Step", justify="right", style="cyan")
    table.add_column("Epoch", justify="right", style="cyan")
    table.add_column("Learning Rate", justify="right", style="cyan")
    table.add_column("Train Loss", justify="right", style="magenta")
    table.add_column("Eval Loss", justify="right", style="green")
    for header in additional_columns.keys():
        table.add_column(header, justify="right", style="black")

    def format_num(num: float | None, digits: int = 4) -> str:
        if num is None:
            return "X"
        return format(num, f".{digits}f")

    for step, data in training_data.items():
        additional_columns_data = [
            format_num(data.get(column_name, None))
            for column_name in additional_columns.values()
        ]
        table.add_row(
            str(step),
            format_num(data.get("epoch", None)),
            format_num(data.get("learning_rate", None), digits=10),
            format_num(data.get("loss", None)),
            format_num(data.get("eval_loss", None)),
            *additional_columns_data,
        )

    if not is_done:
        table.add_row("...", "", "", "", "", "")

    return table


def stream_response(job_id: str, method, additional_columns: dict[str, str]):
    """stream_response streams messages from the provided method

    method should be a Pi client object with `retrieve` and `stream_messages`
    endpoints.  This is primarily for convenience."""

    training_data = defaultdict(dict)
    is_log_console = False

    while True:
        response = method.retrieve(job_id=job_id)
        if (response.state != "QUEUED") and (response.state != "RUNNING"):
            if response.state == "DONE" and not is_log_console:
                for line in response.detailed_status:
                    try:
                        data_dict = json.loads(line)
                        training_data[data_dict["step"]].update(data_dict)
                    except Exception:
                        pass
                console.print(
                    generate_table(
                        job_id,
                        training_data,
                        is_done=True,
                        additional_columns=additional_columns,
                    )
                )
            return response

        with method.with_streaming_response.stream_messages(
            job_id=job_id, timeout=None
        ) as response:
            with Live(auto_refresh=True, console=console, refresh_per_second=4) as live:
                is_done = False
                for line in response.iter_lines():
                    if line == "DONE":
                        is_done = True
                    try:
                        data_dict = json.loads(line)
                        training_data[data_dict["step"]].update(data_dict)
                    except Exception:
                        pass
                    live.update(
                        generate_table(
                            job_id,
                            training_data,
                            is_done,
                            additional_columns=additional_columns,
                        )
                    )
                    is_log_console = True


# Load a contract and dataset

We have a pre-existing contract you can play with.


In [None]:
import datasets

tldr_contract = load_contract(
    "https://raw.githubusercontent.com/withpi/cookbook-withpi/refs/heads/main/contracts/tldr.json"
)

num_examples = 200
tldr_data = datasets.load_dataset("withpi/tldr")["train"].select(range(num_examples))

print(tldr_data)

## Kick off the job

The SFT job internally performs a 90/10 train-test split, which is why the loader is not splitting the input data.

This process takes a while, please be patient as a cloud GPU is aquired, fine tuning is performed, and a result is returned.

In [None]:
status = client.model.sft.start_job(
    contract=tldr_contract,
    examples=[
        {"llm_input": row["prompt"], "llm_output": row["completion"]}
        for row in tldr_data
    ],
    base_sft_model="LLAMA_3.2_3B",
    num_train_epochs=5,
)
print(status)

## Monitor for completion

Now run the following cell to see how training is progressing.

In [None]:
response = stream_response(
    status.job_id,
    client.model.sft,
    additional_columns={"Pi Score": "contract_score"},
)
print("SFT model = {}".format(response.trained_models[0].model_dump_json(indent=2)))

# Now Load the model!

Run the following cell to put it into an inference backend.

In [None]:
client.model.sft.load(
    status.job_id,
)

for idx in range(200):
  is_done = client.model.sft.check(
      status.job_id,
  )
  if is_done:
    print("Loaded!")
    break
  else:
    time.sleep(3)
if not is_done:
  print("Did not load in time.")

# Query the model!

Models are hosted on [Fireworks](https://fireworks.ai) accessible with your API key through your favorite frontend.  Try the following:

In [None]:
prompt = """SUBREDDIT: r/relationships TITLE: I (f/22) have to figure out if I want to still know these girls or not and would hate to sound insulting POST: Not sure if this belongs here but it's worth a try. Backstory: When I (f/22) went through my first real breakup 2 years ago because he needed space after a year of dating roand it effected me more than I thought. It was a horrible time in my life due to living with my mother and finally having the chance to cut her out of my life. I can admit because of it was an emotional wreck and this guy was stable and didn't know how to deal with me. We ended by him avoiding for a month or so after going to a festival with my friends. When I think back I wish he just ended. So after he ended it added my depression I suffered but my friends helped me through it and I got rid of everything from him along with cutting contact. Now: Its been almost 3 years now and I've gotten better after counselling and mild anti depressants. My mother has been out of my life since then so there's been alot of progress. Being stronger after learning some lessons there been more insight about that time of my life but when I see him or a picture everything comes back. The emotions and memories bring me back down. His friends (both girls) are on my facebook because we get along well which is hard to find and I know they'll always have his back. But seeing him in a picture or talking to him at a convention having a conversation is tough. Crying confront of my current boyfriend is something I want to avoid. So I've been thinking that I have to cut contact with these girls because it's time to move on because it's healthier. It's best to avoid him as well. But will they be insulted? Will they accept it? Is there going to be awkwardness? I'm not sure if it's the right to do and could use some outside opinions. TL;DR:"""

response = litellm.text_completion(
    prompt=prompt,
    model="fireworks_ai/unused",
    api_base=f"https://api.withpi.ai/v1/model/sft/{status.job_id}",
    api_key=os.environ["WITHPI_API_KEY"],
    max_tokens=2048,
)

print("Raw Completion response:\n")
print_response(response.choices[0].text)

response = litellm.completion(
    messages=[
        {"content": prompt,
         "role": "user",
    }],
    model="fireworks_ai/unused",
    api_base=f"https://api.withpi.ai/v1/model/sft/{status.job_id}",
    api_key=os.environ["WITHPI_API_KEY"],
    max_tokens=2048
)
print("\nChat completion:\n")
print_response(response.choices[0].message.content)