# Qwen + RAG Runner

Uses your existing `legal_content.json` (unchanged) as the retrieval corpus by reading each node's `node_id` and `rag_text`, retrieves top-k relevant passages, injects them into the prompt, and saves results to a CSV.

**You only need to edit the PATHS cell.**

In [None]:
!pip -q install -U transformers accelerate bitsandbytes pandas tqdm

In [None]:
# ===== PATHS (EDIT ME) =====
EVAL_SET_PATH = "/content/drive/MyDrive/your_project/data/eval_set.jsonl"      # <- TODO
LEGAL_CONTENT_JSON = "/content/drive/MyDrive/your_project/legal_content.json"  # <- TODO (your existing file)
OUTPUT_DIR = "/content/drive/MyDrive/your_project/outputs"                    # <- TODO

MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"

# Retrieval
TOP_K = 5
MAX_CONTEXT_CHARS = 6000  # keep prompt size manageable

# Generation (keep temperature=0 for reproducibility)
TEMPERATURE = 0.0
TOP_P = 1.0
MAX_NEW_TOKENS = 256

In [None]:
import os, json, re, time, datetime
import pandas as pd
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
def load_eval_set(jsonl_path: str):
    items = []
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            items.append(json.loads(line))
    for it in items:
        assert "qid" in it and "question" in it
    return items

def load_legal_docs(legal_json_path: str):
    """
    Reads your existing legal_content.json as-is.
    It must be a JSON list. Each node should include:
      - node_id
      - rag_text
    Returns a list of {doc_id, text}.
    """
    with open(legal_json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    docs = []
    for node in data:
        doc_id = (node.get("node_id") or "").strip()
        text = (node.get("rag_text") or "").strip()
        if doc_id and text:
            docs.append({"doc_id": doc_id, "text": text})
    return docs

def score_overlap(query: str, text: str) -> int:
    # Simple char overlap (language-agnostic, no extra packages)
    q = set(query)
    return sum(1 for ch in q if ch in text)

def retrieve(query: str, docs, top_k: int):
    scored = []
    for d in docs:
        s = score_overlap(query, d["text"])
        if s > 0:
            scored.append((s, d))
    scored.sort(key=lambda x: x[0], reverse=True)
    return [d for _, d in scored[:top_k]]

def format_mcq_prompt_with_context(item: dict, context: str) -> str:
    q = item["question"].strip()
    choices = item.get("choices", {})
    choices_str = "\n".join([f"{k}. {v}" for k,v in choices.items()])

    return (
        "你是一位職安衛法規助理。請根據【參考資料】與題目選出最正確的選項，只輸出選項字母(A/B/C/D)。\n\n"
        f"【參考資料】\n{context}\n\n"
        f"題目：{q}\n"
        f"{choices_str}\n\n"
        "答案："
    )

ANSWER_RE = re.compile(r"\b([ABCD])\b", re.IGNORECASE)

def parse_choice(text: str):
    m = ANSWER_RE.search(text.strip())
    return m.group(1).upper() if m else None

In [None]:
# ===== Load model =====
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    torch_dtype="auto",
    load_in_4bit=True,
)
model.eval()

In [None]:
eval_items = load_eval_set(EVAL_SET_PATH)
docs = load_legal_docs(LEGAL_CONTENT_JSON)
print("eval items:", len(eval_items))
print("legal docs:", len(docs))
print("example doc_id:", docs[0]["doc_id"])

In [None]:
def generate_one(prompt: str):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=(TEMPERATURE > 0),
            temperature=TEMPERATURE if TEMPERATURE > 0 else None,
            top_p=TOP_P,
            pad_token_id=tokenizer.eos_token_id,
        )
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    if "答案：" in text:
        tail = text.split("答案：", 1)[-1].strip()
    else:
        tail = text.strip()
    return text, tail

run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
run_id = f"rag_{run_ts}"
os.makedirs(os.path.join(OUTPUT_DIR, "rag"), exist_ok=True)

rows = []
for item in tqdm(eval_items, desc="RAG eval"):
    qid = item["qid"]
    gold = item.get("answer")
    question_text = item["question"]

    retrieved = retrieve(question_text, docs, top_k=TOP_K)
    retrieved_ids = [d["doc_id"] for d in retrieved]
    context = "\n\n".join([d["text"] for d in retrieved])

    if len(context) > MAX_CONTEXT_CHARS:
        context = context[:MAX_CONTEXT_CHARS] + "\n...(truncated)..."

    prompt = format_mcq_prompt_with_context(item, context)

    t0 = time.time()
    raw, tail = generate_one(prompt)
    latency_ms = int((time.time() - t0) * 1000)

    pred = parse_choice(tail)
    correct = int(pred == gold) if (pred is not None and gold is not None) else None

    rows.append({
        "run_id": run_id,
        "method": "rag",
        "model_name": MODEL_NAME,
        "qid": qid,
        "gold_choice": gold,
        "parsed_choice": pred,
        "correct": correct,
        "raw_output": raw,
        "latency_ms": latency_ms,
        "temperature": TEMPERATURE,
        "top_p": TOP_P,
        "max_new_tokens": MAX_NEW_TOKENS,
        "retrieved_k": TOP_K,
        "retrieved_ids": ";".join(retrieved_ids) if retrieved_ids else None,
        "context_chars": len(context),
    })

df = pd.DataFrame(rows)
out_path = os.path.join(OUTPUT_DIR, "rag", f"{run_id}.csv")
df.to_csv(out_path, index=False, encoding="utf-8-sig")
print("Saved:", out_path)
print("Accuracy:", df["correct"].mean())