# Chapter 3: Evaluating Reasoning Models

  ## Learning Objectives
  - Extract and parse final answers from LLM text responses reliably
  - Verify answer correctness using symbolic math solvers (calculator-like verification)
  - Build an evaluation pipeline: load model → generate outputs → grade against dataset
  - Implement verifiable rewards system (foundation for Chapter 6 reinforcement learning)

 <img src="figure1.png" alt="Figure 1" width="600">

In [None]:
from pathlib import Path
import torch


from reasoning_from_scratch.qwen3 import (
    download_qwen3_small,
    Qwen3Tokenizer,
    Qwen3Model,
    QWEN_CONFIG_06_B
)


def load_model_and_tokenizer(
    which_model, device, use_compile, local_dir="qwen3"
):
    if which_model == "base":

        download_qwen3_small(
            kind="base", tokenizer_only=False, out_dir=local_dir
        )

        tokenizer_path = Path(local_dir) / "tokenizer-base.json"
        model_path = Path(local_dir) / "qwen3-0.6B-base.pth"
        tokenizer = Qwen3Tokenizer(tokenizer_file_path=tokenizer_path)

    elif which_model == "reasoning":

        download_qwen3_small(
            kind="reasoning", tokenizer_only=False, out_dir=local_dir
        )

        tokenizer_path = Path(local_dir) / "tokenizer-reasoning.json"
        model_path = Path(local_dir) / "qwen3-0.6B-reasoning.pth"
        tokenizer = Qwen3Tokenizer(
            tokenizer_file_path=tokenizer_path,
            apply_chat_template=True,
            add_generation_prompt=True,
            add_thinking=True,
        )

    else:
        raise ValueError(f"Invalid choice: which_model={which_model}")

    model = Qwen3Model(QWEN_CONFIG_06_B)
    model.load_state_dict(torch.load(model_path))

    model.to(device)

    if use_compile:
        torch._dynamo.config.allow_unspec_int_on_nn_module = True
        model = torch.compile(model)

    return model, tokenizer

In [14]:
from reasoning_from_scratch.ch02 import (
    get_device
)

WHICH_MODEL = "base"
device = get_device()

# If you have compatibility issues, try to
# uncomment the line below and rerun the notebook
# device = torch.device("cpu")

model, tokenizer = load_model_and_tokenizer(
    which_model=WHICH_MODEL,
    device=device,
    use_compile=False
)

Using Apple Silicon GPU (MPS)
✓ qwen3/qwen3-0.6B-base.pth already up-to-date


In [15]:
from reasoning_from_scratch.ch02_ex import (
    generate_text_basic_stream_cache
)

prompt = (
    r"If $a+b=3$ and $ab=\tfrac{13}{6}$, "
    r"what is the value of $a^2+b^2$?"
)

# Similar to chapter 2 exercise solution:
input_token_ids_tensor = torch.tensor(
    tokenizer.encode(prompt),
    device=device
    ).unsqueeze(0)

all_token_ids = []
for token in generate_text_basic_stream_cache(
    model=model,
    token_ids=input_token_ids_tensor,
    max_new_tokens=2048,
    eos_token_id=tokenizer.eos_token_id
):
    token_id = token.squeeze(0)
    decoded_id = tokenizer.decode(token_id.tolist())
    print(
        decoded_id,
        end="",
        flush=True
    )
    all_token_ids.append(token_id)

all_tokens = tokenizer.decode(all_token_ids)

 To find the value of \( a^2 + b^2 \) given that \( a + b = 3 \) and \( ab = \frac{13}{6} \), we can use the following algebraic identity:

\[
a^2 + b^2 = (a + b)^2 - 2ab
\]

**Step 1:** Substitute the given values into the equation.

\[
a^2 + b^2 = (3)^2 - 2 \left( \frac{13}{6} \right)
\]

**Step 2:** Calculate \( (3)^2 \).

\[
(3)^2 = 9
\]

**Step 3:** Calculate \( 2 \times \frac{13}{6} \).

\[
2 \times \frac{13}{6} = \frac{26}{6} = \frac{13}{3}
\]

