In [1]:
import sys, subprocess, pkgutil, os
def pip_install(pkgs): subprocess.run([sys.executable, "-m", "pip", "install", "-q"] + pkgs, check=True)
reqs = [
    "transformers","accelerate","bitsandbytes",
    "sentence-transformers","faiss-cpu","rank-bm25",
    "trafilatura","pypdf","beautifulsoup4",
    "tqdm","orjson","requests","arxiv","gradio"
]
missing = [p for p in reqs if pkgutil.find_loader(p.replace("-","_")) is None]
if missing: pip_install(missing)

# --- Imports ---
import re, json, warnings, shutil
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple
from urllib.parse import quote, urlparse
import numpy as np

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer, CrossEncoder
import faiss
from rank_bm25 import BM25Okapi

from pypdf import PdfReader
import markdown as md
from bs4 import BeautifulSoup
import trafilatura, orjson, requests, arxiv
import gradio as gr  # UI

warnings.filterwarnings("ignore", category=DeprecationWarning)
os.environ["TOKENIZERS_PARALLELISM"]="false"
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- Config ---
BULLET = "•"                      # internal bullet token from LLM
MAX_BULLETS = 12
WEB_BULLET_MIN = 8
AUTHORITY_MIN = 2
CHUNK_SIZE = 1000
TIMEOUT = 20
LOCAL_DIR = "/content/local_docs"

# Retrieval
NUM_HYDE = 3
QUERY_VARIANTS = 3
BM25_TOP = 200
DENSE_TOP = 200
RRF_K = 60
RERANKER_ID = "BAAI/bge-reranker-v2-m3"
RERANK_MIN = 0.25

# Web budgets
WIKI_MAX_PAGES = 12
OPENALEX_MAX_WORKS = 50
ARXIV_MAX_RESULTS = 40
CROSSREF_ROWS = 60
PUBMED_MAX_PMIDS = 50

# Authorities
AUTHORITY_URLS = [
  "https://www.nccoe.nist.gov/crypto-agility-considerations-migrating-post-quantum-cryptographic-algorithms",
  "https://csrc.nist.gov/pubs/ir/8547/ipd",
  "https://www.cisa.gov/quantum",
  "https://www.etsi.org/technologies/quantum-safe-cryptography",
]

# Wiki fallbacks (benign)
WIKI_FALLBACK_TITLES = [
  "Post-quantum cryptography","Shor's algorithm","Grover's algorithm",
  "RSA (cryptosystem)","Elliptic-curve cryptography","Quantum key distribution",
  "CRYSTALS-Kyber","CRYSTALS-Dilithium","SPHINCS+","Harvest now, decrypt later"
]

# Domain cap to diversify hosts
DOMAIN_CAP = 3

# Intent hints
MITIGATION_HINTS = {"mitigat","strategy","strategies","propos","recommend","roadmap","guidance","migrate","adopt","transition","countermeas","best practice"}
EXPLANATION_HINTS = {"explain","impact","effect","cause","risk","threat","consequence","challenge"}

# --- Helpers ---
def sha1(s: str) -> str:
  import hashlib; return hashlib.sha1(s.encode("utf-8")).hexdigest()

def host_of(url: str) -> str:
  try: return urlparse(url).netloc.lower()
  except: return ""

def split_sentences_regex(text: str) -> List[str]:
  import re as _re
  parts = _re.split(r'(?<=[.!?])\s+', text.strip())
  return [p.strip() for p in parts if p.strip()]

def chunk_text_sentence_aligned(text: str, size: int = CHUNK_SIZE) -> List[str]:
  sents = split_sentences_regex(text)
  chunks, cur, cur_len = [], [], 0
  for s in sents:
    if cur_len + len(s) + 1 <= size or not cur:
      cur.append(s); cur_len += len(s) + 1
    else:
      chunks.append(" ".join(cur)); cur, cur_len = [s], len(s)+1
  if cur: chunks.append(" ".join(cur))
  return chunks

