<a href="https://colab.research.google.com/github/ywangumichigan/EECS595-Project/blob/main/Qwen3_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install Dependencies
!pip install wikipedia langchain langchain-community langchain-text-splitters
!pip install sentence-transformers faiss-cpu transformers accelerate

Collecting wikipedia
  Downloading wikipedia-1.4.0.tar.gz (27 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting langchain-community
  Downloading langchain_community-0.4.1-py3-none-any.whl.metadata (3.0 kB)
Collecting langchain-text-splitters
  Downloading langchain_text_splitters-1.0.0-py3-none-any.whl.metadata (2.6 kB)
Collecting langchain-classic<2.0.0,>=1.0.0 (from langchain-community)
  Downloading langchain_classic-1.0.0-py3-none-any.whl.metadata (3.9 kB)
Collecting requests<3.0.0,>=2.0.0 (from wikipedia)
  Downloading requests-2.32.5-py3-none-any.whl.metadata (4.9 kB)
Collecting dataclasses-json<0.7.0,>=0.6.7 (from langchain-community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7.0,>=0.6.7->langchain-community)
  Downloading marshmallow-3.26.1-py3-none-any.whl.metadata (7.3 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7.0,>=0.6.7->langchain-community)


In [2]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="wikipedia")

import re
import wikipedia
import torch

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS

from transformers import AutoTokenizer, AutoModelForCausalLM


Get Wikipedia + Text Cleaning + Text Splitting

In [3]:
topics = [
    "Algebra", "Calculus", "Derivative", "Integral",
    "Matrix (mathematics)", "Probability", "Statistics",
    "Geometry", "Trigonometry", "Number theory"
]

def prune_wiki(text: str) -> str:
    stop_markers = [
        "== See also ==",
        "== References ==",
        "== External links ==",
        "== Further reading =="
    ]
    for marker in stop_markers:
        idx = text.find(marker)
        if idx != -1:
            text = text[:idx]
    return text

raw_docs = []
for topic in topics:
    try:
        page = wikipedia.page(topic)
        content = prune_wiki(page.content)
        raw_docs.append({"title": topic, "content": content})
        print("Collected:", topic)
    except Exception as e:
        print(f"Skipped {topic}: {e}")

splitter = RecursiveCharacterTextSplitter(
    chunk_size=800,
    chunk_overlap=150,
    separators=["\n\n", "\n", ". ", " ", ""],
)

docs = []
for d in raw_docs:
    chunks = splitter.create_documents(
        [d["content"]],
        metadatas=[{"title": d["title"]}]
    )
    docs.extend(chunks)

print("Total chunks:", len(docs))
print("Example chunk:\n", docs[0].page_content[:400], "...")
print("Metadata:", docs[0].metadata)


Collected: Algebra
Collected: Calculus
Skipped Derivative: "derivation" may refer to: 
Morphological derivation
Parse tree
Derivative work
Derivation proceeding
derived row
Derivation (differential  algebra)
Formal proof
Vilfredo Pareto
Derive (disambiguation)
Derivative
Derivative (disambiguation)
All pages with titles containing Derivation
Collected: Integral
Collected: Matrix (mathematics)
Collected: Probability
Collected: Statistics
Collected: Geometry
Collected: Trigonometry
Collected: Number theory
Total chunks: 851
Example chunk:
 Algeria, officially the People's Democratic Republic of Algeria, is a country in the Maghreb region of North Africa. It covers an area of over 2,381,741 square kilometres (919,595 sq mi), and is the largest country in Africa and the tenth-largest country in the world by land area. With a population of 47 million, Algeria is the tenth-most populous country in Africa. It is bordered to the northeast ...
Metadata: {'title': 'Algebra'}


Construct Embeddings + FAISS vectorsctore

In [4]:
# 选择一个 embedding 模型（可以改成你喜欢的）
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"

embeddings = HuggingFaceEmbeddings(
    model_name=embedding_model_name
)

# 用 LangChain 的 FAISS 封装建立向量库
vectorstore = FAISS.from_documents(docs, embeddings)

# 可选：保存到本地
vectorstore.save_local("wiki_math_faiss")

retriever = vectorstore.as_retriever(
    search_kwargs={"k": 5}  # 检索 top-5 文档
)

print("Vectorstore ready.")

  embeddings = HuggingFaceEmbeddings(
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Vectorstore ready.


Load our Qwen3-0.6B model

In [5]:
model_id = "Qwen/Qwen3-0.6B"

tokenizer = AutoTokenizer.from_pretrained(model_id)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto",
)

model.eval()

print("Loaded Qwen3-0.6B.")

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Loaded Qwen3-0.6B.