**Step 4:** Subtract the second result from the first.

\[
a^2 + b^2 = 9 - \frac{13}{3}
\]

**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.

\[
9 = \frac{27}{3}
\]

\[
a^2 + b^2 = \frac{27}{3} - \frac{13}{3} = \frac{14}{3}
\]

**Final Answer:**

\[
\boxed{\dfrac{14}{3}}
\]

In [None]:
from IPython.display import Latex, display
display(Latex(all_tokens))

 To find the value of \( a^2 + b^2 \) given that \( a + b = 3 \) and \( ab = \frac{13}{6} \), we can use the following algebraic identity:

\[
a^2 + b^2 = (a + b)^2 - 2ab
\]

**Step 1:** Substitute the given values into the equation.

\[
a^2 + b^2 = (3)^2 - 2 \left( \frac{13}{6} \right)
\]

**Step 2:** Calculate \( (3)^2 \).

\[
(3)^2 = 9
\]

**Step 3:** Calculate \( 2 \times \frac{13}{6} \).

\[
2 \times \frac{13}{6} = \frac{26}{6} = \frac{13}{3}
\]

**Step 4:** Subtract the second result from the first.

\[
a^2 + b^2 = 9 - \frac{13}{3}
\]

**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.

\[
9 = \frac{27}{3}
\]

\[
a^2 + b^2 = \frac{27}{3} - \frac{13}{3} = \frac{14}{3}
\]

**Final Answer:**

\[
\boxed{\dfrac{14}{3}}
\]

In [20]:
from IPython.display import Math

display(Math(r"\dfrac{14}{3}"))

<IPython.core.display.Math object>

In [25]:
def generate_text_stream_concat(
    model, tokenizer, prompt, device, max_new_tokens,
    verbose=False,
):
    input_ids = torch.tensor(
        tokenizer.encode(prompt), device=device
        ).unsqueeze(0)

    generated_ids = []
    for token in generate_text_basic_stream_cache(
        model=model,
        token_ids=input_ids,
        max_new_tokens=max_new_tokens,
        eos_token_id=tokenizer.eos_token_id,
    ):
        next_token_id = token.squeeze(0)
        generated_ids.append(next_token_id.item())

        if verbose:
            print(
                tokenizer.decode(next_token_id.tolist()),
                end="",
                flush=True
            )
    return tokenizer.decode(generated_ids)


skip_portion = False

if not skip_portion:
    generated_text = generate_text_stream_concat(
        model, tokenizer, prompt, device,
        max_new_tokens=2048,
        verbose=True
    )

 To find the value of \( a^2 + b^2 \) given that \( a + b = 3 \) and \( ab = \frac{13}{6} \), we can use the following algebraic identity:

\[
a^2 + b^2 = (a + b)^2 - 2ab
\]

**Step 1:** Substitute the given values into the equation.

\[
a^2 + b^2 = (3)^2 - 2 \left( \frac{13}{6} \right)
\]

**Step 2:** Calculate \( (3)^2 \).

\[
(3)^2 = 9
\]

**Step 3:** Calculate \( 2 \times \frac{13}{6} \).

\[
2 \times \frac{13}{6} = \frac{26}{6} = \frac{13}{3}
\]

**Step 4:** Subtract the second result from the first.

\[
a^2 + b^2 = 9 - \frac{13}{3}
\]

**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.

\[
9 = \frac{27}{3}
\]

\[
a^2 + b^2 = \frac{27}{3} - \frac{13}{3} = \frac{14}{3}
\]

**Final Answer:**

\[
\boxed{\dfrac{14}{3}}
\]

In [27]:
generated_text = generate_text_stream_concat(
    model, tokenizer, prompt, device,
    max_new_tokens=2048,
    verbose=True  #A
)

 To find the value of \( a^2 + b^2 \) given that \( a + b = 3 \) and \( ab = \frac{13}{6} \), we can use the following algebraic identity:

\[
a^2 + b^2 = (a + b)^2 - 2ab
\]

**Step 1:** Substitute the given values into the equation.

\[
a^2 + b^2 = (3)^2 - 2 \left( \frac{13}{6} \right)
\]