def md_to_text(md_content: str) -> str:
  html = md.markdown(md_content); return BeautifulSoup(html,"html.parser").get_text("\n")

def read_pdf(path: str) -> str:
  reader = PdfReader(path); out=[]
  for p in reader.pages:
    try: out.append(p.extract_text() or "")
    except: pass
  return "\n".join(out)

def read_local_file(path: str) -> str:
  ext = os.path.splitext(path)[1].lower()
  if ext==".pdf": return read_pdf(path)
  if ext in [".md",".markdown"]:
    with open(path,"r",encoding="utf-8",errors="ignore") as f: return md_to_text(f.read())
  with open(path,"r",encoding="utf-8",errors="ignore") as f: return f.read()

def fetch_clean_text(url: str) -> str:
  text=""
  try:
    d = trafilatura.fetch_url(url, timeout=TIMEOUT)
    if d:
      t = trafilatura.extract(d, include_comments=False, include_tables=False) or ""
      text = t
  except: pass
  if not text:
    try:
      r = requests.get(url, timeout=TIMEOUT, headers={"User-Agent":"Mozilla/5.0"})
      if r.ok:
        soup = BeautifulSoup(r.text,"html.parser"); text = soup.get_text("\n")
    except: pass
  lines=[ln.strip() for ln in text.splitlines()]
  drop={"facebook","twitter","linkedin","search","menu","subscribe","sign up","member portal"}
  clean=[]
  for ln in lines:
    low=ln.lower()
    if len(ln)<40: continue
    if any(k in low for k in drop): continue
    if low in {"introduction","abstract","keywords","conclusion"}: continue
    clean.append(ln)
  return "\n".join(clean)

# --- Data types ---
@dataclass
class Chunk:
  id: str
  source_type: str  # "local" or "web"
  source: str
  title: str
  chunk_index: int
  text: str

# --- LLM (Qwen2.5 4-bit) for HyDE + finalization ---
MODEL_ID="Qwen/Qwen2.5-7B-Instruct"
nf4=BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_use_double_quant=True,bnb_4bit_compute_dtype=torch.bfloat16)
tokenizer=AutoTokenizer.from_pretrained(MODEL_ID)
model=AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", quantization_config=nf4)

def chat(messages, max_new_tokens=850):
  prompt=tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  inputs=tokenizer([prompt], return_tensors="pt").to(model.device)
  with torch.no_grad():
    outputs=model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=tokenizer.eos_token_id)
  text = tokenizer.decode(outputs[0], skip_special_tokens=True)
  return text.split(prompt,1)[-1].strip() if prompt in text else text

# --- Retrieval: BGE-M3 + FAISS + BM25 + CrossEncoder rerank ---
EMBED_ID="Shitao/bge-m3"
class VectorIndex:
  def __init__(self, model_name=EMBED_ID, device="cpu"):
    self.model=SentenceTransformer(model_name, device=device)
    self.dim=self.model.get_sentence_embedding_dimension()
    self.index=faiss.IndexFlatIP(self.dim)
    self.meta: List[Chunk]=[]
    self._bm25=None; self._bm25_corpus=None
  def add(self, chunks: List[Chunk]):
    if not chunks: return
    texts=[c.text for c in chunks]
    embs=self.model.encode(texts, convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)
    self.index.add(embs); self.meta.extend(chunks)
    toks=[re.findall(r"\w+", t.lower()) for t in texts]
    self._bm25_corpus=toks if self._bm25 is None else (self._bm25_corpus + toks)
    self._bm25=BM25Okapi(self._bm25_corpus)
  def dense_search(self, queries: List[str], topk:int)->Dict[int,int]:
    ranks: Dict[int,int] = {}
    q_embs=self.model.encode(queries, convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)
    for q in q_embs:
      D, I = self.index.search(q[None,:], topk)
      for r, idx in enumerate(I[0].tolist()):
        if idx == -1: continue
        ranks[int(idx)] = min(ranks.get(int(idx), 10**9), r)
    return ranks
  def bm25_search(self, queries: List[str], topk:int)->Dict[int,int]:
    ranks: Dict[int,int] = {}
    if self._bm25 is None: return ranks
    for q in queries:
      tokens=re.findall(r"\w+", q.lower())
      scores=self._bm25.get_scores(tokens)
      top_idx = np.argsort(scores)[::-1][:topk]
      for r, idx in enumerate(top_idx.tolist()):
        ranks[int(idx)] = min(ranks.get(int(idx), 10**9), r)
    return ranks

