<a href="https://colab.research.google.com/github/withpi/cookbook-withpi/blob/main/colabs/GRPO_with_Pi_Scorer_TLDR_Usecase.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://withpi.ai"><img src="https://play.withpi.ai/logo/logoFullBlack.svg" width="240"></a>

<a href="https://code.withpi.ai"><font size="4">Documentation</font></a>

<a href="https://play.withpi.ai"><font size="4">Technique Catalog</font></a>

# Introduction to GRPO with Pi-Scorer

Group Relative Policy Optimization (GRPO) represents a significant advancement in the field of reinforcement learning framework. By incorporating our innovative Pi-Scorer reward function, we've created a more robust method for optimizing agent behavior across complex environments. This approach represents a promising direction for developing more capable and aligned AI systems, with applications.


In this colab we will demonstrate how to apply GRPO with Pi-Scorer on the TLDR problem.

# Install packages and utility functions
Here we are installing the Pi SDK, and we're also importing a few additonal things to help out this use case including a dataset utility as well as functions to help us more legibly print scores and Side by Side comparisons

In [38]:
# @title Install necessary packages
%%capture
%pip install withpi
%pip install datasets
%pip install litellm
%pip install httpx jinja2 tqdm

In [39]:
# @title Import utility functions for Pi SDK
%%capture

import os
from google.colab import files, userdata

# Load the notebook secret into the environment so the Pi Client can access it.
os.environ["WITHPI_API_KEY"] = userdata.get('WITHPI_API_KEY')

# Import a bunch of useful libraries for later.
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
import json
from pathlib import Path
import re

import datasets
import httpx
import litellm
import jinja2
from tqdm.notebook import tqdm
from withpi import PiClient
from withpi.types import Contract
from IPython.display import display
import pandas as pd

client = PiClient()


def print_contract(contract: Contract):
    """print_contract pretty-prints a contract"""
    for dimension in contract.dimensions:
        print(dimension.label)
        for sub_dimension in dimension.sub_dimensions:
            print(f"\t{sub_dimension.description}")


def generate(system: str, user: str, model: str) -> str:
    """generate passes the provided system and user prompts into the given model
    via LiteLLM"""
    messages = [
        {"content": system, "role": "system"},
        {"content": user, "role": "user"},
    ]
    return litellm.completion(model=model, messages=messages).choices[0].message.content


class printer(str):
    """printer makes strings with embedded newlines print more nicely"""

    def __repr__(self):
        return self


def print_response(response: str):
    """print_response pretty-prints an LLM response, respecting newlines"""
    display(printer(response))


def print_scores(pi_scores):
    """print_scores pretty-prints a Pi Score response as a table."""
    for dimension_name, dimension_scores in pi_scores.dimension_scores.items():
        print(f"{dimension_name}: {dimension_scores.total_score}")
        for (
            subdimension_name,
            subdimension_score,
        ) in dimension_scores.subdimension_scores.items():
            print(f"\t{subdimension_name}: {subdimension_score}")
        print("\n")
    print("---------------------")
    print(f"Total score: {pi_scores.total_score}")


def save_file(filename: str, model: str):
    """save_file offers to download the model with the given filename"""
    Path(filename).write_text(model)
    files.download(filename)


def load_contract(url: str) -> Contract:
    """load_contract pulls a Contract JSON blob locally with validation."""
    resp = httpx.get(url)
    return Contract.model_validate_json(resp.content)


def load_and_split_dataset(url: str) -> datasets.DatasetDict:
    """load_and_split_dataset pulls in the Parquet file at url and does a 90/10 split"""
    return datasets.load_dataset(
        "parquet", data_files=url, split="train"
    ).train_test_split(test_size=0.1)


def do_bulk_inference(dataset, system, model):
    """do_bulk_inference performs inference on the 'input' column of dataset, using
    the provided system prompt.  The model identified will be used via LiteLLM"""

    def do_generate(user, pbar):
        result = generate(system, user, model)
        pbar.update(1)
        return result

    futures = []
    pbar = tqdm(total=len(dataset))
    with ThreadPoolExecutor(max_workers=4) as executor:
        for row in dataset:
            futures.append(executor.submit(do_generate, row["input"], pbar))
    return [future.result() for future in futures]


def do_bulk_templated_inference(dataset, optimized, model):
    """do_bulk_templated_inference performs inference on the 'input' column of dataset,
    using the provided optimized prompt.  It should be a Jinja2 template as returned
    by DSPy"""
    prompt_template = jinja2.Template(optimized)
    result_extractor = re.compile(
        r".*\[\[ ## response ## \]\](.*)\[\[ ## completed ## \]\]", re.DOTALL
    )

    def do_generate(prompt: str, pbar) -> str:
        messages = json.loads(prompt_template.render(input=prompt))
        result = (
            litellm.completion(model=model, messages=messages)
            .choices[0]
            .message.content
        )

        pbar.update(1)
        return result_extractor.match(result).group(1)

    futures = []
    pbar = tqdm(total=len(dataset))
    with ThreadPoolExecutor(max_workers=4) as executor:
        for row in dataset:
            futures.append(executor.submit(do_generate, row["input"], pbar))
    return [future.result() for future in futures]