**Step 2:** Calculate \( (3)^2 \).

\[
(3)^2 = 9
\]

**Step 3:** Calculate \( 2 \times \frac{13}{6} \).

\[
2 \times \frac{13}{6} = \frac{26}{6} = \frac{13}{3}
\]

**Step 4:** Subtract the second result from the first.

\[
a^2 + b^2 = 9 - \frac{13}{3}
\]

**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.

\[
9 = \frac{27}{3}
\]

\[
a^2 + b^2 = \frac{27}{3} - \frac{13}{3} = \frac{14}{3}
\]

**Final Answer:**

\[
\boxed{\dfrac{14}{3}}
\]

In [38]:
# ============================================
# IMPORTS
# ============================================
from pathlib import Path
import torch
from IPython.display import display, Markdown

from reasoning_from_scratch.qwen3 import (
    download_qwen3_small,
    Qwen3Tokenizer,
    Qwen3Model,
    QWEN_CONFIG_06_B
)
from reasoning_from_scratch.ch02 import get_device
from reasoning_from_scratch.ch02_ex import generate_text_basic_stream_cache


# ============================================
# LOAD MODEL AND TOKENIZER
# ============================================
device = get_device()

download_qwen3_small(kind="base", tokenizer_only=False, out_dir="qwen3")

tokenizer = Qwen3Tokenizer(
    tokenizer_file_path=Path("qwen3") / "tokenizer-base.json"
)

model = Qwen3Model(QWEN_CONFIG_06_B)
model.load_state_dict(torch.load(Path("qwen3") / "qwen3-0.6B-base.pth"))
model.to(device)


# ============================================
# DEFINE THE FUNCTION 
# ============================================
def generate_text_stream_concat(
    model, tokenizer, prompt, device, max_new_tokens,
    verbose=False,
):
    input_ids = torch.tensor(
        tokenizer.encode(prompt), device=device
    ).unsqueeze(0)

    generated_ids = []
    for token in generate_text_basic_stream_cache(
        model=model,
        token_ids=input_ids,
        max_new_tokens=max_new_tokens,
        eos_token_id=tokenizer.eos_token_id,
    ):
        next_token_id = token.squeeze(0)
        generated_ids.append(next_token_id.item())

        if verbose:
            print(
                tokenizer.decode(next_token_id.tolist()),
                end="",
                flush=True
            )

    return tokenizer.decode(generated_ids)


# ============================================
# USE THE FUNCTION
# ============================================
prompt = r"If $a+b=3$ and $ab=\tfrac{13}{6}$, what is the value of $a^2+b^2$?"
#prompt = r"Compute $2+2=$"

# Call with streaming output
result = generate_text_stream_concat(
    model=model,
    tokenizer=tokenizer,
    prompt=prompt,
    device=device,
    max_new_tokens=2048,
    verbose=True
)

result = result.replace(r"\[", "$$").replace(r"\]", "$$")
display(Markdown(result))



Using Apple Silicon GPU (MPS)
✓ qwen3/qwen3-0.6B-base.pth already up-to-date
 To find the value of \( a^2 + b^2 \) given that \( a + b = 3 \) and \( ab = \frac{13}{6} \), we can use the following algebraic identity:

\[
a^2 + b^2 = (a + b)^2 - 2ab
\]

**Step 1:** Substitute the given values into the equation.

\[
a^2 + b^2 = (3)^2 - 2 \left( \frac{13}{6} \right)
\]

**Step 2:** Calculate \( (3)^2 \).

\[
(3)^2 = 9
\]

**Step 3:** Calculate \( 2 \times \frac{13}{6} \).

\[
2 \times \frac{13}{6} = \frac{26}{6} = \frac{13}{3}
\]

**Step 4:** Subtract the second result from the first.

\[
a^2 + b^2 = 9 - \frac{13}{3}
\]

**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.

\[
9 = \frac{27}{3}
\]

\[
a^2 + b^2 = \frac{27}{3} - \frac{13}{3} = \frac{14}{3}
\]

**Final Answer:**

