# Chapter 4: Improving Reasoning with Inference-Time Scaling

## Learning Objectives
- Understand inference-time compute scaling as a way to improve accuracy without retraining
- Implement chain-of-thought prompting to encourage step-by-step reasoning
- Build self-consistency sampling with majority voting
- Create flexible text generation with swappable sampling strategies (temperature, top-p)

## Core Techniques (notes from the chapter)

- Method 1: Extending the chain-of-thought response to prompt the model to explain its reasoning. This is a simple technique that can substantially improve accuracy. 
- Method 2: Parallel sampling via self-consistency, where the model generates multiple responses and selects the most frequent one. 
- Method 3: Iterative self-refinement, where the model reviews and improves its own reasoning and answers across multiple steps. (This topic is implemented and covered in more detail in the next chapter.)

## Note

- We evaluate reasoning techniques using a fixed symbolic grading pipeline adapted from Chapter 3. For each problem, the model generates a free-form response, from which a final answer is extracted (preferring boxed expressions). The extracted answer is normalized and compared against the dataset’s ground-truth answer using SymPy-based equivalence checking. This evaluation verifies final answer correctness but does not validate intermediate reasoning steps.

In [21]:
import torch
from reasoning_from_scratch.ch02 import get_device
from reasoning_from_scratch.ch03 import load_model_and_tokenizer

device = get_device()
device = torch.device("cpu")

model, tokenizer = load_model_and_tokenizer(
    which_model="base",
    device=device,
    use_compile=False
)

Using Apple Silicon GPU (MPS)
✓ qwen3/qwen3-0.6B-base.pth already up-to-date


In [26]:
from reasoning_from_scratch.ch03 import render_prompt

raw_prompt = (
    "Half the value of $3x-9$ is $x+37$. "
    "What is the value of $x$?"
)
prompt = render_prompt(raw_prompt)
print(prompt)

You are a helpful math assistant.
Answer the question and write the final result on a new line as:
\boxed{ANSWER}

Question:
Half the value of $3x-9$ is $x+37$. What is the value of $x$?

Answer:


In [23]:
from reasoning_from_scratch.ch02_ex import generate_text_basic_stream_cache


def generate_text_stream_concat_flex(
    model, tokenizer, prompt, device, max_new_tokens,
    verbose=False,
    generate_func=None,
    **generate_kwargs
):

    if generate_func is None:
        generate_func = generate_text_basic_stream_cache

    input_ids = torch.tensor(
        tokenizer.encode(prompt), device=device
    ).unsqueeze(0)

    generated_ids = []
    for token in generate_func(
        model=model,
        token_ids=input_ids,
        max_new_tokens=max_new_tokens,
        eos_token_id=tokenizer.eos_token_id,
        **generate_kwargs,
    ):
        next_token_id = token.squeeze(0)
        generated_ids.append(next_token_id.item())

        if verbose:
            print(
                tokenizer.decode(next_token_id.tolist()),
                end="",
                flush=True
            )
    return tokenizer.decode(generated_ids)

In [24]:
response = generate_text_stream_concat_flex(
    model, tokenizer, prompt, device,
    max_new_tokens=2048, verbose=True,
    generate_func=generate_text_basic_stream_cache
)

 \boxed{20}

In [25]:
prompt_cot = prompt + " \n\nExplain step by step."

response_cot = generate_text_stream_concat_flex(
    model, tokenizer, prompt_cot, device,
    max_new_tokens=2048, verbose=True,
)

 To solve the problem, we need to find the value of \( x \) such that half the value of \( 3x - 9 \) is equal to \( x + 37 \).

### Step 1: Set up the equation
We are given that half the value of \( 3x - 9 \) is equal to \( x + 37 \). This can be written as:
\[
\frac{1}{2}(3x - 9) = x + 37
\]

### Step 2: Eliminate the fraction
To eliminate the fraction, multiply both sides of the equation by 2:
\[
2 \cdot \frac{1}{2}(3x - 9) = 2(x + 37)
\]
Simplifying both sides:
\[
3x - 9 = 2x + 74
\]

### Step 3: Solve for \( x \)
Subtract \( 2x \) from both sides to isolate \( x \):
\[
3x - 2x - 9 = 74
\]
Simplify:
\[
x - 9 = 74
\]
Add 9 to both sides to solve for \( x \):
\[
x = 74 + 9
\]
\[
x = 83
\]

### Final Answer:
\[
\boxed{83}
\]

In [None]:
import json
import re
import torch
import requests
from pathlib import Path

from reasoning_from_scratch.ch02 import get_device
from reasoning_from_scratch.ch02_ex import generate_text_basic_stream_cache
from reasoning_from_scratch.ch03 import (
    load_model_and_tokenizer,
    extract_final_candidate,
    grade_answer,
)


# ─────────────────────────────────────────────────────────────
# NON-CoT prompt template (your exact version)
# ─────────────────────────────────────────────────────────────

def render_prompt(prompt):
    template = (
        "You are a helpful math assistant.\n"
        "Answer the question and write the final result on a new line as:\n"
        "\\boxed{ANSWER}\n\n"
        f"Question:\n{prompt}\n\nAnswer:"
    )
    return template


# ─────────────────────────────────────────────────────────────
# Load MATH-500
# ─────────────────────────────────────────────────────────────

def load_math500_test(local_path="math500_test.json", save_copy=True):
    local_path = Path(local_path)
    url = (
        "https://raw.githubusercontent.com/rasbt/reasoning-from-scratch/"
        "main/ch03/01_main-chapter-code/math500_test.json"
    )

    if local_path.exists():
        with local_path.open("r", encoding="utf-8") as f:
            data = json.load(f)
    else:
        r = requests.get(url, timeout=30)
        r.raise_for_status()
        data = r.json()

        if save_copy:
            with local_path.open("w", encoding="utf-8") as f:
                json.dump(data, f, indent=2)

    return data