def generate_table(
    job_id: str, training_data: dict, is_done: bool, additional_columns: dict[str, str]
):
    """Generate a training progress table dynamically."""
    data_dict = {}
    for header in ["Step", "Epoch", "Learning_Rate", "Training_Loss", "Eval_Loss"]:
        data_dict[header] = []
    for header in additional_columns.keys():
        data_dict[header] = []

    for step, data in training_data.items():
        data_dict["Step"].append(step)
        for header, key in [
            ("Epoch", "epoch"),
            ("Learning_Rate", "learning_rate"),
            ("Training_Loss", "loss"),
            ("Eval_Loss", "eval_loss"),
        ]:
            data_dict[header].append(data.get(key, "X"))
        for header, key in additional_columns.items():
            data_dict[header].append(data.get(key, "X"))

    if not is_done:
        data_dict["Step"].append("...")
        for header in ["Epoch", "Learning_Rate", "Training_Loss", "Eval_Loss"]:
            data_dict[header].append("")
        for header in additional_columns.keys():
            data_dict[header].append("")

    return pd.DataFrame(data_dict)


def stream_response(job_id: str, method, additional_columns: dict[str, str]):
    """stream_response streams messages from the provided method

    method should be a Pi client object with `retrieve` and `stream_messages`
    endpoints.  This is primarily for convenience."""

    print(f"Training Status for {job_id}")

    training_data = defaultdict(dict)
    is_log_console = False

    stream_output = display(
        generate_table(
            job_id, training_data, is_done=False, additional_columns=additional_columns
        ),
        display_id=True,
    )

    while True:
        response = method.retrieve(job_id=job_id)
        if (response.state != "QUEUED") and (response.state != "RUNNING"):
            if response.state == "DONE" and not is_log_console:
                for line in response.detailed_status:
                    try:
                        data_dict = json.loads(line)
                        training_data[data_dict["step"]].update(data_dict)
                    except Exception:
                        pass
                stream_output.update(
                    generate_table(
                        job_id,
                        training_data,
                        is_done=True,
                        additional_columns=additional_columns,
                    )
                )
            return response

        with method.with_streaming_response.stream_messages(
            job_id=job_id, timeout=None
        ) as response:
            is_done = False
            for line in response.iter_lines():
                if line == "DONE":
                    is_done = True
                try:
                    data_dict = json.loads(line)
                    training_data[data_dict["step"]].update(data_dict)
                except Exception:
                    pass
                stream_output.update(
                    generate_table(
                        job_id,
                        training_data,
                        is_done,
                        additional_columns=additional_columns,
                    )
                )
                is_log_console = True


In [40]:
# @title Import a utility function to pretty print Pi scores
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

def score_to_color(score):
    score = np.clip(score, 0, 1)  # Ensure score is within [0, 1]

    # Define the key color points
    colors = [
        (0.0, "#e74c3c"),  # Red
        (0.3, "#e67e22"),  # Orange
        (0.5, "#f1c40f"),  # Yellow
        (0.7, "#2ecc71"),  # Green-ish
        (1.0, "#27ae60")   # Bright Green
    ]

    # Create a colormap
    cmap = LinearSegmentedColormap.from_list("custom_colormap", [c[1] for c in colors], N=256)

    # Normalize score to the colormap range
    rgba = cmap(score)

    # Convert RGBA to HEX
    return '#{:02x}{:02x}{:02x}'.format(int(rgba[0]*255), int(rgba[1]*255), int(rgba[2]*255))

def print_scores(pi_scores):
  score_html = """
  <style>
  table {
    border-collapse: collapse; /* Ensures borders don't double up */
    width: 100%; /* Optional: makes the table full width */
  }

  tr {
    border-bottom: 1px solid #ccc; /* Sets a bottom border for each row */
    border-top: 1px solid #ccc; /* Sets a bottom border for each row */
  }

  th, td {
    font-weight: bold;
    padding: 4px; /* Adds some spacing */
    text-align: left; /* Aligns text to the left */
    border-right: 1px solid #ccc; /* Sets a bottom border for each row */
    border-left: 1px solid #ccc; /* Sets a bottom border for each row */
  }
  img {
    width: 30%;
  }
  </style>
  <table>"""

  for dimension_name, dimension_scores in pi_scores.dimension_scores.items():
    score_html += f"<tr><td><b>{dimension_name}</b></td><td></td><td style='color: {score_to_color(dimension_scores.total_score)}'>{round(dimension_scores.total_score, 3)}</td></tr>" + "\n"
    for subdimension_name, subdimension_score in dimension_scores.subdimension_scores.items():
      score_html += f"<tr><td></td><td style='font-weight: normal;'>{subdimension_name}</td><td style='color: {score_to_color(subdimension_score)}'>{round(subdimension_score, 3)}</td></tr>" + "\n"
    score_html += "\n\n"
  score_html += "<tr></tr>" + "\n"
  score_html += f"<tr><td>Total score</td><td></td><td style='color: {score_to_color(pi_scores.total_score)}'><b>{round(pi_scores.total_score, 3)}</b></td></tr>" + "\n"
  score_html += "</table>"
  return score_html