\[
\boxed{\dfrac{14}{3}}
\]

 To find the value of \( a^2 + b^2 \) given that \( a + b = 3 \) and \( ab = \frac{13}{6} \), we can use the following algebraic identity:

$$
a^2 + b^2 = (a + b)^2 - 2ab
$$

**Step 1:** Substitute the given values into the equation.

$$
a^2 + b^2 = (3)^2 - 2 \left( \frac{13}{6} \right)
$$

**Step 2:** Calculate \( (3)^2 \).

$$
(3)^2 = 9
$$

**Step 3:** Calculate \( 2 \times \frac{13}{6} \).

$$
2 \times \frac{13}{6} = \frac{26}{6} = \frac{13}{3}
$$

**Step 4:** Subtract the second result from the first.

$$
a^2 + b^2 = 9 - \frac{13}{3}
$$

**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.

$$
9 = \frac{27}{3}
$$

$$
a^2 + b^2 = \frac{27}{3} - \frac{13}{3} = \frac{14}{3}
$$

**Final Answer:**

$$
\boxed{\dfrac{14}{3}}
$$

In [32]:
# Answer we want to extract 

model_answer = (
r"""... some explanation...
**Final Answer:**
 
\[                     #A
\boxed{\dfrac{14}{3}}  #A
\]                     #A
""") 
    
     #A The answer box we want to extract

In [41]:

print(model_answer)

... some explanation...
**Final Answer:**

$$                     #A
\boxed{\dfrac{14}{3}}  #A
$$                     #A



In [34]:
def get_last_boxed(text):
    boxed_start_idx = text.rfind(r"\boxed")  #A
    if boxed_start_idx == -1:
        return None
 
    current_idx = boxed_start_idx + len(r"\boxed")  #B
 
    #C
    while current_idx < len(text) and text[current_idx].isspace():
        current_idx += 1
 
    #D
    if current_idx >= len(text) or text[current_idx] != "{":
        return None
 
    current_idx += 1
    brace_depth = 1
    content_start_idx = current_idx
 
    #E
    while current_idx < len(text) and brace_depth > 0:
        char = text[current_idx]
        if char == "{":
            brace_depth += 1
        elif char == "}":
            brace_depth -= 1
        current_idx += 1
 
    
    if brace_depth != 0:  #F
        return None
 
    
    return text[content_start_idx:current_idx-1]  #G 
    
     #A Find the last occurrence of "\boxed"
     #B Get position after "\boxed"
     #C Skip any whitespace after "\boxed"
     #D Expect an opening brace "{"
     #E Parse the braces with nesting
     #F Account for unbalanced braces
     #G Extract content inside the outermost braces

In [35]:
extracted_answer = get_last_boxed(model_answer)
print(extracted_answer)

\dfrac{14}{3}


In [42]:
import re
 
RE_NUMBER = re.compile(
    r"-?(?:\d+/\d+|\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)"
)
 
def extract_final_candidate(text, fallback="number_then_full"):
    
    result = ""  #A
 
    if text:  #B
        boxed = get_last_boxed(text.strip())
        if boxed:
            result = boxed.strip().strip("$ ")
 
        #C
        elif fallback in ("number_then_full", "number_only"):
            m = RE_NUMBER.findall(text)
            if m:
                result = m[-1]  #D
            elif fallback == "number_then_full":
                
                result = text  #E
    return result 
    
     #A Default return value if nothing matches
     #B Prefer the last boxed expression if present
     #C If no boxed expression, try fallback
     #D Use last number
     #E Else return full text if no number found

In [43]:
print(extract_final_candidate(model_answer))

\dfrac{14}{3}


In [48]:
print(extract_final_candidate(r"\boxed{ 14/3. }"))

14/3.


In [49]:
print(extract_final_candidate("abc < > 14/3 abc"))

14/3


In [50]:
# normalize the answers

LATEX_FIXES = [  # Latex formatting to be replaced
    (r"\\left\s*", ""),
    (r"\\right\s*", ""),
    (r"\\,|\\!|\\;|\\:", ""),
    (r"\\cdot", "*"),
    (r"\u00B7|\u00D7", "*"),
    (r"\\\^\\circ", ""),
    (r"\\dfrac", r"\\frac"),
    (r"\\tfrac", r"\\frac"),
    (r"°", ""),
]