reranker=CrossEncoder(RERANKER_ID, device=device)

def rrf_fuse(rank_lists: List[Dict[int,int]], kconst:int=RRF_K, topn:int=300)->List[int]:
  scores: Dict[int,float] = {}
  for ranks in rank_lists:
    for idx, rank in ranks.items():
      scores[idx] = scores.get(idx, 0.0) + 1.0 / (kconst + rank + 1)
  ordered = sorted(scores.items(), key=lambda t: t[1], reverse=True)
  return [i for i,_ in ordered[:topn]]

# --- Ingestion (local + broad web) ---
def ingest_local(directory=LOCAL_DIR)->List[Chunk]:
  chunks=[]; os.makedirs(directory, exist_ok=True)
  for root,_,files in os.walk(directory):
    for fname in files:
      path=os.path.join(root,fname)
      try: text=read_local_file(path)
      except: text=""
      if not text or len(text.strip())<40: continue
      for i,part in enumerate(chunk_text_sentence_aligned(text, CHUNK_SIZE)):
        cid=f"L{sha1(f'{fname}::{i}')[:8]}#{i}"
        chunks.append(Chunk(cid,"local",path,fname[:120],i,part))
  return chunks

# Wikipedia
def wiki_search_pages(query: str, limit:int=WIKI_MAX_PAGES, lang:str="en"):
  try:
    r=requests.get(f"https://{lang}.wikipedia.org/w/rest.php/v1/search/page", params={"q":query,"limit":limit}, headers={"User-Agent":"colab-rag-agent/0.1"}, timeout=TIMEOUT)
    if r.ok: return r.json().get("pages",[])
  except: return []
  return []
def wiki_get_plain(title: str, lang:str="en")->str:
  try:
    r=requests.get(f"https://{lang}.wikipedia.org/w/rest.php/v1/page/plain/{quote(title)}", headers={"User-Agent":"colab-rag-agent/0.1"}, timeout=TIMEOUT)
    if r.ok and r.text: return r.text
  except: return ""
  return ""
def ingest_wikipedia(query: str, max_pages=WIKI_MAX_PAGES, min_chars=80)->List[Chunk]:
  pages=wiki_search_pages(query, limit=max_pages, lang="en")
  titles=[]; seen=set()
  for p in pages:
    t=p.get("title") or p.get("key")
    if t and t not in seen: titles.append(t); seen.add(t)
  for t in WIKI_FALLBACK_TITLES:
    if t not in seen: titles.append(t)
  out=[]
  for title in titles[:max_pages]:
    text=wiki_get_plain(title,"en")
    if not text or len(text.strip())<min_chars: continue
    url=f"https://en.wikipedia.org/wiki/{quote(title)}"
    for j,part in enumerate(chunk_text_sentence_aligned(text, CHUNK_SIZE)[:6]):
      cid=f"W{sha1(f'{url}::{j}')[:8]}#{j}"
      out.append(Chunk(cid,"web",url,f"Wikipedia: {title}"[:120],j,part))
  return out

# OpenAlex
def reconstruct_openalex_abstract(inv_idx: Dict[str,List[int]])->str:
  if not inv_idx: return ""
  pos=[]
  for w, ps in inv_idx.items():
    for p in ps: pos.append((p,w))
  if not pos: return ""
  m=max(p for p,_ in pos); arr=[""]*(m+1)
  for p,w in pos: arr[p]=w
  return " ".join([w for w in arr if w])