In [41]:
# @title Import a utility function to pretty print side by sides with Pi scores
from IPython.core.display import display, HTML
import markdown


def pretty_print_responses(response1, response2=None, header=None, left_label="Base", right_label="Test", scores_left=None, scores_right=None, debug_left=None, debug_right=None):
    md1 = markdown.markdown(response1)
    if response2:
      md2 = markdown.markdown(response2)
    else:
      md2 = markdown.markdown("")

    #print(md2)

    if scores_left:
      scores_left = print_scores(scores_left)
    if scores_right:
      scores_right = print_scores(scores_right)

    if header:
      header = markdown.markdown(header)
      html = f"""
      <div style="display: flex; gap: 40px;">
          <div style="width: 80%; padding: 30px; border: 1px solid #ddd; background-color: #fff9f5;">
              <h4>{header}</h4>
          </div>
      </div>"""
    else:
      html = ""

    html += f"""
    <div style="display: flex; gap: 20px;">
        <div style="width: 40%; padding: 10px; border: 1px solid #ddd; background-color: #f0f0f0; text-align:center;">
            <h4>{left_label}</h4>
        </div>
        <div style="width: 40%; padding: 10px; border: 1px solid #ddd; background-color: #f0f0f0; text-align:center;">
            <h4>{right_label}</h4>
        </div>
    </div>
    <div style="display: flex; gap: 20px;">
        <div style="width: 40%; padding: 10px; border: 1px solid #ddd;">
            {md1}
        </div>
        <div style="width: 40%; padding: 10px; border: 1px solid #ddd;">
            {md2}
        </div>
    </div>
    """
    if scores_left or scores_right:
      html += f"""
        <div style="display: flex; gap: 20px;">
            <div style="width: 40%; padding: 10px; border: 1px solid #ddd;  background-color: #f2f1fe;">
                {scores_left or ""}
            </div>
            <div style="width: 40%; padding: 10px; border: 1px solid #ddd;  background-color: #f2f1fe;">
                {scores_right or ""}
            </div>
        </div>"""
    if debug_left or debug_right:
      html += f"""
        <div style="display: flex; gap: 20px;">
            <div style="width: 40%; padding: 10px; border: 1px solid #ddd; background-color: #f0f0f0;">
                {debug_left or ""}
            </div>
            <div style="width: 40%; padding: 10px; border: 1px solid #ddd; background-color: #f0f0f0;">
                {debug_right or ""}
            </div>
        </div>"""

    display(HTML(html))

# Define your scoring system


In [8]:
# @title Initialize the Pi scoring system from a JSON description

from withpi.types import Contract