RE_SPECIAL = re.compile(r"<\|[^>]+?\|>")  # strip chat special tokens like <|assistant|>

def normalize_text(text):
    if not text:
        return ""
    text = RE_SPECIAL.sub("", text).strip()
    SUPERSCRIPT_MAP = {
        "⁰": "0", "¹": "1", "²": "2", "³": "3", "⁴": "4",
        "⁵": "5", "⁶": "6", "⁷": "7", "⁸": "8", "⁹": "9",
        "⁺": "+", "⁻": "-", "⁽": "(", "⁾": ")",
    }

    # Strip leading multiple-choice labels
    # E.g., like "c. 3" -> 3, or "b: 2" -> 2
    match = re.match(r"^[A-Za-z]\s*[.:]\s*(.+)$", text)
    if match:
        text = match.group(1)
        
    # Remove angle-degree markers
    text = re.sub(r"\^\s*\{\s*\\circ\s*\}", "", text)   # ^{\circ}
    text = re.sub(r"\^\s*\\circ", "", text)             # ^\circ
    text = text.replace("°", "")                        # Unicode degree

    # unwrap \text{...} if the whole string is wrapped
    match = re.match(r"^\\text\{(?P<x>.+?)\}$", text)
    if match:
        text = match.group("x")

    # strip inline/display math wrappers \( \) \[ \]
    text = re.sub(r"\\\(|\\\)|\\\[|\\\]", "", text)

    # light LaTeX canonicalization
    for pat, rep in LATEX_FIXES:
        text = re.sub(pat, rep, text)

    def convert_superscripts(s, base=None):
        converted = "".join(
            SUPERSCRIPT_MAP[ch] if ch in SUPERSCRIPT_MAP else ch
            for ch in s
        )
        if base is None:
            return converted
        return f"{base}**{converted}"

    # convert unicode superscripts into exponent form (e.g., 2² -> 2**2)m
    text = re.sub(
        r"([0-9A-Za-z\)\]\}])([⁰¹²³⁴⁵⁶⁷⁸⁹⁺⁻]+)",
        lambda m: convert_superscripts(m.group(2), base=m.group(1)),
        text,
    )
    text = convert_superscripts(text)
    
    # numbers/roots
    text = text.replace("\\%", "%").replace("$", "").replace("%", "")
    text = re.sub(
        r"\\sqrt\s*\{([^}]*)\}",
        lambda match: f"sqrt({match.group(1)})",
        text,
    )
    text = re.sub(
        r"\\sqrt\s+([^\\\s{}]+)",
        lambda match: f"sqrt({match.group(1)})",
        text,
    )

    # fractions
    text = re.sub(
        r"\\frac\s*\{([^{}]+)\}\s*\{([^{}]+)\}",
        lambda match: f"({match.group(1)})/({match.group(2)})",
        text,
    )
    text = re.sub(
        r"\\frac\s+([^\s{}]+)\s+([^\s{}]+)",
        lambda match: f"({match.group(1)})/({match.group(2)})",
        text,
    )

    # exponent and mixed numbers
    text = text.replace("^", "**")
    text = re.sub(
        r"(?<=\d)\s+(\d+/\d+)",
        lambda match: "+" + match.group(1),
        text,
    )

    # 1,234 -> 1234
    text = re.sub(
        r"(?<=\d),(?=\d\d\d(\D|$))",
        "",
        text,
    )

    return text.replace("{", "").replace("}", "").strip().lower()

In [51]:
print(normalize_text(extract_final_candidate(model_answer)))

(14)/(3)


In [52]:
print(normalize_text(r"$\dfrac{14}{3.}$"))

(14)/(3.)


In [53]:
print(normalize_text(r"\text{\[\frac{14}{3}\]}"))

(14)/(3)


In [54]:
print(normalize_text("4/3"))

4/3


In [57]:
'''
The grading process:
Get the model's answer from \boxed{}
Clean it up (normalize)
Use SymPy to check if it equals the correct answer
Mark as correct if they match mathematically
'''