# ─────────────────────────────────────────────────────────────
# Setup
# ─────────────────────────────────────────────────────────────

device = get_device()
print(f"Using device: {device}")

model, tokenizer = load_model_and_tokenizer(
    which_model="base",
    device=device,
    use_compile=False
)

math_data = load_math500_test()
print("MATH-500 entries:", len(math_data))


# ─────────────────────────────────────────────────────────────
# Generate 
# ─────────────────────────────────────────────────────────────

def generate_text_stream_concat_flex(
    model, tokenizer, prompt, device, max_new_tokens,
    verbose=False,
    generate_func=None,
    **generate_kwargs
):
    if generate_func is None:
        generate_func = generate_text_basic_stream_cache

    input_ids = torch.tensor(
        tokenizer.encode(prompt), device=device
    ).unsqueeze(0)

    generated_ids = []
    for token in generate_func(
        model=model,
        token_ids=input_ids,
        max_new_tokens=max_new_tokens,
        eos_token_id=tokenizer.eos_token_id,
        **generate_kwargs,
    ):
        next_token_id = token.squeeze(0)
        generated_ids.append(next_token_id.item())

        if verbose:
            print(
                tokenizer.decode(next_token_id.tolist()),
                end="",
                flush=True
            )
    return tokenizer.decode(generated_ids)


# ─────────────────────────────────────────────────────────────
# Answer extraction
# ─────────────────────────────────────────────────────────────

def extract_answer(response):
    answer = extract_final_candidate(response)
    if answer:
        return answer

    matches = re.findall(r'=\s*([+-]?\d+(?:\.\d+)?)', response)
    if matches:
        return matches[-1]

    matches = re.findall(r'(\d+)\s*$', response.strip())
    if matches:
        return matches[-1]

    return None


# ─────────────────────────────────────────────────────────────
# Prompts
# ─────────────────────────────────────────────────────────────

def make_prompt_non_cot(raw_prompt):
    return render_prompt(raw_prompt)


def make_prompt_cot(raw_prompt):
    return render_prompt(raw_prompt) + "\n\nExplain step by step."


# ─────────────────────────────────────────────────────────────
# Evaluate first 10 MATH-500 (single-shot, both prompts)
# ─────────────────────────────────────────────────────────────

print("\n" + "#" * 60)
print("MATH-500 (first 10) single-shot: Non-CoT vs CoT")
print("#" * 60)

non_cot_correct = 0
cot_correct = 0

for i, ex in enumerate(math_data[:10], start=1):
    raw_problem = ex["problem"]
    truth = ex["answer"]

    print(f"\n{'='*70}")
    print(f"PROBLEM {i}/10 | unique_id: {ex.get('unique_id','')}")
    print(raw_problem)
    print(f"GROUND TRUTH: {truth}")
    print("="*70)

    # -------------------------
    # Non-CoT
    # -------------------------
    print("\n--- NON-CoT OUTPUT ---")
    prompt_non_cot = make_prompt_non_cot(raw_problem)
    response_non_cot = generate_text_stream_concat_flex(
        model, tokenizer, prompt_non_cot, device,
        max_new_tokens=2048, verbose=True,
    )

    pred_non_cot = extract_answer(response_non_cot)
    ok_non_cot = grade_answer(pred_non_cot, truth)
    non_cot_correct += int(ok_non_cot)

    print(f"\n\n>>> Extracted: {pred_non_cot}")
    print(f">>> CORRECT: {ok_non_cot}")

    # -------------------------
    # CoT
    # -------------------------
    print("\n--- CoT OUTPUT ---")
    prompt_cot = make_prompt_cot(raw_problem)
    response_cot = generate_text_stream_concat_flex(
        model, tokenizer, prompt_cot, device,
        max_new_tokens=2048, verbose=True,
    )

    pred_cot = extract_answer(response_cot)
    ok_cot = grade_answer(pred_cot, truth)
    cot_correct += int(ok_cot)

    print(f"\n\n>>> Extracted: {pred_cot}")
    print(f">>> CORRECT: {ok_cot}")


print("\n" + "=" * 60)
print("FINAL RESULTS (first 10 MATH-500)")
print("=" * 60)
print(f"Non-CoT: {100*non_cot_correct/10:.1f}% ({non_cot_correct}/10)")
print(f"CoT:     {100*cot_correct/10:.1f}% ({cot_correct}/10)")


Using Apple Silicon GPU (MPS)
Using device: mps
✓ qwen3/qwen3-0.6B-base.pth already up-to-date
MATH-500 entries: 500

############################################################
MATH-500 (first 10) single-shot: Non-CoT vs CoT
############################################################

PROBLEM 1/10 | unique_id: test/precalculus/807.json
Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$
GROUND TRUTH: \left( 3, \frac{\pi}{2} \right)

--- NON-CoT OUTPUT ---
 \boxed{(3,\frac{\pi}{2})}

>>> Extracted: (3,\frac{\pi}{2})
>>> CORRECT: True

--- CoT OUTPUT ---
 To convert the point \((0, 3)\) from rectangular coordinates to polar coordinates, we need to find the radius \(r\) and the angle \(\theta\). Here's a step-by-step explanation:

### Step 1: Find the radius \(r\)
The radius \(r\) is the distance from the origin to the point. It can be calculated using the distance formula:
\[
