In [None]:
import sys
!{sys.executable} -m pip -q install requests beautifulsoup4 lxml markdownify fastembed onnxruntime chromadb rank-bm25 pymupdf

# ---------------- CONFIG ----------------
SITES_HTML = [
    ("BIDS",     "https://bids-specification.readthedocs.io/en/stable/", "/en/stable/"),
    ("fMRIPrep", "https://fmriprep.readthedocs.io/en/stable/",           "/en/stable/"),
]
SITES_PDF = [
    ("MRtrix",   "https://media.readthedocs.org/pdf/mrtrix/latest/mrtrix.pdf"),
    ("SPM12",    "https://www.fil.ion.ucl.ac.uk/spm/doc/spm12_manual.pdf"),
]

ALLOW_SUBSTR = {  
    "BIDS":     ["specification", "glossary", "derivatives", "intro"],
    "fMRIPrep": ["usage", "installation", "faq", "outputs", "reports"],
}
COL_NAME  = "neuro_docs"
MAX_PAGES_PER_SITE = 40  #could be increased
BATCH = 128  

# ---------------- IMPORTS ----------------
import os, re, time, requests, fitz, numpy as np
from urllib.parse import urljoin, urlsplit
from bs4 import BeautifulSoup
from markdownify import markdownify as md
from pathlib import Path
import chromadb
from fastembed import TextEmbedding
from rank_bm25 import BM25Okapi
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

# ---------------- HELPERS ----------------
slug = lambda s: re.sub(r"[^a-z0-9._-]+","-", (s or "").strip().lower())
def save_text(p: Path, text: str):
    p.parent.mkdir(parents=True, exist_ok=True); p.write_text(text, encoding="utf-8", errors="ignore")

def html_to_md(html: str) -> str:
    soup = BeautifulSoup(html, "lxml")
    main = soup.select_one("main, .md-content, div.body, div.document") or soup
    text = md(str(main), heading_style="ATX")
    text = re.sub(r"(?s)Next\s+Previous.*$", "", text)             
    text = re.sub(r"©.*?Read the Docs.*$", "", text)
    return text
# crawling websites :
def crawl_html(tool, base, prefix, outdir: Path, max_pages=120):
    host = urlsplit(base).netloc
    seen, stack, saved = set(), [base], 0
    skip_ext = re.compile(r"\.(png|jpg|gif|svg|pdf|zip|tar\.gz|ico)$", re.I)
    allow = ALLOW_SUBSTR.get(tool)
    headers = {"User-Agent": "Mozilla/5.0"}
    def ok(url):
        p = urlsplit(url)
        if p.netloc != host or skip_ext.search(p.path): return False
        if url == base: return True                      
        if not p.path.startswith(prefix): return False
        return True if not allow else any(s in p.path for s in allow)
    while stack and saved < max_pages:
        url = stack.pop()
        if url in seen or not ok(url): continue
        seen.add(url)
        try:
            html = requests.get(url, timeout=30, headers=headers).text
        except Exception:
            continue
        rel = (urlsplit(url).path or "/").strip("/")
        save_text(outdir / f"{slug(rel or 'index')}.md", html_to_md(html))
        saved += 1
        soup = BeautifulSoup(html, "lxml")
        for a in soup.select("a[href]"):
            stack.append(urljoin(url, a["href"]))
    print(f"[{tool}] HTML pages saved: {saved} -> {outdir}") #saving them

#  PDF downloader 
def fetch_pdf(url: str, out_dir: Path) -> Path:
    out_dir.mkdir(parents=True, exist_ok=True)
    name = slug(Path(urlsplit(url).path).name or "doc.pdf")
    if not name.endswith(".pdf"): name += ".pdf"
    out_path = out_dir / name
    sess = requests.Session()
    sess.headers.update({"User-Agent": "Mozilla/5.0"})
    sess.mount("https://", HTTPAdapter(max_retries=Retry(total=5, backoff_factor=0.4,
                         status_forcelist=[429,500,502,503,504], allowed_methods=["GET","HEAD"])))
    r = sess.get(url, timeout=60, stream=True, allow_redirects=True); r.raise_for_status()
    with open(out_path, "wb") as f:
        for ch in r.iter_content(chunk_size=1<<15):
            if ch: f.write(ch)
    return out_path

