<a href="https://colab.research.google.com/github/prrmzz/RAG-for-Iranian-High-School-Biology-Textbook/blob/main/RAG_biology.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
!apt-get update -qq
!apt-get install -qq -y tesseract-ocr tesseract-ocr-fas tesseract-ocr-ara poppler-utils
!pip install -q faiss-cpu transformers==4.44.2 sentence-transformers==3.0.1 FlagEmbedding \
pdfplumber==0.11.4 pdf2image==1.17.0 pytesseract==0.3.13 pymupdf==1.24.8 rank-bm25==0.2.2


W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)


In [6]:
import os, re, json, pickle, subprocess, numpy as np, torch, faiss
from pathlib import Path
import fitz, pdfplumber, pytesseract
from pdf2image import convert_from_path
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer

# --------- Mount Drive ----------
try:
    from google.colab import drive
    drive.mount('/content/drive')
except:
    pass

DOC_DIR = Path("/content/drive/MyDrive/biologybooks")
WORKDIR = Path("/content/rag_fa"); WORKDIR.mkdir(parents=True, exist_ok=True)
assert DOC_DIR.exists(), f"پوشه پیدا نشد: {DOC_DIR}"

# --------- Persian text helpers ----------
FA_CHARS = re.compile(r'[\u0600-\u06FF]')
DIACRITICS = re.compile(r'[\u064B-\u065F\u0670\u0640]')
def fa_ratio(s:str)->float:
    if not s: return 0.0
    return len(FA_CHARS.findall(s)) / max(1, len(s))
def strip_diacritics(s:str)->str:
    return DIACRITICS.sub('', s)
def normalize_fa(s:str)->str:
    s = s.replace("ك","ک").replace("ي","ی")
    s = re.sub(r'[\u200c\u200f\ufeff]',' ', s)
    s = re.sub(r'\s+',' ', s).strip()
    return s
def clean_line(s:str)->str:
    s = strip_diacritics(normalize_fa(s))
    s = re.sub(r'[A-Za-z_]{2,}',' ', s)             # drop long latin tokens
    s = re.sub(r'(\d[\s/:|,-]){5,}',' ', s)         # chart/axis-like sequences
    s = re.sub(r'\s+',' ', s).strip()
    return s

# --------- Extraction (tiered) ----------
def extract_pymupdf(path:Path)->str:
    out=[]
    try:
        with fitz.open(str(path)) as doc:
            for page in doc:
                txt = page.get_text("text", flags=fitz.TEXTFLAG_TEXT)
                if txt: out.append(txt)
    except Exception:
        pass
    return "\n".join(out).strip()

def extract_pdftotext(path:Path)->str:
    try:
        r = subprocess.run(["pdftotext","-layout","-enc","UTF-8",str(path),"-"],
                           capture_output=True, text=True, timeout=120)
        return (r.stdout or "").strip()
    except Exception:
        return ""

def extract_pdfplumber(path:Path)->str:
    pages=[]
    try:
        with pdfplumber.open(str(path)) as pdf:
            for p in pdf.pages:
                t = p.extract_text(x_tolerance=1, y_tolerance=1) or ""
                if t.strip(): pages.append(t.strip())
    except Exception:
        pass
    return "\n".join(pages).strip()

def extract_ocr(path:Path)->str:
    try:
        images = convert_from_path(str(path), dpi=300)
        ocr = [pytesseract.image_to_string(img, lang="fas+ara", config="--oem 1 --psm 6")
               for img in images]
        return "\n".join(ocr).strip()
    except Exception:
        return ""

def extract_text(path:Path)->str:
    for fn in (extract_pymupdf, extract_pdftotext, extract_pdfplumber, extract_ocr):
        txt = fn(path)
        if len(txt) >= 200:  # accept only sufficiently long text
            return txt
    return ""

def split_sentences(text:str):
    text = normalize_fa(text)
    text = re.sub(r'([\.!؟؛…])', r'\1§', text)
    sents = [clean_line(s) for s in text.split('§') if s.strip()]
    sents = [s for s in sents if fa_ratio(s) >= 0.55 and len(s) >= 12]
    return sents

# --------- Load chosen embedding model (GTE with trust_remote_code; fallback to E5) ----------
device = "cuda" if torch.cuda.is_available() else "cpu"
EMBED_STYLE = "gte"  # or "e5" (set automatically below)

try:
    embedder = SentenceTransformer("Alibaba-NLP/gte-multilingual-base",
                                   device=device, trust_remote_code=True)
    E_TOK = AutoTokenizer.from_pretrained("Alibaba-NLP/gte-multilingual-base")
    EMBED_STYLE = "gte"
    print("✅ Using GTE multilingual base (trust_remote_code=True).")
except Exception as e:
    print("⚠️ GTE load failed, falling back to E5. Error:", repr(e))
    embedder = SentenceTransformer("intfloat/multilingual-e5-base", device=device)
    E_TOK = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-base")
    EMBED_STYLE = "e5"
    print("✅ Using E5-base.")