def ingest_openalex(query: str, max_works=OPENALEX_MAX_WORKS, min_chars=80)->List[Chunk]:
  out=[]
  try:
    r=requests.get("https://api.openalex.org/works", params={"search":query,"per_page":min(50,max_works)}, headers={"User-Agent":"colab-rag-agent/0.1"}, timeout=TIMEOUT)
    if not r.ok: return out
    for w in r.json().get("results",[]):
      title=(w.get("display_name") or "OpenAlex Work")[:120]
      abstract=reconstruct_openalex_abstract(w.get("abstract_inverted_index") or {})
      if not abstract or len(abstract.strip())<min_chars: continue
      src=(w.get("open_access") or {}).get("oa_url") or (w.get("primary_location") or {}).get("landing_page_url") or w.get("id") or f"openalex:{w.get('id','')}"
      for j,part in enumerate(chunk_text_sentence_aligned(abstract, CHUNK_SIZE)[:6]):
        cid=f"W{sha1(f'{src}::{j}')[:8]}#{j}"
        out.append(Chunk(cid,"web",src,f"OpenAlex: {title}",j,part))
  except: return out
  return out

# arXiv
def ingest_arxiv(query: str, max_results=ARXIV_MAX_RESULTS, min_chars=80)->List[Chunk]:
  out=[]
  try:
    search=arxiv.Search(query=query, max_results=max_results, sort_by=arxiv.SortCriterion.Relevance)
    for res in search.results():
      title=(res.title or "arXiv")[:120]; summary=(res.summary or "").strip()
      if not summary or len(summary)<min_chars: continue
      src=res.entry_id or (res.doi or "")
      for j,part in enumerate(chunk_text_sentence_aligned(summary, CHUNK_SIZE)[:6]):
        cid=f"W{sha1(f'{src}::{j}')[:8]}#{j}"
        out.append(Chunk(cid,"web",src,f"arXiv: {title}",j,part))
  except: return out
  return out

# Crossref
def strip_html(s: str)->str: return BeautifulSoup(s,"html.parser").get_text(" ") if s else ""
def ingest_crossref(query: str, rows=CROSSREF_ROWS, min_chars=80)->List[Chunk]:
  out=[]
  try:
    r=requests.get("https://api.crossref.org/works", params={"query":query,"rows":rows}, timeout=TIMEOUT, headers={"User-Agent":"colab-rag-agent/0.1"})
    if not r.ok: return out
    for it in r.json().get("message",{}).get("items",[]):
      title=(" ".join(it.get("title") or []))[:120] or "Crossref Work"
      abstract=strip_html(it.get("abstract") or "")
      if not abstract or len(abstract.strip())<min_chars: continue
      src=it.get("URL") or (("https://doi.org/"+it["DOI"]) if it.get("DOI") else "crossref:unknown")
      for j,part in enumerate(chunk_text_sentence_aligned(abstract, CHUNK_SIZE)[:6]):
        cid=f"W{sha1(f'{src}::{j}')[:8]}#{j}"
        out.append(Chunk(cid,"web",src,f"Crossref: {title}",j,part))
  except: return out
  return out

# PubMed
def pubmed_esearch(query: str, retmax: int = PUBMED_MAX_PMIDS) -> List[str]:
  base = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
  params = {"db":"pubmed","term":query,"retmax":retmax,"retmode":"json"}
  try:
    r = requests.get(base, params=params, timeout=TIMEOUT, headers={"User-Agent":"colab-rag-agent/0.1"})
    if r.ok:
      data = r.json()
      return data.get("esearchresult",{}).get("idlist",[])
  except: return []
  return []
