<a href="https://colab.research.google.com/github/Nebius-Academy/LLM-Engineering-Essentials/blob/main/topic2/r.2_inference_time_compute.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LLM Engineering Essentials by Nebius Academy

Course github: [link](https://github.com/Nebius-Academy/LLM-Engineering-Essentials/tree/main)

The course is in development now, with more materials coming soon. [Subscribe to stay updated](https://academy.nebius.com/llm-engineering-essentials/update/)

# R.2. Inference-time compute

Now that we understand that the success of reasoning in math can be partly attributed to sheer computation, it introduces a trade-off between compute and quality. This raises an important question: **Given a certain amount of inference-time compute per question, how can we allocate it optimally?**

If we only have enough computational resources for a single LLM query, there's not much to consider. We just run it and hope for the best. But what if we can query an LLM 10 times for one user's request? Or even 1000 times? What if we're as reckless as OpenAI that [spent $10k+ **per task** when they tested **o3** in the ARC-AGI prize](https://arcprize.org/blog/oai-o3-pub-breakthrough).

In this notebook, we'll discuss how to use **inference-time compute** wisely through thoughtful orchestration.

# Parallel vs sequential orchestration

Before the recent breakthroughs in non-linear reasoning, there were several ways to "bloat" the compute. Let's start by discussing two of the most straightforward options (see the illustration below), and then move to more intricate ones.

- **Parallel**: Running $N$ identical queries in parallel with **non-zero temperature**, then aggregating the results. For example, we might choose the final answer using a majority vote (this strategy is quite confusingly called **Self-consistency**). A higher temperature ensures the outputs differ enough; it's like having $N$ independent experts, each offering their own opinion. With enough experts, the truth could be uncovered from the variety of answers.

  In the practice section, we implemented **self-consistency** for you.

  This approach is similar to **ensembling** in classical machine learning: a number of sufficiently different models is likely to outperform any individual model. It's not surprising that self-consistency is rather popular due to its power ans simplicity (compared to the other approaches we'll be discussing further).

- **Sequential**: Making an LLM revise its solution repeatedly in an "unsupervised" manner, simply prompting it to correct its previous response. The tricky part is that [LLMs often aren't very good at self-correction](https://arxiv.org/pdf/2310.01798). So, youâ€™d likely need to fine-tune the LLM for this task. First, though, you'd need to gather a dataset, which is a real challenge. No surprise that few people attempt this.

  <details>
  <summary>How would you collect data to fine-tune an LLM for self-correction? Click here if you're curious.</summary>
  Iâ€™ll share a method from [this paper](https://arxiv.org/pdf/2408.03314). The authors:

  * Sampled 64 responses in parallel at a higher temperature.
  * Paired each correct solution with up to four incorrect ones to create multi-turn self-correction data.
  * Used a character edit distance metric to prioritize selecting incorrect solutions closely related to the correct solution. This is a somewhat naive method for determining if one solution is an edit of another, but it worked!
  </details>


* Finally, if youâ€™ve mastered both approaches, you can combine them into a **hybrid strategy**: self-consistency over a series of rewriting sequences!

<center>
<img src="https://drive.google.com/uc?export=view&id=17NfFNgMoWliod4rw3t9W4MGZBwYgiKbC" width=600 />
</center>

Image source: [Scaling LLM Test-Time Compute Optimally can be More Effective than Scaling Model Parameters](https://arxiv.org/pdf/2408.03314) by Google DeepMind, which is a very insightful paper on the topic of inference-time compute.

But to truly understand what's going on in this image, we need to take the next step: **checking the solution, not just the answer**.

# Process Reward Model (PRM)

I used to teach math a lot, and I strongly believe that itâ€™s important to check not only the final answer but also the entire solution. Even a completely wrong solution can sometimes lead to the correct answer, as shown in the following example. When trying to simplify the fraction \(\frac{15}{95}\), why not just cancel out the 5â€™s from both the numerator and the denominator? Wouldnâ€™t that give the correct answer?

$$\frac{1\color{red}{5}}{9\color{red}{5}} = \frac15$$

Let's look at another example with a multi-step solution:

**Problem**: Find the maximal root of the equation $\frac{x^3 - x^2}{x - 1} = 3x - 2$.

**Solution**:

1) First, let's factor the numerator on the left-hand side: $x^3 - x^2 = x^2(x - 1)$.

2) Now we can cancel the $x - 1$ terms, simplifying the equation to $x^2 = 3 - 2x$.

3) Moving everything to one side, we get $x^2 - 3x + 2 = 0$.

4) This equation has roots at $x = 1$ and $x = 2$.

5) Therefore, the maximal root is $\boxed{2}$.

**A grave mistake in the solution** was canceling out $x - 1$ without considering that it canâ€™t be zero. $x = 1$ cannot be a valid root of the original equation, because substituting $x = 1$ results in $\frac00$ on the left-hand side. However, the final answer is still correct, because we donâ€™t need to consider $x = 1$ in this context.

Note that both Llama-3.1-8B and Llama-3.1-70B fall into this trap, while Llama-3.1-405B correctly acknowledges that $x = 1$ cannot be a solution. Without checking the solution, we wouldn't understand where the problem lies.

There have been various attempts to rate solutions as a whole (**Outcome Reward Models**, **ORM**) or for each partial step (**Process Reward Models**, **PRM**). PRMs are trickier, more exciting, and potentially more powerful, so let's discuss them in slightly more detail.

## How to define solution steps?

There's no simple answer to this question. In the example above, the solution is neatly divided into stages, but LLMs don't always do this. It can also be useful to split solutions by `"\n\n"` (end of paragraph), but LaTeX equations displayed on their own lines can cause issues with that. For example, see the solution to the same problem provided by Llama-3.1-405B below; nothing good will happen if you just split it by paragraphs.

<details>
<summary>Click to see the Llama's solution</summary>

To find the maximal root of the equation, we first need to simplify it.

The given equation is $\frac{x^3 - x^2}{x - 1} = 3x - 2$.

We can factor the numerator as $x^2(x - 1)$.

So, the equation becomes $\frac{x^2(x - 1)}{x - 1} = 3x - 2$.

Now, we can cancel out the $(x - 1)$ terms:

$x^2 = 3x - 2$

Rearranging the equation to form a quadratic equation:

$x^2 - 3x + 2 = 0$

Factoring the quadratic equation:

$(x - 2)(x - 1) = 0$

Solving for $x$, we get:

$x = 1$ or $x = 2$

However, the original equation is undefined when $x = 1$, so we exclude this solution.

Therefore, the maximal root of the equation is $x = 2$.

</details>

So, generally, if you want to use a PRM, it's better to specifically prompt your LLM with something like `Split your solution into individual logical steps and keep each step in one line.`

## Setting up the task

A PRM scores not just individual steps, but **partial solutions**, like this:

| Partial solution         | Value |
|---------------|-------|
| **Step 1**        | 0.98  |
| **Steps 1+2**     | 0.76  |
| **Steps 1+2+3**   | 0.21  |
| **Steps 1+2+3+4** | 0.29  |

That's more logical, because an individual step only makes sense given what was before in the solution.

It's reasonable to consider Process Reward Modeling task as binary classification ("good"/"bad"), which naturally results in predicting a score between 0 and 1 ("class probabilities"). As we'll see, we may actually have class probability estimates as targets.

An LLM may be fine tuned into a Process Reward Model, for example, by teaching it to answer to good partial solutions with a token `"+"` and for bad ones with `"-"`. (Or, more generally, to predict `+` with a probability of class "good" and `-` with the probability of class "bad".)

We'll try an LLM tuned to be a PRM in the practice part.

**Note**. Of course, an LLM can also be used as a PRM without any fine tuning, in an **LLM-as-a-Judge** mode. For that, you'd need to prompt it to reason about the potential of each partial solution and to give a score, preferrably on a small discrete scale (like 1-5). Though tempting, this approach is not without caveats:

* We put too much hope on LLMs' ability to judge reasoning, thought it's their ability to reason we want to improve or score in the first place.
* Generally, you'd take a powerful model as an LLM-as-a-Judge scorer, while you can fine tune a smaller model to be a decent PRM. If you only need a scorer to evaluate something once, that may be ok, but if you want to further use it for steering generation at inference, that might be an issue.

## Where to get PRM training data?

This is one of those situations where acquiring data is a tough task. There is no ready-made dataset available, and human labeling would be terribly expensive. (Also, human-generated labels for partial solutions are likely to be poorly calibrated.)