"\nThe grading process:\nGet the model's answer from \x08oxed{}\nClean it up (normalize)\nUse SymPy to check if it equals the correct answer\nMark as correct if they match mathematically\n"

In [67]:
from sympy.parsing import sympy_parser as spp
from sympy.core.sympify import SympifyError
from tokenize import TokenError
 
def sympy_parser(expr):
    try:
        return spp.parse_expr(
            expr,
            transformations=(
                *spp.standard_transformations,  #A
                #B
                spp.implicit_multiplication_application,
            ),
 
            evaluate=True,  #C
        )
    except (SympifyError, SyntaxError, TypeError, IndexError, TokenError):
        return None 
    
     #A Standard transformations like handling parentheses
     #B Allow omitted multiplication symbols (e.g., 2y -> 2*y)
     #C Evaluate during parsing so simple constants simplify (e.g., 2+3 -> 5)

'''
     From the text: The sympy_parser function in listing 3.7 takes an input expression, such as the normalized answers we extract from the LLM response, and converts it into a SymPy object that can be reliably compared for mathematical equivalence. 
     To do so, it applies SymPy's standard parsing rules, supports implicit multiplication like (2y instead of 2*y), and also simplifies basic arithmetic (so 2+3 becomes 5). 
'''
   
 
  

"\n     From the text: The sympy_parser function in listing 3.7 takes an input expression, such as the normalized answers we extract from the LLM response, and converts it into a SymPy object that can be reliably compared for mathematical equivalence. \n     To do so, it applies SymPy's standard parsing rules, supports implicit multiplication like (2y instead of 2*y), and also simplifies basic arithmetic (so 2+3 becomes 5). \n"

In [68]:
print(sympy_parser(normalize_text(
    extract_final_candidate(model_answer)
)))

14/3


In [69]:
print(sympy_parser("28/6"))

14/3


In [70]:
#Equality check function using SymPy 
    
from sympy import simplify
 
def equality_check(expr_gtruth, expr_pred):
    if expr_gtruth == expr_pred:  #A
        return True
 
    #B
    gtruth, pred = sympy_parser(expr_gtruth), sympy_parser(expr_pred)
 
    if gtruth is not None and pred is not None:  #C
        try:
            return simplify(gtruth - pred) == 0  #D
        except (SympifyError, TypeError):
            pass
 
    return False 
    
     #A First, check if the two expressions are exactly the same string
     #B Parse both expressions into SymPy objects (returns None if parsing fails)
     #C If both expressions were parsed successfully, try symbolic comparison
     #D If the difference is 0, they are equivalent”

In [71]:
print(equality_check(
    normalize_text("13/4."),
    normalize_text(r"(13)/(4)")
))

True


In [72]:
print(equality_check(
    normalize_text("13/4."),
    normalize_text(r"(1)/(4)")
))

False


In [73]:
from sympy import simplify
 
def equality_check(expr_gtruth, expr_pred):
    if expr_gtruth == expr_pred:  #A
        return True
 
    #B
    gtruth, pred = sympy_parser(expr_gtruth), sympy_parser(expr_pred)
 
    if gtruth is not None and pred is not None:  #C
        try:
            return simplify(gtruth - pred) == 0  #D
        except (SympifyError, TypeError):
            pass
 
    return False

In [75]:
print(equality_check(
    normalize_text("13/4."),
    normalize_text(r"(13)/(4)")
))


True


In [76]:
print(equality_check(
    normalize_text("0.5"),
    normalize_text(r"(1)/(2)")
))  
    

True


In [77]:
print(equality_check(
    normalize_text("14/3"),
    normalize_text("15/3")
))

False


In [None]:
print(equality_check(
    normalize_text("(14/3, 2/3)"),
    normalize_text("(14/3, 4/6)")
))
'''
Function does not support tuples. Hnece returns false.
'''


False


In [80]:
def split_into_parts(text):
    result = [text]
 
    if text:  #A
        if (
            len(text) >= 2
            and text[0] in "([" and text[-1] in ")]"
            and "," in text[1:-1]
        ):
            items = [p.strip() for p in text[1:-1].split(",")]  #B
            if all(items):
                result = items
    else:  #C
        result = []
 
    return result