def pubmed_efetch_abstracts(pmids: List[str], batch: int = 20) -> List[str]:
  out=[]
  base = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
  for i in range(0, len(pmids), batch):
    ids = ",".join(pmids[i:i+batch])
    params = {"db":"pubmed","id":ids,"retmode":"text","rettype":"abstract"}
    try:
      r = requests.get(base, params=params, timeout=TIMEOUT, headers={"User-Agent":"colab-rag-agent/0.1"})
      if r.ok and r.text:
        parts = [a.strip() for a in re.split(r"\n{2,}", r.text) if len(a.strip())>60]
        out.extend(parts)
    except: pass
  return out
def ingest_pubmed(query: str, max_pmids: int = PUBMED_MAX_PMIDS, min_chars=80) -> List[Chunk]:
  pmids = pubmed_esearch(query, retmax=max_pmids)
  abstracts = pubmed_efetch_abstracts(pmids, batch=20)
  out=[]
  for j,ab in enumerate(abstracts):
    if not ab or len(ab.strip())<min_chars: continue
    src = "https://pubmed.ncbi.nlm.nih.gov/"
    for k, part in enumerate(chunk_text_sentence_aligned(ab, CHUNK_SIZE)[:2]):
      cid=f"W{sha1(f'{src}::{j}:{k}')[:8]}#{k}"
      out.append(Chunk(cid,"web",src,"PubMed abstract",k,part))
  return out

# Authorities
def ingest_authorities(min_chars=120)->List[Chunk]:
  out=[]
  for url in AUTHORITY_URLS:
    text=fetch_clean_text(url)
    if not text or len(text.strip())<min_chars: continue
    title=host_of(url) or "Authority"
    for j,part in enumerate(chunk_text_sentence_aligned(text, CHUNK_SIZE)[:6]):
      cid=f"W{sha1(f'{url}::{j}')[:8]}#{j}"
      out.append(Chunk(cid,"web",url,f"Authority: {title}",j,part))
  return out

def ingest_web(query: str, min_chars=80)->List[Chunk]:
  wiki=ingest_wikipedia(query, max_pages=WIKI_MAX_PAGES, min_chars=min_chars)
  openalex=ingest_openalex(query, max_works=OPENALEX_MAX_WORKS, min_chars=min_chars)
  arx=ingest_arxiv(query, max_results=ARXIV_MAX_RESULTS, min_chars=min_chars)
  cr=ingest_crossref(query, rows=CROSSREF_ROWS, min_chars=min_chars)
  pm=ingest_pubmed(query, max_pmids=PUBMED_MAX_PMIDS, min_chars=min_chars)
  auth=ingest_authorities(min_chars=120)
  return auth + wiki + openalex + arx + cr + pm

# --- Query variants and HyDE ---
def make_variants(question: str, n:int=QUERY_VARIANTS)->List[str]:
  prompt=f"Paraphrase the question into {n} short, keyword-rich search queries (one per line). Keep focus on the exact task, entities, and required outputs."
  out=chat([{"role":"user","content":prompt+"\nQuestion: "+question}], max_new_tokens=120)
  lines=[l.strip("-• ").strip() for l in out.splitlines() if l.strip()]
  return lines[:n] if lines else []

def make_hyde(question: str, n:int=NUM_HYDE)->List[str]:
  prompt=f"Write {n} short factual paragraphs that directly answer the question with precise domain terminology and minimal preamble."
  out=chat([{"role":"user","content":prompt+"\nQuestion: "+question}], max_new_tokens=280)
  paras=[p.strip() for p in re.split(r"\n{2,}", out) if p.strip()]
  return paras[:n] if paras else []