A more viable approach is, non-suprisingly, an unsupervised one, based on  running **Monte Carlo rollouts** from each step in the solution. It was suggested in the [Math-Shepherd](https://arxiv.org/pdf/2312.08935) paper and it works like this:

* To score a partial solution, generate a large number $N$ of its continuations and check how many of them arrive at the correct answer. (Luckily, most math datasets contain correct answers.) The ratio of valid continuations to $N$ will be the score. It seems to be a good estimate of the "probability that this partial solution is good".

<center>
<img src="https://drive.google.com/uc?export=view&id=1-FnTNw1GV9FMu1k4hKnNTjHZCjGlkq-1" width=600 />
</center>

You'll find the code in the practice part of the notebook.

**A word of caution**. The ability of the PRM to detect math and logical errors depends on whether those errors in the training data affect the answers. For example, if the training dataset contains only the problem above, where canceling $x-1$ does not influence the answer, the PRM won't learn that cancellations should be performed responsibly. However, if the PRM training dataset includes the problem of finding **all** the roots of $\frac{x^3 - x^2}{x - 1} = 3x - 2$, it may teach the PRM something about cancellations.

## PRM applications and caveats

There are several ways we could use a PRM:

* As a reward model for further LLM fine tuning. For example, the [Math-Shepherd](https://arxiv.org/pdf/2312.08935) paper demonstrates that RLHF with a PRM trained for math problems may improve the LLM's math capabilities.
* As an alternative to major voting in **Self-consistency**. Indeed, major voting only chooses the most popular answer, but disregards the solution quality. Choosing a solution with the max PRM score may result in more correct outputs. This is exactly what was used in the [DeepMind paper](https://arxiv.org/pdf/2408.03314) we mentioned earlier and from which we borrowed the already familiar picture.

  <center>
  <img src="https://drive.google.com/uc?export=view&id=17NfFNgMoWliod4rw3t9W4MGZBwYgiKbC" width=600 />
  </center>

  There are several ways of scoring an entire solution with a PRM.
  
  * Score each partial solution (step 1, steps 1+2, steps 1+2+3,...) and then agregate the scores.
  * Just score the whole solution without bothering about its prefixes.
  
  There's evidence that the second approach works well enough. It's worth noting though, that PRMs, which are trained to score partial solutions, are more accurate than ORMs (Outcome Reward Models), which are trained to score only full solutions.

* In sequential rewriting settings, instead of picking the final rewrite, we can use PRM to score every solution in a chain and pick the best scoring one.

* Finally, we can use PRM on intermediate steps to steer the generation. And that's what we'll discuss next!

Although cool, PRMs have their own issues. They are not totally reliable, and, even worse, they don't transfer too well between different models. You should definitely be careful when scoring solutions by DeepSeek R1 using a PRM trained on solutions generated by Mistral.

# ORM as a PRM

In some situations we have an **ORM** (**Outcome Reward Model**), which scores only complete solutions, instead of a PRM. For example, in code generation, a set of tests or other automatically verified requirements may serve the role of an ORM.

In such cases, you can use **lookahead** to score partial solutions. The idea won't be new for you:

* For a given partial solution, generate a number of full continuations (**rollouts**).
* Score each of the continuations and average their score. This will give you the partial solution's score.

**Note**. Lookahead search isn't only good for simulating a Process Reward Model with an Outcome Reward Model. Even if you have a PRM, it's sometimes useful to score a partial solution by

- generating several steps ahead,
- and then scoring this longer partial solution.

Moreover, you can use several lookaheads; this may potentially give you a more reliable reward value.

**Note**. Sometimes your ORM is just a deterministic check. For example, in the coding task, where automatic tests may be run once there is a full solution. Of course, in this case lookahead search is a great way of scoring partial solutions.

# PRM-guided generation and orchestration of non-linear reasoning

As we've mentioned before, humans solve problems in a non-linear way - exploring different solution paths, discarding some, and focusing more on others. Self-consistency is a rough analogy of this process - exploring multiple independent thought trajectories - but a rather coarse one.

Before LLMs learned to perform non-linear reasoning on their own, various approaches emerged to orchestrate it. In this section, we'll discuss several of them.

## Tree of Thoughts and Beam Search

Three of Thoughts, first introduced in [the eponymous paper](https://arxiv.org/pdf/2305.10601) is a rather straightforward implementation of nonlinear generation. The rough idea is to explore a tree of potential solutions, where each vertex is a "thought" (a solution step), with Breadth-First Search (BFS) or Depths-First Search (DFS).

<center>
<img src="https://drive.google.com/uc?export=view&id=16J6w1QdzkX81zm10H1hKetdwvhHKOlz_" width=600 />

[Source](https://arxiv.org/pdf/2305.10601)
</center>

For a Three of Thoughts algorithm to take shape, we need to choose:

* The potential **tree structure**: **max number of branches** at each vertex, **max depth**.
* **The way we sample next "thoughts"**. Basically, we either make several parallel queries for the LLM to "generate a next logical step", or ask it to "generate several next logical step options" in one prompt.

  The latter approach may be useful for avoiding duplication when there are not too many options present (thing of something like finding shortest route in a graph). Otherwise, I'd go for parallel queries.

* How we decide which paths to explore further and which to abandon. For that, you need a way of **scoring** a vertex (a state) or a partial solution that led to it. The original Tree of Thoughts paper used LLM-as-a-Judge (which was well-motivated by the specific tasks they considered), but for most tasks a trained **Process Reward Model** would be a reasonable choice.

  Specific PRM-based criteria may be used to determine which branches are hopeless and to prioritize high-potential ones.

  A special case of Tree of Thoughts is **Beam Search**, a variation of Breadth-First Search, that keeps at most $B$ (**beam size**) vertices at each level. If $B=2$, it works as follows:

  * To begin, 2 first step options are generated.
  
  * During each of the following interations, 2 next steps are generated for each of the 2 intermediate solutions we have.

  * 2 of them, the top-scoring ones, are passed to the next iteration.
    
  In some variations 4 first steps are generated initially, with 2 top-scoring ones taken for further generation.


<center>
<img src="https://drive.google.com/uc?export=view&id=1y3io1RqfqyIKcfu7kxRIiZKNKGMXtt9P" width=500 />
</center>

In the practice part, we'll implement beam search.

## Beyond Tree of Thoughts

Even more elaborate orchestration efforts emerged in 2023 and 2024, like [Graph of Thoughts](https://arxiv.org/pdf/2308.09687) or [Algorithm of Thoughts](https://arxiv.org/pdf/2308.10379). Another interesting example of a guided tree-based generation strategy is **Monte Carlo Tree Search** (**MCTS**) which we'll discuss in the next notebook. However, these approaches are costly and complicated, and from this point of view they tend to fall behind simply using larger LLMs or LLMs with native non-linear reasoning like OpenAI's o1 and DeepSeek R1. Still, it's interesting to check these approaches and to ponder how they resonate with what happens in R1 and similar models.

<center>
<img src="https://drive.google.com/uc?export=view&id=1WZWjI7aY3Vu0zEsAO8u7R73iwsC6KJeq" width=600 />

[Source](https://arxiv.org/pdf/2308.09687)
</center>

# Practical guidelines and inference scaling laws

With such an exciting choice of approaches, how can we choose the best one? To start with, we need to understand our **inference budget**, that is the amount of money (or, simplifying this, inference calls) we can spend on processing one query. With the inference budget in mind, we can try choosing

* **The LLM we want to use**: which size tier, non-linear reasoning capabilities, etc.
* **How to stretch the inference budget**: techniques like self-consistency, beam search, and more.

  For example, if our budget allows for $N$ LLM calls per query, we can leverage self-consistency with $N$ parallel calls.

  <details>
  <summary>Check a very rough comparison of beam search vs self-consistency, if you're curious.</summary>
  
  Assume that the beam size is $B$, each solution consists of $D$ thoughts, with $T$ tokens each, and the problem consists of $P$ tokens. Also, it will be reasonable to estimate that processing of an input token costs $\$c$, while generating one output token costs $\$3c$ (that's more or less true for most APIs). Now, with beam search:

  * The first step will cost us $B\cdot(cP + 3cT)$,
  * The second step will cost $B\cdot([cP + cT] + 3cT)$ (the prompts are now problem + first thought),
  * The third step will cost $B\cdot([cP + 2cT] + 3cT)$ (the prompts are now problem + first thought + second thought),
  * ...
  * The last, $D$-th step will cost $B\cdot([cP + (D-1)\cdot cT] + 3cT)$

  Totally, this gives
  $$Bc\cdot\left(DP + T + 2T + \ldots + (D-1)T + 3T \right)=
  Bc\cdot\left(DP + \left[\frac12D(D-1) + 3\right]T \right)$$

  For self-consistency with $B$ parallel queries, we need $B$ calls with $P$ input tokens and $DT$ output tokens, which results in the cost
  $$B\cdot(cP + 3cDT) = Bc\cdot(P + 3DT)$$
  Note that we didn't count $BD$ PRM calls here!
  
  Typically, beam search will be more expensive than self-consistency, but now we also see that, due to beam search's cost being quadratic in $D$ (solution length in "thoughts"), for long solutions beam search will be staggeringly more expensive.

  </details>

Some key considerations are:

* **Larger models vs. smarter strategies**
  Upgrading to a more powerful LLM (e.g., **Llama-3.1-405B** instead of **Llama-3.1-8B**) can provide a major boost in quality, but at a steep cost. In some cases, a smaller LLM combined with self-consistency or a more advanced strategy may outperform the larger model at a lower cost.

  However, with growing amoung of inference-time budget, a larger LLM, straightforwardly queried, may eventually become a more favourable choice than a tricky, bug-prone orchestration.
  
  We'll explore this trade-off in the practical section.

* **Trade-offs in PRM-guided generation**
  PRM-guided, tree-based generation is costly and more complex to set up. (And good PRMs aren't just lying around!) If you're looking for a starting point, self-consistency is a solid choice.

  The [Scaling LLM Test-Time Compute Optimally can be More Effective than Scaling Model Parameters](https://arxiv.org/pdf/2408.03314) paper presents several experiments aimed at identifying **inference scaling laws** - the patterns of how efficiency evolves as inference budgets change.

  The results below are for the [MATH benchmark](https://huggingface.co/datasets/nlile/hendrycks-MATH-benchmark).

  * For smaller inference budgets, beam search consistently outperforms other approaches (as long as there's enough compute to run it).
  * For larger inference budgets, another method takes the lead: generating multiple solutions in parallel and selecting the best one based on the top PRM score.

  Of course, there's no guarantee that these results will generalize to all LLMs and datasets. However, they strongly suggest that even with a large inference budget, you don't necessarily need highly complex strategies to optimize performance.

  <center>
  <img src="https://drive.google.com/uc?export=view&id=1Bh81Ycx-2H0676F7MTpnRhGYFMg8QMyo" width=600 />
  </center>




# Ready for more?

This notebook is part of the larger free course â€” **LLM Engineering Essentials** â€” where youâ€™ll go even further in your learning and build a service for creating smart, human-like NPCs.

ðŸŽ“ New materials are coming soon. Click the link below to subscribe for updates and make sure you donâ€™t miss anything:

[Stay updated](https://academy.nebius.com/llm-engineering-essentials/update/)

# Practice session

We have ambitious plans for this practice session. We will:

* Practice balancing the inference budget between using large LLMs and leveraging complex orchestration strategies with smaller LLMs.
* Implement Beam Search using both model-based and confidence-based PRMs.

If you encounter any difficulties or simply want to see our solutions, feel free to check the [Solutions notebook](https://colab.research.google.com/github/Nebius-Academy/LLM-Engineering-Essentials/blob/main/topic2/r.2_inference_time_compute_solutions.ipynb).

## Getting ready

In [1]:
!pip install -q openai

In [2]:
import os

with open("nebius_api_key", "r") as file:
    nebius_api_key = file.read().strip()

os.environ["NEBIUS_API_KEY"] = nebius_api_key

We'll be calling APIs quite often in this notebook, so let's define a shortcut fuction to avoid repeating all the code:

In [3]:
from openai import OpenAI

nebius_client = OpenAI(
    base_url="https://api.studio.nebius.ai/v1/",
    api_key=os.environ.get("NEBIUS_API_KEY"),
)

llama_8b_model = "meta-llama/Meta-Llama-3.1-8B-Instruct"

def prettify_string(text, max_line_length=80):
    """Prints a string with line breaks at spaces to prevent horizontal scrolling.

    Args:
        text: The string to print.
        max_line_length: The maximum length of each line.
    """

    output_lines = []
    lines = text.split("\n")
    for line in lines:
        current_line = ""
        words = line.split()
        for word in words:
            if len(current_line) + len(word) + 1 <= max_line_length:
                current_line += word + " "
            else:
                output_lines.append(current_line.strip())
                current_line = word + " "
        output_lines.append(current_line.strip())  # Append the last line
    return "\n".join(output_lines)

def answer_with_llm(prompt: str,
                    system_prompt="You are a helpful assistant",
                    max_tokens=512,
                    client=nebius_client,
                    model=llama_8b_model,
                    prettify=True,
                    temperature=None) -> str:

    messages = []

    if system_prompt:
        messages.append(
            {
                "role": "system",
                "content": system_prompt
            }
        )

    messages.append(
        {
            "role": "user",
            "content": prompt
        }
    )

    completion = client.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=max_tokens,
        temperature=temperature
    )

    if prettify:
        return prettify_string(completion.choices[0].message.content)
    else:
        return completion.choices[0].message.content

# Practice, part 1. Larger models vs smarter strategies: Classifying dialogue roles with CoT and Self-Consistency

In this task, we'll work with the [DialogRE benchmark](https://github.com/nlpdata/dialogre) and classify roles in dialogs using Llama models and Nebius AI Studio.

Let's start by loading the dataset.

In [None]:
import json

with open('dialogre_dev.json', 'r') as f:
    dialog_data_raw = json.load(f)

It's important to look at the data.

In [None]:
len(dialog_data_raw)

358

In [None]:
dialog_data_raw[0]

[['Speaker 1: Hey!',
  'Speaker 2: Hey.',
  "Speaker 3: Hey, man. What's up?",
  "Speaker 1: Maybe you can tell me. My agent would like to know why I didn't show up at the audition I didn't know I had today. The first good thing she gets me in weeks. How could you not give me the message?!",
  "Speaker 3: Well, I'll tell ya I do enjoy guilt, but, ah, it wasn't me.",
  'Speaker 2: Yes, it was! It was him! Uh huh! Okay, it was me!',
  'Speaker 1: How is it you?',
  "Speaker 2: Well, it was just, it was all so crazy, you know. I mean, Chandler was in the closet, counting to 10, and he was up to 7 and I hadn't found a place to hide yet. I-I-I meant to tell you, and I wrote it all down on my hand. See, all of it.",
  "Speaker 1: Yep, that's my audition.",
  'Speaker 4: See, now this is why I keep notepads everywhere.',
  "Speaker 2: Yep, and that's why we don't invite you to play.",
  'Speaker 5: What is the great tragedy here? You go get yourself another appointment.',
  'Speaker 1: Well, 

As you see, many different character roles are labeled for every dialog.

Here is the list of all roles:

In [None]:
all_relations = []
for dialog in dialog_data_raw:
    for relation in dialog[1]:
        for individual_relation in relation['r']:
            if not individual_relation in all_relations:
                all_relations.append(individual_relation)
all_relations

['per:title',
 'per:alternate_names',
 'per:client',
 'per:friends',
 'unanswerable',
 'per:spouse',
 'per:children',
 'per:parents',
 'per:age',
 'per:siblings',
 'per:roommate',
 'per:negative_impression',
 'per:pet',
 'per:positive_impression',
 'per:girl/boyfriend',
 'org:employees_or_members',
 'per:employee_or_member_of',
 'per:dates',
 'per:boss',
 'per:subordinate',
 'per:other_family',
 'org:students',
 'per:major',
 'per:schools_attended',
 'per:origin',
 'gpe:visitors_of_place',
 'per:visited_place',
 'per:alumni',
 'per:works',
 'per:place_of_residence',
 'gpe:residents_of_place',
 'per:place_of_work',
 'per:date_of_birth',
 'per:acquaintance',
 'per:neighbor',
 'gpe:births_in_place',
 'per:place_of_birth']

We don't want to overcomplicate the task, so we'll stick to simpler and more frequent roles. We'll also exclude `per:alternate_names`, because it's too adundant in the dataset and it would affect class balancing too much.

In [None]:
our_relations = [
    'per:friends', 'per:spouse', 'per:children', 'per:parents',
    'per:siblings', 'per:girl/boyfriend', 'per:boss', 'per:subordinate'
]

We'll select only dialogs that contain at least one of the chosen relationship statuses.

In [None]:
dialog_data = []
for dialog in dialog_data_raw:
    current_dialog = dialog[0]
    current_relations = []
    for relation in dialog[1]:
        if relation['r'][0] in our_relations:
            current_relations.append({
                'x': relation['x'],
                'y': relation['y'],
                'r': relation['r'][0]
            })
    if len(current_relations) > 0:
        dialog_data.append([current_dialog, current_relations])

In [None]:
len(dialog_data)

180

In [None]:
dialog_data[1]

[['Speaker 1, Speaker 2: Hi',
  'Speaker 3: Hi! Hey mom.',
  'Speaker 4: This is such a great party! 35 years. Very impressive, do you guys have any pearls of wisdom?',
  'Speaker 2: Jack?',
  'Speaker 1: Why would you serve food on such a sharp stick?',
  'Speaker 3: Thatâ€™s a good question, dad. Thatâ€™s a good questionâ€¦',
  'Speaker 4: Hmmmâ€¦.'],
 [{'x': 'Speaker 1', 'y': 'Speaker 2', 'r': 'per:spouse'},
  {'x': 'Speaker 1', 'y': 'Speaker 3', 'r': 'per:children'},
  {'x': 'Speaker 2', 'y': 'Speaker 1', 'r': 'per:spouse'},
  {'x': 'Speaker 2', 'y': 'Speaker 3', 'r': 'per:children'},
  {'x': 'Jack', 'y': 'Speaker 3', 'r': 'per:children'},
  {'x': 'Speaker 3', 'y': 'Speaker 2', 'r': 'per:parents'},
  {'x': 'Speaker 3', 'y': 'Speaker 1', 'r': 'per:parents'},
  {'x': 'Speaker 3', 'y': 'Jack', 'r': 'per:parents'}]]

To further simplify the task, for each dialog we'll select only one role to predict. We'll use random sampling, but **we'll fix the random seed** to make the selection procedure reproducible. And to save time and money, we'll only take 50 first dialogues from the `dev` set.

In [None]:
import numpy as np

np.random.seed(28)
dialog_data_short = [[dialog, np.random.choice(relations)] for dialog, relations in dialog_data[:50]]

In [None]:
dialog_data_short[1]

[['Speaker 1, Speaker 2: Hi',
  'Speaker 3: Hi! Hey mom.',
  'Speaker 4: This is such a great party! 35 years. Very impressive, do you guys have any pearls of wisdom?',
  'Speaker 2: Jack?',
  'Speaker 1: Why would you serve food on such a sharp stick?',
  'Speaker 3: Thatâ€™s a good question, dad. Thatâ€™s a good questionâ€¦',
  'Speaker 4: Hmmmâ€¦.'],
 {'x': 'Speaker 1', 'y': 'Speaker 3', 'r': 'per:children'}]

In [None]:
verdicts_true = [relations['r'] for _, relations in dialog_data_short]

In ML tasks it's always important to look at the target label distribution. In our case it's not balanced: there are much more friends and girl/boyfriends than other roles. We won't take it into the account for now, but in a real-life Data Science problem, we'd try to adjust our metrics to take class imbalance into account.

In [None]:
from collections import Counter

relations_counter = Counter(verdicts_true)

relations_counter

Counter({'per:friends': 11,
         'per:children': 3,
         'per:parents': 4,
         'per:siblings': 6,
         'per:spouse': 3,
         'per:girl/boyfriend': 16,
         'per:boss': 5,
         'per:subordinate': 2})

## Take 1: A large model with a simple CoT strategy

To start, we'll use Llama-3.1-405B with a straighforward CoT + programmed answer parsing strategy.

We could have used an LLM chain with the second model extracting the answer, but it shouldn't be difficult just parse it. The only non-trivial thing we introduce at the parsing stage comes from the observation that sometimes LLMs predict `boyfriend` or `girlfriend` instead of `girl/boyfriend`. So, we just manually map `boyfriend` or `girlfriend` to `girl/boyfriend`.

In [None]:
class RelationClassifier():
    def __init__(self, client, model):
        self.client = client
        self.model = model
        self.raw_classes = ["friends", "spouse", "children", "parents",
                            "siblings", "girl/boyfriend", "boss", "subordinate"]

    def predict(self, dialog, character_x, character_y, verbose=False):
        reasoning_completion = self.client.chat.completions.create(
            messages=[
                {
            "role": "user",
            "content": f"""You are an expert in Natural Nanguage Understanding.
You are gived a dialog and two characters participating or mentioned in this dialog.
You need to predict relationships between these characters choosing from the following list:
- friends
- spouse
- children
- parents
- siblings
- girl/boyfriend
- boss
- subordinate
Provide a clear reasoning justifying your choice. Then write your final answer after #VERDICT:
Now, take a deep breath and work out this problem step by step. If you do well, I'll tip you 200$.

DIALOG: {dialog}

FIRST CHARACTER: {character_x}

SECOND CHARACTER: {character_y}

REASONING:"""
                }
            ],
            model=self.model,
            )
        reasoning = reasoning_completion.choices[0].message.content

        # Extract whatever is after #VERDICT:
        re_match = re.search(r"#VERDICT(.*)", reasoning, re.DOTALL)
        if re_match:
            extracted_answer = re_match.group(1).strip()
        else:
            extracted_answer = "Failed to parse"

        # Parse the answer
        verdict = extracted_answer.lower().strip("'\".; ")
        if verdict == "boyfriend" or verdict == "girlfriend":
            verdict = "girl/boyfriend"
        if verdict in self.raw_classes:
            verdict = "per:" + verdict
        else:
            verdict = "per:failed"

        if verbose:
            return {
                "reasoning_completion": reasoning_completion,
                "extracted_answer": extracted_answer,
                "verdict": verdict
            }
        else:
            return verdict

In [None]:
client = OpenAI(
    base_url="https://api.studio.nebius.ai/v1/",
    api_key=os.environ.get("NEBIUS_API_KEY"),
)

classifier_llama_405b = RelationClassifierChain(
    client=client, model="meta-llama/Meta-Llama-3.1-405B-Instruct"
    )

It's good to start logging the results

In [None]:
completions_log = dict() # Raw completions
verdicts_log = dict() # Final verdicts

In [None]:
# The tqdm library allows to create progress bars for cycles
from tqdm import tqdm

current_configuration = "Meta-Llama-3.1-405B-Instruct, no enhancements"
results = []

# If you're short in compute, try for dialog_data_short[:10]
for dialog, relations in tqdm(dialog_data_short):
    results.append(classifier_llama_405b.predict(dialog, relations['x'], relations['y'], verbose=True))

completions_log[current_configuration] = results

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 50/50 [08:53<00:00, 10.66s/it]


Let's look at the results:

In [None]:
results[1]

{'reasoning_completion': ChatCompletion(id='chat-96edbd003f1c489f93aa2d75dd2a958f', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='To determine the relationship between Speaker 1 and Speaker 3, let\'s analyze the dialog provided.\n\n1. Speaker 3 greets with "Hi! Hey mom." Initially, this could suggest that Speaker 3 is addressing their mother, but since Speaker 3 later addresses Speaker 1 as "dad," we can infer that Speaker 3 was addressing their mother who is not participating in this part of the conversation or is Speaker 4. It doesn\'t impact our main relationship in question.\n\n2. The critical clue is when Speaker 3 says, "Thatâ€™s a good question, dad. Thatâ€™s a good questionâ€¦" This line directly addresses Speaker 1 as "dad," which indicates a parental relationship.\n\nGiven these points, the relationship between Speaker 1 and Speaker 3 can be classified under the \'per:parents\' category since Speaker 1 is the parent (fath

First, let's check the percentage of cases when we failed to parse the verdict:

In [None]:
verdicts = [result["verdict"] for result in results]
print(sum([verdict == "per:failed" for verdict in verdicts]) / len(verdicts))

['per:friends', "'per:parents'", "'per:parents'", 'per:friends', 'per:spouse', 'per:friends', 'per:girl/boyfriend', "'per:boss'", 'per:spouse', 'per:girl/boyfriend', 'per:boss', 'per:girl/boyfriend', 'per:subordinate', 'per:girl/boyfriend', 'per:girl/boyfriend', 'per:boss', 'per:friends', 'per:girl/boyfriend', 'per:spouse', 'per:friends', 'per:friends', 'per:girl/boyfriend', 'per:girl/boyfriend', 'per:friends', 'per:girl/boyfriend', 'per:siblings', 'per:boss', 'per:friends', 'per:parents', 'per:girl/boyfriend', 'per:friends', 'per:friends', 'per:spouse', 'per:friends', 'per:siblings', "'per:girl/boyfriend'", "'per:siblings'", 'per:girl/boyfriend', 'per:parents', 'per:friends', 'per:friends', 'per:girl/boyfriend', 'per:siblings', 'per:girl/boyfriend', 'per:girl/boyfriend', 'per:parents', 'per:children', 'per:boss', 'per:subordinate', 'per:friends']


Now, let's check the accuracy of our predictions.

In [None]:
import numpy as np
def accuracy_score(y_true, y_pred):
    return sum(np.array(y_true) == np.array(y_pred)) / len(y_true)

accuracy_score(verdicts_true, verdicts)

0.74

That's not bad for a start!

## Take 2: A smaller LLM

Let's see how a smaller model **Llama-3.1-8B** will cope with this task!

In [None]:
classifier_llama_8b = RelationClassifierChain(
    client=client, model="meta-llama/Meta-Llama-3.1-8B-Instruct"
    )

current_configuration = "Meta-Llama-3.1-8B-Instruct, no enhancement"
results = []
# do it for patient_visits[-10:] to save time
for dialog, relations in tqdm(dialog_data_short):
    results.append(classifier_llama_8b.predict(dialog, relations['x'], relations['y'], verbose=True))

completions_log[current_configuration] = results

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 50/50 [04:13<00:00,  5.07s/it]


In [None]:
results[-3]

{'reasoning_completion': ChatCompletion(id='chat-934dd724a2814e11847567c550d93a25', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Based on the dialog, I predict the relationship between Mr. Kostelick and Speaker 2 as \'per:boss\'.\n\nReasoning:\n\n1. Speaker 2 is mentioned as needing to stop by Mr. Kostelick\'s office at the end of the day (Speaker 3\'s line).\n2. Mr. Kostelick is referred to as "Mr. Kostelick wants you to stop by his office at the end of the day" (Speaker 3\'s line), indicating that Speaker 2 reports to or is supervised by Mr. Kostelick.\n3. Speaker 2 is also defensive when accused of being involved in a prank memo, and mentions "Mr. Kostelick" in the context of work (Speaker 2\'s lines "If this is about those prank memos, I had nothing to do with them.").\n4. Speaker 7, Chandler, mentions that Speaker 2 has been at the office for five years, implying a long-standing employment relationship.\n5. Speaker 2\'s lines

Let's see how many verdicts are parsed wrongly:

In [None]:
verdicts = [result["verdict"] for result in results]
print(sum([verdict == "per:failed" for verdict in verdicts]) / len(verdicts))

["'per:client'", "'per:parents'.", '**per:parents**', "'per:friends'", 'Ross is the boyfriend or partner of Speaker 1.', "['per:friends']", 'per:spouse', 'per:boss', "'per:girl/boyfriend'", "'per:friends'", "['per:subordinate']", 'per:girl/boyfriend', "['per:boss']", 'per:spouse', "['per:friends']", '\'per:girl/boyfriend\' is unlikely, but... \'per:girl/boyfriend\' is not the final verdict\n\nWAIT, MY DIAGNOSTIC IS \'per:girl/boyfriend\' UNLIKELY BUT...\n\nMy diagnosis is \'per:girl/boyfriend\' unlikely, but the final verdict is likely one of the remaining options which is:\n\n\'per:girl/boyfriend\' is unlikely but another option is more likely\n\nThe final verdict is:\n\'per:girl/boyfriend\' is unlikely but another option is more likely.\n\nHowever, there is more text that indicates the verdict.\n\n* \'per:girl/boyfriend\'- unlikely, as the dialogue does not suggest a romantic relationship.\n* \'per:subordinate\' - unlikely, as Speaker 2 seems to be on an equal level or possibly highe

And, finally, the accuracy:

In [None]:
accuracy_score(verdicts_true, verdicts)

0.44

That's worse.

But let's also check how many of the raw verdict are in the correct format:

## Take 3: A smaller LLM with Self-consistency

LLM's generations will change from iteration to iteration, and the answers may change as well. And actually the accuracy of the above algorithm is unstable.

We'll leverage this as another possibility of improving the quality. There's a reason to suppose that even though one reasoning may be false, several attempts of the LLM at reasoning may reveal the truth.

The most popular approach is called **Self Consistency**. It works as follows:

- Generate several (say, 5 or 7) reasoning paths, extract answer from each of them,
- Choose the most frequent option.

This is like a majority vote of several identical LLMs.

In [None]:
dialog, relations = dialog_data_short[5]
for _ in range(5):
    print(classifier_few_shot.predict(
        dialog, relations['x'], relations['y'], verbose=True
        )["reasoning_completion"].choices[0].message.content)
    print("\n###\n")

Speaker 2 and Speaker 3 are mentioned in the dialog as being in a relationship, as evidenced by Speaker 3's use of the term "my husband" to refer to Speaker 2 (at the end of the dialog). This implies that Speaker 2 and Speaker 3 are married.

Therefore, based on the relationship options provided, I choose:

VERDICT: 'per:spouse'

###

After analyzing the dialog, I can infer the following relationships between the characters:

* Speaker 2 (Monica) is a close friend of Speaker 3 (Rachel), as they share a warm and intimate conversation, with Speaker 3 involving Speaker 2 in her personal life (treating Speaker 3 to a solo where she discusses how to communicate with Phoebe).
* Speaker 3 (Rachel) and Speaker 4 (Phoebe) are also close friends, as they are part of the same social circle and share inside jokes ("What about Mike? Alright, well, let's just gag him and handcuff him and force him down the aisle. I can just see it: 'Mike, do you take Phoebe...").
* Speaker 2 (Monica) seems to be awa

Now, let's create a self-consistency-based classifier.

In [None]:
from collections import Counter

def most_frequent(List):
    occurence_count = Counter(List)
    return occurence_count.most_common(1)[0][0]

class RelationClassifierSelfConsistency():
    def __init__(self, client: OpenAI, model: str, n_trials: int = 5):
        self.client = client
        self.model = model
        self.n_trials = n_trials
        self.raw_classes = ["friends", "spouse", "children", "parents",
                            "siblings", "girl/boyfriend", "boss", "subordinate"]

    def predict(self, dialog, character_x, character_y, verbose=False):
        reasoning_completion = self.client.chat.completions.create(
            messages=[
                {
            "role": "user",
            "content": f"""You are an expert in Natural Nanguage Understanding.
You are gived a dialog and two characters participating or mentioned in this dialog.
You need to predict relationships between these characters choosing from the following list:
- friends
- spouse
- children
- parents
- siblings
- girl/boyfriend
- boss
- subordinate
Provide a clear reasoning justifying your choice. Then write your final answer after #VERDICT:
Now, take a deep breath and work out this problem step by step. If you do well, I'll tip you 200$.

DIALOG: {dialog}

FIRST CHARACTER: {character_x}

SECOND CHARACTER: {character_y}

REASONING:"""
                }
            ],
            model=self.model,
            n=n_trials # That's the main difference.
            )

        reasoning_completions = []
        verdicts = []
        for i in range(n_trials):
            reasoning = reasoning_completion.choices[i].message.content
            reasoning_completions.append(reasoning)

            # Extract whatever is after #VERDICT:
            re_match = re.search(r"#VERDICT(.*)", reasoning, re.DOTALL)
            if re_match:
                extracted_answer = re_match.group(1).strip()
            else:
                extracted_answer = "Failed to parse"

            # Parse the answer
            verdict = extracted_answer.lower().strip("'\".; ")
            if verdict == "boyfriend" or verdict == "girlfriend":
                verdict = "girl/boyfriend"
            if verdict in self.raw_classes:
                verdict = "per:" + verdict
            else:
                verdict = "per:failed"

            verdicts.append(verdict)

        final_verdict = most_frequent(verdicts)
        if verbose:
            return {
                "reasoning_completions": reasoning_completions,
                "verdicts": verdicts,
                "verdict": final_verdict
            }
        else:
            return final_verdict

Now, let's see if self-consistency is able to improve the results of **Llama-3.1-8B**

**This may take time!**

In [None]:
# DELETE THIS LATER

classifier_sc = RelationClassifierSelfConsistency(
    client=client, model="meta-llama/Meta-Llama-3.1-8B-Instruct", n_trials=5
)

result = classifier_sc.predict(
    dialog_data_short[5][0], dialog_data_short[5][1]['x'], dialog_data_short[5][1]['y'],
    verbose=True)

Let's classify all the statements and check the accuracy.



In [None]:
from tqdm import tqdm

current_configuration = "Meta-Llama-3.1-8B-Instruct, self-consistency"
results = []
# If you want, do it for dialog_data_short[-20:] to save time
for dialog, relations in tqdm(dialog_data_short):
    results.append(classifier_sc.predict(dialog, relations['x'], relations['y'], verbose=True))

completions_log[current_configuration] = results

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 50/50 [18:04<00:00, 21.69s/it]


As before, we check both the failed extraction rate and the classification accuracy.

In [None]:
verdicts = [result["verdict"] for result in results]
print(sum([verdict == "per:failed" for verdict in verdicts]) / len(verdicts))

['per:friends', 'per:parents', 'per:parents', 'per:friends', 'per:girl/boyfriend', 'per:spouse', 'per:girl/boyfriend', 'per:boss', 'per:friends', 'per:friends', 'per:boss', 'per:friends', 'per:boss', 'per:spouse', 'per:girl/boyfriend', 'per:boss', 'per:friends', 'per:girl/boyfriend', 'per:spouse', 'per:friends', 'per:friends', 'per:girl/boyfriend', 'per:girl/boyfriend', 'per:friends', 'per:girl/boyfriend', 'per:siblings', 'per:boss', 'per:friends', 'per:girl/boyfriend', 'per:spouse', 'per:friends', 'per:friends', 'per:spouse', 'per:friends', 'per:siblings', 'per:girl/boyfriend', 'per:siblings', 'per:ex-spouse', 'per:parents', 'per:spouse', 'per:friends', 'per:girl/boyfriend', 'per:siblings', 'per:spouse', 'per:girl/boyfriend', 'per:parents', 'per:children', 'per:boss', 'per:boss', 'per:friends']


In [None]:
verdicts = verdicts_raw
verdicts_log[current_configuration] = verdicts

In [None]:
accuracy_score(verdicts_true[10:], verdicts[10:])

0.675

It's almost like we had with Llama-3.1-405b. Now, let's save our logs to avoid losing all the results:

In [None]:
import pickle

pickle.dump(completions_log, open("completions_log.pkl", "wb"))
pickle.dump(verdicts_log, open("verdicts_log.pkl", "wb"))

## Computing the cost

To wrap up the dialog relationship classifier task, let's calculate the LLM API cost for each of the scenarios using

* Prices from the [Nebius AI Studio model reference](https://studio.nebius.ai/)
* Token counts from the prompts and completions that we diligently logged.

In [None]:
# Models costs, per 1M tokens:
costs = {
    '405B': {
        'input': 1,
        'output': 3
    },
    '8B': {
        'input': 0.02,
        'output': 0.06
    }
}

for key, value in completions_log.items():
    print(f'=== With {key} ===')
    total_input_tokens = 0
    total_output_tokens = 0
    for result in value:
        for k, v in result.items():
            if 'completion' in k:
                if isinstance(v, list):
                    for completion in v:
                        total_input_tokens += completion.usage.prompt_tokens
                        total_output_tokens += completion.usage.completion_tokens
                else:
                    total_input_tokens += v.usage.prompt_tokens
                    total_output_tokens += v.usage.completion_tokens
    if '405' in key:
        model_size = '405B'
    elif '8B' in key:
        model_size = '8B'
    else:
        print('And what is that?..')
    input_cost = total_input_tokens / 1000000 * costs[model_size]['input']
    output_cost = total_output_tokens / 1000000 * costs[model_size]['output']
    total_cost = input_cost + output_cost

    print(f'''
        Input cose: {input_cost}
        Output cost: {output_cost}
        Total cost: {total_cost}
              ''')

=== With Meta-Llama-3.1-405B-Instruct, chain ===

        Input cose: 0.046945
        Output cost: 0.049715999999999996
        Total cost: 0.096661
              
=== With Meta-Llama-3.1-8B-Instruct, chain ===

        Input cose: 0.00084292
        Output cost: 0.00088362
        Total cost: 0.00172654
              
=== With Meta-Llama-3.1-8B-Instruct, few-shot ===

        Input cose: 0.00229196
        Output cost: 0.0006172199999999999
        Total cost: 0.00290918
              
=== With Meta-Llama-3.1-8B-Instruct, self-consistency ===

        Input cose: 0.0116306
        Output cost: 0.00359088
        Total cost: 0.015221479999999999
              


As we see, even self-consistency, Llama-3.1-8B is still way cheaper than Llama-3.1-405B. Moreover, we could greatly increase `n_trials` before we hit the price of the larger model.

**Key takeaway**: always be aware of the larger model vs smarter strategy trade-off and take costs into account.

# Practice, part 2. Your turn now!

Now it's your time to put `MMLUEvaluator` we've played with in the previous notebook to good use! Choose one of the math-related fields (you can check the choice in [the paper](https://arxiv.org/pdf/2009.03300)) and select first 50 examples.

**Your task.** Compare accuracy on this small dataset when using:

- **Llama-3.1-70B** with CoT suppression,
- **Llama-3.1-8B** with CoT suppression,
- **Llama-3.1-70B** with basic CoT,
- **Llama-3.1-8B** with basic CoT,
- **Llama-3.1-8B** with self-consistency.

Also, compare the cost of processing 50 examples with:
- **Llama-3.1-70B** without any additional tricks,
- **Llama-3.1-8B** with all additional tricks you can use.
Can you reach **Llama-3.1-70B**'s quality with **Llama-3.1-8B** while staying cheaper?

In [None]:
# <Your experiments here>

# Practice, part 3. PRMs and Beam Search

## Using a PRM

In this section, you'll play with a **PRM** (**Process Reward Model**).

At the moment, we're not aware of any good PRM API (if you know any, please share with us!), so we'll have to use an open source model. Namely, we'll use this one: [RLHFlow/Llama3.1-8B-PRM-Deepseek-Data](https://huggingface.co/RLHFlow/Llama3.1-8B-PRM-Deepseek-Data), which is a fine tune of Llama-3.1-8B. We haven't yet discussed using open-source models, but you'll have all the code, so we hope you won't have problems with it.

**To run this model, you'll need a GPU**. The model itself will take around 16Gb, and some more GPU memory will be used for inference. Either L40s in Nebius cloud or L4 in Colab should be enough for the task. Just don't forget to switch off (and better to delete) the virtual machine after you finish; otherwise, you'll be charged for the time it stays idle.

We've created a wrapper class `ProcessRewardModel`, which loads the model for you and also offers an `evaluate_partial_solution(prompt, partial solution)` method which incapsulates all the details of calling the model allowing you not to think about how it's done under the hood.

However, if you're curious, we'll share a little bit about how this model works. It is a chat model, and to each user's message it's trained to answer with a `"+"` or a `"-"`. The probability of predicting `"+"` is exactly the score of the partial solution ("how likely it is for it to give a correct solution, if continued").

A multi-step solution, when scored by this PRM, is transformed into a dialog like this:

```
[
      {"role": "user", "content": "Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. To convert from rectangular coordinates $(x, y)$ to polar coordinates $(r, \\theta)$, we can use the formulas\n\\[r = \\sqrt{x^2 + y^2}\\]\n\\[\\theta = \\arctan \\frac{y}{x}\\]"},
      {"role": "assistant", "content": "+"},
      {"role": "user", "content": "In this case, the rectangular coordinates are $(0,3)$, so $x = 0$ and $y = 3$."},
      {"role": "assistant", "content": "+"},
      {"role": "user", "content": "First, we calculate $r$:\n\\[r = \\sqrt{0^2 + 3^2} = \\sqrt{9} = 3\\]"},
      {"role": "assistant", "content": "+"},
      {"role": "user", "content": "Next, we calculate $\\theta$:\n\\[\\theta = \\arctan \\frac{3}{0}\\]"},
      {"role": "assistant", "content": "+"},
      {"role": "user", "content": "Since the tangent function is not defined for $x = 0$, we need to use a special case. When $x = 0$, $\\theta = \\frac{\\pi}{2}$ if $y > 0$, and $\\theta = \\frac{3\\pi}{2}$ if $y < 0$."},
      {"role": "assistant", "content": "+"},
      {"role": "user", "content": "In this case, $y = 3 > 0$, so $\\theta = \\frac{\\pi}{2}$."},
      {"role": "assistant", "content": "+"},
      {"role": "user", "content": "So, the polar coordinates equivalent to $(0,3)$ are $\\boxed{(3,\\frac{\\pi}{2})}$."},
      {"role": "assistant", "content": "+"},
]
```

Source: [model's github](https://github.com/RLHFlow/RLHF-Reward-Modeling/tree/main/math-rm)

Basically, `evaluate_partial_solution` does the follows:

* Splits the solution by `"\n\n"` (paragraph end) which seems to be a proxy for a change of "thoughts" in the PRM's training,
* Turns the solutions steps into a dialog.
* Returns an estimate of the probability of generating a `"+"`.

The code below is mostly an adaptation of [this evaluation script](https://github.com/RLHFlow/RLHF-Reward-Modeling/blob/main/math-rm/prm_evaluate.py).


In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class ProcessRewardModel:
    def __init__(self, model: str = "RLHFlow/Llama3.1-8B-PRM-Deepseek-Data"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.model = AutoModelForCausalLM.from_pretrained(
            model,
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
            device_map="auto"
        )

        # Set up tokenizer settings
        self.tokenizer.padding_side = "right"
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model.config.pad_token_id = self.model.config.eos_token_id

        # Get token IDs for + and -
        self.plus_token_id = self.tokenizer.encode("+")[-1]
        self.minus_token_id = self.tokenizer.encode("-")[-1]
        self.candidate_tokens = [self.plus_token_id, self.minus_token_id]

    def evaluate_partial_solution(self, prompt: str, partial_solution: str) -> float:
        """Evaluate a partial solution using PRM."""
        # Split solution into steps
        steps = [step.strip() for step in partial_solution.split("\n\n") if step.strip()]

        # Convert to chat format starting with prompt
        conversation = []
        first_text = prompt + " " + steps[0]
        conversation.append({"role": "user", "content": first_text})
        conversation.append({"role": "assistant", "content": "+"})

        for step in steps[1:]:
            conversation.append({"role": "user", "content": step})
            conversation.append({"role": "assistant", "content": "+"})

        # Remove last assistant message for scoring
        conversation = conversation[:-1]

        # Get model prediction
        input_ids = self.tokenizer.apply_chat_template(
            conversation,
            return_tensors="pt"
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(input_ids)
            logits = outputs.logits[:, -3, self.candidate_tokens]
            scores = logits.softmax(dim=-1)
            plus_prob = scores[:, 0].item()  # Probability of + token

        return plus_prob

Let's load the PRM and score with it one correct solution and one incorrect solution of the equation $x^2 - x - 2 = 0$.

In [5]:
# Initialize PRM once
prm = ProcessRewardModel(model="RLHFlow/Llama3.1-8B-PRM-Deepseek-Data")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/444 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/896 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

In [6]:
prompt = "Solve the equation $x^2 - x - 2 = 0$."

# A correct solution
solution = """Factor the left part as $(x + 1)(x - 2)$.

Now, let's rewrite the initial equation $x^2 - x - 2 = 0$ as $(x + 1)(x - 2) = 0$.

Thus, either $x + 1 = 0$ or $x - 2 = 0$.

Thus, the roots are -1 and 2"""

# Get step-score pairs
steps = solution.split("\n\n")
for i in range(len(steps)):
    partial_solution = "\n\n".join(steps[:i+1])
    step_score = prm.evaluate_partial_solution(prompt, partial_solution)
    print(f"\n#PARTIAL SOLUTION:\n {partial_solution}")
    print(f"#SCORE: {step_score:.4f}")


#PARTIAL SOLUTION:
 Factor the left part as $(x + 1)(x - 2)$.
#SCORE: 0.9995

#PARTIAL SOLUTION:
 Factor the left part as $(x + 1)(x - 2)$.

Now, let's rewrite the initial equation $x^2 - x - 2 = 0$ as $(x + 1)(x - 2) = 0$.
#SCORE: 0.9478

#PARTIAL SOLUTION:
 Factor the left part as $(x + 1)(x - 2)$.

Now, let's rewrite the initial equation $x^2 - x - 2 = 0$ as $(x + 1)(x - 2) = 0$.

Thus, either $x + 1 = 0$ or $x - 2 = 0$.
#SCORE: 0.9209

#PARTIAL SOLUTION:
 Factor the left part as $(x + 1)(x - 2)$.

Now, let's rewrite the initial equation $x^2 - x - 2 = 0$ as $(x + 1)(x - 2) = 0$.

Thus, either $x + 1 = 0$ or $x - 2 = 0$.

Thus, the roots are -1 and 2
#SCORE: 1.0000


We'll comment briefly on the output above. The solution consists of 4 paragraphs and is thus split into 4 individual "thoughts" (solution steps). In the cycle, we evaluate:

* Step 1
* Steps 1+2
* Steps 1+2+3
* Steps 1+2+3+4 (the full solution)

As you see, all the scores are quite high, and the full solution gets max grade (and it's correct indeed).

In [7]:
prompt = "Solve the equation $x^2 - x - 2 = 0$."

# An incorrect solution
solution = """Rewrite the equation as $x^2 = x + 2$.

Divide both parts by x: $x = 1 + 2x$.

Rewrite it as $x = 1$. Thus, x = 1."""

# Get step-score pairs
steps = solution.split("\n\n")
for i in range(len(steps)):
    partial_solution = "\n\n".join(steps[:i+1])
    step_score = prm.evaluate_partial_solution(prompt, partial_solution)
    print(f"\n#PARTIAL SOLUTION:\n {partial_solution}")
    print(f"#SCORE: {step_score:.4f}")


#PARTIAL SOLUTION:
 Rewrite the equation as $x^2 = x + 2$.
#SCORE: 0.5889

#PARTIAL SOLUTION:
 Rewrite the equation as $x^2 = x + 2$.

Divide both parts by x: $x = 1 + 2x$.
#SCORE: 0.1550

#PARTIAL SOLUTION:
 Rewrite the equation as $x^2 = x + 2$.

Divide both parts by x: $x = 1 + 2x$.

Rewrite it as $x = 1$. Thus, x = 1.
#SCORE: 0.3999


For an incorrect solution, the scores are much lower. Moreover:

* The first step, which is somewhat strange in solving a simple quadratic equation, is scored as dubious.
* The next step, which intoduces a serious mistake, gets a very low grade.

So, the PRM seems to be somewhat aligned with our math intuition.

## Beam Search

In this part, we'll share our implementation of **Beam Search** - the class `MathBeamSearch`. It's quite big, so we'll comment on several important things about it:

* It uses a [min heap](https://en.wikipedia.org/wiki/Heap_(data_structure)) to store the newly generated completions and their scores, because this data structure allows for fast addition and for fast deletion of minimal elements.

  We chose the `heapq` implementation; because `heapq` a max heap, we actually store pairs `(-score, partial_solution)` to make a min heap out of it.

* Each time, the LLM is prompted to generate a next logical step of the solution and to keep it on one line. This (in most cases) allows to establish `"\n\n"` as separators between individual "thoughts", as expected by the PRM.

* We prompt the LLM to output `#ANSWER: <answer>` when it obtains the final answer. This allows to finalize successful solutions without continuing them aimlessly intil they hit `max_steps` "thoughts".

In [22]:
import heapq
from typing import List, Dict, Tuple, Optional
from openai import OpenAI

class LLMClient:
    """Wrapper for OpenAI-compatible API clients with consistent interface."""
    def __init__(
        self,
        client: OpenAI,
        model: str,
        default_temperature: float = 0.0,
        default_max_tokens: int = 1024,
        system_prompt: Optional[str] = None
    ):
        self.client = client
        self.model = model
        self.default_temperature = default_temperature
        self.default_max_tokens = default_max_tokens
        self.system_prompt = system_prompt

    def generate(
        self,
        prompt: str,
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None,
        system_prompt: Optional[str] = None
    ) -> str:
        """Generate completion with consistent interface across different LLM providers."""
        messages = []

        # Use provided system prompt or fall back to default
        current_system_prompt = system_prompt or self.system_prompt
        if current_system_prompt:
            messages.append({"role": "system", "content": current_system_prompt})

        messages.append({"role": "user", "content": prompt})

        completion = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            temperature=temperature if temperature is not None else self.default_temperature,
            max_tokens=max_tokens if max_tokens is not None else self.default_max_tokens
        )

        return completion.choices[0].message.content

class MathBeamSearch:
    """Beam search implementation for math problem solving."""
    def __init__(
        self,
        prm: ProcessRewardModel,
        llm_client: LLMClient,
        beam_width: int = 2,
        max_steps: int = 10
    ):
        self.prm = prm
        self.llm_client = llm_client
        self.beam_width = beam_width
        self.max_steps = max_steps

    def generate_next_steps(self, prompt: str, partial_solution: str, num_continuations: int) -> List[str]:
        """Generate next possible steps using LLM."""
        if partial_solution:
            message = f"""You are an expert math problem solver. Given a math problem and a partial solution, generate the next logical step.
Keep the step concise and focused on one specific calculation or logical deduction.

Problem:
{prompt}

Current partial solution:
{partial_solution}

Generate the next step in the solution. Only output:
- the next step
- #ANSWER: followed by your answer, if you can determine the final answer. If you output #ANSWER:, you need to output the actual answer after it.
Don't output anything else!
Keep the whole new step of the single line.
If you need to write formulas, use latex $ markup to format them."""
        else:
            message = f"""You are an expert math problem solver. Given a math problem, generate the first step of the solution.
Keep the step concise and focused on one specific calculation or logical deduction.

Problem:
{prompt}

Generate the first step in the solution. Only output:
- the first step,
- #ANSWER: followed by your answer, if you can determine the final answer. If you output #ANSWER:, you need to output the actual answer after it.
Don't output anything else!
Keep the whole first step of the single line.
If you need to write formulas, use latex $ markup to format them."""

        responses = []
        for _ in range(num_continuations):
            response = self.llm_client.generate(
                message
            )
            responses.append(response.strip())
        return responses

    def beam_search(self, prompt: str,
            verbose: bool = False) -> List[Tuple[float, str]]:
        """Perform beam search to find the best solution."""
        # Initialize beam with empty solutions
        current_beam = [(0.0, "", False)]  # (score, solution, is_finalized)

        # Get initial steps
        initial_continuations = self.generate_next_steps(prompt, None, self.beam_width)

        # Initialize beam with scored initial steps
        candidates = []
        for continuation in initial_continuations:
            score = self.prm.evaluate_partial_solution(prompt, continuation)
            is_finalized = "#ANSWER:" in continuation
            candidates.append((-score, continuation, is_finalized))  # Negative for max-heap
            if verbose:
                print(f"\nInitial step (score {score:.4f}):")
                print(f"{continuation}\n")

        # Select top-k candidates for initial beam
        heapq.heapify(candidates)
        current_beam = [(-score, solution, is_finalized)
                       for score, solution, is_finalized in heapq.nsmallest(self.beam_width, candidates)]

        # Beam search iterations
        step = 0
        while step < self.max_steps:
            # Check if all solutions are finalized
            if all(is_finalized for _, _, is_finalized in current_beam):
                break

            if verbose:
                print(f"\n=== Step {step + 1} ===")
            candidates = []

            # Keep finalized solutions and generate continuations for unfinished ones
            for score, partial_solution, is_finalized in current_beam:
                if is_finalized:
                    # Keep finalized solutions in candidates without modification
                    candidates.append((-score, partial_solution, True))
                else:
                    # Generate continuations only for unfinished solutions
                    continuations = self.generate_next_steps(
                        prompt,
                        partial_solution,
                        self.beam_width
                    )

                    # Evaluate each continuation
                    for continuation in continuations:
                        new_solution = partial_solution + "\n\n" + continuation if partial_solution else continuation
                        new_score = self.prm.evaluate_partial_solution(prompt, new_solution)
                        is_finished = "#ANSWER:" in continuation
                        candidates.append((-new_score, new_solution, is_finished))
                        if verbose:
                            print(f"\nCandidate (score {new_score:.4f}):")
                            print(f"{continuation}\n")

            # Select top-k candidates for next beam
            heapq.heapify(candidates)
            current_beam = [(-score, solution, is_finalized)
                          for score, solution, is_finalized in heapq.nsmallest(self.beam_width, candidates)]

            if verbose:
                print("\nSelected for next beam:")
                for score, solution, is_finalized in current_beam:
                    status = "FINALIZED" if is_finalized else "IN PROGRESS"
                    print(f"\nScore: {score:.4f} [{status}]")
                    print(f"{solution}\n")

            step += 1

        # Return all solutions (now guaranteed to include any finalized ones)
        return [(score, solution) for score, solution, _ in current_beam]

Let's try it!

In [23]:
client = OpenAI(
    base_url="https://api.studio.nebius.ai/v1/",
    api_key=os.environ.get("NEBIUS_API_KEY"),
)


# Create beam search instance
beam_search = MathBeamSearch(
    prm=prm,
    llm_client=LLMClient(
        client=client,
        model="meta-llama/Meta-Llama-3.1-70B-Instruct",
        default_temperature=1,
        default_max_tokens=8192
    ),
    beam_width=2,
    max_steps=20
)

We'll turn on the `verbose` parameter to see all the intermediate results. (It's `False` by default.)

In [24]:
prompt = "Inside a circle, two parallel chords are 6 units apart. One chord has length 14 and the other has length 10. Find the radius of the circle."

# Run beam search
results = beam_search.beam_search(prompt, verbose=True)

# Print final results
print("\n=== Final Results ===")
for score, solution in results:
    print(f"\nScore: {score:.4f}")
    print(f"{solution}\n")


Initial step (score 0.2830):
Let $r$ be the radius of the circle, and draw a perpendicular line from the center of the circle to each of the two chords, which divides each chord into two equal parts: $7$ and $5$.
$ANSWER:


Initial step (score 0.6875):
Draw a perpendicular line from the center of the circle to each chord, and denote the radius of the circle as $r$. Let $d$ be the distance from the center of the circle to the chord with length 10, so the distance from the center of the circle to the chord with length 14 is $d+6$.


=== Step 1 ===

Candidate (score 0.4688):
Using the Pythagorean theorem, we can set up two equations: $r^2 = d^2 + 5^2$ and $r^2 = (d+6)^2 + 7^2$.


Candidate (score 0.5796):
Apply the Pythagorean theorem to the two right triangles formed by the radii and the chords, resulting in the equations: $r^2 = d^2 + 5^2$ and $r^2 = (d+6)^2 + 7^2$.


Candidate (score 0.8408):
The perpendicular line from the center of the circle and the chords form two right triangles:

# Practice, Part 4: Confidence as a Synthetic PRM

Training **Process Reward Models (PRMs)** is challenging, and only a few such models are available on Hugging Faceâ€”none of which are ideal. Therefore, having a **model-free** method for estimating solutions would be beneficial. One simple surrogate for process reward to consider is **confidence**.

In the [LLM Inference Parameters notebook](https://colab.research.google.com/github/Nebius-Academy/LLM-Engineering-Essentials/blob/main/topic1/1.6_llm_inference_parameters.ipynb), we discussed that LLMs exhibit varying levels of confidence in their generated outputs:

<center>
<img src="https://drive.google.com/uc?export=view&id=12k5EFzMZAcHntuJZBZwbm6NKqJZ1OF3l" width=600 />
</center>

The left image illustrates a case where the LLM is almost certain to generate "LLM," while the right image shows a scenario where the model is less confident in its output. While uncertainty can be valuable in creative writing, it may indicate confusion - or even hallucinations - in mathematical problem-solving. Thus, for math and logical reasoning tasks, it is reasonable to assume that **solutions generated with higher confidence are more likely to be correct**.

### Simple approach: using top predicted probability

With this in mind, we suggest modifying the **Beam Search** algorithm to evaluate partial solutions based on their **mean confidence**. Confidence can be estimated using the **mean top predicted log probability**, calculated as:

$$\frac{1}{\mathrm{n\_steps}}\sum_{i=1}^{\mathrm{n\_steps}}\log\left(\mbox{Top token probability predicted at step $i$}\right)$$

The top probability can be obtained by calling `client.chat.completions.create` with `logprobs=True` and extracting `completion.choices[0].logprobs`.

Although this approach is fairly simplistic, it may still be effective. A higher top probability implies lower probabilities for alternative tokens, indicating greater confidence in the top prediction.

### A fancier approach: Negative Mean Entropy

A more robust method involves using **negative mean entropy**. [Entropy](https://en.wikipedia.org/wiki/Entropy_(information_theory)) quantifies the uncertainty of a probability distribution. For next-token generation, it is calculated as:

$$-\sum_{w\in\mbox{Vocab}}\widehat{p}_{w}\log{ \widehat{p}_{w} },$$

where $\widehat{p}_{w}$ represents the predicted probability of token $w$. Entropy behaves as follows:

- $0$ when one token has a probability of 1 while all others have 0 (**absolute certainty**).
- Maximum when all tokens have equal probabilities (**absolute uncertainty**).

Thus, solutions with **lower entropy** are more confidently generated.

Unfortunately, OpenAI's API only provides the top-5 token probabilities, limiting direct entropy calculation. However, entropy can still be estimated using these top-5 probabilities. So, you can also try this, but we recommend you to start with using only the top probability.

In [None]:
# <YOUR CODE HERE>