In [81]:
split_into_parts(normalize_text(r"(14/3, 2/3)"))

['14/3', '2/3']

In [82]:
def grade_answer(pred_text, gt_text):
    result = False   #A
    if pred_text is not None and gt_text is not None:  #B
        gt_parts = split_into_parts(
            normalize_text(gt_text)
        )
        pred_parts = split_into_parts(
            normalize_text(pred_text)
        )
 
        if (gt_parts and pred_parts                #C
           and len(gt_parts) == len(pred_parts)):  #C
            result = all(
                equality_check(gt, pred)
                for gt, pred in zip(gt_parts, pred_parts)
            )  #D
 
    return result  #E
  
     #A Default outcome if checks fail
     #B Only continue if both inputs are non-empty strings
     #C Ensure both sides have same number of valid parts
     #D Check each part for mathematical equivalence
     #E True only if all checks passed

In [83]:
grade_answer(r"(14/3, 2/3)", "(14/3, 4/6)")


True

In [84]:
grade_answer("14/3", r"\frac{14}{3}")



True

In [85]:
tests = [  #A
        ("check_1", "3/4", r"\frac{3}{4}", True),
        ("check_2", "(3)/(4)", r"3/4", True),
        ("check_3", r"\frac{\sqrt{8}}{2}", "sqrt(2)", True),
        ("check_4", r"\( \frac{1}{2} + \frac{1}{6} \)", "2/3", True),
        ("check_5", "(1, 2)", r"(1,2)", True),
        ("check_6", "(2, 1)", "(1, 2)", False),
        ("check_7", "(1, 2, 3)", "(1, 2)", False),
        ("check_8", "0.5", "1/2", True),
        ("check_9", "0.3333333333", "1/3", False),
        ("check_10", "1,234/2", "617", True),
        ("check_11", r"\text{2/3}", "2/3", True),
        ("check_12", "50%", "1/2", False),
        ("check_13", r"2\cdot 3/4", "3/2", True),
        ("check_14", r"90^\circ", "90", True),
        ("check_15", r"\left(\frac{3}{4}\right)", "3/4", True),
    ]
 
 
def run_demos_table(tests):
    header = ("Test", "Expect", "Got", "Status")
    rows = []
    for name, pred, gtruth, expect in tests:
        got = grade_answer(pred, gtruth)  #B
        status = "PASS" if got == expect else "FAIL"
        rows.append((name, str(expect), str(got), status))
 
    data = [header] + rows
    
    col_widths = [  #C
        max(len(row[i]) for row in data)
        for i in range(len(header))
    ]
 
    for row in data:  #D
        line = " | ".join(
            row[i].ljust(col_widths[i])
            for i in range(len(header))
        )
        print(line)
 
    passed = sum(r[3] == "PASS" for r in rows)  #E
    print(f"\nPassed {passed}/{len(rows)}")     #E 
    
   #A Define test cases: (name, prediction, ground truth, expected result)
     #B Run equality check
     #C Compute max width for each column to align table nicely
     #D Print table row by row
     #E Print summary of passed tests
     


In [88]:
run_demos_table(tests)

Test     | Expect | Got   | Status
check_1  | True   | True  | PASS  
check_2  | True   | True  | PASS  
check_3  | True   | True  | PASS  
check_4  | True   | True  | PASS  
check_5  | True   | True  | PASS  
check_6  | False  | False | PASS  
check_7  | False  | False | PASS  
check_8  | True   | True  | PASS  
check_9  | False  | False | PASS  
check_10 | True   | True  | PASS  
check_11 | True   | True  | PASS  
check_12 | False  | False | PASS  
check_13 | True   | True  | PASS  
check_14 | True   | True  | PASS  
check_15 | True   | True  | PASS  

Passed 15/15


In [92]:
import json
import requests

