# Scaling Test-Time Compute for Longer Thinking in LLMs

_Adapted from [Hugging Face](https://github.com/huggingface/search-and-learn)_

_**Requirements:**_ A100 GPU (good luck)

---

This notebook adapts the text-time compute solution presented in [this **blog post**](https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute) to extend its capabilities. The goal is to analyse inference-time compute and produce plots of number of generations vs quality. This should work for the two approaches for reasoning:
* Verifier-based: plot number of generations versus number of tokens, to ultimately inform on the cost. Additionally, compute accuracy by connecting the results to nemo-evaluator
* CoT-based: run deepseek-R1 in a recursive fashion (no verifier), using the distilled model 

This extension also allows to plug in different datasets for a quick understanding of the generalization capabilities of the results.

---


<img src="https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/methods-thumbnail.png" alt="Instruct LLM Methodology" width="800"/>

---

## 1. Install Dependencies _(copied from HF)_

Let’s start by installing the [search-and-learn](https://github.com/huggingface/search-and-learn) repository! 🚀  
This repo is designed to replicate the experimental results and is not a Python pip package. However, we can still use it to generate our system. To do so, we’ll need to install it from source with the following steps:

In [None]:
!git clone https://github.com/huggingface/search-and-learn

In [None]:
%cd hack-search-and-learn

In [None]:
!pip install -e '.[dev]'
!pip install matplotlib

Log in to Hugging Face to access [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct), as it is a gated model! 🗝️  
If you haven't previously requested access, you'll need to submit a request before proceeding.


In [2]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## 2. Setup the Large Language Model (LLM) and the Process Reward Model (PRM) 💬 _(copied from HF)_

As illustrated in the diagram, the system consists of an LLM that generates intermediate answers based on user input, a [PRM model](https://huggingface.co/papers/2211.14275) that evaluates and scores these answers, and a search strategy that uses the PRM feedback to guide the subsequent steps in the search process until reaching the final answer.

Let’s begin by initializing each model. For the LLM, we’ll use the [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct) model, and for the PRM, we’ll use the [RLHFlow/Llama3.1-8B-PRM-Deepseek-Data](https://huggingface.co/RLHFlow/Llama3.1-8B-PRM-Deepseek-Data) model.




![system](https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/system.png)

In [None]:
import sys
import os

project_src = "src/"

# Add it to sys.path
sys.path.append(project_src)

In [None]:
import torch
from vllm import LLM
from sal.models.reward_models import RLHFFlow

model_path="meta-llama/Llama-3.2-1B-Instruct"
prm_path="RLHFlow/Llama3.1-8B-PRM-Deepseek-Data"

llm = LLM(
    model=model_path,
    gpu_memory_utilization=0.5,  # Utilize 50% of GPU memory
    enable_prefix_caching=True,  # Optimize repeated prefix computations
    seed=42,                     # Set seed for reproducibility
)

prm = RLHFFlow(prm_path)

### 2.1 Instantiate the Question, Search Strategy, and Call the Pipeline

Now that we've set up the LLM and PRM, let's proceed by defining the question, selecting a search strategy to retrieve relevant information, and calling the pipeline to process the question through the models.

1. **Instantiate the Question**: In this step, we define the input question that the system will answer, considering the given context.

2. **Search Strategy**: The system currently supports the following search strategies: `best_of_n`, `beam_search`, and `dvts` (see diagram). For this example, we'll use `best_of_n`, but you can easily switch to any of the other strategies based on your needs. We need to define some configuration parameters for the configuration of the search strategy. You can check the full list [here](https://github.com/huggingface/search-and-learn/blob/main/src/sal/config.py).

3. **Call the Pipeline**: With the question and search strategy in place, we’ll call the inference pipeline, processing the inputs through both the LLM and PRM to generate the final answer.

![](https://huggingface.co/datasets/HuggingFaceH4/blogpost-images/resolve/main/search-strategies.png)

The first step is to clearly define the question that the system will answer. This ensures that we have a precise task for the model to tackle.

In [None]:
question_text = 'Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$'
input_batch = {"problem": [question_text]}

Next, we define the configuration, including parameters like the number of candidate answers `(N)`, and choose the search strategy that will be used. The search strategy dictates how we explore the potential answers. In this case, we'll use `best_of_n`.

With the question and configuration in place, we use the selected search strategy to generate multiple candidate answers. These candidates are evaluated based on their relevance and quality and the final answer is returned.


In [None]:
from sal.config import Config
from sal.search import beam_search, best_of_n, dvts

config = Config()
config.n=32 # Number of answers to generate during the search

search_result = best_of_n(x=input_batch, config=config, llm=llm, prm=prm)

### 2.2 Display the Final Result

Once the pipeline has processed the question through the LLM and PRM, we can display the final result. This result will be the model's output after considering the intermediate answers and scoring them using the PRM.

Here's how to display the final answer:

In [None]:
search_result['pred'][0]

The model’s output might include special tokens, such as `<|start_header_id|>` or `<|end_header_id|>`. To make the answer more readable, we can safely remove them before displaying it to the end user.

In [None]:
formatted_output = search_result['pred'][0].replace("<|start_header_id|>assistant<|end_header_id|>\n\n", "").strip()
formatted_output

After removing any special tokens, we can display the final answer to the user. Since the answer is based on markdown, it can be rendered properly by displaying it as markdown.

In [None]:
from IPython.display import display, Markdown

display(Markdown(formatted_output))

## 3. Assembling It All! 🧑‍🏭️ _(copied from HF)_

Now, let's create a method that encapsulates the entire pipeline. This will allow us to easily reuse the process in future applications, making it efficient and modular.

By combining the LLM, PRM, search strategy, and result display, we can simplify the workflow and ensure that it’s reusable for other tasks or questions.

We simplify the workflow, ensuring that it’s reusable for different tasks or questions. Additionally, we’ll track the time spent on each method so that we can **understand the practical implications** of using each strategy and configuration.

Here’s how we can structure the method:

In [None]:
import time

def generate_with_search_and_learn(question, config, llm, prm, method='best_of_n'):
    """
    Generate an answer for a given question using the search-and-learn pipeline.

    Args:
    - question (str): The input question to generate an answer for.
    - config (Config): Configuration object containing parameters for search strategy.
    - llm (LLM): Pretrained large language model used for generating answers.
    - prm (RLHFFlow): Process reward model used for evaluating answers.
    - method (str): Search strategy to use. Options are 'best_of_n', 'beam_search', 'dvts'. Default is 'best_of_n'.

    Returns:
    - str: The formatted output after processing the question.
    """
    batch = {"problem": [question]}

    start_time = time.time()
    if method == 'best_of_n':
      result = best_of_n(x=batch, config=config, llm=llm, prm=prm)
    elif method == 'beam_search':
      result = beam_search(examples=batch, config=config, llm=llm, prm=prm)
    elif method == 'dvts':
      result = dvts(examples=batch, config=config, llm=llm, prm=prm)

    elapsed_time = time.time() - start_time
    print(f"\nFinished in {elapsed_time:.2f} seconds\n")

    tokenizer = llm.get_tokenizer()
    total_tokens = 0
    for completion in result['completions']:
        for comp in  completion:
            output_tokens = tokenizer.encode(comp)
            total_tokens += len(output_tokens)

    print(f"Total tokens in all completions: {total_tokens}")

    formatted_output = result['pred'][0].replace("<|start_header_id|>assistant<|end_header_id|>\n\n", "").strip()
    return formatted_output, elapsed_time, total_tokens

### ⏳  3.1 Comparing Thinking Time for Each Strategy

Let’s compare the **thinking time** of three methods: `best_of_n`, `beam_search`, and `dvts`. Each method is evaluated using the same number of answers during the search process, measuring the time spent thinking in seconds and the number of generated tokens.

In the results below, the `best_of_n` method shows the least thinking time, while the `dvts` method takes the most time. However, `best_of_n` generates more tokens due to its simpler search strategy.

| **Method**      | **Number of Answers During Search** | **Thinking Time (Seconds)** | **Generated Tokens** |
|------------------|-------------------------------------|-----------------------------|-----------------------|
| **best_of_n**    | 8                                   | 3.54                        | 3087                  |
| **beam_search**  | 8                                   | 10.06                       | 2049                  |
| **dvts**         | 8                                   | 8.46                        | 2544                  |

This comparison illustrates the trade-offs between the strategies, balancing time spent thinking and the complexity of the search process.


#### 1. **Best of n**

We’ll begin by using the `best_of_n` strategy. Here’s how to track the thinking time for this method:

In [None]:
question = 'Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$'

config.n=8

formatted_output = generate_with_search_and_learn(question=question, config=config, llm=llm, prm=prm, method='best_of_n')

In [None]:
display(Markdown(formatted_output))

#### 2. **Beam Search**

Now, let's try using the `beam_search` strategy.

In [None]:
config.n=8
# beam search specific
config.sort_completed=True
config.filter_duplicates=True

formatted_output = generate_with_search_and_learn(question=question, config=config, llm=llm, prm=prm, method='beam_search')

In [None]:
display(Markdown(formatted_output))

#### 3. **Diverse Verifier Tree Search (DVTS)**

Finally, let's try the `dvts` strategy.

In [None]:
config.n=8
# dvts specific
config.n_beams = config.n // config.beam_width

formatted_output = generate_with_search_and_learn(question=question, config=config, llm=llm, prm=prm, method='dvts')

In [None]:
display(Markdown(formatted_output))

### 🙋 3.2 Testing the System with a Simple Question

In this final example, we’ll test the system using a straightforward question to observe how it performs in simpler cases. This allows us to verify that the system works as expected even for basic queries.

Let's try the following question:

In [None]:
question = 'What\'s the capital of Spain?'

config.n=32

formatted_output = generate_with_search_and_learn(question=question, config=config, llm=llm, prm=prm, method='best_of_n')

In [None]:
display(Markdown(formatted_output))

Even though we set a larger number of candidate answers (`N`), the time spent thinking remains relatively small (1.03 seconds and 544 generated tokens). This demonstrates the system’s ability to efficiently handle easier problems, spending less time on them, while leveraging its enhanced capabilities for more complex questions.

🏆 **We now have a fully operational pipeline** that leverages test-time compute, enabling the system to "think longer" for more complicated queries, while also maintaining fast response times for straightforward questions.

This approach ensures the system can scale its thinking time based on the task's complexity, offering an efficient and responsive solution for both simple and challenging problems.


## 4. Benchmarking _(extension by NVIDIA)_

### 4.1 Defining what we are evaluating

- Model to be analysed

In [None]:
model_path = "meta-llama/Llama-3.1-8B-Instruct"
prm_path = "RLHFlow/Llama3.1-8B-PRM-Deepseek-Data"

- Methods to analyse

In [None]:
# Define methods
methods = ["Best-of-n", "Beam search", "Diverse verifier tree search"]

In [15]:
# Which n values are being tested

n_values = [2**i for i in range(2, 9)]
print(n_values)

[4, 8, 16, 32, 64, 128, 256]


- Dataset

In [18]:
question = 'Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$'

In [13]:
dataset_name = "allenai/math_qa"
split = "test"
samples = 10

#### Calling the model using vllm

In [None]:
import torch
from vllm import LLM
from sal.models.reward_models import RLHFFlow
from sal.search import beam_search, best_of_n, dvts
import sys
import os

project_src = "src/"

llm = LLM(
    model=model_path,
    gpu_memory_utilization=0.5,  # Utilize 50% of GPU memory
    enable_prefix_caching=True,  # Optimize repeated prefix computations
    seed=42,                     # Set seed for reproducibility
)

prm = RLHFFlow(prm_path)

### 4.2 Generating all results

In [None]:
sys.path.append(os.path.abspath("/hack-search-and-learn/evaluation/hack-eval/prep.py"))
from prep import preparing_input_dataset, preparing_output_dataset

In [None]:
# Getting questions from the dataset (all_entries)
all_entries, all_options = preparing_input_dataset(dataset_name, split, samples)

In [None]:
import json
import csv
# Add it to sys.path
sys.path.append(project_src)
from sal.config import Config

In [None]:
all_outputs = {}

# Define output filename for tokens and time
csv_filename = "search_methods_results.csv"

# Define headers (Each method has its own Time (s) and Total tokens)
headers = ["Number of generations"]

for method in methods:
    headers.append(f"{method} Time (s)")
    headers.append(f"{method} Total tokens")
    

with open(csv_filename, mode="w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(headers)  # Write header

    for i in n_values:
        config = Config()
        config.n = i
        row = [i]
        
        for method in methods:

            if method == "Beam search":
                config.sort_completed=True
                config.filter_duplicates=True
            elif method == "Diverse verifier tree search":
                config.n_beams = config.n // config.beam_width

            method_output = {}
            for n, example in enumerate(all_entries):
                prompt = example["prompt"]  # Input prompt from the dataset
                formatted_output, elapsed_time, token_number = generate_with_search_and_learn(question=prompt, config=config, llm=llm, prm=prm, method=method)
                
                if n == 0:
                    row.append(elapsed_time)
                    row.append(token_number)
                    
                
                method_output[prompt] = formatted_output

            all_outputs.update({method + "_" + i: method_output})
            
        writer.writerow(row)

    print(f"CSV file '{csv_filename}' has been created successfully.")

### 4.3 Computing Accuracy

In [None]:
# This generates a {method}_outputs.json file that can be ingested by nemo evaluator

# You need to choose the evaluation first

results_to_evaluate = all_outputs["Best-of-n_4"]

preparing_output_dataset(all_entries, all_options, results_to_evaluate, dataset_name, split, method)

**<< Offline connection to NEMO Evaluator >>**

### 4.4 Evaluation: Token / Time Versus Generation

#### Plotting Results

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np


def plot_methods(csv_file, analysis_type="Time (s)"): # or Total tokens

    df = pd.read_csv(csv_filename)

    col_names = df.columns

    # Extract relevant columns
    x = df[col_names[0]] # Should be Number of Generations
    number_of_methods = len(col_names) // 2

    # Plot the data
    plt.figure(figsize=(10, 6))

    if "Time" in analysis_type:
        start = 1
        markers = ['o', 'o', 'o']
    else: # Tokens
        start = 2
        markers = ['o', 's', '^']

    j = 0
    for i in range(start, len(col_names), 2):
        plt.plot(x, df[col_names[i]], marker=markers[j], label=col_names[i].split(analysis_type)[0])
        j += 1

    # Labels and title
    plt.xlabel(col_names[0])
    plt.ylabel(analysis_type)

    plt.suptitle(f"Elapsed {analysis_type} in All Completions vs. Number of Generations", fontsize=14)
    plt.title(f"LLM: {model_path}, PRM: {prm_path}", fontsize=10, color='gray')
    plt.legend()

    # Set log scale
    plt.xscale("log")
    plt.yscale("log")

    # Ensure x-axis ticks show as integers
    plt.xticks(x, labels=[str(int(val)) for val in x])  

    # Grid and formatting
    plt.grid(True, which="both", linestyle="--", linewidth=0.5)

    # Show the plot
    plt.show()

In [None]:
# Load CSV file
csv_filename = "search_methods_results.csv"  # Update this with the actual CSV file path

plot_methods(csv_filename, "Time (s)")

In [None]:
plot_methods(csv_filename, "Total tokens")

Compare Llama 1b and Llama 3b on the basis of DVTS Time

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Load CSV files
llama1b_csv = "llama_3_2_1b.csv"  
llama3b_csv = "llama_3_2_3b.csv"  
llama8b_csv = "llama_3_1_8b.csv"

df_1b = pd.read_csv(llama1b_csv)
df_3b = pd.read_csv(llama3b_csv)
df_8b = pd.read_csv(llama8b_csv)

# Extract relevant data
x_1b = df_1b["Number of generations"]
time_1b = df_1b["Diverse verifier tree search Time (s)"]

x_3b = df_3b["Number of generations"]
time_3b = df_3b["Diverse verifier tree search Time (s)"]

x_8b = df_8b["Number of generations"]
time_8b = df_8b["Diverse verifier tree search Time (s)"]

# Plot comparison
plt.figure(figsize=(10, 6))
plt.plot(x_1b, time_1b, marker='o', label="Llama 3.2 1b", linestyle='-')
plt.plot(x_3b, time_3b, marker='s', label="Llama 3.2 3b", linestyle='--')
plt.plot(x_3b, time_3b, marker='x', label="Llama 3.1 8b", linestyle='---')

# Titles and Labels
plt.suptitle("Comparison of Diverse Verifier Tree Search Time", fontsize=14)
plt.title("Llama 3.2 1b vs. Llama 3.2 3b vs Llama 3.1 8b", fontsize=10, color='gray')
plt.xlabel("Number of Generations")
plt.ylabel("Diverse Verifier Tree Search Time (s)")
plt.legend()

# Ensure integer values on x-axis
plt.xticks(x_1b, labels=[str(int(val)) for val in x_1b])  

# Grid for better readability
plt.grid(True, linestyle="--", linewidth=0.5)

# Show the plot
plt.show()


Compare Llama 1b and Llama 3b on the basis of DVTS Tokens number

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Load CSV files
llama1b_csv = "llama_3_2_1b.csv"  
llama3b_csv = "llama_3_2_3b.csv"  
llama8b_csv = "llama_3_1_8b.csv"

df_1b = pd.read_csv(llama1b_csv)
df_3b = pd.read_csv(llama3b_csv)
df_8b = pd.read_csv(llama8b_csv)

# Extract relevant data
x_1b = df_1b["Number of generations"]
tokens_1b = df_1b["Diverse verifier tree search Total tokens"].astype(int)

x_3b = df_3b["Number of generations"]
tokens_3b = df_3b["Diverse verifier tree search Total tokens"].astype(int)

x_8b = df_8b["Number of generations"]
tokens_8b = df_8b["Diverse verifier tree search Total tokens"].astype(int)

# Plot comparison
plt.figure(figsize=(10, 6))
plt.plot(x_1b, tokens_1b, marker='o', label="Llama 1b", linestyle='-')
plt.plot(x_3b, tokens_3b, marker='s', label="Llama 3b", linestyle='--')
plt.plot(x_8b, tokens_8b, marker='x', label="Llama 8b", linestyle='--')

# Titles and Labels
plt.suptitle("Comparison of Total Tokens in Diverse Verifier Tree Search", fontsize=14)
plt.title("Llama 3.2 1b vs. Llama 3.2 3b vs Llama 3.1 8b", fontsize=10, color='gray')
plt.xlabel("Number of Generations")
plt.ylabel("Total Tokens")
plt.legend()

# Ensure integer values on x-axis
plt.xticks(x_1b, labels=[str(int(val)) for val in x_1b])  

# Ensure integer values on y-axis
y_ticks = sorted(set(tokens_1b.tolist() + tokens_3b.tolist()))
plt.yticks(y_ticks, labels=[str(val) for val in y_ticks])

# Grid for better readability
plt.grid(True, linestyle="--", linewidth=0.5)

# Show the plot
plt.show()
