In [1]:
# Install Dependencies
!pip install wikipedia langchain langchain-community langchain-text-splitters
!pip install sentence-transformers faiss-cpu transformers accelerate

Collecting wikipedia
  Downloading wikipedia-1.4.0.tar.gz (27 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting langchain-community
  Downloading langchain_community-0.4.1-py3-none-any.whl.metadata (3.0 kB)
Collecting langchain-text-splitters
  Downloading langchain_text_splitters-1.0.0-py3-none-any.whl.metadata (2.6 kB)
Collecting langchain-classic<2.0.0,>=1.0.0 (from langchain-community)
  Downloading langchain_classic-1.0.0-py3-none-any.whl.metadata (3.9 kB)
Collecting requests<3.0.0,>=2.0.0 (from wikipedia)
  Downloading requests-2.32.5-py3-none-any.whl.metadata (4.9 kB)
Collecting dataclasses-json<0.7.0,>=0.6.7 (from langchain-community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7.0,>=0.6.7->langchain-community)
  Downloading marshmallow-3.26.1-py3-none-any.whl.metadata (7.3 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7.0,>=0.6.7->langchain-community)


In [2]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="wikipedia")

import re
import wikipedia
import torch

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS

from transformers import AutoTokenizer, AutoModelForCausalLM


Get Wikipedia + Text Cleaning + Text Splitting

In [3]:
topics = [
    "Algebra", "Calculus", "Derivative", "Integral",
    "Matrix (mathematics)", "Probability", "Statistics",
    "Geometry", "Trigonometry", "Number theory"
]

def prune_wiki(text: str) -> str:
    stop_markers = [
        "== See also ==",
        "== References ==",
        "== External links ==",
        "== Further reading =="
    ]
    for marker in stop_markers:
        idx = text.find(marker)
        if idx != -1:
            text = text[:idx]
    return text

raw_docs = []
for topic in topics:
    try:
        page = wikipedia.page(topic)
        content = prune_wiki(page.content)
        raw_docs.append({"title": topic, "content": content})
        print("Collected:", topic)
    except Exception as e:
        print(f"Skipped {topic}: {e}")

splitter = RecursiveCharacterTextSplitter(
    chunk_size=800,
    chunk_overlap=150,
    separators=["\n\n", "\n", ". ", " ", ""],
)

docs = []
for d in raw_docs:
    chunks = splitter.create_documents(
        [d["content"]],
        metadatas=[{"title": d["title"]}]
    )
    docs.extend(chunks)

print("Total chunks:", len(docs))
print("Example chunk:\n", docs[0].page_content[:400], "...")
print("Metadata:", docs[0].metadata)


Collected: Algebra
Collected: Calculus
Skipped Derivative: "derivation" may refer to: 
Morphological derivation
Parse tree
Derivative work
Derivation proceeding
derived row
Derivation (differential  algebra)
Formal proof
Vilfredo Pareto
Derive (disambiguation)
Derivative
Derivative (disambiguation)
All pages with titles containing Derivation
Collected: Integral
Collected: Matrix (mathematics)
Collected: Probability
Collected: Statistics
Collected: Geometry
Collected: Trigonometry
Collected: Number theory
Total chunks: 851
Example chunk:
 Algeria, officially the People's Democratic Republic of Algeria, is a country in the Maghreb region of North Africa. It covers an area of over 2,381,741 square kilometres (919,595 sq mi), and is the largest country in Africa and the tenth-largest country in the world by land area. With a population of 47 million, Algeria is the tenth-most populous country in Africa. It is bordered to the northeast ...
Metadata: {'title': 'Algebra'}


Construct Embeddings + FAISS vectorsctore

In [4]:
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"

embeddings = HuggingFaceEmbeddings(
    model_name=embedding_model_name
)

vectorstore = FAISS.from_documents(docs, embeddings)

vectorstore.save_local("wiki_math_faiss")

retriever = vectorstore.as_retriever(
    search_kwargs={"k": 5}
)

print("Vectorstore ready.")

  embeddings = HuggingFaceEmbeddings(
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Vectorstore ready.


Load our Qwen3-0.6B model

In [5]:
model_id = "Qwen/Qwen3-1.7B"

tokenizer = AutoTokenizer.from_pretrained(model_id)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto",
)

model.eval()

print(f"Loaded {model_id}.")

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/622M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Loaded Qwen/Qwen3-1.7B.


