## Stanford CS 329a Self-Improving AI Agents, Homework 1

In this homework, we will go over basic techniques to scale test-time compute and self-improve. 

The homework will have 3 parts:

1. Majority Voting
2. Best-of-N with a Generative Reward Model
3. Self-refinement
4. [Bonus] Generative Process Reward Model

In [None]:
## Importing the relevant tools.
## Make sure you installed this package with the instructions in README.md
from cs329_hw1.tasks import MATH500
from cs329_hw1.methods.verifiers import MATH500Verifier
from cs329_hw1.methods import get_sampler
import matplotlib.pyplot as plt


DEBUG_MODE = True # Set this to True when developing your code, set this to False when rendering your final numbers.

###  Dataset: MATH500
We will work with a commonly used mathematical question answering dataset, MATH500.
This is the subset of the larger MATH dataset, based on [Lightman et al. 2023](https://arxiv.org/abs/2305.20050).
<br>
First, let's get familiar with the dataset.

In [None]:
math500 = MATH500()
problems = math500.get_problems(debug_mode=DEBUG_MODE)
system_prompt = math500.get_system_prompt()

# Each problem is a dictionary with a "problem" and an "answer"
print(problems[0]["problem"])
print(problems[0]["answer"])

Here, each problem is a dictionary with a "problem" and an "answer". 
<br>
"problem" is the statement of the mathematical question, and "answer" is the correct answer to the question. Importantly, this is not a multiple-choice question, so the answer can often be a mathematical expression.


###  Useful tools
To help you implement the techniques we will cover in this homework, we implemented a few useful methods for you under `cs329_hw1.methods`. 

In [None]:
# We implement a few useful methods for you under `methods`. 

# First, the simple multiple sampler.
# This method returns a list of lists, where for each prompt that is passed as an input, we return a list of responses.
# You will use this method as the basis of the techniques we will implement.
# See the below example where we sample 3 responses from gpt-4o-mini with a temperature of 0.7.

method = get_sampler("sample_multiple", "gpt-4o-mini", temperature=0.7, system_prompt=system_prompt)
prompts = ["What language model are you based on?"]
responses = method(prompts)
print("\n".join(responses[0]))


### Testing on a single problem

In [None]:
# You can use this method function to get predictions for a given problem.
# We get our first prediction from the method.
# Let's get three predictions for the first problem, and read them out.
problem_prompts = [problems[0]["problem"]]
predictions = method(problem_prompts)[0]
print("\n".join(predictions))


### Verifier

In [None]:
# We give you a verifier tool for the MATH500 dataset.
# The verifier for MATH500 is a simple tool that checks if the prediction is correct.
# Internally, it parses the prediction and the answer and compares the final mathematical expression.
# Importantly, the verifier is not perfect due to parsing challenges. 
# We mostly re-use the Qwen verifier based on the [Qwen-2.5 MATH repository](https://github.com/QwenLM/Qwen2.5-Math/tree/main/evaluation).
# If you can be nerdsniped into writing a better verifier, we'd love to see it!

# Let's test the predictions we got for the first problem that we have.
verifier = MATH500Verifier()

print("Correct answer: ", problems[0]["answer"])
for prediction in predictions:
    print("Last line of the prediction: ", prediction.split("\n\n")[-1])
    print("Is correct: ", verifier.verify(prediction, problems[0]["answer"]))

In [None]:
## Here's an example of how the verifier may actually fail:
print(verifier.verify(solution="My final answer is \frac{3\pi}{2}", ground_truth="3\frac{\pi}{2}"))

## Your turn!

### 1- Evaluating zero-shot predictions (10 points)

First, we will evaluate the accuracy of the predictions with a single sample, without using any test-time compute techniques. <br>
You will only need to use the `method` we defined above and the `verifier` to compute the accuracy. This procedure makes 1 API call per problem.


Deliverable: 
- Write your code in the section specified by `TODO: YOUR CODE STARTS HERE` and `TODO: YOUR CODE ENDS HERE`.
- Report the accuracy of the predictions below.

In [None]:
method = get_sampler("sample_multiple", "gpt-4o-mini", temperature=0.7, n_samples=1, system_prompt=system_prompt)

test_problems = problems

### TODO: YOUR CODE STARTS HERE

### TODO: YOUR CODE ENDS HERE
### Report the accuracy of the predictions below.

# print(f"Accuracy: {accuracy}")

### 2- Majority Voting (30 points)

Here we will implement our first test-time compute technique, majority voting, as described in [this paper](https://arxiv.org/abs/2408.03314) or [this earlier paper](https://arxiv.org/abs/2203.11171). In particular,
- You will sample multiple (in this case, 16) responses for each problem.
- You will then take the majority vote as the prediction. The voting will be performed per each normalized expression (i.e., given the entire solution, we will parse the final numerical expression and perform the voting on that). We provide utility functions to do this.
- You will then evaluate the prediction against the ground truth.

Deliverable: 
- Write your code in the section specified by `TODO: YOUR CODE STARTS HERE` and `TODO: YOUR CODE ENDS HERE`.
- Report the accuracy of the predictions below.

In [8]:
## 2.1 - Implementing majority voting.
from typing import List, Union
from cs329_hw1.methods.simple_samplers import SampleMultiple
from cs329_hw1.tasks.math_utils import (
    strip_string,
    extract_answer,
)

class MajorityVoting:
    """
    A class that implements majority voting strategy using multiple samples.
    It generates multiple responses for each prompt and selects the most common answer.
    """

    def __init__(
        self,
        model: str,
        system_prompt: str = None,
        n_samples: int = 5,
        temperature: float = 0.7,
        max_workers: int = 256
    ):
        """
        Initialize the majority voting method.

        Args:
            model (str): The name of the model to use
            system_prompt (str, optional): System prompt to use for the model
            n_samples (int, optional): Number of samples to generate per prompt. Defaults to 5.
            temperature (float, optional): Temperature for sampling. Defaults to 0.7.
        """
        self.sampler = SampleMultiple(
            model=model,
            system_prompt=system_prompt,
            n_samples=n_samples,
            temperature=temperature,
            max_workers=max_workers
        )

    def _parse_answer(self, response: str) -> str:
        return strip_string(extract_answer(response, "math"))

    def _get_majority_answer(self, responses: List[str]) -> str:
        """
        Determine the majority answer from a list of responses using a simple counter.

        Args:
            responses (List[str]): List of model responses

        Returns:
            str: The most common answer
        """
        assert isinstance(responses, list), "Responses must be a list"
        assert all(
            isinstance(r, str) for r in responses
        ), "All responses must be strings"
        # Extract answers from responses

        ## TODO: YOUR CODE STARTS HERE
        ## Implement the majority voting logic here.
        ## Feel free to use the `_parse_answer` method we implemented for you.
        # Create a counter for each unique answer

        ## TODO: YOUR CODE ENDS HERE
        return majority_answer

    def __call__(self, prompts: Union[str, List[str]], majority_voting_levels: List[int] = [1, 2, 4, 8, 16]) -> List[str]:
        """
        Execute majority voting on given prompt(s).

        Args:
            prompts (str or List[str]): The input prompt(s) to process

        Returns:
            List[List[str]]: For each majority voting level, we return the majority answer.
        """
        
        ## TODO: YOUR CODE STARTS HERE
        ## Implement the majority voting logic here.
        ## Feel free to use the `_parse_answer` method we implemented for you.
        ## len(majority_answers) should be equal to len(majority_voting_levels), and majority_answers[i] should be a list of strings of length len(prompts).
        ## TODO: YOUR CODE ENDS HERE
        return majority_answers


In [9]:
method = MajorityVoting(
    model="gpt-4o-mini",
    n_samples=16,
    temperature=0.7,
    system_prompt=system_prompt,
    max_workers=1024
)

In [None]:
## Do not modify the code below; it is used to evaluate the accuracy of the predictions across different majority voting budgets.

test_problems = problems

total = len(test_problems)
prompts = [problem["problem"] for problem in test_problems]

majority_voting_levels = [1, 2, 4, 8, 16]
# Get all responses at once (this uses internal threading)
all_responses_list = method(prompts)


majority_voting_accuracies = []
for responses_list, majority_voting_level in zip(all_responses_list, majority_voting_levels):
    correct = 0
    
    for prediction, problem in zip(responses_list, test_problems):
        # We do not re-normalize the prediction here, as we already normalized it in the majority voting step.
        # To do so, we set the normalize_prediction flag to False.
        if verifier.verify(prediction, problem["answer"], normalize_prediction=False):
            correct += 1
        
    accuracy = correct / len(test_problems)
    majority_voting_accuracies.append(accuracy)


In [None]:
plt.figure(figsize=(10, 6))
plt.scatter([1], [sum(zero_shot_correctness) / len(zero_shot_correctness)], s=100, color="darkorange", marker="o", label='No test-time compute')
plt.plot(majority_voting_levels, majority_voting_accuracies, marker='o', color="green", label='Majority Voting')
plt.xlabel('Test-time Budget (# of API calls)')
plt.ylabel('Accuracy')
plt.title('Accuracy vs Number of Samples in Majority Voting')
plt.legend()
plt.grid(True)
plt.show()

### 3- Best-of-N with a Generative Reward Model (30 points)

Here we will implement our second test-time compute technique, best-of-N with a generative reward model. In particular,
- You will sample multiple (in this case, 16) responses for each problem.
- You will then use an LLM to aggregate these responses and select the best answer. This is akin to best-of-N with a reward model as in [the paper](https://arxiv.org/abs/2408.03314), but we use an LLM as a judge instead of a discriminative reward model. This is sometimes referred to [Generative Verifiers](https://arxiv.org/abs/2408.15240) or [Generative Reward Models](https://arxiv.org/abs/2410.12832).
- Overall, this procedure makes 16+1 API calls per problem.
- You will then verify the prediction against the ground truth.
- (Bonus, 2 points) We give a default prompt for the generative reward model. If you can improve the existing prompt and improve the accuracy of the technique compared to the numbers you get otherwise, you will get 2 points of extra credit. There is not an absolute number to beat, but ideally you observe +5% absolute percentage points improvement over the default prompt.

Deliverables:
- Write your code in the section specified by `TODO: YOUR CODE STARTS HERE` and `TODO: YOUR CODE ENDS HERE`.
- Report the figure below.

In [12]:
class LLMVoting:
    """
    A class that uses an LLM to aggregate multiple responses and select the best answer.
    It generates multiple responses for each prompt and uses another LLM call to choose the best answer.
    """

    def __init__(
        self,
        model: str,
        system_prompt: str = None,
        n_samples: int = 5,
        temperature: float = 0.7,
        max_workers: int = 256
    ):
        """
        Initialize the LLM voting method.

        Args:
            model (str): The name of the model to use
            system_prompt (str, optional): System prompt to use for the model
            n_samples (int, optional): Number of samples to generate per prompt. Defaults to 5.
            temperature (float, optional): Temperature for sampling. Defaults to 0.7.
        """
        self.sampler = SampleMultiple(
            model=model,
            system_prompt=system_prompt,
            n_samples=n_samples,
            temperature=temperature,
            max_workers=max_workers
        )
        
        self.aggregator = SampleMultiple(
            model=model,
            system_prompt=system_prompt + """\n\nYou are a mathematical expert. You will be shown multiple solutions to a math problem.
Your task is to analyze these solutions and select the most likely correct answer.""",
            n_samples=1,
            temperature=0,  
            max_workers=max_workers
        )

    def _create_aggregation_prompts(self, problems: List[str], all_responses: List[List[str]]) -> List[str]:
        """
        Create prompts for aggregation in parallel.
        
        Args:
            problems (List[str]): List of original problems
            all_responses (List[List[str]]): List of response lists for each problem
            
        Returns:
            List[str]: List of prompts for the aggregator
        """
        return [
            f"""Here is a math problem:
{problem}

I have received multiple solutions. Here they are:

{chr(10).join(f'Solution {i+1}:{chr(10)}{r}' for i, r in enumerate(responses))}

Based on these solutions, restate the solution here that you think is most likely correct."""
            for problem, responses in zip(problems, all_responses)
        ]

    def __call__(self, prompts: Union[str, List[str]]) -> List[List[str]]:
        """
        Execute LLM-based voting on given prompt(s).

        Args:
            prompts (str or List[str]): The input prompt(s) to process

        Returns:
            List[List[str]]: The chosen answer(s) and all responses
        """
        # Get multiple samples for each prompt
        ## TODO: YOUR CODE STARTS HERE
        ## TODO: YOUR CODE ENDS HERE

In [None]:
method = LLMVoting(
    model="gpt-4o-mini",
    n_samples=16,
    temperature=0.7,
    system_prompt=system_prompt,
    max_workers=128
)

test_problems = problems
prompts = [problem["problem"] for problem in test_problems]
all_responses = method(prompts)

In [None]:
# Calculate accuracy
correct_llm_voting = 0
total_llm_voting = len(test_problems)
for responses, problem in zip(all_responses, test_problems):
    is_correct = verifier.verify(
        responses[0],
        problem["answer"],
        normalize_prediction=True
    )
    if is_correct:
        correct_llm_voting += 1

print(f"\nFinal accuracy: {correct_llm_voting/total_llm_voting:.2%}")

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(majority_voting_levels, majority_voting_accuracies, marker='o', color="green", label='Majority Voting')
# we sample 16 responses, and then we sample 1 response to aggregate them. thus, test-time budget is 16+1 API calls.
plt.scatter([16+1], [correct_llm_voting/total_llm_voting], s=100, color="darkblue", marker="s", label='LLM Voting')
plt.scatter([sum(zero_shot_correctness) / len(zero_shot_correctness)], [correct/total], s=100, color="darkorange", marker="o", label='No test-time compute')
plt.xlabel('Test-time Budget (# of API calls)')
plt.ylabel('Accuracy')
plt.title('Accuracy vs Number of Samples')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
## Take a look at some of the aggregated responses, and the rationale there. 
## What patterns do you notice?
print(all_responses[0][0])

### 4- Self-refinement

Here we will implement our third test-time compute technique, self-refinement. This is akin to  In particular,
- You will sample a single response for each problem.
- You will then use an LLM to judge and refine this response. In the terminology of [this paper](https://arxiv.org/abs/2408.03314), in this case, both the refinement of the proposal distribution and the evaluation of a prediction are done by a generative model. However, using out-of-the-box LLMs for this step can be challenging to get right.
- To match the test-time budget of 16 API calls, we will perform 16 iterations of refinement.
- Finally, you will then verify the prediction against the ground truth.
- (Bonus, 2 points) We give a default prompt for the refiner. If you can improve the existing refiner prompt and improve the accuracy of the technique by 5% absolute percentage points, you will get 2 points of extra credit.


In [None]:
class SelfRefinement:
    """
    A class that implements iterative self-refinement strategy.
    It generates an initial response and then repeatedly refines it.
    """

    def __init__(
        self,
        model: str,
        system_prompt: str = None,
        n_iterations: int = 3,
        temperature: float = 0,
        max_workers: int = 256
    ):
        """
        Initialize the self-refinement method.

        Args:
            model (str): The name of the model to use
            system_prompt (str, optional): System prompt to use for the model
            n_iterations (int, optional): Number of refinement iterations. Defaults to 3.
            temperature (float, optional): Temperature for sampling. Defaults to 0.
        """
        # Initial solution generator
        self.generator = SampleMultiple(
            model=model,
            system_prompt=system_prompt,
            n_samples=1,
            temperature=temperature,
            max_workers=max_workers
        )
        
        # Refinement sampler
        self.refiner = SampleMultiple(
            model=model,
            system_prompt="""You are a mathematical expert. You will be shown a math problem and a previous solution attempt.
Your task is to carefully review the solution and provide an improved version.
Focus on fixing any errors and making the solution more precise.""",
            n_samples=1,
            temperature=temperature,
            max_workers=max_workers
        )
        
        self.n_iterations = n_iterations

    def _create_refinement_prompts(self, problems: List[str], current_solutions: List[str]) -> List[str]:
        """
        Create prompts for refinement in parallel.
        
        Args:
            problems (List[str]): List of original problems
            current_solutions (List[str]): List of current solutions to refine
            
        Returns:
            List[str]: List of prompts for the refiner
        """
        return [
            f"""Here is a math problem:
{problem}

Here is a previous solution attempt:
{solution}

Please provide an improved solution to this problem. Focus on accuracy and clarity."""
            for problem, solution in zip(problems, current_solutions)
        ]

    def __call__(self, prompts: Union[str, List[str]]) -> List[List[str]]:
        """
        Execute self-refinement on given prompt(s).

        Args:
            prompts (str or List[str]): The input prompt(s) to process

        Returns:
            List[List[str]]: The final refined answer(s) and intermediate solutions
        """
        
        ## TODO: YOUR CODE STARTS HERE
        ## TODO: YOUR CODE ENDS HERE
        # Return a list of lists where for each problem, we have the sequence of solutions.
        # e.g., all_solutions[0] is the sequence of solutions for the first problem, and all_solutions[0][0] is the first solution for the first problem.

method = SelfRefinement(
    model="gpt-4o-mini",
    n_iterations=16,
    temperature=0.7,
    system_prompt=system_prompt,
    max_workers=1024
)

# Test it on your problems
test_problems = problems
prompts = [problem["problem"] for problem in test_problems]
final_responses = method(prompts)


In [None]:
iteration_accuracies = []
n_iterations = len(final_responses[0])

for iteration in range(n_iterations):
    correct = 0
    for problem_iterations, problem in zip(final_responses, test_problems):
        solution = problem_iterations[iteration]
        is_correct = verifier.verify(
            solution,
            problem["answer"],
            normalize_prediction=True
        )
        if is_correct:
            correct += 1
    iteration_accuracies.append(correct/len(test_problems))

plt.figure(figsize=(10, 6))
plt.scatter([16+1], [correct_llm_voting/total_llm_voting], s=100, color="darkblue", marker="s", label='LLM Voting')
plt.plot(majority_voting_levels, majority_voting_accuracies, marker='o', color="green", label='Majority Voting')
plt.scatter([sum(zero_shot_correctness) / len(zero_shot_correctness)], [correct/total], s=100, color="darkorange", marker="o", label='No test-time compute')
iterations = list(range(n_iterations))
plt.plot(iterations, iteration_accuracies, marker='^', color="red", label='Self-Refinement')

plt.xlabel('Test-time Budget (# of API calls)')
plt.ylabel('Accuracy')
plt.title('Comparison of Different Methods')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
## Looking at the self-refinement results, what patterns do you notice?
# let's look at a few examples
print(final_responses[0][-1])


## Bonus: Generative Process Reward Model (10 points)

Here we will try to get some intuition about situations where process reward models can be useful compared to outcome reward models. In this exercise, we will try to generate a flawed solution for a problem, and demonstrate when a process reward model can be useful compared to an outcome reward model.

- Your goal is to write a problem / solution pair.
- The problem / solution pair should be such that the outcome reward model cannot find the mistake, but the process reward model can.
- You can re-use the generative outcome / process reward model functions we wrote for you, or it is acceptable to write your own prompts / process reward models.


In [None]:
## TODO: YOUR CODE STARTS HERE
problem_text = r""" """

flawed_solution = r""" """
## TODO: YOUR CODE ENDS HERE

def outcome_rm_judge_flaw(problem: str, solution: str) -> str:
    prompt = f"""
Problem:
{problem}

Proposed Multi-Step Solution:
{solution}

Is this entire multi-step solution correct? Think step by step, then respond with "Yes" or "No".
"""
    sampler = SampleMultiple(
        model="gpt-4o-mini",
        system_prompt=None,
        n_samples=1,
        temperature=0.7
    )
    response = sampler(prompt)[0][0]
    return response


def process_rm_judge_flaw(problem: str, solution: str) -> list:
    """
    Asks the LLM to judge correctness of each step in parallel.
    If a step is incorrect, the LLM should indicate the mistake.
    We provide each step individually.
    """
    lines = solution.strip().split("\n")
    lines = [ln.strip() for ln in lines if ln.strip()]

    # We'll treat each "Step i)" block as a separate unit to judge.
    prompts = []
    for line in lines:
        if line.startswith("Step "):
            prompt = f"""
We will judge whether this step of the solution is correct or not.

Problem:
{problem}

Step to judge:
{line}

If there's a hidden or subtle mistake (including numeric assumptions), please indicate it.
Again, do NOT provide any corrections beyond identifying an error, if present.
"""
            prompts.append(prompt)

    # Create a parallel sampler to check each step
    sampler = SampleMultiple(
        model="gpt-4o-mini",
        system_prompt=None,
        n_samples=1,
        temperature=0.1,
        max_workers=8
    )

    batch_responses = sampler(prompts)
    return [resp_list[0] for resp_list in batch_responses]  # Flatten single responses

# --------------------------------------------
# Demonstration & Comparison of Both Methods
# --------------------------------------------
# 1) Outcome RM (entire solution at once)
outcome_judgment = outcome_rm_judge_flaw(problem_text, flawed_solution)
print("=== Outcome RM's Judgment (Single-Shot) ===")
print(outcome_judgment, "\n")

# 2) Process RM (step-by-step)
process_judgments = process_rm_judge_flaw(problem_text, flawed_solution)
print("=== Process RM's Judgment (Step-by-Step, Parallel) ===")
for idx, judge in enumerate(process_judgments, start=1):
    print(f"Step {idx} Judgment:", judge, "\n")