def enc_passages(texts):
    if EMBED_STYLE == "e5":
        texts = [f"passage: {t}" for t in texts]
    return embedder.encode(texts, batch_size=48, show_progress_bar=True, normalize_embeddings=True)

def enc_query(q):
    q = normalize_fa(q)
    if EMBED_STYLE == "e5":
        q = f"query: {q}"
    return embedder.encode([q], show_progress_bar=False, normalize_embeddings=True).astype("float32")

def chunk_sents(sents, max_tok=90, overlap=20):
    out, buf, buf_tok = [], [], 0
    for s in sents:
        toks = E_TOK(s, add_special_tokens=False)["input_ids"]
        if buf_tok + len(toks) <= max_tok:
            buf.append(s); buf_tok += len(toks)
        else:
            if buf: out.append(" ".join(buf))
            keep, acc = [], 0
            for ss in reversed(buf):
                L = len(E_TOK(ss, add_special_tokens=False)["input_ids"])
                if acc + L <= overlap: keep.append(ss); acc += L
                else: break
            buf = list(reversed(keep)) + [s]
            buf_tok = sum(len(E_TOK(x, add_special_tokens=False)["input_ids"]) for x in buf)
    if buf: out.append(" ".join(buf))
    out = [c for c in out if fa_ratio(c) >= 0.6 and len(c) >= 30]
    return out

# --------- Load docs & build chunks ----------
files = sorted(DOC_DIR.glob("*.pdf"))
if not files: raise SystemExit(f"هیچ PDFی در {DOC_DIR} پیدا نشد.")
print("PDFs:", [f.name for f in files])

all_chunks, all_meta = [], []
report = []
for i, pdf_path in enumerate(files):
    raw = extract_text(pdf_path)
    sents = split_sentences(raw) if raw else []
    chunks = chunk_sents(sents, max_tok=90, overlap=22) if sents else []
    # additional noise filters
    keep=[]
    for ch in chunks:
        if re.search(r'(شکل|نمودار|جدول)\s*\d+', ch): continue
        if sum(c.isdigit() for c in ch) > 0.25*len(ch): continue
        keep.append(ch)
    chunks = keep
    report.append((pdf_path.name, len(raw), len(sents), len(chunks)))
    for j, ch in enumerate(chunks):
        all_chunks.append(ch)
        all_meta.append({"doc": pdf_path.name, "chunk": j})
print("Extract report:", report)
if not all_chunks:
    raise SystemExit("هیچ متنی استخراج نشد. OCR/مسیر را بررسی کنید.")

# --------- Embeddings + FAISS ----------
embs = enc_passages(all_chunks)
index = faiss.IndexFlatIP(embs.shape[1]); index.add(embs.astype("float32"))

# --------- BM25 lexical backstop ----------
def tok4bm25(s):
    s = normalize_fa(s)
    return [w for w in re.split(r'\W+', s) if 2 <= len(w) <= 32 and FA_CHARS.search(w)]
bm25 = BM25Okapi([tok4bm25(c) for c in all_chunks])

# --------- Reranker ----------
try:
    from FlagEmbedding import FlagReranker as _R
    reranker = _R("BAAI/bge-reranker-base", use_fp16=torch.cuda.is_available())
except Exception:
    from FlagEmbedding.FlagModel import Reranker as _R
    reranker = _R("BAAI/bge-reranker-base", use_fp16=torch.cuda.is_available())

# --------- Retrieval ----------
def retrieve(query:str, topk=10):
    qn = normalize_fa(query)
    qvec = enc_query(qn)
    D, I = index.search(qvec, topk*5)
    vec_cands = [(float(D[0][i]), int(I[0][i])) for i in range(len(I[0])) if I[0][i] >= 0]
    bm = bm25.get_top_n(tok4bm25(qn), list(range(len(all_chunks))), n=topk*5)
    bm_cands = [(0.0, idx) for idx in bm]
    merged = {}
    for s, idx in vec_cands + bm_cands:
        merged[idx] = max(merged.get(idx, -1), s)
    idxs = list(merged.keys())
    pair_texts = [all_chunks[i] for i in idxs]
    scores = reranker.compute_score([[qn, t] for t in pair_texts], batch_size=64)
    ranked = sorted(zip(idxs, scores), key=lambda x: x[1], reverse=True)[:topk]
    hits = [{"text": all_chunks[i], "meta": all_meta[i], "score": float(sc)} for i, sc in ranked]
    return hits

# --------- Structured extractive answer ----------
def split_for_parts(text):
    return [s for s in re.split(r'(?<=[\.!؟؛…])\s+', text) if 6 < len(s) < 260]

def pick_summary(hits):
    sents = split_for_parts(hits[0]["text"])
    return sents[0] if sents else hits[0]["text"][:220]

def pick_definition(hits, q):
    keys = ["تعریف","کارکرد","نقش","فرآیند","واکنش","عملکرد","ویژگی"]
    for h in hits[:5]:
        for s in split_for_parts(h["text"]):
            if any(k in s for k in keys): return s[:240]
    return split_for_parts(hits[0]["text"])[1] if len(split_for_parts(hits[0]["text"]))>1 else hits[0]["text"][:200]