tldr_scoring_system_json = """
{
  "name": "Default",
  "description": "Generate a short TLDR of a subreddit post without any surrounding text. Here are some requirement of the TLDR: 1. Make sure that the TLDR is 1 to 3 sentence long and no more than 35 words. 2. Make sure that the TLDR state the important points of the post 3. Make sure that the TLDR should make sense on its own.",
  "dimensions": [
    {
      "label": "Length",
      "description": "Length",
      "sub_dimensions": [
        {
          "label": "Length Compliance",
          "description": "Is the TLDR between 15 to 35 word long?",
          "scoring_type": "PYTHON_CODE",
          "python_code": "\\nimport re\\n\\ndef score(\\n    response_text: str,\\n    input_text: str,\\n    kwargs: dict,\\n) -> dict:\\n    response_len = len(re.findall(r\'\\\\S+\', response_text))\\n\\n    return {\\n      \\"score\\": 1.0 if response_len < 35 and response_len > 15 else 0.0,\\n      \\"explanation\\": \\"\\"\\n    }\\n"
        }
      ],
      "weight": 1.0
    },
    {
      "label": "Structure",
      "description": "Structure",
      "sub_dimensions": [
        {
          "label": "Conciseness",
          "description": "Is the TLDR concise and to the point?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Redundancy",
          "description": "Does the TLDR avoid redundant information?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Repetition",
          "description": "Does the TLDR avoid repetition of the same point?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Incomplete Sentences",
          "description": "Does the TLDR avoid incomplete sentences?",
          "scoring_type": "PI_SCORER"
        }
      ],
      "weight": 0.3
    },
    {
      "label": "Content Accuracy",
      "description": "Content Accuracy",
      "sub_dimensions": [
        {
          "label": "Important Points",
          "description": "Does the TLDR state the important points of the post?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "Relevance",
          "description": "Is the content of the TLDR relevant to the original post?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "Factually Correct",
          "description": "Is the information in the TLDR factually correct?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Assumptions",
          "description": "Does the TLDR avoid making assumptions not supported by the original post?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Speculation",
          "description": "Does the TLDR avoid speculation?",
          "scoring_type": "PI_SCORER"
        }
      ],
      "weight": 0.3
    },
    {
      "label": "Clarity and Readability",
      "description": "Clarity and Readability",
      "sub_dimensions": [
        {
          "label": "Clarity",
          "description": "Is the TLDR clear and easy to understand?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "Grammar",
          "description": "Is the TLDR grammatically correct?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "Spelling",
          "description": "Is the TLDR free of spelling errors?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "Proper Punctuation",
          "description": "Is the TLDR properly punctuated?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Ambiguity",
          "description": "Is the TLDR free from ambiguity?",
          "scoring_type": "PI_SCORER"
        }
      ],
      "weight": 0.3
    },
    {
      "label": "Objectivity and Neutrality",
      "description": "Objectivity and Neutrality",
      "sub_dimensions": [
        {
          "label": "No Personal Opinions",
          "description": "Does the TLDR avoid including personal opinions?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "Objective Language",
          "description": "Is the language in the TLDR objective?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Bias",
          "description": "Is the TLDR free from bias?",
          "scoring_type": "PI_SCORER"
        }
      ],
      "weight": 0.3
    },
    {
      "label": "Self-Containment",
      "description": "Self-Containment",
      "sub_dimensions": [
        {
          "label": "Self-Contained",
          "description": "Does the TLDR make sense on its own without needing to refer to the original post?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Contradictions",
          "description": "Is the TLDR free from contradictions?",
          "scoring_type": "PI_SCORER"
        }
      ],
      "weight": 0.3
    },
    {
      "label": "Language Use",
      "description": "Language Use",
      "sub_dimensions": [
        {
          "label": "No Jargon",
          "description": "Does the TLDR avoid unnecessary jargon?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Hyperbole",
          "description": "Does the TLDR avoid hyperbole?",
          "scoring_type": "PI_SCORER"
        }
      ],
      "weight": 0.3
    },
    {
      "label": "Relevance and Focus",
      "description": "Relevance and Focus",
      "sub_dimensions": [
        {
          "label": "No Extraneous Information",
          "description": "Does the TLDR avoid including extraneous information not present in the original post?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Irrelevant Details",
          "description": "Does the TLDR avoid irrelevant details?",
          "scoring_type": "PI_SCORER"
        }
      ],
      "weight": 0.3
    }
  ]
}
"""
tldr_scoring_system = Contract.model_validate_json(tldr_scoring_system_json)


# Generate TLDR using base model

In [9]:
# @title Define a system prompt for TLDR
system_prompt_for_tldr = """
Generate a short TLDR of a subreddit post without any introduction text or "TL;DR:". Here are the requirements of the TLDR:
1. Make sure that the TLDR is 1 to 3 sentence long and and no more than 35 words.
2. Make sure that the TLDR state the important points of the post.
3. Make sure that the TLDR should make sense on its own.
"""

In [10]:
# @title Define a TLDR generator
import litellm
import asyncio

async def generate_tldrs(reddit_posts, system_prompt, model_id, api_base, api_key, concurrency_limit=5):
    """Generate TLDR for all REDDIT posts with TaskGroup and rate limiting"""
    # Create a semaphore to limit concurrent API calls
    semaphore = asyncio.Semaphore(concurrency_limit)

    async def generate_single_tldr(reddit_post, index):
        """Process a single REDDIT post generation with rate limiting"""
        async with semaphore:
            try:
                response = await litellm.acompletion(
                    messages=[
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": reddit_post},
                    ],
                    model=model_id,
                    api_base=api_base,
                    api_key=api_key,
                    temperature=0.2,
                )
                generated_tldr = response.choices[0].message.content
                print("Generated a tldr for post #{}: {}".format(index, reddit_post[:40]))
                return generated_tldr
            except Exception as e:
                print(f"Error generating tldr for post #{index}: {e}")
                return f"Error: {str(e)}"

    generated_tldrs = []

    # Using TaskGroup for cleaner task management
    async with asyncio.TaskGroup() as tg:
        tasks = [
            tg.create_task(generate_single_tldr(reddit_post, i + 1))
            for i, reddit_post in enumerate(reddit_posts)
        ]

    # Collect results in the same order as topics
    for task in tasks:
        generated_tldrs.append(task.result())

    print("Done generating TLDRs!!")
    return generated_tldrs