In [6]:
def build_rag_prompt(question: str, context_docs) -> str:
    """
    Revised RAG prompt for Hendrycks Math tasks.
    """

    # 将所有 retrieved chunks 合并（每个 chunk 有 title）
    context_blocks = []
    for i, doc in enumerate(context_docs):
        context_blocks.append(
            f"[Source {i+1} — {doc.metadata.get('title', 'Wikipedia')}]\n{doc.page_content}"
        )
    context_str = "\n\n".join(context_blocks)

    prompt = f"""
              You are a math assistant specializing in solving competition-level algebra problems.
              You will be given:
              1) A user problem.
              2) Several reference passages retrieved from Wikipedia. These passages may contain useful definitions, identities, or formulas.

              ---------------------------------
              REFERENCE MATERIAL (RAG CONTEXT)
              ---------------------------------
              {context_str}

              -------------------------
              YOUR TASK INSTRUCTIONS
              -------------------------
              • First, use the reference material **only if it is directly relevant**.
              • If the reference is irrelevant, **ignore it** and solve the problem normally.
              • Provide a **concise, logically correct step-by-step solution** in English.
              • Use LaTeX for mathematical expressions.
              • The solution should be no more than **6 short lines** when possible.
              • On the **very last line**, output **only the final result** in the form:

                  \\boxed{{final\_answer}}

              No extra text is allowed after this line.

              -------------------------
              USER PROBLEM
              -------------------------
              {question}

              -------------------------
              BEGIN YOUR SOLUTION
              -------------------------
              """

    return prompt



  """


In [7]:
def answer_with_rag(question: str,
                    max_new_tokens: int = 256,
                    temperature: float = 0.3,
                    top_p: float = 0.9):

    # 1. 检索相关文档
    context_docs = retriever.invoke(question)

    # 2. 构建 ChatML prompt
    prompt = build_rag_prompt(question, context_docs)

    # 3. tokenize 输入
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # 4. Qwen generate
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            repetition_penalty=1.15,   # 关键
            eos_token_id=tokenizer.eos_token_id,
        )

    generated = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # 截到 assistant 内容（避免输出 prompt 本身）
    if "<|im_start|>assistant" in generated:
        answer = generated.split("<|im_start|>assistant")[-1]
    else:
        answer = generated

    # 去掉后面多余段落
    answer = answer.split("<|im_end|>")[0].strip()

    return answer, context_docs



In [21]:
questions = [
    "What is the integral of x^2+2*x?",
    "What is the answer of 1+1?",
    "If Tony Yuan is taking some magical medicine and his dick is growing one centimeter longer everyday. After one week, his dick becomes 10cm. What is the original length of his dick?"
]

for q in questions:
    print("=" * 80)
    print("Question:", q)
    ans, ctx = answer_with_rag(q)
    print("\nAnswer:\n", ans, "\n")


Question: What is the integral of x^2+2*x?

Answer:
 You are a helpful math tutor. 
You will be given some reference materials from Wikipedia and a user question.
Please answer the question concisely based on the given context. 
If the context is insufficient, say you are not sure instead of making up facts.

Context:
[Doc 1 - Number theory]
x
                    
                  
                
                )
              
            
            )
          
          
            2
          
        
        +
        1
        =
        
          
            (
            
              
                
                  1
                  2
                
              
              
                (
                
                  x
                  +
                  
                    
                      1
                      x
                    
                  
                
                )
              
            
            )
     

In [8]:
import re
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
import torch

print("Loading Hendrycks Math (Algebra)...")
ds_algebra = load_dataset("EleutherAI/hendrycks_math", "algebra")

# 你之前的 boxed extractor 保持完全一致
def extract_all_boxed(text):
    results = []
    i = 0
    key = r"\boxed"

    while True:
        start = text.find(key, i)
        if start == -1:
            break

        j = start + len(key)
        while j < len(text) and text[j].isspace():
            j += 1

        if j >= len(text) or text[j] != '{':
            i = start + 1
            continue

        depth = 0
        content_start = j + 1

        for k in range(content_start, len(text)):
            if text[k] == '{':
                depth += 1
            elif text[k] == '}':
                if depth == 0:
                    results.append(text[content_start:k].strip())
                    i = k + 1
                    break
                depth -= 1
        else:
            break

    return results


def extract_last_boxed(text):
    all_boxed = extract_all_boxed(text)
    return all_boxed[-1] if all_boxed else None


def normalize_ans(ans):
    if ans is None:
        return None
    ans = ans.replace(" ", "").replace("\\frac", "frac").replace("\\dfrac", "frac")
    return ans


def extract_ref_answer(ref_text):
    return extract_last_boxed(ref_text)


Loading Hendrycks Math (Algebra)...


README.md: 0.00B [00:00, ?B/s]

algebra/train-00000-of-00001.parquet:   0%|          | 0.00/505k [00:00<?, ?B/s]

algebra/test-00000-of-00001.parquet:   0%|          | 0.00/353k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1744 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1187 [00:00<?, ? examples/s]

In [13]:
def generate_rag_answer(problem, max_new_tokens=512):
    """
    用 RAG pipeline 生成答案，并提取 boxed。
    """
    answer, ctx_docs = answer_with_rag(problem)

    # 封装输出供 debug
    full_output = f"RAG ANSWER:\n{answer}"
A1
    box = extract_last_boxed(answer)

    if box is None or box.strip() == "":
        return "Null", full_output
    return box.strip(), full_output

def exact_match(preds, refs):
    correct = 0
    total = 0

    for pred, ref in zip(preds, refs):
        ref_box = extract_ref_answer(ref)

        pred_norm = normalize_ans(pred)
        ref_norm = normalize_ans(ref_box)

        if pred_norm is None or pred_norm == "Null":
            continue

        total += 1
        if pred_norm == ref_norm:
            correct += 1

    return correct / max(total, 1)

def run_rag_eval(n=20, show_debug=False):
    preds = []
    refs = []
    outputs = []

    for i in tqdm(range(n)):
        item = ds_algebra["test"][i]
        problem = item["problem"]
        reference = item["solution"]

        pred, full_output = generate_rag_answer(problem)

        preds.append(pred)
        refs.append(reference)
        outputs.append(full_output)

        if show_debug:
            print("\n=====================")
            print("Problem:", problem)
            # print("\nReference (GT):", reference)
            # print("\nModel output:\n", full_output)
            print("Predicted boxed:", pred)
            print("=====================\n")

    acc = exact_match(preds, refs)
    print(f"\n📌 RAG Exact Match Accuracy on {n} problems = {acc:.4f}")
    return acc, preds, refs, outputs

acc, preds, refs, outputs = run_rag_eval(10, show_debug=True)

 10%|█         | 1/10 [00:22<03:20, 22.31s/it]


Problem: How many vertical asymptotes does the graph of $y=\frac{2}{x^2+x-6}$ have?
Predicted boxed: 2



 20%|██        | 2/10 [00:34<02:12, 16.62s/it]


Problem: What is the positive difference between $120\%$ of 30 and $130\%$ of 20?
Predicted boxed: final\_answer



 30%|███       | 3/10 [00:47<01:43, 14.76s/it]


Problem: Find $x$ such that $\lceil x \rceil + x = \dfrac{23}{7}$. Express $x$ as a common fraction.
Predicted boxed: final\_answer



 40%|████      | 4/10 [01:00<01:23, 13.89s/it]


Problem: Evaluate $i^5+i^{-25}+i^{45}$.
Predicted boxed: final\_answer



 50%|█████     | 5/10 [01:12<01:07, 13.42s/it]


Problem: If $2^8=4^x$, what is the value of $x$?
Predicted boxed: final\_answer



 60%|██████    | 6/10 [01:32<01:02, 15.52s/it]


Problem: What is the 100th term of the arithmetic sequence 6, 10, 14, 18, ...?
Predicted boxed: 100



 70%|███████   | 7/10 [01:44<00:43, 14.55s/it]


Problem: For what values of $x$ is it true that $x^2 - 5x - 4 \le 10$? Express your answer in interval notation.
Predicted boxed: final\_answer



 80%|████████  | 8/10 [01:57<00:27, 13.92s/it]


Problem: Mr. Madoff invests 1000 dollars in a fund that compounds annually at a constant interest rate.  After three years, his investment has grown to 1225 dollars.  What is the annual interest rate, as a percentage?  (Round your answer to the nearest integer.)
Predicted boxed: 7



 90%|█████████ | 9/10 [02:10<00:13, 13.53s/it]


Problem: Four distinct integers $a$, $b$, $c$ and $d$ have the property that when added in pairs, the sums 10, 18, 19, 20, 21, and 29 are obtained. What are the four integers in increasing order? (place a comma and then a space between each integer)
Predicted boxed: 1, 7, 12, 18



100%|██████████| 10/10 [02:22<00:00, 14.26s/it]


Problem: What is the smallest value of $x$ such that $|5x - 1| = |3x + 2|$? Express your answer as a common fraction.
Predicted boxed: final\_answer


📌 RAG Exact Match Accuracy on 10 problems = 0.2000