def pdf_paragraphs(pdf_path: Path): #prep for chunking and parsing
    def clean(s: str) -> str:
        s = re.sub(r'-\n(?=\w)', '', s); s = re.sub(r'\n(?=[a-z])', ' ', s)
        s = re.sub(r'\s*\n\s*', ' ', s); return re.sub(r'\s{2,}', ' ', s).strip()
    doc = fitz.open(pdf_path)
    for pno, page in enumerate(doc, 1):
        blocks = sorted(page.get_text("blocks"), key=lambda b:(b[1], b[0]))
        paras=[]
        for b in blocks:
            txt = b[4] if len(b)>4 and isinstance(b[4], str) else ""
            if not txt.strip(): continue
            t = clean(txt)
            for p in re.split(r'\n{2,}', t):
                if len(p.split())>6: paras.append(p.strip())
        for j, para in enumerate(paras, 1):
            yield pno, j, para

def chunk_md(text: str, size=480, overlap=60): #chunking
    words = re.split(r"\s+", text.strip()); 
    if not words: return
    step = max(1, size-overlap)
    for i in range(0, len(words), step):
        piece = " ".join(words[i:i+size]).strip()
        if len(piece.split()) > 30: yield piece

def reflow(t: str) -> str:
    t = re.sub(r'-\n(?=\w)', '', t); t = re.sub(r'\s*\n\s*', ' ', t)
    return re.sub(r'\s{2,}', ' ', t).strip()

ROOT = Path("corpus_multi")
all_docs, all_ids, all_metas = [], [], []

# HTML (BIDS + fMRIPrep)
for tool, base, pref in SITES_HTML:
    out = ROOT / "html" / slug(tool)
    crawl_html(tool, base, pref or urlsplit(base).path, out, max_pages=MAX_PAGES_PER_SITE)
    for mdfile in sorted(out.rglob("*.md")):
        txt = mdfile.read_text(encoding="utf-8", errors="ignore")
        for k, chunk in enumerate(chunk_md(txt), 1):
            all_docs.append(chunk)
            all_ids.append(f"{tool}__{mdfile.stem}__chunk{k:03}")
            all_metas.append({"tool": tool, "source": str(mdfile)})

# PDFs (MRtrix + SPM12)
PDF_DIR = ROOT / "pdf"; PDF_DIR.mkdir(parents=True, exist_ok=True)
for tool, url in SITES_PDF:
    try:
        pdf_path = fetch_pdf(url, PDF_DIR / slug(tool))
        print(f"[{tool}] saved PDF:", pdf_path.name)
    except Exception as e:
        print(f"[{tool}] SKIP {url} ->", e); continue
    for pno, j, para in pdf_paragraphs(pdf_path):
        all_docs.append(para)
        all_ids.append(f"{tool}__{pdf_path.stem}__p{pno:03}__para{j:02}")
        all_metas.append({"tool": tool, "source": str(pdf_path), "page": pno, "para": j})

print("TOTAL chunks:", len(all_docs))

# Embedding
client = chromadb.PersistentClient(path="vector_db")
col    = client.get_or_create_collection(COL_NAME)
existing = set()
try:
    got = col.get(include=[], limit=100000)
    for _id in got.get("ids", []): existing.add(_id)
except Exception: pass