In [36]:
# @title Prepare the REDDIT posts
from datasets import load_dataset
from google.colab import userdata

ds = load_dataset("withpi/tldr", split="train").select(range(200))

reddit_posts = ds["prompt"]

In [11]:
# @title Generate TLDRs using an untrained model for evaluation

# Generate the blogs using an untrained llama 8B
loop = asyncio.get_running_loop()
generated_tldrs = await loop.create_task(
    generate_tldrs(
        reddit_posts,
        model_id="fireworks_ai/llama-v3p2-3b-instruct",
        api_key=userdata.get("FIREWORKS_API_KEY"),
        api_base = None,
        system_prompt=system_prompt_for_tldr
    )
)

README.md:   0%|          | 0.00/319 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/47.4M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generated a tldr for post #1: SUBREDDIT: r/relationships

TITLE: I (f/
Generated a tldr for post #3: SUBREDDIT: r/relationships

TITLE: Me [1
Generated a tldr for post #2: SUBREDDIT: r/loseit

TITLE: SV & NSV! Ke
Generated a tldr for post #4: SUBREDDIT: r/personalfinance

TITLE: Pri
Generated a tldr for post #5: SUBREDDIT: r/relationships

TITLE: My[25
Generated a tldr for post #6: SUBREDDIT: r/relationships

TITLE: Me 28
Generated a tldr for post #7: SUBREDDIT: r/relationships

TITLE: Is it
Generated a tldr for post #8: SUBREDDIT: r/relationships

TITLE: I (27
Generated a tldr for post #9: SUBREDDIT: r/relationships

TITLE: Advic
Generated a tldr for post #11: SUBREDDIT: r/offmychest

TITLE: I'm just
Generated a tldr for post #10: SUBREDDIT: r/relationships

TITLE: Me [2
Generated a tldr for post #12: SUBREDDIT: r/relationship_advice

TITLE:
Generated a tldr for post #13: SUBREDDIT: r/relationships

TITLE: Me [ 
Generated a tldr for post #15: SUBREDDIT: r/dating_advice

TITLE: I thi
G

In [12]:
# @title Use the Pi scoring system to evaluate the TLDRs' quality
from tqdm import tqdm
import pandas as pd

df_data = []
scores = []
for reddit_post, tldr in tqdm(zip(reddit_posts, generated_tldrs)):
  score = client.contracts.score(
      llm_input=reddit_post,
      llm_output=tldr,
      contract=tldr_scoring_system)
  scores.append(score)
  df_data.append({'reddit post': reddit_post, 'tldr': tldr, 'score': score.total_score})

df = pd.DataFrame(df_data)
df

100it [00:35,  2.85it/s]


Unnamed: 0,reddit post,tldr,score
0,SUBREDDIT: r/relationships\n\nTITLE: I (f/22) ...,I've had a past relationship that ended due to...,0.614529
1,SUBREDDIT: r/loseit\n\nTITLE: SV & NSV! Keepin...,I weighed myself weekly and measured myself mo...,0.348343
2,SUBREDDIT: r/relationships\n\nTITLE: Me [19F] ...,I have acne scars from a long-standing skin co...,0.622480
3,SUBREDDIT: r/personalfinance\n\nTITLE: Priorit...,I have $25k in student debt with a minimum pay...,0.638810
4,SUBREDDIT: r/relationships\n\nTITLE: My[25m] g...,My girlfriend only shows affection and is plea...,0.631817
...,...,...,...
95,SUBREDDIT: r/relationships\n\nTITLE: My [30 F]...,"I'll split household expenses 50/50, including...",0.529851
96,SUBREDDIT: r/relationships\n\nTITLE: Me[19M] p...,I've been bedridden for 2 weeks after a lumbar...,0.663710
97,SUBREDDIT: r/relationships\n\nTITLE: Am I bein...,I'm upset that my boyfriend wouldn't give me a...,0.664441
98,SUBREDDIT: r/relationships\n\nTITLE: My boyfri...,"I've had 29 sexual partners, including a three...",0.673803


In [14]:
# @title Manually inspect TLDR with scores
def pretty_print_blog(i):
  pretty_print_responses(
      response1 = generated_tldrs[i],
      header="##### " + reddit_posts[i],
      left_label="Base",
      scores_left=scores[i])

for i in range(10):
  pretty_print_blog(i)
  print("\n\n")

0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.881
,Conciseness,0.777
,No Redundancy,0.996
,No Repetition,0.945
,No Incomplete Sentences,0.805
Content Accuracy,,0.909
,Important Points,0.746
,Relevance,0.801







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.474
,Conciseness,0.531
,No Redundancy,0.473
,No Repetition,0.443
,No Incomplete Sentences,0.447
Content Accuracy,,0.468
,Important Points,0.41
,Relevance,0.408







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.879
,Conciseness,0.926
,No Redundancy,0.984
,No Repetition,0.934
,No Incomplete Sentences,0.672
Content Accuracy,,0.882
,Important Points,0.711
,Relevance,0.699







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.936
,Conciseness,0.914
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,0.828
Content Accuracy,,0.963
,Important Points,0.816
,Relevance,0.996







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.902
,Conciseness,0.82
,No Redundancy,1.0
,No Repetition,0.984
,No Incomplete Sentences,0.805
Content Accuracy,,0.956
,Important Points,0.781
,Relevance,1.0







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,1.0
,Conciseness,1.0
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,1.0
Content Accuracy,,0.957
,Important Points,0.797
,Relevance,0.988







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.526
,Conciseness,0.719
,No Redundancy,0.727
,No Repetition,0.262
,No Incomplete Sentences,0.398
Content Accuracy,,0.636
,Important Points,0.508
,Relevance,0.711







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.727
,Conciseness,0.742
,No Redundancy,0.762
,No Repetition,0.691
,No Incomplete Sentences,0.711
Content Accuracy,,0.805
,Important Points,0.734
,Relevance,0.738







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.84
,Conciseness,0.781
,No Redundancy,0.91
,No Repetition,0.863
,No Incomplete Sentences,0.805
Content Accuracy,,0.919
,Important Points,0.777
,Relevance,0.816







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.812
,Conciseness,0.773
,No Redundancy,0.859
,No Repetition,0.812
,No Incomplete Sentences,0.805
Content Accuracy,,0.909
,Important Points,0.805
,Relevance,1.0







# Use GRPO to fine tune the model to generate better TLDR


In [35]:
# @title Prepare training data without golden TLDR

examples = [{"llm_input": post} for post in reddit_posts]

In [17]:
# @title [SLOW - will run for 80+ minutes] Run GRPO on the model based on the above training data
status = client.model.rl.grpo.start_job(
    contract=tldr_scoring_system,
    examples=examples,
    base_rl_model="LLAMA_3.2_3B",
    system_prompt=system_prompt_for_tldr,
    lora_config={"lora_rank": "R_64"},
    learning_rate=5e-6,
    num_train_epochs=10,
)

print(status)

RlGrpoStatus(detailed_status=['LAUNCHING'], job_id='rl_grpo_jobs:09ac227c912792130a876568ea872593308c0d4b3d7c896ec7991f041cbeedd8:5f74c3aa-e82c-40b3-a86e-8a11e8717876', state='QUEUED', trained_models=[])


In [20]:
# @title Monitor the GRPO job for completion (watch the Eval_Pi_Score increase!)
response = stream_response(
    status.job_id,
    client.model.rl.grpo,
    additional_columns={
        "Train_Pi_Reward": "rewards/pi_reward_func",
        "Train_Std_Reward": "reward_std",
        "Eval_Pi_Reward": "eval_rewards/pi_reward_func",
        "Eval_Std_Reward": "eval_reward_std",
        "Train_KL": "kl",
        "Eval_KL": "eval_kl",
        "Train_Completion_Length": "completion_length",
        "Eval_Completion_Length": "eval_completion_length",
    },
)
if response.state == "ERROR":
  print("The job failed due to:\n{}".format('\n'.join(response.detailed_status[-5:])))
else:
  print("GRPO model = {}".format(response.trained_models[0].model_dump_json(indent=2)))


Training Status for rl_grpo_jobs:09ac227c912792130a876568ea872593308c0d4b3d7c896ec7991f041cbeedd8:5f74c3aa-e82c-40b3-a86e-8a11e8717876


Unnamed: 0,Step,Epoch,Learning_Rate,Training_Loss,Eval_Loss,Train_Pi_Reward,Train_Std_Reward,Eval_Pi_Reward,Eval_Std_Reward,Train_KL,Eval_KL,Train_Completion_Length,Eval_Completion_Length
0,0,0.0,X,X,0.027354,X,X,0.566328,0.112464,0.0,X,X,67.96667
1,22,0.488889,0.000005,0.0009,X,0.552291,0.112587,X,X,0.000706,X,70.238638,X
2,44,0.977778,0.000005,0.0188,0.001999,X,X,0.559805,0.097716,X,0.001489,X,68.691669
3,66,1.466667,0.000005,0.0069,X,0.567319,0.105323,X,X,0.003976,X,68.390154,X
4,88,1.955556,0.000005,-0.002,0.01987,X,X,0.583637,0.146458,X,0.021817,X,68.633336
5,110,2.444444,0.000005,0.026,X,0.555968,0.112889,X,X,0.072703,X,70.102275,X
6,132,2.933333,0.000004,0.0009,0.030782,X,X,0.589604,0.120329,X,0.012842,X,68.041668
7,154,3.422222,0.000004,0.0026,X,0.565626,0.10308,X,X,0.033033,X,68.373108,X
8,176,3.911111,0.000004,0.033,-0.001534,X,X,0.592694,0.112277,X,0.030801,X,69.216669
9,198,4.4,0.000003,0.0191,X,0.574258,0.104621,X,X,0.044251,X,70.549245,X


GRPO model = {
  "contract_score": 0.5926938831806183,
  "epoch": 3.911111111111111,
  "eval_loss": -0.0015342840924859047,
  "is_loaded": true,
  "serving_id": 0,
  "step": 176
}


In [None]:
client.model.rl.grpo.load()

# Test Out & Evaluate Your GRPO RL Model

In [27]:
# @title Generate TLDRs using the fine tuned model for evaluation

GRPO_JOB_ID = "rl_grpo_jobs:09ac227c912792130a876568ea872593308c0d4b3d7c896ec7991f041cbeedd8:5f74c3aa-e82c-40b3-a86e-8a11e8717876"

# Generate the blogs using GRPO llama 3B model
client.model.rl.grpo.load(GRPO_JOB_ID)
loop = asyncio.get_running_loop()
new_generated_tldrs = await loop.create_task(
    generate_tldrs(
        reddit_posts,
        model_id="fireworks_ai/0",
        api_base=f"https://api.withpi.ai/v1/model/rl/grpo/{GRPO_JOB_ID}",
        api_key=os.environ["WITHPI_API_KEY"],
        system_prompt=system_prompt_for_tldr
    )
)


Generated a tldr for post #1: SUBREDDIT: r/relationships

TITLE: I (f/
Generated a tldr for post #2: SUBREDDIT: r/loseit

TITLE: SV & NSV! Ke
Generated a tldr for post #3: SUBREDDIT: r/relationships

TITLE: Me [1
Generated a tldr for post #5: SUBREDDIT: r/relationships

TITLE: My[25
Generated a tldr for post #4: SUBREDDIT: r/personalfinance

TITLE: Pri
Generated a tldr for post #6: SUBREDDIT: r/relationships

TITLE: Me 28
Generated a tldr for post #7: SUBREDDIT: r/relationships

TITLE: Is it
Generated a tldr for post #11: SUBREDDIT: r/offmychest

TITLE: I'm just
Generated a tldr for post #8: SUBREDDIT: r/relationships

TITLE: I (27
Generated a tldr for post #9: SUBREDDIT: r/relationships

TITLE: Advic
Generated a tldr for post #10: SUBREDDIT: r/relationships

TITLE: Me [2
Generated a tldr for post #12: SUBREDDIT: r/relationship_advice

TITLE:
Generated a tldr for post #13: SUBREDDIT: r/relationships

TITLE: Me [ 
Generated a tldr for post #14: SUBREDDIT: r/relationships

TITLE: Me [2
G

In [34]:
# @title Compare the newly generated blogs against previous ones using the Pi scoring system
from tqdm import tqdm
import pandas as pd

scores = []
generated_scores = []
new_generated_scores = []
for reddit_post, tldr, new_tldr in tqdm(zip(reddit_posts, generated_tldrs, new_generated_tldrs)):
  generated_score = client.contracts.score(
      llm_input=reddit_post,
      llm_output=tldr,
      contract=tldr_scoring_system)
  new_generated_score = client.contracts.score(
      llm_input=reddit_post,
      llm_output=new_tldr,
      contract=tldr_scoring_system)
  generated_scores.append(generated_score)
  new_generated_scores.append(new_generated_score)
  score = {'reddit post': reddit_post, 'generated': generated_score.total_score, 'new generated': new_generated_score.total_score}
  scores.append(score)

df = pd.DataFrame(scores)
display(df)
print("Mean generated scores: {}".format(df["generated"].mean()))
print("Mean new generated scores: {}".format(df["new generated"].mean()))

100it [00:56,  1.78it/s]


Unnamed: 0,reddit post,generated,new generated
0,SUBREDDIT: r/relationships\n\nTITLE: I (f/22) ...,0.614529,0.955318
1,SUBREDDIT: r/loseit\n\nTITLE: SV & NSV! Keepin...,0.348343,0.496944
2,SUBREDDIT: r/relationships\n\nTITLE: Me [19F] ...,0.622480,0.659936
3,SUBREDDIT: r/personalfinance\n\nTITLE: Priorit...,0.638810,0.620955
4,SUBREDDIT: r/relationships\n\nTITLE: My[25m] g...,0.631817,0.654435
...,...,...,...
95,SUBREDDIT: r/relationships\n\nTITLE: My [30 F]...,0.529851,0.923828
96,SUBREDDIT: r/relationships\n\nTITLE: Me[19M] p...,0.663710,0.677419
97,SUBREDDIT: r/relationships\n\nTITLE: Am I bein...,0.664441,0.627552
98,SUBREDDIT: r/relationships\n\nTITLE: My boyfri...,0.673803,0.666797


Mean generated scores: 0.6048754668456894
Mean new generated scores: 0.6729859973538308


In [30]:
# @title Manually inspect new generated blogs against previous ones with scores
def pretty_print_blog(i):
  pretty_print_responses(
      response1 = generated_tldrs[i],
      response2 = new_generated_tldrs[i],
      header="##### " + reddit_posts[i],
      left_label="Base (generated)",
      right_label="Test (new generated)",
      scores_left=generated_scores[i],
      scores_right=new_generated_scores[i])

for i in range(10):
  pretty_print_blog(i)
  print("\n\n")

0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.881
,Conciseness,0.777
,No Redundancy,0.996
,No Repetition,0.945
,No Incomplete Sentences,0.805
Content Accuracy,,0.909
,Important Points,0.746
,Relevance,0.801

0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.951
,Conciseness,0.832
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,0.973
Content Accuracy,,0.87
,Important Points,0.578
,Relevance,0.773







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.474
,Conciseness,0.531
,No Redundancy,0.473
,No Repetition,0.443
,No Incomplete Sentences,0.447
Content Accuracy,,0.468
,Important Points,0.41
,Relevance,0.408

0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.667
,Conciseness,0.742
,No Redundancy,0.777
,No Repetition,0.617
,No Incomplete Sentences,0.531
Content Accuracy,,0.77
,Important Points,0.734
,Relevance,0.887







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.879
,Conciseness,0.926
,No Redundancy,0.984
,No Repetition,0.934
,No Incomplete Sentences,0.672
Content Accuracy,,0.882
,Important Points,0.711
,Relevance,0.699

0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.987
,Conciseness,1.0
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,0.949
Content Accuracy,,0.939
,Important Points,0.781
,Relevance,0.914







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.936
,Conciseness,0.914
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,0.828
Content Accuracy,,0.963
,Important Points,0.816
,Relevance,0.996

0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.846
,Conciseness,0.77
,No Redundancy,1.0
,No Repetition,0.809
,No Incomplete Sentences,0.805
Content Accuracy,,0.906
,Important Points,0.766
,Relevance,0.77







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.902
,Conciseness,0.82
,No Redundancy,1.0
,No Repetition,0.984
,No Incomplete Sentences,0.805
Content Accuracy,,0.956
,Important Points,0.781
,Relevance,1.0

0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.977
,Conciseness,0.922
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,0.984
Content Accuracy,,0.955
,Important Points,0.777
,Relevance,1.0







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,1.0
,Conciseness,1.0
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,1.0
Content Accuracy,,0.957
,Important Points,0.797
,Relevance,0.988

0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,1.0
,Conciseness,1.0
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,1.0
Content Accuracy,,0.921
,Important Points,0.77
,Relevance,0.836







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.526
,Conciseness,0.719
,No Redundancy,0.727
,No Repetition,0.262
,No Incomplete Sentences,0.398
Content Accuracy,,0.636
,Important Points,0.508
,Relevance,0.711

0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.614
,Conciseness,0.707
,No Redundancy,0.855
,No Repetition,0.451
,No Incomplete Sentences,0.443
Content Accuracy,,0.929
,Important Points,0.738
,Relevance,1.0







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.727
,Conciseness,0.742
,No Redundancy,0.762
,No Repetition,0.691
,No Incomplete Sentences,0.711
Content Accuracy,,0.805
,Important Points,0.734
,Relevance,0.738

0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.941
,Conciseness,0.961
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,0.805
Content Accuracy,,0.955
,Important Points,0.773
,Relevance,1.0







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.84
,Conciseness,0.781
,No Redundancy,0.91
,No Repetition,0.863
,No Incomplete Sentences,0.805
Content Accuracy,,0.919
,Important Points,0.777
,Relevance,0.816

0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.862
,Conciseness,0.871
,No Redundancy,0.875
,No Repetition,0.875
,No Incomplete Sentences,0.828
Content Accuracy,,0.916
,Important Points,0.785
,Relevance,0.816







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.812
,Conciseness,0.773
,No Redundancy,0.859
,No Repetition,0.812
,No Incomplete Sentences,0.805
Content Accuracy,,0.909
,Important Points,0.805
,Relevance,1.0

0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.997
,Conciseness,0.988
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,1.0
Content Accuracy,,1.0
,Important Points,1.0
,Relevance,1.0





