<a href="https://colab.research.google.com/github/Nebius-Academy/LLM-Engineering-Essentials/blob/main/topic2/r.3_establishing_non_linear_reasoning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LLM Engineering Essentials by Nebius Academy

Course github: [link](https://github.com/Nebius-Academy/LLM-Engineering-Essentials/tree/main)

The course is in development now, with more materials coming soon. [Subscribe to stay updated](https://academy.nebius.com/llm-engineering-essentials/update/)

# R.3. Establishing non-linear reasoning capabilities

In the previous notebook, we discussed ways of orchestrating non-linear reasoning process. However, some of today's LLMs - exaples include **o1**, **o3-mini**, **Gemini** (a "thinking" species of this family), **DeepSeek R1**, **Grok 3**, and **Claude 3.7 Sonnet** - are capable of generating non-linear reasoning without additional engineering on the user's side.

In this notebook, we'll discuss several approaches to establishing native non-linear reasoning capabilities.

But before we proceed, let's briefly revisit what **non-linear reasoning** is. Unlike immediately arriving at a correct solution, this approach is exploratory - experimenting, iterating, self-correcting, and occasionally starting from scratch. In essence, it's like a whole **Tree of Thoughts** condensed into a single text. It is also often referred to as **long reasoning** due to its extended reasoning trace.

# Non-linear reasoning and inference-time compute

As we discussed earlier, one reason Chain of Thought (CoT) brings value is that it allows an LLM to perform more computations under the hood - and increasing these computations even further may bring additional profit. In the previous notebook, we expanded inference-time compute through orchestration. **Non-linear reasoning** takes this a step further - it's essentially a **native Tree of Thoughts**.  

It's no surprise that as the length of non-linear reasoning (i.e., the size of this native tree) increases, quality improves, but at an increading cost, of course.

Here are figures from **OpenAI** about **o1**:  

<center>  
<img src="https://drive.google.com/uc?export=view&id=1x8gxVEENI8tD2yV2KD1sV8jSuWNXTLDh" width=600 />  

[Source](https://openai.com/index/learning-to-reason-with-llms/)  
</center>  

And here are figures from **Anthropic** about **Claude 3.7 Sonnet**:  

<center>  
<img src="https://drive.google.com/uc?export=view&id=1cNKRvHupStpMh4EVG4XgwZ9l53N-1E5x" width=600 />  

[Source](https://www.anthropic.com/news/visible-extended-thinking)  
</center>  

It's also not surprising that the latest **GPT-4.5**, that doesn't possess long reasoning capabilities, performs significantly worse in math than **o3-mini-high**, an earlier model capable of non-linear reasoning.  

<img src="https://drive.google.com/uc?export=view&id=1N7hHOFr8SHRbUFWBz5ZQDHxJqnEPVpXW" width=600 />

[Source](https://openai.com/index/introducing-gpt-4-5/)

</center>

A great feature of Claude 3.7 Sonnet is conrollable reasoning length. Let's illustrate this on a math problem.

In [None]:
import os

with open("anthropic_api_key", "r") as file:
    nebius_api_key = file.read().strip()

os.environ["ANTHROPIC_API_KEY"] = nebius_api_key

In [None]:
!pip install -q anthropic

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/242.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m235.5/242.4 kB[0m [31m9.8 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.4/242.4 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25h

Anthropic's `client.messages.create` has additional argument `thinking` for **Claude 3.7 Sonnet**. If it's disabled, you'll get an ordinary CoT answer.

In [None]:
import anthropic

# Problem 22 from the AIME dataset: https://huggingface.co/datasets/AI-MO/aimo-validation-aime
prompt = """Denali and Nate work for a dog walking business and are paid for each dog they walk. Denali is responsible for $16$ dogs and Nate is responsible for $12$ dogs. Under the company's new policy, they will be assigned or unassigned new dogs in groups of $x$ dogs. The ratio of Denali's pay to Nate's pay would be the same if Denali started walking $4x$ more dogs and Nate stayed at $12$ dogs or if $x$ of Nate's dogs were reassigned to Denali. Find $x$ if $x\\neq0$."""


client = anthropic.Anthropic(
    api_key=os.environ.get("ANTHROPIC_API_KEY"),
)
reply = client.messages.create(
    model="claude-3-7-sonnet-20250219",
    max_tokens=1024,
    thinking={
        "type": "disabled",
    },
    messages=[
        {"role": "user", "content": prompt}
    ]
)
print(reply.content[0].text)

# Finding $x$ in the Dog Walking Problem

I need to determine the value of $x$ where $x$ represents the number of dogs in a group that can be added or removed from Denali's and Nate's responsibilities.

## Setting up the problem

Given information:
- Denali has 16 dogs
- Nate has 12 dogs
- The ratio of their pay would be the same under two scenarios:
  1. Denali gets 4x more dogs (Nate's stays at 12)
  2. x dogs are transferred from Nate to Denali

## Creating equations

Let me denote Denali's pay rate per dog as $d$ and Nate's pay rate as $n$.

### Initial pay ratio
Initial ratio of Denali's pay to Nate's pay: $\frac{16d}{12n}$

### First scenario
When Denali gets 4x more dogs, the ratio becomes:
$\frac{(16+4x)d}{12n}$

### Second scenario
When x dogs are transferred from Nate to Denali, the ratio becomes:
$\frac{(16+x)d}{(12-x)n}$

### Setting the scenarios equal
Since both scenarios must yield the same ratio:
$\frac{(16+4x)d}{12n} = \frac{(16+x)d}{(12-x)n}$

## Solving the equation


The correct answer is $5$, and it is obtained here. Let's also calculate the number of tokens in the answer. Anthropic's interface for this is a bit strange: you need to supply whatever string you want to know the token length of as a user's message, and the token count will manifest as the number of input tokens:

In [None]:
token_count = client.messages.count_tokens(
    model="claude-3-7-sonnet-20250219",
    messages=[
        {"role": "user", "content": reply.content[0].text}
    ]
)
print(token_count.model_dump_json())

{"input_tokens":568}


To turn long reasoning on, we set `type` of `thinking` to `enabled` and set some budget. The larger this budget is, the longer (at least theoretically) the resulting reasoning will be.

In [None]:
answers = []

budgets = [1024, 2048, 4096, 8192]

for budget in budgets:
    reply = client.messages.create(
        model="claude-3-7-sonnet-20250219",
        max_tokens=512+budget,
        thinking={
            "type": "enabled",
            "budget_tokens": budget
        },
        messages=[
            {"role": "user", "content": prompt}
        ]
    )
    answers.append(reply)

    # Checking actual answer length:
    thinking_token_count = client.messages.count_tokens(
        model="claude-3-7-sonnet-20250219",
        messages=[
            {"role": "user", "content": reply.content[0].thinking}
        ]
    )

    print(f"""For budget={budget}:
          Thinking tokens: {thinking_token_count.model_dump_json()}""")

    if len(reply.content) > 1:
        answer_token_count = client.messages.count_tokens(
            model="claude-3-7-sonnet-20250219",
            messages=[
                {"role": "user", "content": reply.content[1].text}
            ]
        )
        print(f"""          Answer tokens: {answer_token_count.model_dump_json()}""")
    else:
        print("    Reasoning was cut short, no answer generated")

For budget=1024:
          Thinking tokens: {"input_tokens":1035}
          Answer tokens: {"input_tokens":440}
For budget=2048:
          Thinking tokens: {"input_tokens":1120}
          Answer tokens: {"input_tokens":600}
For budget=4096:
          Thinking tokens: {"input_tokens":2159}
          Answer tokens: {"input_tokens":547}
For budget=8192:
          Thinking tokens: {"input_tokens":2102}
          Answer tokens: {"input_tokens":614}


As you see, the number of reasoning tokens grows indeed (the particular numbers may vary), while the answer length doesn't change much. Now, if you check the reasonings, you'll see that with high budgets Claude does several rounds of double-checking:

In [None]:
for budget, answer in zip(budgets, answers):
    print(f"===BUDGET {budget}===")
    print("====REASONING====")
    print(answer.content[0].thinking)
    if len(answer.content) > 1:
        print("====ANSWER====")
        print(answer.content[1].text)
    else:
        print("====NO ANSWER GENERATED====")
    print("\n\n")

===BUDGET 1024===
====REASONING====
Let's think this through step by step.

Initially, Denali has 16 dogs and Nate has 12 dogs. Let's say the pay is proportional to the number of dogs, with proportionality constant $P$.

Denali's initial pay is $16P$.
Nate's initial pay is $12P$.
The ratio of Denali's pay to Nate's pay is $\frac{16P}{12P} = \frac{16}{12} = \frac{4}{3}$.

Scenario 1: Denali gets 4x more dogs, and Nate stays at 12 dogs.
Denali would have $16 + 4x$ dogs, with pay $(16 + 4x)P$.
Nate would still have 12 dogs, with pay $12P$.
The new ratio would be $\frac{(16 + 4x)P}{12P} = \frac{16 + 4x}{12}$.

Scenario 2: x of Nate's dogs are reassigned to Denali.
Denali would have $16 + x$ dogs, with pay $(16 + x)P$.
Nate would have $12 - x$ dogs, with pay $(12 - x)P$.
The new ratio would be $\frac{(16 + x)P}{(12 - x)P} = \frac{16 + x}{12 - x}$.

Since these two ratios are equal, we have:
$\frac{16 + 4x}{12} = \frac{16 + x}{12 - x}$

Let's solve for $x$.
$(16 + 4x)(12 - x) = (16 + x)(12)$

We have little change to peek into the training process of Claude, but luckily we can get some insights from the [s1: Simple test-time scaling](https://arxiv.org/pdf/2501.19393) paper. We'll yet return to it when discussing training data quality and quantity, and right now we'll discuss the **budget forcing** tactics suggested in this paper.

Budget forcing is a surprisingly simple idea. The authors suggest making an LLM to reason for a longer time by just introducing `Wait` every time it's going to produce an answer, until some target length is hit. As we've already seen, `Wait` prompts the LLM to double-check everything or maybe explore a new approach.

And indeed, this approach to increasing inference-time budget proves to be more efficient than self-consistency:

<center>
<img src="https://drive.google.com/uc?export=view&id=1lD-NBNW242onaE4_v-qYq3nMGoDVxCNS" width=600 />
</center>

# **The Starting Point for Training a Model in Non-Linear Reasoning**  

In all cases where training details are publicly known, the process begins with a **pre-trained model**. For example, **DeepSeek R1** was trained starting from **DeepSeek V3**.  

In this notebook, we'll explore surprisingly simple approaches to establishing non-linear reasoning capabilities. Underlying these methods, I believe, is a key principle highlighted in the [Demystifying Long Chain-of-Thought Reasoning in LLMs](https://arxiv.org/pdf/2502.03373) paper:  

**Core reasoning abilities, such as error correction, are inherently present in base (pre-trained) models—even if activating them can be tricky.**  

Consider this: a pre-trained model has encountered vast amounts of text, including examples of non-linear reasoning. However, since such examples are likely scarce, the model typically does not generate them on its own. To adjust the likelihood of this "way of thinking" being used, some ingenuity is required. And in the following sections we'll discuss in more details several of the approaches.

# **A Classical Approach: SFT**  

The most straightforward way to fine-tune an LLM for any task, including non-linear reasoning, is **supervised fine-tuning (SFT)** - that is, fine-tuning on a dataset of **(prompt, completion)** pairs, where completions exhibit the desired properties.  

If we choose SFT, we must first gather a dataset of non-linear reasoning examples, which is far from simple:  

- **Hiring human annotators** is expensive, and even for a math researcher, crafting high-quality non-linear reasoning examples is difficult. A good example should explore multiple thought paths (including ultimately unhelpful ones), incorporate mistakes, and demonstrate self-correction - something which rarely is on top of your mind.  
- **Math textbooks rarely contain true non-linear reasoning.** Theorem proofs and problem solutions are typically presented in their final form, with little insight into the creative problem-solving process.  

However, the comparison of **non-linear reasoning** to a **Tree of Thoughts** suggests a promising approach for **automatically generating** such examples: we can simply take a Tree of Thoughts and rewrite it as a single sequence!  

A similar (tree search) -> (non-linear reasoning) approach was explored, for example, in the [*Towards System 2 Reasoning in LLMs: Learning How to Think With Meta Chain-of-Thought*](https://arxiv.org/pdf/2501.04682) paper. However, instead of a simple Breadth/Depth-First Search, they used more sophisticated tree search strategy based on **Monte Carlo Tree Search (MCTS)**. Let's briefly discuss what that is.  

## **Monte Carlo Tree Search (MCTS)**  

MCTS is a **tree search algorithm** that appeared long pefore LLMs and has been widely used in game AI, including the famous [AlphaGo](https://en.wikipedia.org/wiki/AlphaGo) system.  

If we apply MCTS for reasoning orchestration, each **node** in the search tree represents a **partial solution**, and each node stores:  

- **$n_i$**: The number of times this node has been visited.  
- **$v_i$**: The cumulative **reward** from all visits to this node.  
- **Possible actions** (i.e., possible next reasoning steps). At each particular moment, some actions are **tried**, while others remain **untried**.  

The tree grows step by step through four iterative phases: **Selection, Expansion, Simulation, and Backpropagation**.  

### **1. Selection**  

Here we select a **partial solution** (=node) to extend from. Starting from the root (an empty solution), we follow these rules:  

- If the **current node has untried actions**, we select it.  
- Otherwise, we choose the **child node** that maximizes the **Upper Confidence Bound (UCB) score**:  

  $$UCB(i) = \frac{v_i}{n_i} + c\sqrt{\frac{\log{n_p}}{n_i}},$$  

  where **$p$** is the index of the parent node and **$i$** is the index of its child.  
  - The **first term** encourages **exploitation** (choosing nodes with high rewards).  
  - The **second term** encourages **exploration** (trying less-visited nodes).  

Additionally, it's good to exclude nodes whose entire subtree has already been fully explored and reached terminal states (i.e., final solutions).  

### **2. Expansion**  

For the selected node, we **explore one untried action** - adding a new reasoning step to the partial solution.  

- If this **new partial solution** is **not terminal** (i.e., does not yet contain the final answer), we generate possible next steps.  
- The **number of generated next steps** is controlled by a hyperparameter (`branching_factor`).  
- The `max_height` constraint prevents the tree from expanding too far from the root.  

### **3. Simulation**  
We score the newly generated **partial solution**.  

- If no **Process Reward Model** is available, we estimate the quality using **rollouts** - generating multiple completions and computing the ratio of correct answers. This approach would unfeasible on inference, where answers would be unknown, but it works well for creating a non-linear reasoning dataset from an available `(problem, answer dataset)`.

### **4. Backpropagation**  
The computed score is propagated up the tree:  

- For the **new partial solution** and all its **ancestor nodes**, we update:  
  - $v_j$, adding the computed score.  
  - $n_j$, incrementing the visit count by $1$.  

These four stages **repeat for `n_iterations`**, after which we select the **highest-scoring terminal solution** - or, if none exists, we grieve and select the top-scoring stump.

---

Here is a simple visualization:

<center>
<img src="https://drive.google.com/uc?export=view&id=1m3UxM4s_x-kAzprHUh30TQZL7rkORv3O" width=600 />
</center>

Here, each node shows

```
v: <mean_value>
n: <n_visits>
```

Terminal nodes are depicted as small squares, while non-terminal ones are drawn as circles.

---

See the practical part for the code of this visualization and also for an implementation of MCTS for LLM orchestration.

An interesting thing about MCTS is that, as noticed in [Towards System 2 Reasoning in LLMs](https://arxiv.org/pdf/2501.04682), actual reasoning of **o1** and similar models is much alike MCTS search trajectories.

**Note**. We could have included MCTS in the previous notebook among other resoning orchestration techniques. And indeed, sometimes it may be used this way. But still, its complexity makes MCTS as unlikely tool in most situations.

## Data quality and quantity

It may seem that training an LLM for non-linear reasoning would take tremendous amount of data, but it turns out that data quality may compensate the lack of its quantity. Let's briefly discuss two examples.

#### [S1: Simple test-time scaling](https://arxiv.org/pdf/2501.19393)

We've already discuss **budget forcing** suggested in this paper. Now, it's time to mention that they fine tuned their base model, **Qwen2.5-
32B-Instruct** on only **1000** high-quality reasoning examples.

How they collected this data:

* They got 59,029 questions from 16 diverse sources (including many Olympiad problems).
*	For each question, they generated a reasoning trace and solution using the experimental Google Gemini Flash Thinking API. This approach to data collection is known as **distillation**, because larger/cooler model's capabilities are being distilled into a smaller/more modest one.
*	Then, they filter the problems, first by *quality* (no poor formatting etc), then by *difficulty*. A problem is deemed difficult if neither **Qwen2.5-7B-Instruct** nor **Qwen2.5-32B-Instruct** is able to solve it, and also the reasoning length is large enough.
*	Finally, 1000 problems were sampled in a stratified way with respect to a number of topics.

In terms of quality + training cost, this model is very cool:

<center>
<img src="https://drive.google.com/uc?export=view&id=1edwrnzsCzSfskwJUgecIcsnO7QOqBatO" width=400 />
</center>

#### [LIMO: Less is More for Reasoning](https://arxiv.org/pdf/2502.03387)

The authors of this paper take **NuminaMath** as a base model and fine tune it on only **817** high-quality curated training samples to obtain an LLM that achieves quite impressive math performance with exceptional out-of-distribution generalization. Just look at the benchmark scores below. (And it really takes courage to compare onself with **o1**.)

<center>
<img src="https://drive.google.com/uc?export=view&id=1uAj-Dara6G985eLxHT1mVgnIB_T-uxMS" width=600 />
</center>

It's also interesting to compare their model with the base **NuminaMath** LLM, which is already quite cool

<center>
<img src="https://drive.google.com/uc?export=view&id=1IlomIGaROqDfoT4BDxxuTzlBNuXkzp8q" width=600 />
</center>

So, when it comes to high-quality non-linear reasoning data, what guidelines do the authors propose? There are three:

*	**Structured Organization**. Tokens are allocated to individual "thoughts" according to their importance and complexity, with more tokens for key reasoning points, while keeping simpler steps concise.
*	***Cognitive Scaffolding**. Concepts are introduced strategically, with careful bridging of gaps to make complex reasoning more accessible.
*	**Rigorous Verification**. Intermediate results and assumptions are frequently checked; logical consistency is ensured.

Regarding the third point, verification is especially important due to the risk of hallucinations. An interesting example here is the [rStar-Math](https://arxiv.org/pdf/2501.04519) paper, where the authors train their LLM to produce solutions as Python code with text as code comments.

<center>
<img src="https://drive.google.com/uc?export=view&id=1Psnpfn5iCdPLKGlj6-GfwnmPt2qU6iVn" width=800 />
</center>

With agentic capabilities, executing this code could provide the LLM with valuable feedback, ultimately improving its reasoning.

**LIMO** is, of course, only one example. The authors of DeepSeek R1, for instance, used thousands of examples during the cold-start fine-tuning phase of training.


# A new approach: Reinforcement Learning (RL)

Before diving into reasoning, let's briefly discuss why traditional Supervised Fine-Tuning (SFT) may sometimes be not enough.

SFT works well when we have a dataset of `(prompt, answer)` pairs - essentially, when we can clearly specify what the model should generate to align with our expectations. However, what if we want to train an LLM to produce **non-toxic** text? Simply providing examples of non-toxic outputs isn't enough to teach the model what crosses the line into toxicity. Without explicit guidance on what is unacceptable, the model won't learn the boundaries it shouldn't cross.  

On the other hand, training a **reward model** to distinguish toxic from non-toxic text is relatively straightforward. So, **instead of explicitly showing the LLM what to generate, we train it to maximize a reward signal.** And that's where **Reinforcement Learning (RL)** comes into play!  



## RL in a nutshell (feel free to skip if you're not new to the topic)

Imagine you want to train an AI bot to play [Prince of Persia](https://www.youtube.com/watch?v=FGQmtlxllWY) (the 1989 game). In this game, the player's character (that is, the titular prince) can:

* Walk left or right, jump and fight guards with his sword
* Fall into pits, get impaled on spikes, or killed by guards
* Run out of time and lose
* Save the princess and win

<center>
<img src="https://drive.google.com/uc?export=view&id=1jye3OWWmdRr2dWSDQw5NUSFa_jOyJfJK" width=600 />
</center>

The simplest AI bot would be a neural network that takes the current screen (or maybe several recent screens) as an input and predicts the next action - but how to train it?

A supervised learning paradigm would probably require us to play many successful games, record all the screens, and train the model to predict the actions we chose. But there are several problems with this approach, including the following:

* The game is quite long, so it'd simply be too tiresome to collect a satisfactory number of rounds.
* It's not sufficient to show the right ways of playing; the bot should also learn the moves to avoid.

* The game provides a huge number of possible actions on many different screens. It's reasonable to expect that successful games played by experienced gamers won't produce data with the level of necessary diversity for the bot to "understand" the entire distribution of actions.

So, these considerations have us move to consider training the bot by **trial-and-error**:

1. Initializing its behavior ("**policy**") somehow.
2. Allowing it to play according to this policy, checking various moves (including very awkward ones) and to enjoy falling to the bottom of a pit, and so on.
3. Correct the policy based on its success or failures.
4. Repeat step 2 and 3 until we're tired of waiting or the bot learns to play Prince of Persia like a pro.

Let's formalize this a bit using conventional RL terminology:

* The (observed) **state** is the information we have about the game at the present moment. In our case, this is the content of the current screen.
* The **agent** is a bot which is capable of several **actions**.
* The **environment** is the game. It defines the possible states, the possible actions, and the effects of each action on the current state – and which state will be the next.
* The **policy** is the (trainable) strategy the bot uses for playing. In our case, this is the neural network that predicts actions given the current state, or the state history.
* The **reward** is the score that we assign to the states. For example, defeating a guard, progressing to a next level, or winning the game might have positive rewards, while falling into a pit or getting wounded by a guard would mean negative rewards.

<center>
<img src="https://drive.google.com/uc?export=view&id=1PqJaNX6bvSPHIc6JCNStIo7v7YKzTKHM" width=600 />
</center>

The goal of the training is finding a policy that maximizes reward, and there are many ways of achieving this.
We'll now see that LLM training sometimes has much relevance with Prince of Persia.

## RL in LLM training

Here is the most simple way RL manifests in LLM training:

* An **agent** is our LLM
* An observed **state** is the prompt
* **Action** is generation of completion
* **Reward** is the **reward model** score

<center>
<img src="https://drive.google.com/uc?export=view&id=1sk9h95G_JsU7QkZiDAcCeRZmfP-tlVrZ" width=600 />
</center>

This schematics is much simpler, with only one iteration of state -> action -> reward.

Probably the most influential training method based on RL is **RLHF** (which simply stands for **Reinforcement Learning with Human Feedback**), which is still used for establishment LLM alignment with human values and preferences - helpfulness, harmlessness etc, non-tocixity included. (Although now it's often replaced by RL-free alternatives such as Direct Policy Optimization)

## RL for reasoning: how DeepSeek trained R1-Zero

You probably wonder what it all has in common with reasoning. Quite a revolutionary thing that DeepSeek accomplished was training their **R1-Zero** model for non-linear reasoning not only without a `(prompt, long reasoning)` dataset, but even without a Process Reward Model. For that, they took they existing **DeepSeek V3** model and trained it with **Reinforcement Learning** for optimizing rewards of two kinds:

* **Accuracy rewards**, which is just the accuracy of the answer. So no, the solution only needs to lead to correct answer, in a however strange way.
* **Format following reward** enforces the model to enclose its thinking process between `<think>` and  `</think>`, as we've seen in the examples.

And it was terribly exciting to see that, despite getting no special guidelines towards that, R1Zero not only learnt to output longer and longer solutions

<center>
<img src="https://drive.google.com/uc?export=view&id=1lO0wN7fbsgLSsrWojHh68YgF-FVXNTEi" width=600 />

[Source](file:///C:/Users/st-fe/Downloads/DeepSeek_R1-2.pdf)
</center>

but also started exhibiting human-like patterns of thought expression:

<center>
<img src="https://drive.google.com/uc?export=view&id=14z0RVH_GyHaoRXzKasnM5QMX4z_PLolW" width=600 />

[The famous aha moment](file:///C:/Users/st-fe/Downloads/DeepSeek_R1-2.pdf)
</center>

So, why did it work? To answer, let's return to one of the main highlights of this topic: **a well-pre-trained LLM already has latent reasoning capabilities, and we only need to awake them**. **DeepSeek V3** was already quite a capable model, and it definitely "saw" some amount of non-linear reasoning during pre-training. RL helped to turn this latent knowledge into active skills.

We won't go deep into paticular RL algorithm DeepSeek used. That's not only because we don't want to scare you off with math. It seems that different algorithms may be leveraged here, if only the authors are lucky and capable enough to make them work, which is nothing like easy - RL is notoriously tricky. We'll just briefly mention two of them:

* DeepSeek used the algorithm they'd recently developed, called **GPRO** (**group relative policy optimization**). Its interesting feature is that the reward or, more accurately, the **advantage** is evaluated from a mini-batch:

  * For a prompt $q$, several answers $y_1,\ldots,y_G$ are generated.
  * For each $y_i$, a reward $r_i$ is calculated.
  * The **relative advantage** is calculated for each answer as
  $$A_i = \frac{r_i - \mathrm{mean}(r_1,\ldots,r_G)}{\mathrm{std}(r_1,\ldots,r_G)}$$

  This helps to stabilize the training.

* The [Open Reasoning Zero](https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero) project just uses the traditional PPO (Proximal Policy Optimization), familiar from RLHF - but surprisingly, even without KL penalty :O

**Note**. The authors of DeepSeek actually tried to use Process Reward Models, but it didn't work out. This helps to underline that PRMs are capricious creatures.

## RL is not enough: from R1-Zero to R1

Despite having non-linear reasoning capabilities, **DeepSeek R1-Zero** has issues like poor readability and language mixing. So, several improvements were introduced into the training of the final **DeepSeek-R1**. Here's what its training pipeline looked like:

* First, fine tuning on several thousands of high-quality data created with CoT and refined by human annotators. (And not unlikely obtained from already existing long-reasoning LLMs.) This was made to improve the unstable cold start of RL + to give the model some intial guidelines for what to learn during RL. That sounds like a good idea.
* Then, the RL.
* Then, perform additional SFT on both reasoning and non-reasoning (such as writing, factual QA, self-cognition, and translation) data. Partially, this stage tries to address the known post-RL problems with readability.
* Finally, one more RL stage to establish alignment with human preferences. (An analog of RLHF.)

# Takeaways

In these three notebooks, we explored two of the hottest topics related to LLMs: **LLM Reasoning** and **Inference-Time Compute**. Now you know:

* When to rely on reasoning and when to hold back.
* How to allocate your inference-time budget and choose among different options.
* What non-linear reasoning is, how it is established, and why it has suddenly become so popular.

We have not yet covered a topic closely related to reasoning—LLM planning and execution—but we will discuss that when we introduce LLM Agents. Stay tuned for further course updates!

# Ready for more?

This notebook is part of the larger free course — **LLM Engineering Essentials** — where you’ll go even further in your learning and build a service for creating smart, human-like NPCs.

🎓 New materials are coming soon. Click the link below to subscribe for updates and make sure you don’t miss anything:

[Stay updated](https://academy.nebius.com/llm-engineering-essentials/update/)

# Practice session

## Getting ready

In [4]:
!pip install -q openai

In [5]:
import os

with open("nebius_api_key", "r") as file:
    nebius_api_key = file.read().strip()

os.environ["NEBIUS_API_KEY"] = nebius_api_key

Below is the revised text, followed by an explanation of the edits:

---

## Practice, Part 1: MCTS

In this section, we share our implementation of Monte Carlo Tree Search (MCTS) and highlight some of its key features:

- **Parallel rollouts:**  
  Because we use rollouts to evaluate partial solutions, our approach involves many LLM calls. This is expensive - so please monitor your API usage - and, if done sequentially, can be very slow. Therefore, it is beneficial to parallelize these calls. Unfortunately, MCTS iterations are inherently sequential, so the only part that can be parallelized is the rollouts. We achieve this by using `client.chat.completions.create(..., n=number_of_rollouts)`.

- You can adjust several parameters:
  - `max_depth`: The maximum depth of the search tree.
  - `num_rollouts`: The number of full continuations generated to evaluate a partial solution.
  - `num_iterations`: The number of selection-expansion-simulation-backpropagation cycles to run.
  - `branching_factor`: How many untried actions to give to a node when expanding to it.
  - `c`: the coefficient in the formula `ucb = exploitation + self.c * exploration`. This parameter helps to balance between uniform tree exploration and steady push towards the goal. We recommend to actually start from lower values of `c`, favouring exploitation. When you're sure that you reach the answer in a reasonable number of iteration, you can try extending exploration.

- The generated answer is compared with the actual answer using an evaluation function, passed to the MCTS, that defaults to `answer_comparison_evaluator(generated_answer, true_answer)`. If your answer is a LaTeX formula and you need to compute its value before comparison, you can redefine this evaluator.

We encourage you to experiment with MCTS on different tasks. Just remember that it is significantly more expensive than Beam Search because of the additional cost of rollouts.

In [40]:
import numpy as np
import random
from typing import List, Dict, Optional, Tuple, Callable, Any
from dataclasses import asdict, dataclass
import json
import re

@dataclass
class NodeStats:
    """Statistics for a node in the MCTS tree"""
    visits: int
    value: float
    success_rate: float
    depth: int
    is_terminal: bool
    num_children: int

    def to_dict(self) -> Dict:
        return asdict(self)

@dataclass
class MCTSNode:
    state: str  # Current solution state
    parent: Optional['MCTSNode']
    children: List['MCTSNode']
    visits: int
    value: float
    untried_actions: List[str]  # Possible next steps
    is_terminal: bool
    correct_continuations: int = 0
    total_continuations: int = 0

    def get_stats(self) -> NodeStats:
        """Get statistics for this node"""
        return NodeStats(
            visits=self.visits,
            value=self.value,
            success_rate=self.value / max(self.visits, 1),
            depth=self.get_depth(),
            is_terminal=self.is_terminal,
            num_children=len(self.children)
        )

    def get_depth(self) -> int:
        """Get the depth of this node in the tree"""
        depth = 0
        node = self
        while node.parent:
            depth += 1
            node = node.parent
        return depth

    def to_dict(self) -> Dict:
        """Convert node to dictionary representation"""
        return {
            'state': self.state,
            'stats': self.get_stats().to_dict(),
            'children': [child.to_dict() for child in self.children]
        }

class LLMClient:
    """Wrapper for OpenAI-compatible API clients."""
    def __init__(
        self,
        client,
        model: str,
        default_temperature: float = 0.7,
        default_max_tokens: int = 1024,
        system_prompt: Optional[str] = None
    ):
        self.client = client
        self.model = model
        self.default_temperature = default_temperature
        self.default_max_tokens = default_max_tokens
        self.system_prompt = system_prompt

    def generate(
        self,
        prompt: str,
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None,
        system_prompt: Optional[str] = None,
        n: int = 1
    ) -> List[str]:
        """Generate completions with consistent interface."""
        messages = []

        # Use provided system prompt or fall back to default
        current_system_prompt = system_prompt or self.system_prompt
        if current_system_prompt:
            messages.append({"role": "system", "content": current_system_prompt})

        messages.append({"role": "user", "content": prompt})

        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                temperature=temperature if temperature is not None else self.default_temperature,
                max_tokens=max_tokens if max_tokens is not None else self.default_max_tokens,
                n=n
            )
            return [choice.message.content.strip() for choice in response.choices]
        except Exception as e:
            return [""] * n

class MCTS:
    def __init__(self,
                 llm_client: LLMClient,
                 exploration_weight: float = 1.414,
                 num_rollouts: int = 5,
                 max_depth: int = 5,
                 branching_factor: int = 3,
                 c: float = 1.):
        """
        Initialize MCTS for LLM-based problem solving.

        Args:
            llm_client: LLMClient instance
            exploration_weight: UCB1 exploration parameter
            num_rollouts: Number of Monte Carlo rollouts for evaluation
            max_depth: Maximum depth of the search tree
        """
        self.llm_client = llm_client
        self.exploration_weight = exploration_weight
        self.num_rollouts = num_rollouts
        self.max_depth = max_depth
        self.search_tree = None
        self.branching_factor = branching_factor
        self.c = c

    def ucb1(self, node: MCTSNode, parent_visits: int) -> float:
        """Calculate UCB1 value for node selection."""
        if node.visits == 0:
            return float('inf')

        exploitation = node.value / node.visits
        exploration = self.exploration_weight * np.sqrt(np.log(parent_visits) / node.visits)
        return exploitation + self.c * exploration

    def get_possible_actions(self, state: str, problem: str) -> List[str]:
        """Get possible next steps from current state using LLM."""
        prompt = f"""Given the math problem and current partial solution:
        Problem: {problem}
        Current solution steps:
        {state}\n\n"""

        prompt += """Generate the next logical step in the solution. Be concise and direct.
        - Each step should be a single, clear logical or math operation
        - Don't repeat previous steps
        - If during the next step you get the final answer for the problem, you immediately output this answer in \\boxed\{\}. This is very important!

        Next step:"""

        responses = self.llm_client.generate(prompt, n=self.branching_factor)
        return [step for step in responses if step.strip()]

    def evaluate_partial_solution(self,
                               state: str,
                               problem: str,
                               evaluate_continuation: Callable[[str, str], Tuple[int, int]]) -> Tuple[int, int]:
        """
        Evaluate partial solution using the provided evaluation function.
        Returns (correct_continuations, total_continuations)
        """
        return evaluate_continuation(state, problem)

    def select(self, node: MCTSNode) -> MCTSNode:
        """Select most promising node using UCB1."""
        while node.children and not node.is_terminal:
            if node.untried_actions:
                return node

            node = max(node.children,
                      key=lambda n: self.ucb1(n, node.visits))
        return node

    def expand(self,
              node: MCTSNode,
              problem: str,
              evaluate_continuation: Callable[[str, str], Tuple[int, int]]) -> Optional[MCTSNode]:
        """Expand node by adding a child with an untried action."""

        if not node.untried_actions or node.is_terminal:
            return None

        action = random.choice(node.untried_actions)
        node.untried_actions.remove(action)

        new_state = f"{node.state}\n{action}" if node.state else action
        possible_actions = self.get_possible_actions(new_state, problem)

        actual_depth = node.get_depth() + 1  # +1 for the child we're creating
        is_terminal = (actual_depth >= self.max_depth) or ("\\boxed" in new_state)
        correct_cont, total_cont = self.evaluate_partial_solution(
            new_state, problem, evaluate_continuation)

        child = MCTSNode(
            state=new_state,
            parent=node,
            children=[],
            visits=0,
            value=0.0,
            untried_actions=possible_actions,
            is_terminal=is_terminal,
            correct_continuations=correct_cont,
            total_continuations=total_cont
        )

        node.children.append(child)
        return child

    def rollout(self,
            node: MCTSNode,
            problem: str,
            evaluate_continuation: Callable[[str, str], Tuple[int, int]]) -> float:
        """
        Run multiple rollouts by generating full continuations from the node's state.
        For each rollout, generate a complete solution by calling the LLM,
        then evaluate that complete solution using the provided deterministic evaluator.
        Returns an average score (ratio of correct to total continuations).
        """
        # Construct a prompt that asks the LLM to complete the chain-of-thought
        prompt = f"""You are given the math problem and current partial solution:
        Problem: {problem}
        Current solution steps:
        {node.state}\n\n"""

        prompt += """Complete the solution and output the final answer in \\boxed\{\}"""

        # Generate multiple full continuations
        responses = self.llm_client.generate(prompt, n=self.num_rollouts)

        scores = []
        for response in responses:
            # Construct the full solution by appending the generated continuation
            full_solution = f"{node.state}\n{response}"

            # Evaluate the full solution using the deterministic evaluator
            correct, total = evaluate_continuation(full_solution, problem)
            score = correct / max(total, 1)
            scores.append(score)

        # Return the average score over all generated completions
        aggregated_score = sum(scores) / len(scores) if scores else 0.0
        return aggregated_score

    def backpropagate(self, node: MCTSNode, value: float):
        """Backpropagate value up the tree."""
        while node:
            node.visits += 1
            node.value += value
            node = node.parent

    def search(self,
              problem: str,
              evaluate_continuation: Callable[[str, str], Tuple[int, int]],
              initial_state: str = "",
              num_iterations: int = 100,
              verbose: bool = False) -> Tuple[str, MCTSNode]:
        """
        Perform MCTS search to find solution.

        Args:
            problem: Problem description
            evaluate_continuation: Function to evaluate solution continuations
            initial_state: Starting point for solution
            num_iterations: Number of MCTS iterations
            verbose: Whether to print detailed logs

        Returns:
            Best solution found and the search tree root
        """
        possible_actions = self.get_possible_actions(initial_state, problem)
        root = MCTSNode(
            state=initial_state,
            parent=None,
            children=[],
            visits=0,
            value=0.0,
            untried_actions=possible_actions,
            is_terminal=False
        )

        for iteration in range(num_iterations):
            node = self.select(root)
            child = self.expand(node, problem, evaluate_continuation)

            if child:
                # Use multi-rollout simulation to get a robust estimate
                simulation_value = self.rollout(child, problem, evaluate_continuation)
                self.backpropagate(child, simulation_value)

            if verbose and iteration % 5 == 0:
                tree_info = self.get_tree_summary(root)
                print(f"\nIteration {iteration} - Tree Statistics:")
                print(json.dumps(tree_info['statistics'], indent=2))

        # Store the search tree for later analysis
        self.search_tree = root

        best_path = self._get_best_path(root)
        final_solution = best_path[-1]['state']
        return final_solution, root

    def get_tree_summary(self, node: MCTSNode, max_depth: Optional[int] = None) -> Dict:
        """
        Get a summary of the search tree up to max_depth.
        If max_depth is None, returns the entire tree.
        """
        return {
            'tree_structure': node.to_dict(),
            'statistics': {
                'total_nodes': self._count_nodes(node),
                'max_depth': self._get_max_depth(node),
                'leaf_nodes': self._count_leaf_nodes(node),
                'avg_branching': self._get_avg_branching(node),
                'best_path': self._get_best_path(node)
            }
        }

    def _count_nodes(self, node: MCTSNode) -> int:
        """Count total nodes in tree"""
        return 1 + sum(self._count_nodes(child) for child in node.children)

    def _get_max_depth(self, node: MCTSNode) -> int:
        """Get maximum depth of tree"""
        if not node.children:
            return 0
        return 1 + max(self._get_max_depth(child) for child in node.children)

    def _count_leaf_nodes(self, node: MCTSNode) -> int:
        """Count leaf nodes in tree"""
        if not node.children:
            return 1
        return sum(self._count_leaf_nodes(child) for child in node.children)

    def _get_avg_branching(self, node: MCTSNode) -> float:
        """Calculate average branching factor"""
        total_nodes = self._count_nodes(node)
        non_leaf_nodes = total_nodes - self._count_leaf_nodes(node)
        if non_leaf_nodes == 0:
            return 0.0
        return (total_nodes - 1) / non_leaf_nodes

    def _get_best_path(self, node: MCTSNode) -> List[Dict]:
        """Get the path with highest value that leads to a terminal state"""
        path = []
        current = node
        while current:
            path.append({
                'state': current.state,
                'stats': current.get_stats().to_dict()
            })
            if not current.children:
                break

            # Among children that have been visited at least once,
            # choose the one with highest success rate
            visited_children = [c for c in current.children if c.visits > 0]
            if not visited_children:
                break

            # If any child is terminal with a good score, prefer it
            terminal_children = [c for c in visited_children if c.is_terminal and c.value/max(c.visits, 1) > 0]
            if terminal_children:
                current = max(terminal_children,
                             key=lambda n: n.value / max(n.visits, 1))
            else:
                current = max(visited_children,
                         key=lambda n: n.value / max(n.visits, 1))
        return path

# Example evaluation function:
def answer_comparison_evaluator(state: str, problem: str,
                               expected_answer: float) -> Tuple[int, int]:
    """Evaluates solution by comparing extracted answers with expected answer."""
    # Extract answer from \boxed{} in the solution
    answer_match = re.search(r'oxed\{([^}]+)\}', state)
    if not answer_match:
        return (0, 1)

    solution_answer = answer_match.group(1).strip().replace("\\", "").strip()
    try:
        solution_answer = float(solution_answer)
    except ValueError:
        return (0, 1)

    return (1, 1) if np.abs(solution_answer - expected_answer) < 1e-10 else (0, 1)

In [41]:
from openai import OpenAI
import os

# Initialize the LLM client
async_client = OpenAI(
    base_url="https://api.studio.nebius.ai/v1/",
    api_key=os.environ.get("NEBIUS_API_KEY"),
)

# Create the LLM client wrapper
llm_client = LLMClient(
    client=async_client,
    model="meta-llama/Meta-Llama-3.1-70B-Instruct",
    system_prompt=None,
)

# Initialize MCTS with the client
mcts = MCTS(
    llm_client=llm_client,
    max_depth=10,
    num_rollouts=5,
    c=0.5
)

Now, the `mcts.search` function will launch MCTS, showing general tree statistics and the best solution path every 5 iterations.

In [42]:
# Run the search
solution, tree = mcts.search(
    problem="Inside a circle, two parallel chords are 6 units apart. One chord has length 14 and the other has length 10. Find the square of the radius of the circle.",
    evaluate_continuation=lambda state, problem: answer_comparison_evaluator(state, problem, 50.0),
    num_iterations=100,
    verbose=True
)

print(solution)


Iteration 0 - Tree Statistics:
{
  "total_nodes": 2,
  "max_depth": 1,
  "leaf_nodes": 1,
  "avg_branching": 1.0,
  "best_path": [
    {
      "state": "",
      "stats": {
        "visits": 1,
        "value": 0.2,
        "success_rate": 0.2,
        "depth": 0,
        "is_terminal": false,
        "num_children": 1
      }
    },
    {
      "state": "Let's denote the radius of the circle as r. Draw a line from the center of the circle to the midpoint of each chord, and another line from the center of the circle that is perpendicular to the chords. This will create two right triangles. Using the Pythagorean theorem, we can express the radius squared as r^2 = (7^2 + 3^2) = (5^2 + 3^2), where 7 and 5 are half the lengths of the two chords.",
      "stats": {
        "visits": 1,
        "value": 0.2,
        "success_rate": 0.2,
        "depth": 1,
        "is_terminal": false,
        "num_children": 0
      }
    }
  ]
}

Iteration 5 - Tree Statistics:
{
  "total_nodes": 7,
  "max

If you look at the best solution path, you'll observe that the `success_rate` grows with depth, which is logical:

- at start, there are many paths starting from the partial solution stump, and many of them are wrong.
- later, if we're really on a path towards the correct solution, less and less space is left for failure.

**A question for you**. What do you think, how many LLM calls have we done here?

**A task for you**. In out implementation, MCTS stops when it exhausts the number of iterations. However, in simpler tasks we may reach terminal nodes and stop producing new ones earlier, so an early stopping criterion would be handy. There may be several ways of defining it, for example:

* No new nodes for several iterations,
* Convergence of value estimates: if values almost don't change for a number of iterations, it may be a signal that it's time to stop.

## Practice, Part 2: Budget Forcing

Although we cannot afford to fine-tune an LLM as in the [S1 paper](https://arxiv.org/pdf/2501.19393), we can still apply budget forcing.

Recall that the idea is simple: force the LLM to generate “Wait” (or another thought-provoking phrase) every time it attempts to output an answer—until the target solution length is reached.

This approach requires a degree of control over LLM generation that APIs cannot offer, so we must run the model locally in this Colab notebook. Before you begin, please switch to a GPU machine—an L4 GPU will suffice.

Our implementation of budget forcing generates a solution in two stages:

1. **Pre-Wait Phase:**  
   Until the solution length reaches `wait_tokens`, we intervene in the next token generation using a custom **logits processor**. In this processor, we monitor several markers, and when one is detected, we insert a *wait sequence* (`wait_text`) instead of the marker's final token. By default, the LLM we use encloses answers in `\boxed{}`, so we use `\boxed` as a marker.

2. **Post-Wait Phase:**  
   Once the solution length reaches `wait_tokens`, we allow the LLM to generate freely.

We encourage you to experiment with budget forcing on various tasks and with different wait sequences.

**Food for Thought: Questions to Consider**

- What would happen if we set the log probability of the `<eos>` (End Of String) token to `-inf` before reaching `wait_tokens`?
- What would happen if we suppressed the generation of `\boxed{}` by setting the log probability of `boxed` to `-inf` after the backslash (`\`)?

First of all, you'll need to load your Hugging Face access token:

In [1]:
!pip install -q openai

In [2]:
import os

with open("nebius_api_key", "r") as file:
    nebius_api_key = file.read().strip()

os.environ["NEBIUS_API_KEY"] = nebius_api_key

with open("hf_access_token", "r") as file:
    hf_access_token = file.read().strip()

In [46]:
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList

class BudgetForcingProcessor(torch.nn.Module):
    """Custom logits processor that forces the model to wait before generating answer markers"""
    def __init__(self, tokenizer, wait_tokens=200, debug=True, markers=None, wait_text=None):
        super().__init__()
        self.tokenizer = tokenizer
        self.wait_tokens = wait_tokens
        self.debug = debug

        # These will be set by the parent class, but provide defaults if needed
        # They're initially set here but will be overridden by BudgetForcingLLM
        self.markers = markers or {
            "boxed": {"text": "\\boxed", "case_sensitive": True},
            "eos": {"text": self.tokenizer.eos_token, "case_sensitive": True}
        }
        self.wait_text = wait_text or "\nWait, I need to think more about this..."

        # Get EOS token ID specifically
        self.eos_token_id = self.tokenizer.eos_token_id

        # Initialize token maps
        self.marker_tokens = {}
        self.wait_tokens_ids = []
        self.update_token_maps()

        # State tracking
        self.token_buffer = []
        self.intervention_active = False
        self.total_tokens_generated = 0  # Count from the beginning
        self.wait_text_inserted = False
        self.forced_wait_index = 0
        self.last_intervention_position = 0  # Track position of last intervention
        self.cooldown_tokens = 50  # Minimum tokens between interventions
        self.budget_notification_sent = False

        # Flag to indicate if the budget is exhausted
        self.budget_exhausted = False

        # Flag to signal early stopping once budget is met
        self.stop_at_budget = True

    def update_token_maps(self):
        """Update token maps when markers or wait text changes"""
        # Find how each marker is tokenized
        self.marker_tokens = {}
        for key, marker_info in self.markers.items():
            marker_text = marker_info["text"]
            if marker_text:  # Skip None values
                self.marker_tokens[key] = self.tokenizer.encode(marker_text, add_special_tokens=False)

        # Get wait text tokens
        self.wait_tokens_ids = self.tokenizer.encode(self.wait_text, add_special_tokens=False)

        # State tracking
        self.token_buffer = []
        self.intervention_active = False
        self.total_tokens_generated = 0  # Count from the beginning
        self.wait_text_inserted = False
        self.forced_wait_index = 0
        self.last_intervention_position = 0  # Track position of last intervention
        self.cooldown_tokens = 50  # Minimum tokens between interventions
        self.budget_notification_sent = False

        # Flag to indicate if the budget is exhausted
        self.budget_exhausted = False

        # Flag to signal early stopping once budget is met
        self.stop_at_budget = True

    def detect_marker_start(self, token_ids):
        """Check if the current sequence of tokens might be starting to generate any of the answer markers"""
        # For efficiency, only check recent tokens
        check_window = 10  # Check the last 10 tokens at most
        start_idx = max(0, len(token_ids) - check_window)
        recent_token_ids = token_ids[start_idx:].tolist()

        # Completely rebuild the buffer from the recent tokens to avoid stale state
        self.token_buffer = recent_token_ids

        # For debugging, show the tokens we're checking
        if self.debug:
            recent_str = self.tokenizer.decode(self.token_buffer)

        # Check if we're in cooldown period after a recent intervention
        if self.last_intervention_position > 0:
            tokens_since_intervention = len(token_ids) - self.last_intervention_position
            if tokens_since_intervention < self.cooldown_tokens:
                # Don't trigger intervention during cooldown
                return False

        # Direct check for EOS token - this still needs to be caught immediately
        if self.eos_token_id in self.token_buffer and self.total_tokens_generated < self.wait_tokens:
            if self.debug:
                print(f"EOS token detected at position {len(token_ids)}, before wait_tokens threshold!")
            self.detected_marker = "EOS token"
            return True

        # Check for markers in the current window
        token_str = self.tokenizer.decode(self.token_buffer)

        # Use the markers dictionary with its properties
        for marker_key, marker_info in self.markers.items():
            marker_text = marker_info["text"]
            case_sensitive = marker_info["case_sensitive"]

            if not marker_text or marker_key == "eos":
                continue  # Skip None or EOS markers

            # Check for the marker in the token string
            if case_sensitive:
                # Case-sensitive check
                marker_present = marker_text in token_str
            else:
                # Case-insensitive check
                marker_present = marker_text.lower() in token_str.lower()

            if marker_present:
                if self.debug:
                    print(f"{marker_key.upper()} marker detected in: '{token_str}'")
                self.detected_marker = marker_text
                return True

        return False

    def __call__(self, input_ids, scores):
        """Process logits to implement budget forcing"""
        # Count tokens from the beginning, excluding the prompt
        prompt_token_length = len(input_ids)
        self.total_tokens_generated = len(input_ids[0]) - prompt_token_length

        if self.debug and self.total_tokens_generated % 50 == 0:
            print(f"Total tokens generated: {self.total_tokens_generated}/{self.wait_tokens}")

            # Periodically show the last tokens in debug mode to help with debugging
            if len(input_ids[0]) > 10:
                recent_text = self.tokenizer.decode(input_ids[0][-10:])

        # Check if we've reached the token budget
        self.budget_exhausted = self.total_tokens_generated >= self.wait_tokens

        # If budget is exhausted and we want to stop, force EOS token
        if self.budget_exhausted and self.stop_at_budget:
            if self.debug and not self.budget_notification_sent:
                print(f"\n==== BUDGET REACHED, STOPPING FIRST STAGE ====")
                print(f"Generated {self.total_tokens_generated} tokens")
                self.budget_notification_sent = True

            # Force EOS token to stop generation
            scores[0, :] = -float('inf')
            scores[0, self.eos_token_id] = 100.0
            return scores

        # Clear the token buffer periodically to avoid keeping stale tokens
        if len(self.token_buffer) > 0 and self.total_tokens_generated % 100 == 0:
            self.token_buffer = []

        # Always block EOS until budget is exhausted to prevent early termination
        if not self.budget_exhausted:
            # Make EOS impossible - this is critical
            scores[0, self.eos_token_id] = -float('inf')

            # Reset detected_marker before checking
            self.detected_marker = None

            # Only check for specific markers, no rule-based pattern detection
            if not self.intervention_active and self.detect_marker_start(input_ids[0]):

                if self.debug:
                    print("\n==== ACTIVATING BUDGET FORCING INTERVENTION ====")
                    print(f"Detected marker: {self.detected_marker}")
                    print(f"At position {len(input_ids[0])}")
                    print(f"Current token count: {self.total_tokens_generated}/{self.wait_tokens}")

                    # Print recent context
                    if len(input_ids[0]) > 20:
                        recent_text = self.tokenizer.decode(input_ids[0][-20:])
                        print(f"Recent context: '{recent_text}'")

                # Activate intervention
                self.intervention_active = True
                self.wait_text_inserted = False
                self.forced_wait_index = 0

                # Force the start of wait text by modifying scores
                next_wait_token = self.wait_tokens_ids[0]
                scores[0, :] = -float('inf')  # Set all token probabilities to very low
                scores[0, next_wait_token] = 100.0  # Force the first wait token
                self.forced_wait_index = 1

                if self.debug:
                    print(f"Starting wait text: '{self.tokenizer.decode([next_wait_token])}'")

                return scores

        # If we're in active intervention mode, force wait text
        if self.intervention_active:
            # If we haven't started inserting wait text
            if self.forced_wait_index == 0:
                next_wait_token = self.wait_tokens_ids[0]
                scores[0, :] = -float('inf')  # Set all token probabilities to very low
                scores[0, next_wait_token] = 100.0  # Force the first wait token
                self.forced_wait_index = 1

                if self.debug:
                    print(f"Starting wait text: '{self.tokenizer.decode([next_wait_token])}'")

                return scores

            # If we're in the middle of inserting wait text
            elif not self.wait_text_inserted:
                # Check if we've inserted all wait tokens
                if self.forced_wait_index >= len(self.wait_tokens_ids):
                    self.wait_text_inserted = True
                    self.intervention_active = False  # End intervention after wait text

                    # Add a cooldown period to prevent multiple consecutive wait texts
                    self.last_intervention_position = len(input_ids[0])

                    if self.debug:
                        print("Wait text fully inserted, continuing normal generation")
                    return scores

                # Force next wait token
                next_wait_token = self.wait_tokens_ids[self.forced_wait_index]
                scores[0, :] = -float('inf')
                scores[0, next_wait_token] = 100.0
                self.forced_wait_index += 1

                return scores

        return scores


class BudgetForcingLLM:
    # Define markers at the class level for consistency
    # Each marker has: text value and case_sensitive flag
    DEFAULT_MARKERS = {
        "correct_answer": {"text": """the correct answer is:""", "case_sensitive": False},
        "final_answer": {"text": """the final answer is:""", "case_sensitive": False},
        "boxed": {"text": "\\boxed", "case_sensitive": True},
        "eos": {"text": None, "case_sensitive": True}  # Will be filled with the tokenizer's EOS token
    }
    DEFAULT_WAIT_TEXT = "\nWait, I need to think more about this..."

    def __init__(
        self,
        model_id: str = "meta-llama/Llama-3.1-8B",
        device: str = "cuda",
        wait_tokens: int = 200,  # Number of tokens to wait before allowing markers
        debug: bool = True,
        hf_access_token: str = None,
        markers: dict = None,
        wait_text: str = None
    ):
        """
        Initialize the Budget Forcing implementation for Llama models.

        Args:
            model_id: HuggingFace model ID
            device: Device to run the model on (cuda/cpu)
            wait_tokens: Number of tokens to generate before allowing answer markers
            debug: Whether to print debug information
            markers: Custom markers dictionary (optional)
            wait_text: Custom wait text (optional)
        """
        print(f"Loading model {model_id} on {device}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_access_token)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            token=hf_access_token
        )
        self.model.to(device)
        self.device = device
        self.wait_tokens = wait_tokens
        self.debug = debug

        # Initialize markers and wait text
        self.markers = markers or self.DEFAULT_MARKERS.copy()
        # Ensure EOS token is correctly set
        if "eos" in self.markers:
            self.markers["eos"]["text"] = self.tokenizer.eos_token
        self.wait_text = wait_text or self.DEFAULT_WAIT_TEXT

        # Create our logits processor with our markers and wait text
        self.budget_processor = BudgetForcingProcessor(
            self.tokenizer,
            wait_tokens=wait_tokens,
            debug=debug,
            markers=self.markers,
            wait_text=self.wait_text
        )

    def clean_markers(self, text):
        """Remove only markers from generated text, keeping wait text intact"""
        cleaned_text = text

        # Remove all markers from the text using the dictionary with its properties
        for marker_key, marker_info in self.budget_processor.markers.items():
            marker_text = marker_info["text"]
            case_sensitive = marker_info["case_sensitive"]

            if marker_text and marker_key != "eos":  # Skip None values and EOS token
                pattern = re.escape(marker_text)

                if case_sensitive:
                    # Case-sensitive removal
                    cleaned_text = re.sub(pattern, "", cleaned_text)
                else:
                    # Case-insensitive removal
                    cleaned_text = re.sub(pattern, "", cleaned_text, flags=re.IGNORECASE)

        return cleaned_text

    def generate(
        self,
        prompt: str,
        max_new_tokens: int = 512,
        temperature: float = 0.7,
        top_p: float = 0.9
    ) -> str:
        """Generate text with two-stage budget forcing"""
        if self.debug:
            print("\n===== STARTING TWO-STAGE GENERATION =====")
            print(f"Wait tokens: {self.wait_tokens}")
            print(f"Max new tokens: {max_new_tokens}")

        # Customize prompt to strongly encourage the model to use the marker


        # STAGE 1: Generate with budget forcing until budget is exhausted
        # -------------------------------------------------------------------
        # Reset the processor state completely
        self.budget_processor = BudgetForcingProcessor(
            self.tokenizer,
            wait_tokens=self.wait_tokens,
            debug=self.debug
        )

        # Enable "stop at budget" mode
        self.budget_processor.stop_at_budget = True

        # Store the prompt token length for accurate token counting
        self.budget_processor.prompt_token_length = len(self.tokenizer.encode(prompt, add_special_tokens=False))

        # Tokenize input
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        # Define our logits processor list
        logits_processor = LogitsProcessorList([self.budget_processor])

        # Generate first stage with our custom processor
        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs,
                max_new_tokens=min(self.wait_tokens + 100, max_new_tokens),  # Just enough to reach budget
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                logits_processor=logits_processor,
                pad_token_id=self.tokenizer.eos_token_id,
            )

        # Decode the output and extract only the generated text
        full_output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        first_stage_text = full_output[len(prompt):]

        if self.debug:
            print(f"\nStage 1 generated {len(self.tokenizer.encode(first_stage_text))} tokens")

        # Clean the generated text to remove markers but keep wait text
        cleaned_text = self.clean_markers(first_stage_text)

        if self.debug:
            print(f"Cleaned text for stage 2: {cleaned_text}")

        # STAGE 2: Continue generation from cleaned text without budget forcing
        # -------------------------------------------------------------------
        # Create new prompt with the cleaned text from stage 1
        stage2_prompt = prompt + cleaned_text

        # Calculate remaining tokens for stage 2
        tokens_used_stage1 = len(self.tokenizer.encode(cleaned_text))
        remaining_tokens = max_new_tokens - tokens_used_stage1

        if remaining_tokens <= 0:
            if self.debug:
                print("No tokens remaining for stage 2, returning stage 1 result")
            return cleaned_text

        if self.debug:
            print(f"\n===== STARTING SECOND STAGE GENERATION =====")
            print(f"Tokens remaining: {remaining_tokens}")

        # Tokenize the new prompt
        stage2_inputs = self.tokenizer(stage2_prompt, return_tensors="pt").to(self.device)

        # Generate without budget forcing for the second stage
        with torch.no_grad():
            stage2_output_ids = self.model.generate(
                **stage2_inputs,
                max_new_tokens=remaining_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                pad_token_id=self.tokenizer.eos_token_id,
            )

        # Extract only the new generated text (beyond the stage 2 prompt)
        stage2_full_output = self.tokenizer.decode(stage2_output_ids[0], skip_special_tokens=True)
        stage2_generated_text = stage2_full_output[len(stage2_prompt):]

        if self.debug:
            print(f"Stage 2 generated {len(self.tokenizer.encode(stage2_generated_text))} tokens")

        # Combine the cleaned first stage and second stage text
        final_output = cleaned_text + stage2_generated_text

        if self.debug:
            print(f"\nTotal tokens generated: {len(self.tokenizer.encode(final_output))}/{max_new_tokens}")

        return final_output

Let's create a budget-forcing LLM class. We chose a small LLM to be sure that we won't exhaust GPU even with long solutions.

In [47]:
# Example usage
budget_llm = BudgetForcingLLM(
        model_id="Qwen/Qwen2.5-1.5B-Instruct",
        wait_tokens=2048,
        debug=True,
        hf_access_token=hf_access_token
    )

Loading model Qwen/Qwen2.5-1.5B-Instruct on cuda


Let's run an attempt at generation!

In [48]:
# Problem 63 from the AIME dataset: https://huggingface.co/datasets/AI-MO/aimo-validation-aime
math_problem = """Let\n$$x^8 + 3x^4 - 4 = p_1(x) p_2(x) \\dotsm p_k(x),$$\nwhere each non-constant polynomial $p_i(x)$ is monic with integer coefficients, and cannot be factored further over the integers.  Compute $p_1(1) + p_2(1) + \\dots + p_k(1).$"""

    # Generate with budget forcing
print("\n--- Generating solution with budget forcing ---")
output = budget_llm.generate(
    math_problem,
    max_new_tokens=3096,
    temperature=0.8
)

print("\nGenerated solution:")
print(output)




--- Generating solution with budget forcing ---

===== STARTING TWO-STAGE GENERATION =====
Wait tokens: 2048
Max new tokens: 3096
Total tokens generated: 100/2048
Total tokens generated: 150/2048
Total tokens generated: 200/2048
Total tokens generated: 250/2048
Total tokens generated: 300/2048
Total tokens generated: 350/2048
Total tokens generated: 400/2048
Total tokens generated: 450/2048
Total tokens generated: 500/2048
Total tokens generated: 550/2048
Total tokens generated: 600/2048
Total tokens generated: 650/2048
Total tokens generated: 700/2048
Total tokens generated: 750/2048
Total tokens generated: 800/2048
Total tokens generated: 850/2048
BOXED marker detected in: ']
Therefore, the answer is:
\[ \boxed'

==== ACTIVATING BUDGET FORCING INTERVENTION ====
Detected marker: \boxed
At position 863
Current token count: 862/2048
Recent context: ' 2 + 0 = 8. \]
Therefore, the answer is:
\[ \boxed'
Starting wait text: '
'
Wait text fully inserted, continuing normal generation
Total t

In [49]:
math_problem = """Inside a circle, two parallel chords are 6 units apart. One chord has length 14 and the other has length 10. Find the radius of the circle."""

    # Generate with budget forcing
print("\n--- Generating solution with budget forcing ---")
output = budget_llm.generate(
    math_problem,
    max_new_tokens=3096,
    temperature=0.8
)

print("\nGenerated solution:")
print(output)




--- Generating solution with budget forcing ---

===== STARTING TWO-STAGE GENERATION =====
Wait tokens: 2048
Max new tokens: 3096
Total tokens generated: 50/2048
Total tokens generated: 100/2048
Total tokens generated: 150/2048
Total tokens generated: 200/2048
Total tokens generated: 250/2048
Total tokens generated: 300/2048
Total tokens generated: 350/2048
Total tokens generated: 400/2048
Total tokens generated: 450/2048
Total tokens generated: 500/2048
Total tokens generated: 550/2048
Total tokens generated: 600/2048
Total tokens generated: 650/2048
BOXED marker detected in: ', the radius of the circle is \(\boxed'

==== ACTIVATING BUDGET FORCING INTERVENTION ====
Detected marker: \boxed
At position 694
Current token count: 693/2048
Recent context: ' 5\sqrt{2} \]

So, the radius of the circle is \(\boxed'
Starting wait text: '
'
Total tokens generated: 700/2048
Wait text fully inserted, continuing normal generation
Total tokens generated: 750/2048
Total tokens generated: 800/2048
To

If you experiment with our implementation of budget forcing, you'll observe that **In-context learning** becomes an unexpected problem for this model. Indeed, after several forces "Wait" after what converges to an answer, makes a model anxious, and it often exhibits the same, "waiting" behaviour even after the interventions are over. So, indeed additional fine tuning wouldn't hurt this model.

**Your task**. Experiment with different problems and lengths of intervention window. Try different temperatures (we made it relatively high for a reason). If your GPU allows, try a larger model. Will you be able to overcome model's anxiety without fine tuning?