# --- Retrieval pipeline (Hybrid + RRF + strong rerank) ---
def retrieve(vindex:"VectorIndex", question:str)->List[Tuple[Chunk,float]]:
  queries=[question] + make_variants(question, n=QUERY_VARIANTS) + make_hyde(question, n=NUM_HYDE)
  dense=vindex.dense_search(queries, topk=DENSE_TOP)
  bm25=vindex.bm25_search(queries, topk=BM25_TOP)
  fused=rrf_fuse([dense,bm25], kconst=RRF_K, topn=300)
  pairs=[(question, vindex.meta[idx].text) for idx in fused]
  scores=reranker.predict(pairs)
  rescored=[(vindex.meta[idx], float(s)) for idx,s in zip(fused, scores)]
  rescored=[(c,s) for (c,s) in rescored if s >= RERANK_MIN]
  rescored.sort(key=lambda t: t[1], reverse=True)
  return rescored

# --- LLM finalization + single-line normalization ---
def select_chunks_for_llm(question:str, reranked: List[Tuple[Chunk,float]], max_items:int=28)->List[Chunk]:
  auth_hosts={urlparse(u).netloc.lower() for u in AUTHORITY_URLS}
  web=[c for (c,_) in reranked if c.source_type=="web"]
  loc=[c for (c,_) in reranked if c.source_type=="local"]
  auth=[c for c in web if host_of(c.source) in auth_hosts]
  other=[c for c in web if host_of(c.source) not in auth_hosts]
  ordered = auth + other + loc
  out=[]; seen=set(); per_domain={}
  def ok_domain(src):
    d=host_of(src) if src.startswith("http") else "local"
    per_domain.setdefault(d,0)
    if per_domain[d]>=DOMAIN_CAP: return False
    per_domain[d]+=1; return True
  for c in ordered:
    if len(out)>=max_items: break
    if (c.source, c.chunk_index) in seen: continue
    if not ok_domain(c.source): continue
    seen.add((c.source, c.chunk_index))
    out.append(c)
  return out

def fold_bullets_to_single_lines(md_text:str, bullet:str=BULLET, max_items:int=MAX_BULLETS)->List[str]:
  # Collapse wrapped lines so each bullet is exactly one line
  lines = [ln.rstrip() for ln in md_text.splitlines()]
  bullets=[]; cur=None
  for ln in lines:
    if not ln.strip(): continue
    if ln.lstrip().startswith(bullet):
      if cur is not None:
        bullets.append(cur)
      cur = ln.strip()
    else:
      if cur is not None:
        cur = re.sub(r"\s+", " ", (cur + " " + ln.strip()))
  if cur is not None:
    bullets.append(cur)
  out=[]
  for b in bullets:
    s=b.strip()
    if not s.startswith(bullet):
      s=f"{bullet} {s.lstrip('-• ')}"
    s=re.sub(r"\s+", " ", s)
    out.append(s)
    if len(out)>=max_items: break
  return out

def llm_generate_markdown_singleline(question:str, chunks:List[Chunk], min_bullets:int=4, max_bullets:int=12)->str:
  def shorten(s, n=700):
    s=s.strip().replace("\n"," ")
    return (s[:n]+"…") if len(s)>n else s
  items=[]
  for i,c in enumerate(chunks, start=1):
    items.append(f"[{i}] Source: {c.source}\nExcerpt: {shorten(c.text)}")
  context = "\n\n".join(items)
  sys_msg = "You are a precise research assistant; output only the requested Markdown."
  user_msg = f"""
Task: From the context items, write concise Markdown bullets that directly answer the question.
Question: {question}

Context (each item has an id and a Source + Excerpt):
{context}

Output requirements:
- Output ONLY Markdown bullets (no prose, no headings), each starting with "{BULLET} ".
- Each bullet MUST be a single line (no internal line breaks).
- Write {min_bullets} to {max_bullets} bullets that are directly supported by the excerpts.
- After each bullet’s statement, append " — Source: <exact Source>" using the Source string from the corresponding item.
- Prefer authoritative or highly relevant sources; avoid redundancy; do not invent facts.
"""
  raw_md = chat([{"role":"system","content":sys_msg},{"role":"user","content":user_msg}], max_new_tokens=900)
  single_lines = fold_bullets_to_single_lines(raw_md, bullet=BULLET, max_items=max_bullets)
  if len(single_lines) < min_bullets:
    lines=[ln.strip() for ln in raw_md.splitlines() if ln.strip()]
    picks=[ln if ln.startswith(BULLET) else f"{BULLET} {ln}" for ln in lines][:max_bullets]
    picks=[re.sub(r"\s+"," ", p) for p in picks]
    single_lines = picks
  return "\n".join(single_lines[:max_bullets])

