In [None]:
import utils
import os
import importlib
import random
from pprint import pprint
import regex as re


importlib.reload(utils)

[2023-09-12 14:14:16,537] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


<module 'utils' from '/data/tongyx361/reward-by-prm800k/src/utils.py'>

In [None]:
prm800k_all_train_samples = utils.get_prm800k_all_train_samples()

len(prm800k_all_train_samples) = 87012
example: {'labeler': '2d45c8c4-b303-4c12-9fc7-12dcb4f1d028', 'timestamp': '2022-11-06T03:24:46.649379', 'generation': 5, 'is_quality_control_question': False, 'is_initial_screening_question': False, 'question': {'problem': 'Find the distance between the planes $x - 3y + 3z = 8$ and $2x - 6y + 6z = 2.$', 'ground_truth_solution': 'We can write the equation of the second plane as $x - 3y + 3z = 1.$  Note that $(1,0,0)$ is a point on this plane.  (Also, note that both plane have the same normal vector, so they are parallel.)\n\nTherefore, from the formula for the distance between a point and a plane, the distance between the two planes is\n\\[\\frac{|1 - 3 \\cdot 0 + 3 \\cdot 0 - 8|}{\\sqrt{1^2 + (-3)^2 + 3^2}} = \\boxed{\\frac{7 \\sqrt{19}}{19}}.\\]', 'ground_truth_answer': '\\frac{7 \\sqrt{19}}{19}', 'pre_generated_steps': ['One way to find the distance between two parallel planes is to find the length of a perpendicular segment that joins them.', '

In [None]:
def prm800k_sample2prompt(
    example, idx=None, compact=False, sample_heading_level=1, step_delimiter='"""'
):
    example_heading_prefix, component_heading_prefix = sample_heading_level2prefix(
        sample_heading_level
    )

    problem_prompt = example["question"]["problem"]

    if idx is not None:
        heading = f"{example_heading_prefix} Example {idx+1}"

        reformatted_example = utils.reformat_prm800k_sample(example)

        if not compact:
            steps_prompt = "\n".join(
                [
                    f"Step {step_idx+1}: {step_delimiter}{step_rating['step']}{step_delimiter}"
                    for step_idx, step_rating in enumerate(
                        reformatted_example["step_ratings"]
                    )
                ]
            )

            ratings_prompt = prm800k_get_ratings_prompt_from_reformatted_sample(
                reformatted_example
            )

            step_rating_prompt = f"""{component_heading_prefix} Steps
{steps_prompt}
{component_heading_prefix} Ratings
{ratings_prompt}"""
        else:  # compact
            step_rating_pairs_prompt = "\n".join(
                [
                    f"Step {step_idx+1}: {step_delimiter}{step_rating['step']}{step_delimiter} Rating: {step_rating['rating']}"
                    for step_idx, step_rating in enumerate(
                        reformatted_example["step_ratings"]
                    )
                ]
            )
            step_rating_prompt = f"\n{component_heading_prefix} Step and Rating Pairs"
            step_rating_prompt += f"\n{step_rating_pairs_prompt}"
    else:  # idx is None, i.e. solution to rate
        heading = f"{example_heading_prefix} Solution to Rate"
        # print(problem_prompt)
        # print(example["question"]["pre_generated_steps"])
        steps_prompt = "\n".join(
            [
                f"Step {step_idx+1}: {step_delimiter}{step}{step_delimiter}"
                for step_idx, step in enumerate(
                    example["question"]["pre_generated_steps"]
                )
            ]
        )
        # print(steps_prompt)
        step_rating_prompt = f"""{component_heading_prefix} Steps
{steps_prompt}"""
        if not compact:
            step_rating_prompt += f"\n{component_heading_prefix} Ratings"
        else:  # compact
            step_rating_prompt += f"\n{component_heading_prefix} Step and Rating Pairs"

    example_prompt = f"""{heading}
{component_heading_prefix} Problem
{problem_prompt}
{step_rating_prompt}"""

    return example_prompt


def sample_heading_level2prefix(sample_heading_level):
    example_heading_prefix = "#" * sample_heading_level
    component_heading_level = sample_heading_level + 1
    component_heading_prefix = "#" * component_heading_level
    return example_heading_prefix, component_heading_prefix


def prm800k_get_ratings_prompt_from_reformatted_sample(reformatted_sample):
    return "\n".join(
        [
            f"Rating {step_idx+1}: {step_rating['rating']}"
            for step_idx, step_rating in enumerate(reformatted_sample["step_ratings"])
        ]
    )


def sampleprm800k_examples2prompt_list(examples, compact=False, verbose=False):
    example_prompt_list = []
    for idx, example in enumerate(examples):
        # pprint(example)

        example_prompt = prm800k_sample2prompt(example, idx, compact=compact)
        example_prompt_list.append(example_prompt)

    if verbose:
        for example_prompt in example_prompt_list:
            print(example_prompt)

    return example_prompt_list


def prm800k_pick_constrainted_ramdom_sample(prm800k_all_train_samples, constraint):
    random.shuffle(prm800k_all_train_samples)
    for sample in prm800k_all_train_samples:
        if constraint(sample):
            return sample


def prm800k_sample2ratings(sample):
    reformat_sample = utils.reformat_prm800k_sample(sample)
    ratings = [step_rating["rating"] for step_rating in reformat_sample["step_ratings"]]
    return ratings


compact = True

In [None]:
num_examples = 5
examples = random.sample(prm800k_all_train_samples, num_examples)

example_prompt_list = sampleprm800k_examples2prompt_list(
    examples, compact=compact, verbose=True
)

In [None]:
def is_prm800k_sample_ratings_enough(sample, min_num_ratings=10):
    return len(utils.reformat_prm800k_sample(sample)["step_ratings"]) >= min_num_ratings


def is_prm800k_sample_ratings_meaningful(sample, min_meaningful_ratio=0.8):
    finish_reason = sample["label"]["finish_reason"]
    ratings = prm800k_sample2ratings(sample)

    num_meaningful_ratings = sum([rating != 0 for rating in ratings])
    return (
        num_meaningful_ratings / len(ratings) > min_meaningful_ratio
        and finish_reason == "found_error"
    )


def is_prm800k_sample_non_negative_ratings_balanced(sample, non_negative_min_ratio=0.2):
    ratings = prm800k_sample2ratings(sample)
    num_positive_ratings = sum([rating > 0 for rating in ratings])
    num_neutral_ratings = sum([rating == 0 for rating in ratings])
    num_all_ratings = len(ratings)
    return (num_positive_ratings / num_all_ratings > non_negative_min_ratio) and (
        num_neutral_ratings / num_all_ratings > non_negative_min_ratio
    )


def is_prm800k_sample_ratings_short(sample, max_num_ratings=1):
    return len(utils.reformat_prm800k_sample(sample)["step_ratings"]) <= max_num_ratings


def is_prm800k_sample_ratings_as_long_as(sample, num_ratings=3):
    return len(utils.reformat_prm800k_sample(sample)["step_ratings"]) == num_ratings


# constraint = lambda sample: is_prm800k_sample_ratings_enough(
#     sample
# ) and is_prm800k_sample_ratings_meaningful(sample)

# constraint = lambda sample: is_prm800k_sample_ratings_short(sample, max_num_ratings=1)

constraint = lambda sample: is_prm800k_sample_ratings_enough(
    sample
) and is_prm800k_sample_non_negative_ratings_balanced(sample)


random_sample = prm800k_pick_constrainted_ramdom_sample(
    prm800k_all_train_samples, constraint
)

random_sample_prompt = prm800k_sample2prompt(random_sample, compact=compact)
print(random_sample_prompt)
reformatted_random_sample = utils.reformat_prm800k_sample(random_sample)
ratings_prompt = prm800k_get_ratings_prompt_from_reformatted_sample(
    reformatted_random_sample
)

print(sampleprm800k_examples2prompt_list([random_sample], compact=compact)[0])

# Solution to Rate
## Problem
What is the smallest integer $b > 3$ for which the base $b$ number $23_b$ is a perfect square?
## Steps
Step 1: """To solve this problem, I need to find the smallest base $b$ that makes $23_b$ a perfect square."""
Step 2: """I can rewrite $23_b$ as $2b + 3$ in base 10, since the digits in base $b$ represent powers of $b$."""
Step 3: """So I need to find the smallest $b > 3$ that makes $2b + 3$ a perfect square."""
Step 4: """I can try different values of $b$ and see if they work."""
Step 5: """If $b = 4$, then $2b + 3 = 2 \cdot 4 + 3 = 11$, which is not a perfect square."""
Step 6: """If $b = 5$, then $2b + 3 = 2 \cdot 5 + 3 = 13$, which is not a perfect square."""
Step 7: """If $b = 6$, then $2b + 3 = 2 \cdot 6 + 3 = 15$, which is not a perfect square."""
Step 8: """If $b = 7$, then $2b + 3 = 2 \cdot 7 + 3 = 17$, which is not a perfect square."""
Step 9: """If $b = 8$, then $2b + 3 = 2 \cdot 8 + 3 = 19$, which is not a perfect square."""
Step 10: """If $b

In [None]:
with open("./example-gpt4-response-with-analysis.txt") as f:
    example_response = f.read()
print(example_response)

# Solution to Provide Steps with Intermediate Analysis for Rating
## Problem
Let $\theta$ be the smallest acute angle for which $\sin \theta,$ $\sin 2 \theta,$ $\sin 3 \theta$ form an arithmetic progression, in some order.  Find $\cos \theta.$
## Step and Rating Pairs
Step 1: """I notice that the problem involves trigonometric functions and arithmetic progressions, so I wonder if there is a connection between them.""" Rating: 0
Step 2: """I recall that an arithmetic progression is a sequence of numbers where each term is obtained by adding a constant amount to the previous term.""" Rating: 1
Step 3: """I also remember that the sine function is periodic, which means that it repeats the same values over and over again at regular intervals.""" Rating: 1
Step 4: """I wonder if I can use these facts to find a relationship between $\sin \theta,$ $\sin 2 \theta,$ and $\sin 3 \theta.$""" Rating: 0
Step 5: """I try to visualize what the graph of the sine function looks like, and how it changes 

In [None]:
# def prm800k_extract_synthesized_analysis(response: str):


def prm800k_extract_synthesized_analysis(
    synthesized_analysis: str,
    step_rating_analysis_pattern: str = r'Step \d+: """(.+)""" Rating: (-1|0|1) Analysis: (.+)',
):
    analyses = []

    search_results = re.findall(step_rating_analysis_pattern, synthesized_analysis)
    prompt_with_analysis = "## Step-Analysis-Rating"
    for idx, step_rating_analysis in enumerate(search_results):
        step, rating, analysis = step_rating_analysis
        assert rating in ("-1", "0", "1"), f"rating {rating} is not valid"
        assert (
            rating in analysis.split(",")[-1]
        ), f"rating {rating} is not in analysis {analysis}"
        analyses.append(analysis)

    return analyses


def prm800k_synthesized_analysis2prompt_with_analysis(
    synthesized_analysis: str,
    step_rating_analysis_pattern: str = r'Step \d+: """(.+)""" Rating: (-1|0|1) Analysis: (.+)',
):
    search_results = re.findall(step_rating_analysis_pattern, synthesized_analysis)
    prompt_with_analysis = "## Step-Analysis-Rating"
    for idx, step_rating_analysis in enumerate(search_results):
        step, rating, analysis = step_rating_analysis
        assert rating in ("-1", "0", "1"), f"rating {rating} is not valid"
        assert (
            rating in analysis.split(",")[-1]
        ), f"rating {rating} is not in analysis {analysis}"
        step_rating_analysis_prompt = (
            f"Step {idx + 1}: {step} Analysis: {analysis} Rating: {rating}"
        )
        prompt_with_analysis += f"\n{step_rating_analysis_prompt}"
    print(prompt_with_analysis)

## Step-Analysis-Rating
Step 1: I notice that the problem involves trigonometric functions and arithmetic progressions, so I wonder if there is a connection between them. Analysis: this step points out what the problem involves and leads to probing into the connection between them, so this step is appropriate in conversation, contains no inaccuracies, contains no weirdness, and contains no computations to verify, but fails to substantially advance the process of solving the problem, so it should be rated as 0. Rating: 0
Step 2: I recall that an arithmetic progression is a sequence of numbers where each term is obtained by adding a constant amount to the previous term. Analysis: this step recalls the definition of an arithmetic progression, which is accurate and relevant to the problem. It is also appropriate in conversation and contains no weirdness. The computation is not applicable here, but the information provided helps to advance the solution process, so it should be rated as 1. R