In [6]:
def build_rag_prompt(question: str, context_docs, strategy: int = 2) -> str:
    """
    Build RAG prompt for math problems under two strategies:
    strategy 1 — Competition-math optimized prompt
    strategy 2 — Minimal context-only prompt
    """

    # -----------------------------
    # Merge retrieved docs
    # -----------------------------
    context_blocks = []
    for i, doc in enumerate(context_docs):
        title = doc.metadata.get("title", "Wikipedia")
        context_blocks.append(
            f"[Source {i+1} — {title}]\n{doc.page_content}"
        )
    context_str = "\n\n".join(context_blocks)

    # ============================================================
    # STRATEGY 1 — Detailed competition math prompt
    # ============================================================
    if strategy == 1:
        prompt = f"""
                      You are a math assistant specializing in solving competition-level algebra problems.
                      You will be given:
                      1) A user problem.
                      2) Several reference passages retrieved from Wikipedia.

                      ---------------------------------
                      REFERENCE MATERIAL (RAG CONTEXT)
                      ---------------------------------
                      {context_str}

                      -------------------------
                      YOUR TASK INSTRUCTIONS
                      -------------------------
                      • Use the reference material **only if directly relevant**.
                      • If irrelevant, **ignore it** and solve normally.
                      • Provide a **concise, logically correct step-by-step solution** in English.
                      • Use LaTeX for math expressions.
                      • The solution should be ≤ **6 short lines** where possible.
                      • On the **very last line**, output **only the final result** as:

                          \\boxed{{final\_answer}}

                      No extra text is allowed after this line.

                      -------------------------
                      USER PROBLEM
                      -------------------------
                      {question}

                      -------------------------
                      BEGIN YOUR SOLUTION
                      -------------------------
                  """
        return prompt

    # ============================================================
    # STRATEGY 2 — Very simple context-only prompt
    # ============================================================
    elif strategy == 2:
        prompt = (
            "You are a helpful math tutor.\n\n"
            "Use ONLY the following context to answer the question. "
            "On the very last line, output **only the final result** "
            "in the form: \\boxed{final_answer}.\n"
            "If the answer is not contained in the context, say you don't know.\n\n"
            f"Context:\n{context_str}\n\n"
            f"Question:\n{question}\n\n"
            "Answer:"
        )
        return prompt

    else:
        raise ValueError("strategy must be 1 or 2")



  """


In [7]:
def answer_with_rag(question: str, max_new_tokens: int = 256, temperature: float = 0.3, top_p: float = 0.9):

    context_docs = retriever.invoke(question)
    prompt = build_rag_prompt(question, context_docs)
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # 4. Qwen generate
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            repetition_penalty=1.15,
            eos_token_id=tokenizer.eos_token_id,
        )

    generated = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    if "<|im_start|>assistant" in generated:
        answer = generated.split("<|im_start|>assistant")[-1]
    else:
        answer = generated

    answer = answer.split("<|im_end|>")[0].strip()

    return answer, context_docs



In [9]:
import json
from collections import defaultdict

with open("test_math.json", "r", encoding="utf-8") as f:
    testdata_all = json.load(f)

testdata_dict = defaultdict(list)

for dt in testdata_all:
    testdata_dict[dt['level']].append(dt)

In [10]:
import re
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
import torch
from sympy import sympify, simplify

# print("Loading Hendrycks Math (Algebra)...")
# ds_algebra = load_dataset("EleutherAI/hendrycks_math", "algebra")

def extract_all_boxed(text):
    results = []
    i = 0
    key = r"\boxed"

    while True:
        start = text.find(key, i)
        if start == -1:
            break

        j = start + len(key)
        while j < len(text) and text[j].isspace():
            j += 1

        if j >= len(text) or text[j] != '{':
            i = start + 1
            continue

        depth = 0
        content_start = j + 1

        for k in range(content_start, len(text)):
            if text[k] == '{':
                depth += 1
            elif text[k] == '}':
                if depth == 0:
                    results.append(text[content_start:k].strip())
                    i = k + 1
                    break
                depth -= 1
        else:
            break

    return results


def extract_last_boxed(text):
    all_boxed = extract_all_boxed(text)
    return all_boxed[-1] if all_boxed else None


def normalize_ans(ans: str):

    if ans is None:
        return None

    # 统一 lower
    ans = ans.strip()

    # 去掉所有空格
    ans = ans.replace(" ", "")

    # 去掉外层括号 ( ) => (x) -> x
    if ans.startswith("(") and ans.endswith(")"):
        ans = ans[1:-1]

    # \dfrac -> frac
    ans = ans.replace("\\dfrac", "\\frac")

    # 把 \frac{a}{b} -> (a)/(b)
    # 这样 sympify 才能吃进去
    # e.g. \frac{1}{2} -> (1)/(2)
    ans = re.sub(r'\\frac\{(.+?)\}\{(.+?)\}', r'(\1)/(\2)', ans)

    # \frac12 （没有花括号）
    ans = re.sub(r'\\frac(\d+)(\d+)', r'(\1)/(\2)', ans)

    # 删除冗余的 \left \right
    ans = ans.replace("\\left", "").replace("\\right", "")

    # 尝试用 sympy 简化
    try:
        simp = simplify(sympify(ans))
        ans = str(simp)
    except:
        pass

    return ans


def extract_ref_answer(ref_text):
    return extract_last_boxed(ref_text)


In [11]:
def generate_rag_answer(problem, max_new_tokens=512):
    """
    用 RAG pipeline 生成答案，并提取 boxed。
    并打印所有 boxed 与最终 boxed（用于 debug）
    """
    answer, ctx_docs = answer_with_rag(problem)

    # ALL boxed answers
    all_boxes = extract_all_boxed(answer)
    last_box = extract_last_boxed(answer)

    # Leave some messages for debug
    debug_output = []
    debug_output.append("====== RAG MODEL RAW OUTPUT ======")
    debug_output.append(answer)
    debug_output.append("\n====== ALL COLLECTED BOXED ======")
    if all_boxes:
        for idx, b in enumerate(all_boxes, 1):
            debug_output.append(f"{idx}. {b}")
    else:
        debug_output.append("No boxed answers found.")
    debug_output.append("\n====== LAST BOXED (FINAL) ======")
    debug_output.append(str(last_box))
    debug_output.append("=================================")

    full_debug_text = "\n".join(debug_output)

    if last_box is None or last_box.strip() == "":
        return "Null", full_debug_text

    return last_box.strip(), full_debug_text

def exact_match(preds, refs, verbose=True):

    n = len(preds)

    correct = 0
    total = 0

    for pred, ref in zip(preds, refs):
        ref_box = extract_ref_answer(ref)

        pred_norm = normalize_ans(pred)
        ref_norm = normalize_ans(ref_box)

        if pred_norm is None or pred_norm == "Null":
            continue

        total += 1
        if pred_norm == ref_norm:
            correct += 1
    acc = correct / max(total, 1)
    if verbose:
        print(f"\n📌 RAG Exact Match Accuracy on {n} problems = {acc:.4f} [{correct}/{max(total, 1)}]")

    return acc

def run_rag_eval(n=20, prob_level='Level 2', show_debug=False, save=True, save_path="rag_debug_log.txt"):

    preds = []
    refs = []
    outputs = []

    debug_lines = []

    for i in tqdm(range(n)):

        # item = ds_algebra["test"][i]
        item = testdata_dict[prob_level][i]
        problem = item["problem"]
        reference = item["output"]

        pred, full_output = generate_rag_answer(problem)
        ref_box = extract_ref_answer(reference)

        preds.append(pred)
        refs.append(reference)
        outputs.append(full_output)


        debug_lines.append("\n=====================")
        debug_lines.append(f"Problem {i+1}")
        debug_lines.append("---------------------")
        debug_lines.append(f"Problem:\n{problem}")
        debug_lines.append("")
        debug_lines.append(full_output)
        debug_lines.append("")
        debug_lines.append(f"Final Pred Boxed  : {pred}")
        debug_lines.append(f"Ref Full Solution : {reference}")
        debug_lines.append(f"Ref Boxed         : {ref_box}")
        same = (normalize_ans(pred) == normalize_ans(ref_box))
        debug_lines.append(f"Exact Match       : {same}")
        debug_lines.append("=====================\n")

        if show_debug:
            print("\n".join(debug_lines[-30:]))

    acc = exact_match(preds, refs)

    if save:
        with open(save_path, "w", encoding="utf-8") as f:
            f.write("\n".join(debug_lines))
            f.write("\n RAG Exact Match Accuracy on {n} problems = {acc:.4f}")
        print(f"\n🔍 Debug saved to: {save_path}")

    return acc, preds, refs, outputs



# acc, preds, refs, outputs = run_rag_eval(30, show_debug=False, save=True)

In [16]:
acc, preds, refs, outputs = run_rag_eval(50, prob_level = 'Level 1', show_debug=False, save=True, save_path='Qwen1.7B_rag_level_1.txt')

100%|██████████| 50/50 [11:02<00:00, 13.26s/it]


📌 RAG Exact Match Accuracy on 50 problems = 0.6923 [27/39]

🔍 Debug saved to: Qwen1.7B_rag_level_1.txt





In [12]:
acc, preds, refs, outputs = run_rag_eval(30, prob_level = 'Level 2', show_debug=False, save=True, save_path='Qwen1.7B_rag_level_2.txt')

100%|██████████| 30/30 [07:16<00:00, 14.53s/it]


📌 RAG Exact Match Accuracy on 30 problems = 0.3333 [8/24]

🔍 Debug saved to: Qwen1.7B_rag_level_2.txt





In [13]:
acc, preds, refs, outputs = run_rag_eval(30, prob_level = 'Level 3', show_debug=False, save=True, save_path='Qwen1.7B_rag_level_3.txt')

100%|██████████| 30/30 [06:52<00:00, 13.76s/it]


📌 RAG Exact Match Accuracy on 30 problems = 0.2353 [4/17]

🔍 Debug saved to: Qwen1.7B_rag_level_3.txt





In [14]:
acc, preds, refs, outputs = run_rag_eval(50, prob_level = 'Level 4', show_debug=False, save=True, save_path='Qwen1.7B_rag_level_4.txt')

100%|██████████| 50/50 [11:33<00:00, 13.86s/it]


📌 RAG Exact Match Accuracy on 50 problems = 0.1667 [5/30]

🔍 Debug saved to: Qwen1.7B_rag_level_4.txt





In [15]:
acc, preds, refs, outputs = run_rag_eval(50, prob_level = 'Level 5', show_debug=False, save=True, save_path='Qwen1.7B_rag_level_5.txt')

100%|██████████| 50/50 [11:19<00:00, 13.59s/it]


📌 RAG Exact Match Accuracy on 50 problems = 0.0000 [0/35]

🔍 Debug saved to: Qwen1.7B_rag_level_5.txt