def load_math500_test(local_path="math500_test.json", save_copy=True):
    local_path = Path(local_path)
    url = (
        "https://raw.githubusercontent.com/rasbt/reasoning-from-scratch/"
        "main/ch03/01_main-chapter-code/math500_test.json"
    )

    if local_path.exists():
        with local_path.open("r", encoding="utf-8") as f:
            data = json.load(f)
    else:
        r = requests.get(url, timeout=30)
        r.raise_for_status()
        data = r.json()

        if save_copy:  # Saves a local copy
            with local_path.open("w", encoding="utf-8") as f:
                json.dump(data, f, indent=2)

    return data

math_data = load_math500_test()
print("Number of entries:", len(math_data))

Number of entries: 500


In [94]:
from pprint import pprint
pprint(math_data[0])

{'answer': '\\left( 3, \\frac{\\pi}{2} \\right)',
 'level': 2,
 'problem': 'Convert the point $(0,3)$ in rectangular coordinates to polar '
            'coordinates.  Enter your answer in the form $(r,\\theta),$ where '
            '$r > 0$ and $0 \\le \\theta < 2 \\pi.$',
 'solution': 'We have that $r = \\sqrt{0^2 + 3^2} = 3.$  Also, if we draw the '
             'line connecting the origin and $(0,3),$ this line makes an angle '
             'of $\\frac{\\pi}{2}$ with the positive $x$-axis.\n'
             '\n'
             '[asy]\n'
             'unitsize(0.8 cm);\n'
             '\n'
             'draw((-0.5,0)--(3.5,0));\n'
             'draw((0,-0.5)--(0,3.5));\n'
             'draw(arc((0,0),3,0,90),red,Arrow(6));\n'
             '\n'
             'dot((0,3), red);\n'
             'label("$(0,3)$", (0,3), W);\n'
             'dot((3,0), red);\n'
             '[/asy]\n'
             '\n'
             'Therefore, the polar coordinates are $\\boxed{\\left( 3, '
             '\\frac

In [99]:
def render_prompt(prompt):
    template = (
        "You are a helpful math assistant.\n"
        "Answer the question and write the final result on a new line as:\n"
        "\\boxed{ANSWER}\n\n"
        f"Question:\n{prompt}\n\nAnswer:"
    )
    return template
 
# ============================================
# IMPORTS 
# ============================================
from pathlib import Path
import torch
from reasoning_from_scratch.qwen3 import (
    download_qwen3_small,
    Qwen3Tokenizer,
    Qwen3Model,
    QWEN_CONFIG_06_B
)
from reasoning_from_scratch.ch02 import get_device
from reasoning_from_scratch.ch02_ex import generate_text_basic_stream_cache


# ============================================
# LOAD MODEL 
# ============================================
device = get_device()
download_qwen3_small(kind="base", tokenizer_only=False, out_dir="qwen3")
tokenizer = Qwen3Tokenizer(tokenizer_file_path=Path("qwen3") / "tokenizer-base.json")
model = Qwen3Model(QWEN_CONFIG_06_B)
model.load_state_dict(torch.load(Path("qwen3") / "qwen3-0.6B-base.pth"))
model.to(device)


# ============================================
# HELPER FUNCTIONS 
# ============================================
# - generate_text_stream_concat
# - render_prompt
# - extract_final_candidate
# - grade_answer


# ============================================
# DEFINE THE DEMO FUNCTION
# ============================================
def mini_eval_demo(model, tokenizer, device):
    ex = {
        "problem": "Compute 1/2 + 1/6.",
        "answer": "2/3"
    }
    prompt = render_prompt(ex["problem"])
    gen_text = generate_text_stream_concat(
        model, tokenizer, prompt, device,
        max_new_tokens=64,
    )
    pred_answer = extract_final_candidate(gen_text)
    is_correct = grade_answer(pred_answer, ex["answer"])
    
    print(f"Device: {device}")
    print(f"Prediction: {pred_answer}")
    print(f"Ground truth: {ex['answer']}")
    print(f"Correct: {is_correct}")


# ============================================
# RUN IT
# ============================================
mini_eval_demo(model, tokenizer, device)

Using Apple Silicon GPU (MPS)
✓ qwen3/qwen3-0.6B-base.pth already up-to-date
Device: mps
Prediction: 1/3
Ground truth: 2/3
Correct: False