def render_as_markdown_list(single_line_bullets: List[str]) -> str:
  # Convert "• text — Source: ..." to real Markdown list items "- text — Source: ..."
  items=[]
  for ln in single_line_bullets:
    clean = re.sub(r'^\s*[•*\-]\s*', '', ln)
    items.append(f"- {clean}")
  # Prepend a newline so Markdown list renders properly
  return "\n" + "\n".join(items)

# --- Core runner (no prints) ---
def run_agent_markdown_singleline(question: str, local_dir=LOCAL_DIR):
  local_chunks=ingest_local(local_dir)
  web_chunks=ingest_web(question, min_chars=80)
  vindex=VectorIndex(model_name=EMBED_ID, device=device)
  vindex.add(local_chunks); vindex.add(web_chunks)
  reranked=retrieve(vindex, question)
  selected = select_chunks_for_llm(question, reranked, max_items=28)
  markdown = llm_generate_markdown_singleline(question, selected, min_bullets=4, max_bullets=MAX_BULLETS)
  # Enforce single-line and render as Markdown list for proper line separation in Gradio
  lines = fold_bullets_to_single_lines(markdown, bullet=BULLET, max_items=MAX_BULLETS)
  return render_as_markdown_list(lines)

# --- Gradio UI ---
def prepare_local_dir(uploaded_files):
  os.makedirs(LOCAL_DIR, exist_ok=True)
  # Clear previous files to avoid stale context
  for fn in os.listdir(LOCAL_DIR):
    try: os.remove(os.path.join(LOCAL_DIR, fn))
    except: pass
  # Save new uploads
  if not uploaded_files: return
  if isinstance(uploaded_files, list):
    for f in uploaded_files:
      if not f: continue
      src = f.name if hasattr(f, "name") else str(f)
      dst = os.path.join(LOCAL_DIR, os.path.basename(src))
      try:
        shutil.copyfile(src, dst)
      except:
        try:
          with open(src, "rb") as rf, open(dst, "wb") as wf: wf.write(rf.read())
        except: pass
  else:
    f = uploaded_files
    src = f.name if hasattr(f, "name") else str(f)
    dst = os.path.join(LOCAL_DIR, os.path.basename(src))
    try:
      shutil.copyfile(src, dst)
    except:
      try:
        with open(src, "rb") as rf, open(dst, "wb") as wf: wf.write(rf.read())
      except: pass

def answer_fn(question, uploads):
  prepare_local_dir(uploads)
  md = run_agent_markdown_singleline(question, local_dir=LOCAL_DIR)
  return md  # already proper Markdown list with one item per line

with gr.Blocks(title="Multi-Document Research Agent") as demo:
  gr.Markdown("### Multi-Document Research Agent\nEnter a complex research question, optionally upload local files (PDF/MD/TXT), and get single-line Markdown bullets with sources.", elem_id="title")
  with gr.Row():
    q = gr.Textbox(label="Question", lines=3, placeholder="Ask a complex research question...")
  with gr.Row():
    uploads = gr.File(label="Upload local files (optional)", file_count="multiple")
  run_btn = gr.Button("Answer")
  out_md = gr.Markdown(label="Answers")
  run_btn.click(fn=answer_fn, inputs=[q, uploads], outputs=out_md)

demo.launch(share=False)


  missing = [p for p in reqs if pkgutil.find_loader(p.replace("-","_")) is None]
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/795 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.27G [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/964 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Note: opening Chrome Inspector may crash demo inside Colab notebooks.
* To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>

