<a href="https://colab.research.google.com/github/withpi/cookbook-withpi/blob/main/colabs/GRPO_RL_with_Pi_Scorer_to_build_a_better_Summarizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://withpi.ai"><img src="https://play.withpi.ai/logo/logoFullBlack.svg" width="240"></a>

<a href="https://code.withpi.ai"><font size="4">Documentation</font></a>

<a href="https://play.withpi.ai"><font size="4">Technique Catalog</font></a>

# Introduction to GRPO with Pi-Scorer

Group Relative Policy Optimization (GRPO) represents a significant advancement in the field of Reinforcement Learning. By incorporating our innovative Pi-Scorer reward function, we've created a more robust method for optimizing agent behavior across complex environments. This approach represents a promising direction for developing more capable and aligned AI systems, with applications.


In this colab we will demonstrate how to use the Pi Scoring system to steer GRPO towards tuning a great Reddit post Summarizer ([TL;DR](https://www.reddit.com/r/help/comments/gzlnku/i_am_fairly_new_and_have_a_question_what_does/) generator)

# Install packages and utility functions
Here we are installing the Pi SDK, and we're also importing a few additonal things to help out this use case including a dataset utility as well as functions to help us more legibly print scores and Side by Side comparisons

In [18]:
# @title Install necessary packages
%%capture
%pip install withpi withpi-utils
%pip install datasets
%pip install litellm
%pip install httpx jinja2 tqdm

In [19]:
# @title Intitialize PiClient
import os
from google.colab import files, userdata
from withpi import PiClient

# Load the notebook secret into the environment so the Pi Client can access it.
os.environ["WITHPI_API_KEY"] = userdata.get('WITHPI_API_KEY')


client = PiClient()

# Define your scoring system
Here, we capture the criteria that we believe define a great summary. Including whether these summaries are 15-35 words long, contain the important points of the original post, and are self-contained enough to make sense on their own without requiring the reader to view the original content


In [20]:
# @title Initialize the Pi scoring system from a JSON description

from withpi.types import Scorer
from withpi_utils.colab import display_scorer

tldr_scorer_json = """
{
  "name": "Default",
  "description": "Generate a short TLDR of a subreddit post without any surrounding text. Here are some requirement of the TLDR: 1. Make sure that the TLDR is 1 to 3 sentence long and no more than 35 words. 2. Make sure that the TLDR state the important points of the post 3. Make sure that the TLDR should make sense on its own.",
  "dimensions": [
    {
      "label": "Length",
      "description": "Length",
      "sub_dimensions": [
        {
          "label": "Length Compliance",
          "description": "Is the TLDR between 15 to 35 word long?",
          "scoring_type": "PYTHON_CODE",
          "python_code": "\\nimport re\\n\\ndef score(\\n    response_text: str,\\n    input_text: str,\\n    kwargs: dict,\\n) -> dict:\\n    response_len = len(re.findall(r\'\\\\S+\', response_text))\\n\\n    return {\\n      \\"score\\": 1.0 if response_len < 35 and response_len > 15 else 0.0,\\n      \\"explanation\\": \\"\\"\\n    }\\n"
        }
      ],
      "weight": 1.0
    },
    {
      "label": "Structure",
      "description": "Structure",
      "sub_dimensions": [
        {
          "label": "Conciseness",
          "description": "Is the TLDR concise and to the point?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Redundancy",
          "description": "Does the TLDR avoid redundant information?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Repetition",
          "description": "Does the TLDR avoid repetition of the same point?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Incomplete Sentences",
          "description": "Does the TLDR avoid incomplete sentences?",
          "scoring_type": "PI_SCORER"
        }
      ],
      "weight": 0.3
    },
    {
      "label": "Content Accuracy",
      "description": "Content Accuracy",
      "sub_dimensions": [
        {
          "label": "Important Points",
          "description": "Does the TLDR state the important points of the post?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "Relevance",
          "description": "Is the content of the TLDR relevant to the original post?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "Factually Correct",
          "description": "Is the information in the TLDR factually correct?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Assumptions",
          "description": "Does the TLDR avoid making assumptions not supported by the original post?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Speculation",
          "description": "Does the TLDR avoid speculation?",
          "scoring_type": "PI_SCORER"
        }
      ],
      "weight": 0.3
    },
    {
      "label": "Clarity and Readability",
      "description": "Clarity and Readability",
      "sub_dimensions": [
        {
          "label": "Clarity",
          "description": "Is the TLDR clear and easy to understand?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "Grammar",
          "description": "Is the TLDR grammatically correct?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "Spelling",
          "description": "Is the TLDR free of spelling errors?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "Proper Punctuation",
          "description": "Is the TLDR properly punctuated?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Ambiguity",
          "description": "Is the TLDR free from ambiguity?",
          "scoring_type": "PI_SCORER"
        }
      ],
      "weight": 0.3
    },
    {
      "label": "Objectivity and Neutrality",
      "description": "Objectivity and Neutrality",
      "sub_dimensions": [
        {
          "label": "No Personal Opinions",
          "description": "Does the TLDR avoid including personal opinions?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "Objective Language",
          "description": "Is the language in the TLDR objective?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Bias",
          "description": "Is the TLDR free from bias?",
          "scoring_type": "PI_SCORER"
        }
      ],
      "weight": 0.3
    },
    {
      "label": "Self-Containment",
      "description": "Self-Containment",
      "sub_dimensions": [
        {
          "label": "Self-Contained",
          "description": "Does the TLDR make sense on its own without needing to refer to the original post?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Contradictions",
          "description": "Is the TLDR free from contradictions?",
          "scoring_type": "PI_SCORER"
        }
      ],
      "weight": 0.3
    },
    {
      "label": "Language Use",
      "description": "Language Use",
      "sub_dimensions": [
        {
          "label": "No Jargon",
          "description": "Does the TLDR avoid unnecessary jargon?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Hyperbole",
          "description": "Does the TLDR avoid hyperbole?",
          "scoring_type": "PI_SCORER"
        }
      ],
      "weight": 0.3
    },
    {
      "label": "Relevance and Focus",
      "description": "Relevance and Focus",
      "sub_dimensions": [
        {
          "label": "No Extraneous Information",
          "description": "Does the TLDR avoid including extraneous information not present in the original post?",
          "scoring_type": "PI_SCORER"
        },
        {
          "label": "No Irrelevant Details",
          "description": "Does the TLDR avoid irrelevant details?",
          "scoring_type": "PI_SCORER"
        }
      ],
      "weight": 0.3
    }
  ]
}
"""
tldr_scorer = Scorer.model_validate_json(tldr_scorer_json)

display_scorer(tldr_scorer)


# Generate summaries (TL;DRs) using a base model
Before we get into trying anything fancy, let's see how well a base model can work with a simple system prompt to generate summaries. We'll:

1. Write a system prompt for summarization
2. Write a function to run that system prompt to generate summaries for a set of Reddit posts
3. Use our Pi Scoring system to evaluate the quality of the generated summaries

In [None]:
# @title Define a system prompt for TLDR
system_prompt_for_tldr = """
Generate a short TLDR of a subreddit post. Here are the requirements of the TLDR:
1. Make sure that the TLDR is 1 to 3 sentence long and and no more than 35 words.
2. Make sure that the TLDR state the important points of the post.
3. Make sure that the TLDR should make sense on its own.
4. Make sure that the TLDR do not have any surrounding text and TLDR: prefix or quote.
"""

In [None]:
# @title Define a TLDR generator
import litellm
import asyncio

async def generate_tldrs(reddit_posts, system_prompt, model_id, api_base, api_key, concurrency_limit=5):
    """Generate TLDR for all REDDIT posts with TaskGroup and rate limiting"""
    # Create a semaphore to limit concurrent API calls
    semaphore = asyncio.Semaphore(concurrency_limit)

    async def generate_single_tldr(reddit_post, index):
        """Process a single REDDIT post generation with rate limiting"""
        async with semaphore:
            try:
                response = await litellm.acompletion(
                    messages=[
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": reddit_post},
                    ],
                    model=model_id,
                    api_base=api_base,
                    api_key=api_key,
                    temperature=0.2,
                )
                generated_tldr = response.choices[0].message.content
                print("Generated a tldr for post #{}: {}".format(index, reddit_post[:40]))
                return generated_tldr
            except Exception as e:
                print(f"Error generating tldr for post #{index}: {e}")
                return f"Error: {str(e)}"

    generated_tldrs = []

    # Using TaskGroup for cleaner task management
    async with asyncio.TaskGroup() as tg:
        tasks = [
            tg.create_task(generate_single_tldr(reddit_post, i + 1))
            for i, reddit_post in enumerate(reddit_posts)
        ]

    # Collect results in the same order as topics
    for task in tasks:
        generated_tldrs.append(task.result())

    print("Done generating TLDRs!!")
    return generated_tldrs

In [None]:
# @title Prepare the REDDIT posts
from datasets import load_dataset
from google.colab import userdata

ds = load_dataset("withpi/tldr", split="train").select(range(200))

reddit_posts = ds["prompt"]

README.md:   0%|          | 0.00/319 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/47.4M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [None]:
# @title Generate TLDRs using an untrained model for evaluation

# Generate the blogs using an untrained llama 8B
loop = asyncio.get_running_loop()
generated_tldrs = await loop.create_task(
    generate_tldrs(
        reddit_posts,
        model_id="fireworks_ai/llama-v3p2-3b-instruct",
        api_key=userdata.get("FIREWORKS_API_KEY"),
        api_base = None,
        system_prompt=system_prompt_for_tldr
    )
)

Generated a tldr for post #5: SUBREDDIT: r/relationships

TITLE: My[25
Generated a tldr for post #3: SUBREDDIT: r/relationships

TITLE: Me [1
Generated a tldr for post #1: SUBREDDIT: r/relationships

TITLE: I (f/
Generated a tldr for post #2: SUBREDDIT: r/loseit

TITLE: SV & NSV! Ke
Generated a tldr for post #4: SUBREDDIT: r/personalfinance

TITLE: Pri
Generated a tldr for post #6: SUBREDDIT: r/relationships

TITLE: Me 28
Generated a tldr for post #7: SUBREDDIT: r/relationships

TITLE: Is it
Generated a tldr for post #9: SUBREDDIT: r/relationships

TITLE: Advic
Generated a tldr for post #8: SUBREDDIT: r/relationships

TITLE: I (27
Generated a tldr for post #10: SUBREDDIT: r/relationships

TITLE: Me [2
Generated a tldr for post #11: SUBREDDIT: r/offmychest

TITLE: I'm just
Generated a tldr for post #12: SUBREDDIT: r/relationship_advice

TITLE:
Generated a tldr for post #15: SUBREDDIT: r/dating_advice

TITLE: I thi
Generated a tldr for post #13: SUBREDDIT: r/relationships

TITLE: Me [ 
G

In [None]:
# @title Use the Pi scoring system to evaluate the TLDRs' quality
from tqdm import tqdm
import pandas as pd

df_data = []
scores = []
for reddit_post, tldr in tqdm(zip(reddit_posts, generated_tldrs)):
  score = client.scoring_system.score(
      llm_input=reddit_post,
      llm_output=tldr,
      scorer=tldr_scorer)
  scores.append(score)
  df_data.append({'reddit post': reddit_post, 'tldr': tldr, 'pi-score': score.total_score})

df = pd.DataFrame(df_data)
display(df)
print("Mean pi-score: {}".format(df["pi-score"].mean()))

200it [01:37,  2.06it/s]


Unnamed: 0,reddit post,tldr,pi-score
0,SUBREDDIT: r/relationships\n\nTITLE: I (f/22) ...,TL;DR: I'm considering cutting contact with my...,0.686561
1,SUBREDDIT: r/loseit\n\nTITLE: SV & NSV! Keepin...,"I weighed myself and measured myself, and I'm ...",0.893800
2,SUBREDDIT: r/relationships\n\nTITLE: Me [19F] ...,TL;DR: Woman (19F) seeks advice on how to disc...,0.452692
3,SUBREDDIT: r/personalfinance\n\nTITLE: Priorit...,I have $25k in student debt with a 9.5% loan a...,0.875428
4,SUBREDDIT: r/relationships\n\nTITLE: My[25m] g...,TL;DR: Girlfriend becomes pleasant and attenti...,0.590579
...,...,...,...
195,SUBREDDIT: r/relationships\n\nTITLE: Earlier t...,"Here is a possible TLDR:\n\n""29M breaks up wit...",0.554732
196,SUBREDDIT: r/relationships\n\nTITLE: Me [25F] ...,TL;DR: Boyfriend of 2.5 years confesses to che...,0.963099
197,SUBREDDIT: r/relationships\n\nTITLE: She[23f] ...,TL;DR: Friend is having 2nd ACL surgery in 2 w...,0.953459
198,SUBREDDIT: r/pettyrevenge\n\nTITLE: Tailgate m...,"Woman tailgates me on highway, refuses to pass...",0.999584


Mean pi-score: 0.7808680722742313


In [None]:
# @title Manually inspect TLDR with scores
from withpi_utils.colab import pretty_print_responses

def pretty_print_blog(i):
  pretty_print_responses(
      response1 = generated_tldrs[i],
      header="##### " + reddit_posts[i],
      scores_left=scores[i])

for i in range(10):
  pretty_print_blog(i)
  print("\n\n")

0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.422
,Conciseness,0.43
,No Redundancy,0.586
,No Repetition,0.279
,No Incomplete Sentences,0.393
Content Accuracy,,0.508
,Important Points,0.245
,Relevance,0.271







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.846
,Conciseness,0.805
,No Redundancy,1.0
,No Repetition,0.789
,No Incomplete Sentences,0.789
Content Accuracy,,0.745
,Important Points,0.652
,Relevance,0.68







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.171
,Conciseness,0.218
,No Redundancy,0.142
,No Repetition,0.171
,No Incomplete Sentences,0.152
Content Accuracy,,0.2
,Important Points,0.213
,Relevance,0.185







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.658
,Conciseness,0.695
,No Redundancy,0.777
,No Repetition,0.598
,No Incomplete Sentences,0.562
Content Accuracy,,0.78
,Important Points,0.492
,Relevance,0.613







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.32
,Conciseness,0.245
,No Redundancy,0.344
,No Repetition,0.346
,No Incomplete Sentences,0.344
Content Accuracy,,0.477
,Important Points,0.239
,Relevance,0.4







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.686
,Conciseness,0.734
,No Redundancy,0.887
,No Repetition,0.555
,No Incomplete Sentences,0.566
Content Accuracy,,0.656
,Important Points,0.42
,Relevance,0.609







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.532
,Conciseness,0.613
,No Redundancy,0.664
,No Repetition,0.418
,No Incomplete Sentences,0.432
Content Accuracy,,0.727
,Important Points,0.484
,Relevance,0.996







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.867
,Conciseness,0.793
,No Redundancy,0.914
,No Repetition,0.922
,No Incomplete Sentences,0.84
Content Accuracy,,0.912
,Important Points,0.766
,Relevance,0.875







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.654
,Conciseness,0.773
,No Redundancy,0.797
,No Repetition,0.455
,No Incomplete Sentences,0.59
Content Accuracy,,0.841
,Important Points,0.766
,Relevance,0.863







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.337
,Conciseness,0.307
,No Redundancy,0.426
,No Repetition,0.246
,No Incomplete Sentences,0.371
Content Accuracy,,0.512
,Important Points,0.371
,Relevance,0.342







# Use GRPO to fine tune the model to generate better TLDR
We can see that the prompted summarizer model has some clear room for improvement. Let's see if we can beat this performance with Reinforcement Learning. Lets:

1. Prepare the example set. Remember, different from Supervised Fine Tuning, reinforcement learning does NOT require example outputs. It just needs the input, as it's going to play with generating and scoring different outputs in order to figure out what works best

2. We will run and monitor the progress of Reinforcement Learning, paying special attention to the "Eval_Pi_Reward" which signifies how well the model is performing according to the Scoring System we defined


In [None]:
# @title Prepare training data without golden TLDR

examples = [{"llm_input": post} for post in reddit_posts]

In [None]:
# @title [SLOW - will run for 80+ minutes] Run GRPO on the model based on the above training data
status = client.training.grpo.start_job(
    scorer=tldr_scorer,
    examples=examples,
    base_rl_model="LLAMA_3.2_3B",
    system_prompt=system_prompt_for_tldr,
    lora_config={"lora_rank": "R_64"},
    learning_rate=5e-6,
    num_train_epochs=10,
)

print(status)

RlGrpoStatus(detailed_status=['LAUNCHING'], job_id='rl_grpo_jobs:09ac227c912792130a876568ea872593308c0d4b3d7c896ec7991f041cbeedd8:d96e6d51-da4e-415e-9b36-8919b9db1ce6', state='QUEUED', trained_models=[])


In [None]:
# @title Monitor the GRPO job for completion (watch the Eval_Pi_Score increase!)
from withpi_utils.colab import stream_training_response

GRPO_JOB_ID="rl_grpo_jobs:09ac227c912792130a876568ea872593308c0d4b3d7c896ec7991f041cbeedd8:d96e6d51-da4e-415e-9b36-8919b9db1ce6"

response = stream_training_response(
    GRPO_JOB_ID,
    client.training.grpo,
    additional_columns={
        "Train_Pi_Reward": "rewards/pi_reward_func",
        "Train_Std_Reward": "reward_std",
        "Eval_Pi_Reward": "eval_rewards/pi_reward_func",
        "Eval_Std_Reward": "eval_reward_std",
        "Train_KL": "kl",
        "Eval_KL": "eval_kl",
        "Train_Completion_Length": "completion_length",
        "Eval_Completion_Length": "eval_completion_length",
    },
)
if response.state == "ERROR":
  print("The job failed due to:\n{}".format('\n'.join(response.detailed_status[-5:])))
else:
  print("GRPO model = {}".format(response.trained_models[0].model_dump_json(indent=2)))


Training Status for rl_grpo_jobs:09ac227c912792130a876568ea872593308c0d4b3d7c896ec7991f041cbeedd8:d96e6d51-da4e-415e-9b36-8919b9db1ce6


Unnamed: 0,Step,Epoch,Learning_Rate,Training_Loss,Eval_Loss,Train_Pi_Reward,Train_Std_Reward,Eval_Pi_Reward,Eval_Std_Reward,Train_KL,Eval_KL,Train_Completion_Length,Eval_Completion_Length
0,0,0.0,X,X,0.044004,X,X,0.623223,0.211226,0.0,X,X,46.191668
1,45,0.5,0.000005,0.0466,X,0.644,0.171707,X,X,0.002429,X,46.444446,X
2,90,1.0,0.000005,0.0575,0.057458,X,X,0.637817,0.177759,X,0.118085,X,44.983334
3,135,1.5,0.000005,0.0653,X,0.653613,0.179628,X,X,0.169914,X,45.395372,X
4,180,2.0,0.000005,0.0594,0.064911,X,X,0.634552,0.211285,X,0.273814,X,44.975001
5,225,2.5,0.000004,0.0511,X,0.655746,0.188338,X,X,0.184623,X,45.026853,X
6,270,3.0,0.000004,0.0593,0.041731,X,X,0.640519,0.175897,X,0.285641,X,46.225001
7,315,3.5,0.000004,0.0603,X,0.689227,0.192666,X,X,0.266206,X,44.412964,X
8,360,4.0,0.000004,0.056,0.071562,X,X,0.660285,0.188727,X,0.401534,X,45.100002
9,405,4.5,0.000003,0.0628,X,0.678925,0.182098,X,X,0.36702,X,44.532409,X


GRPO model = {
  "contract_score": 0.6787437707185745,
  "epoch": 8.0,
  "eval_loss": 0.06985440850257874,
  "serving_id": 0,
  "serving_state": "UNLOADED",
  "step": 720
}


# Test Out & Evaluate Your GRPO RL Model
Don't take the above metrics for granted! Let's play with our newly tuned model to assess its performance for ourselves. To do so we will:

1. Prepare our evaluation set, for which we will generate summaries with our original prompted model as well as the GRPO tuned model for side by side comparison

2. We'll use our scoring system to compare the two summarizers' performance quantitatively

3. We'll print out example outputs from each summarizer and review them, and their Pi Scores, against each other

In [None]:
# @title Prepare the evaluation REDDIT posts
from datasets import load_dataset
from google.colab import userdata
import asyncio

ds = load_dataset("withpi/tldr", split="train").select(range(1000, 1100))

reddit_posts = ds["prompt"]

loop = asyncio.get_running_loop()
generated_tldrs = await loop.create_task(
    generate_tldrs(
        reddit_posts,
        model_id="fireworks_ai/llama-v3p2-3b-instruct",
        api_key=userdata.get("FIREWORKS_API_KEY"),
        api_base = None,
        system_prompt=system_prompt_for_tldr
    )
)

Generated a tldr for post #4: SUBREDDIT: r/AskReddit

TITLE: Redditors
Generated a tldr for post #3: SUBREDDIT: r/offmychest

TITLE: Very rec
Generated a tldr for post #1: SUBREDDIT: r/BreakUps

TITLE: Ex Girlfri
Generated a tldr for post #2: SUBREDDIT: r/relationships

TITLE: Me [2
Generated a tldr for post #5: SUBREDDIT: r/relationship_advice

TITLE:
Generated a tldr for post #7: SUBREDDIT: r/relationships

TITLE: I [19
Generated a tldr for post #6: SUBREDDIT: r/tifu

TITLE: TIFU by watchi
Generated a tldr for post #8: SUBREDDIT: r/relationships

TITLE: Me [2
Generated a tldr for post #9: SUBREDDIT: r/dating_advice

TITLE: Male,
Generated a tldr for post #10: SUBREDDIT: r/relationships

TITLE: My (1
Generated a tldr for post #12: SUBREDDIT: r/legaladvice

TITLE: Father 
Generated a tldr for post #13: SUBREDDIT: r/running

TITLE: Night runni
Generated a tldr for post #15: SUBREDDIT: r/relationships

TITLE: Me [2
Generated a tldr for post #11: SUBREDDIT: r/tifu

TITLE: TIFU by not op
G

In [None]:
# @title Generate TLDRs using the fine tuned model for evaluation
import time

# Generate the blogs using GRPO llama 3B model
client.training.grpo.load(GRPO_JOB_ID)

# Wait for the model to be loaded
while not (client.training.grpo.retrieve(GRPO_JOB_ID).trained_models[0].serving_state == "SERVING"):
    time.sleep(3)


loop = asyncio.get_running_loop()
new_generated_tldrs = await loop.create_task(
    generate_tldrs(
        reddit_posts,
        model_id="fireworks_ai/0",
        api_base=f"https://api.withpi.ai/v1/training/grpo/{GRPO_JOB_ID}",
        api_key=os.environ["WITHPI_API_KEY"],
        system_prompt=system_prompt_for_tldr
    )
)


Generated a tldr for post #5: SUBREDDIT: r/relationship_advice

TITLE:
Generated a tldr for post #4: SUBREDDIT: r/AskReddit

TITLE: Redditors
Generated a tldr for post #2: SUBREDDIT: r/relationships

TITLE: Me [2
Generated a tldr for post #3: SUBREDDIT: r/offmychest

TITLE: Very rec
Generated a tldr for post #1: SUBREDDIT: r/BreakUps

TITLE: Ex Girlfri
Generated a tldr for post #6: SUBREDDIT: r/tifu

TITLE: TIFU by watchi
Generated a tldr for post #7: SUBREDDIT: r/relationships

TITLE: I [19
Generated a tldr for post #8: SUBREDDIT: r/relationships

TITLE: Me [2
Generated a tldr for post #10: SUBREDDIT: r/relationships

TITLE: My (1
Generated a tldr for post #9: SUBREDDIT: r/dating_advice

TITLE: Male,
Generated a tldr for post #12: SUBREDDIT: r/legaladvice

TITLE: Father 
Generated a tldr for post #11: SUBREDDIT: r/tifu

TITLE: TIFU by not op
Generated a tldr for post #14: SUBREDDIT: r/AskReddit

TITLE: Wouldn't 
Generated a tldr for post #13: SUBREDDIT: r/running

TITLE: Night runni
G

In [None]:
# @title Compare the GRPO fine-tuned TLDRs against previous ones using the Pi scoring system
from tqdm import tqdm
import pandas as pd

scores = []
generated_scores = []
new_generated_scores = []
for reddit_post, tldr, new_tldr in tqdm(zip(reddit_posts, generated_tldrs, new_generated_tldrs)):
  generated_score = client.scoring_system.score(
      llm_input=reddit_post,
      llm_output=tldr,
      scorer=tldr_scorer)
  new_generated_score = client.scoring_system.score(
      llm_input=reddit_post,
      llm_output=new_tldr,
      scorer=tldr_scorer)
  generated_scores.append(generated_score)
  new_generated_scores.append(new_generated_score)
  score = {'reddit post': reddit_post, 'generated': generated_score.total_score, 'new generated': new_generated_score.total_score}
  scores.append(score)

df = pd.DataFrame(scores)
display(df)
print("Mean generated scores: {}".format(df["generated"].mean()))
print("Mean new generated scores: {}".format(df["new generated"].mean()))

100it [01:22,  1.21it/s]


Unnamed: 0,reddit post,generated,new generated
0,SUBREDDIT: r/BreakUps\n\nTITLE: Ex Girlfriend ...,0.652911,0.999149
1,SUBREDDIT: r/relationships\n\nTITLE: Me [24F] ...,0.968643,0.982214
2,SUBREDDIT: r/offmychest\n\nTITLE: Very recentl...,0.338344,0.944437
3,"SUBREDDIT: r/AskReddit\n\nTITLE: Redditors, I'...",0.501142,0.685881
4,SUBREDDIT: r/relationship_advice\n\nTITLE: Sho...,0.906329,0.986448
...,...,...,...
95,SUBREDDIT: r/relationships\n\nTITLE: Me [24F] ...,0.788489,0.921856
96,SUBREDDIT: r/tifu\n\nTITLE: TIFU by going to t...,0.920609,0.915877
97,SUBREDDIT: r/relationships\n\nTITLE: Me (23f) ...,0.676934,0.862481
98,SUBREDDIT: r/relationships\n\nTITLE: My [20 F]...,0.984785,1.000000


Mean generated scores: 0.7418496625346525
Mean new generated scores: 0.9270921591481855


In [None]:
# @title Manually inspect new generated blogs against previous ones with scores
from withpi_utils.colab import pretty_print_responses

def pretty_print_blog(i):
  pretty_print_responses(
      response1 = generated_tldrs[i],
      response2 = new_generated_tldrs[i],
      header="##### " + reddit_posts[i],
      left_label="Base (generated)",
      right_label="Test (new generated)",
      scores_left=generated_scores[i],
      scores_right=new_generated_scores[i])

for i in range(10):
  pretty_print_blog(i)
  print("\n\n")

0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.977
,Conciseness,0.91
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,0.996
Content Accuracy,,0.952
,Important Points,0.785
,Relevance,0.977

0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.999
,Conciseness,0.996
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,1.0
Content Accuracy,,1.0
,Important Points,1.0
,Relevance,1.0







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.884
,Conciseness,0.77
,No Redundancy,1.0
,No Repetition,0.82
,No Incomplete Sentences,0.945
Content Accuracy,,0.938
,Important Points,0.781
,Relevance,1.0

0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.96
,Conciseness,0.84
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,1.0
Content Accuracy,,0.958
,Important Points,0.789
,Relevance,1.0







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.458
,Conciseness,0.527
,No Redundancy,0.512
,No Repetition,0.406
,No Incomplete Sentences,0.389
Content Accuracy,,0.539
,Important Points,0.377
,Relevance,0.377

0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.829
,Conciseness,0.758
,No Redundancy,0.859
,No Repetition,0.809
,No Incomplete Sentences,0.891
Content Accuracy,,0.899
,Important Points,0.762
,Relevance,0.82







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.228
,Conciseness,0.224
,No Redundancy,0.23
,No Repetition,0.23
,No Incomplete Sentences,0.227
Content Accuracy,,0.256
,Important Points,0.226
,Relevance,0.227

0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.424
,Conciseness,0.535
,No Redundancy,0.5
,No Repetition,0.379
,No Incomplete Sentences,0.281
Content Accuracy,,0.538
,Important Points,0.422
,Relevance,0.471







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.775
,Conciseness,0.766
,No Redundancy,0.93
,No Repetition,0.467
,No Incomplete Sentences,0.938
Content Accuracy,,0.85
,Important Points,0.75
,Relevance,0.719

0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.995
,Conciseness,0.988
,No Redundancy,1.0
,No Repetition,0.992
,No Incomplete Sentences,1.0
Content Accuracy,,0.963
,Important Points,0.816
,Relevance,1.0







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.867
,Conciseness,0.91
,No Redundancy,0.953
,No Repetition,0.82
,No Incomplete Sentences,0.785
Content Accuracy,,0.951
,Important Points,0.758
,Relevance,0.996

0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.911
,Conciseness,0.906
,No Redundancy,1.0
,No Repetition,0.926
,No Incomplete Sentences,0.812
Content Accuracy,,0.951
,Important Points,0.773
,Relevance,0.98







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.275
,Conciseness,0.254
,No Redundancy,0.373
,No Repetition,0.224
,No Incomplete Sentences,0.25
Content Accuracy,,0.35
,Important Points,0.239
,Relevance,0.243

0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.934
,Conciseness,0.91
,No Redundancy,1.0
,No Repetition,0.824
,No Incomplete Sentences,1.0
Content Accuracy,,0.952
,Important Points,0.762
,Relevance,1.0







0,1,2
Length,,0.0
,Length Compliance,0.0
Structure,,0.844
,Conciseness,0.832
,No Redundancy,0.832
,No Repetition,0.91
,No Incomplete Sentences,0.801
Content Accuracy,,0.93
,Important Points,0.77
,Relevance,0.891

0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.87
,Conciseness,0.793
,No Redundancy,1.0
,No Repetition,0.867
,No Incomplete Sentences,0.82
Content Accuracy,,0.952
,Important Points,0.762
,Relevance,1.0







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.817
,Conciseness,0.762
,No Redundancy,0.977
,No Repetition,0.785
,No Incomplete Sentences,0.746
Content Accuracy,,0.952
,Important Points,0.762
,Relevance,1.0

0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,1.0
,Conciseness,1.0
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,1.0
Content Accuracy,,1.0
,Important Points,1.0
,Relevance,1.0







0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.992
,Conciseness,0.969
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,1.0
Content Accuracy,,0.918
,Important Points,0.789
,Relevance,0.805

0,1,2
Length,,1.0
,Length Compliance,1.0
Structure,,0.991
,Conciseness,1.0
,No Redundancy,1.0
,No Repetition,1.0
,No Incomplete Sentences,0.965
Content Accuracy,,0.912
,Important Points,0.785
,Relevance,0.777