def pick_steps(hits, k=6):
    words = ["ابتدا","سپس","مرحله","بعد","در نهایت","نتیجه","تبدیل","انتقال","ساخت","جذب","تولید"]
    pool=[]
    for h in hits[:8]:
        for s in split_for_parts(h["text"]):
            if any(w in s for w in words): pool.append(s)
    if len(pool)<3:
        for h in hits[:8]: pool.extend(split_for_parts(h["text"])[:8])
    seen=set(); uniq=[]
    for p in pool:
        key = re.sub(r'\s+',' ', p)[:64]
        if key not in seen:
            uniq.append(p); seen.add(key)
        if len(uniq)>=k: break
    return uniq

def pick_keys(hits, used, k=4):
    pool=[]
    for h in hits[:8]: pool.extend(split_for_parts(h["text"])[:12])
    pool=[p for p in pool if p not in used]
    return pool[:k]

def format_answer(query, hits):
    if not hits: return "پاسخی پیدا نشد."
    summary = pick_summary(hits)
    definition = pick_definition(hits, query)
    steps = pick_steps(hits, k=6)
    used = set([summary, definition] + steps)
    keys = pick_keys(hits, used, k=4)
    cites=[]; seen=set()
    for h in hits[:8]:
        m=h["meta"]; kk=(m["doc"], m["chunk"])
        if kk not in seen:
            cites.append(f"- {m['doc']} | chunk {m['chunk']}")
            seen.add(kk)
    parts=[]
    parts.append(f"۱) خلاصه یک‌جمله‌ای:\n{summary}")
    parts.append(f"۲) تعریف علمی کوتاه:\n{definition}")
    if steps: parts.append("۳) مراحل گام‌به‌گام:\n" + "\n".join(f"- {s}" for s in steps))
    if keys: parts.append("۴) نکات کلیدی:\n" + "\n".join(f"- {s}" for s in keys))
    parts.append("۵) ارجاعات:\n" + "\n".join(cites))
    return "\n\n".join(parts)

def answer_persian(query:str, topk=10):
    hits = retrieve(query, topk=topk)
    return format_answer(query, hits)

# --------- Build & quick test ----------
print("✅ Index built over", len(all_chunks), "chunks.")
print(answer_persian("فتوسنتز چگونه در کلروپلاست انجام می‌شود؟"))


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


configuration.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/Alibaba-NLP/new-impl:
- configuration.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/Alibaba-NLP/new-impl:
- modeling.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/611M [00:00<?, ?B/s]

Some weights of the model checkpoint at Alibaba-NLP/gte-multilingual-base were not used when initializing NewModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing NewModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing NewModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

✅ Using GTE multilingual base (trust_remote_code=True).
PDFs: ['1.pdf', '2.pdf', '3.pdf']
Extract report: [('1.pdf', 412204, 2041, 1085), ('2.pdf', 507540, 2410, 1265), ('3.pdf', 466727, 2109, 1158)]


Batches:   0%|          | 0/74 [00:00<?, ?it/s]

✅ Index built over 3508 chunks.


pre tokenize:   0%|          | 0/2 [00:00<?, ?it/s]You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
pre tokenize: 100%|██████████| 2/2 [00:00<00:00, 90.71it/s]
Compute Scores: 100%|██████████| 2/2 [00:00<00:00, 11.21it/s]

۱) خلاصه یک‌جمله‌ای:
یکی از این ویژگی ها‬ ‫داشتن مولکول های رنگیزه ای است که بتوانند انرژی نور خورشید را جذب کنند‪ .

۲) تعریف علمی کوتاه:
یکی از این ویژگی ها‬ ‫داشتن مولکول های رنگیزه ای است که بتوانند انرژی نور خورشید را جذب کنند‪ .

۳) مراحل گام‌به‌گام:
- یکی از این ویژگی ها‬ ‫داشتن مولکول های رنگیزه ای است که بتوانند انرژی نور خورشید را جذب کنند‪ .
- ‬همچنین‪ ،‬باید سامانه ای‬ ‫برای تبدیل این انرژی به انرژی شیمیایی وجود داشته باشد‪ .
- ‬بعد از مدتی‬ ‫این قطعه آگار را روی لبۀ دانه رستی قرار می دهند که نوک آن بریده شده؛
- )4‬‬ ‫در فتوسنتز‪،‬انرژی الکترون های برانگیخته در رنگیزه های موجود‬ ‫در آنتن ها از رنگیزه ای به رنگیزه دیگر منتقل و در نهایت‪ ،‬به مرکز‬ ‫واکنش می رود و در آنجا سبب ایجاد الکترون برانگیخته در سبزینۀ ‪a‬‬ ‫و خروج الکترون از آن می شود (شکل ‪.
- ‬‬ ‫محققان داد تا با استفاده از این مواد‪،‬‬ ‫همان طور که در شکل ‪ 7‬می بینید‪ ،‬تعدادی از این قندها برای ساخته شدن گلوکز و ترکیبات آلی دیگر‬ ‫فرایندهای زیستی را شناسایی کنند‪.

۴) نکات کلیدی:
- ‬انواعی از جانداران وجود دارند ک