embedder = TextEmbedding("BAAI/bge-small-en-v1.5")
t0 = time.time()
for i in range(0, len(all_docs), BATCH):
    docs_b = all_docs[i:i+BATCH]
    ids_b  = all_ids[i:i+BATCH]
    mets_b = all_metas[i:i+BATCH]
    keep = [j for j,_id in enumerate(ids_b) if _id not in existing]
    if not keep:
        continue

    docs_b = [str(x) for x in docs_b]                     # ensure strings
    ids_b  = [str(x) for x in ids_b]                      # <- important: force str
    mets_b = [dict(m) for m in mets_b]  

    embs_b = [list(map(float, v)) for v in embedder.embed(docs_b)]

    assert len(ids_b) == len(docs_b) == len(mets_b) == len(embs_b) > 0

    col.upsert(
        ids=ids_b,
        embeddings=embs_b,
        metadatas=mets_b,
        documents=docs_b
    )
    existing.update(ids_b)
    if (i//BATCH) % 2 == 0 or i+BATCH >= len(all_docs):
        print(f"Indexed {min(i+BATCH,len(all_docs))}/{len(all_docs)}")

print("Embedding done.")

#retrival action
from collections import defaultdict

file_to_idxs = defaultdict(list)
for i,m in enumerate(all_metas): file_to_idxs[m["source"]].append(i)
file_list  = sorted(file_to_idxs.keys())
file_texts = [" ".join(all_docs[i] for i in file_to_idxs[f]) for f in file_list]
bm25_doc   = BM25Okapi([t.split() for t in file_texts])
id_to_idx  = {i:k for k,i in enumerate(all_ids)}

def _z(a):
    a = np.asarray(a, float); return (a - a.mean()) / (a.std() + 1e-6)
# checking first always preferred doc
def best_idx_docfirst(q, k_files=8, k_vec=80, k_bm=80, w_vec=0.7, w_bm=0.3):
    qtok = q.split()
#checking file scores
    scores_doc = bm25_doc.get_scores(qtok)                    
    top_idx = np.argsort(scores_doc)[::-1][:max(1, min(k_files, len(scores_doc)))]
    cand_files = [file_list[i] for i in top_idx] if len(file_list) else []
    if not cand_files:                                      
        cand_files = file_list[:min(8, len(file_list))]
    cand_idxs = [i for f in cand_files for i in file_to_idxs.get(f, [])]
    if not cand_idxs: #(getting error here)
        cand_files = file_list[:min(16, len(file_list))]
        cand_idxs  = [i for f in cand_files for i in file_to_idxs.get(f, [])]
    q_emb = next(embedder.embed([q])) #vectorsearch
    if cand_files:
        v = col.query(query_embeddings=[q_emb.tolist()],
                      n_results=min(k_vec, len(all_docs)),
                      where={"source": {"$in": cand_files}},
                      include=["distances"])
    else:
        # global fallback
        v = col.query(query_embeddings=[q_emb.tolist()],
                      n_results=min(k_vec, len(all_docs)),
                      include=["distances"])

    vec_idxs = [id_to_idx[i] for i in v.get("ids", [[]])[0] if i in id_to_idx]
    vec_sims = {j: -d for j, d in zip(vec_idxs, v.get("distances", [[]])[0])}
    pool_for_bm = cand_idxs if cand_idxs else (vec_idxs if vec_idxs else [])
    if pool_for_bm:
        bm_local = BM25Okapi([all_docs[i].split() for i in pool_for_bm])
        local_scores = bm_local.get_scores(qtok)
        bm_scores = {pool_for_bm[j]: float(local_scores[j]) for j in range(len(pool_for_bm))}
    else:
        bm_scores = {}

    pool = set(vec_sims.keys()) | set(pool_for_bm) #chunk fusion
    if not pool:
        v2 = col.query(query_embeddings=[q_emb.tolist()], n_results=1, include=["ids"])
        return id_to_idx[v2["ids"][0][0]]

    v_list = [vec_sims.get(i, 0.0) for i in pool]
    b_list = [bm_scores.get(i, 0.0) for i in pool]
    v_z, b_z = _z(v_list), _z(b_list)
    fused = {i: (w_vec*v + w_bm*b) for i, (v, b) in zip(pool, zip(v_z, b_z))}
    file_scores = {}   #chunk the good scorer file
    for i, s in fused.items():
        f = all_metas[i]["source"]
        file_scores[f] = max(file_scores.get(f, -1e9), s)
    top_file = max(file_scores, key=file_scores.get)
    best_chunk = max((i for i in pool if all_metas[i]["source"] == top_file), key=lambda i: fused[i])
    return best_chunk
# finding answers to the prompts
def answer(q, neighbors=1, max_chars=900):
    i = best_idx_docfirst(q)
    m = all_metas[i]; src = m["source"]; page = m.get("page")
    same = [k for k,x in enumerate(all_metas) if x["source"]==src and x.get("page")==page]
    if not same: same = [k for k,x in enumerate(all_metas) if x["source"]==src]
    same.sort(); pos = same.index(i)
    win = same[max(0,pos-neighbors):pos+neighbors+1]
    txt = re.sub(r'\s{2,}',' ', re.sub(r'\s*\n\s*',' ', " ".join(all_docs[k] for k in win))).strip()
    if len(txt) > max_chars: txt = txt[:max_chars].rsplit(". ",1)[0] + "."
    cite = f"{Path(src).name}" + (f" (p{page})" if page else "")
    return txt, cite
