In [1]:
# =============================================================================
# FINAL RAG Wikipedia Answerer (Best Stable Version)
# Fully Spec Compliant | Strong Grounding | Deterministic
# =============================================================================

import re
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
from sklearn.metrics.pairwise import cosine_similarity

# ---------------- CONFIG ----------------

CHUNK_TOKENS = 400
OVERLAP = 80
TOP_K = 3
TOP_SENTENCES = 5
THRESHOLD = 0.60
MAX_CHARS = 50_000_000
BATCH_SIZE = 32
MAX_PROMPT_TOKENS = 460

DATA_PATH = "/kaggle/input/datasets/vaibhavchourasia2611/wikipedia-english/AllCombined.txt"
EMBED_MODEL = "intfloat/e5-large"
LLM_MODEL = "google/flan-t5-base"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ---------------- LOAD DATA ----------------

print("Loading dataset...")
with open(DATA_PATH, "r", encoding="utf-8") as f:
    raw_text = f.read(MAX_CHARS)
print(f"Loaded {len(raw_text):,} chars")

# ---------------- ARTICLE SPLIT ----------------

articles = []
current_title = None
current_text = []

for line in raw_text.split("\n"):
    line = line.strip()
    if not line:
        continue
    if len(line) < 80 and not line.endswith(".") and not line.endswith(","):
        if current_title and current_text:
            articles.append((current_title, " ".join(current_text)))
        current_title = line
        current_text = []
    else:
        current_text.append(line)

if current_title and current_text:
    articles.append((current_title, " ".join(current_text)))

print(f"Articles: {len(articles)}")

# ---------------- TOKENIZERS ----------------

embed_tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL)
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)

# ---------------- CHUNKING ----------------

chunks = []
chunk_titles = []
stride = CHUNK_TOKENS - OVERLAP

for title, text in articles:
    paras = [p.strip() for p in re.split(r"\n+", text) if len(p.strip()) > 20]
    ids = []
    for p in paras:
        ids.extend(embed_tokenizer.encode(
            p, add_special_tokens=False, max_length=512, truncation=True))
    for i in range(0, len(ids), stride):
        piece = ids[i:i+CHUNK_TOKENS]
        if len(piece) >= 80:
            chunks.append(embed_tokenizer.decode(piece, skip_special_tokens=True))
            chunk_titles.append(title.lower())

print(f"Chunks: {len(chunks)}")

# ---------------- EMBEDDING ----------------

embed_model = AutoModel.from_pretrained(EMBED_MODEL).to(DEVICE)
embed_model.eval()

def encode(texts, is_query=False):
    prefix = "query: " if is_query else "passage: "
    prefixed = [prefix + t for t in texts]
    enc = embed_tokenizer(prefixed, return_tensors="pt",
                          truncation=True, max_length=192, padding=True)
    enc = {k: v.to(DEVICE) for k, v in enc.items()}
    with torch.no_grad():
        hidden = embed_model(**enc).last_hidden_state
    mask = enc["attention_mask"].unsqueeze(-1).float()
    pooled = (hidden * mask).sum(1) / mask.sum(1)
    vecs = pooled.cpu().numpy().astype("float32")
    return vecs / (np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-9)

print("Building index...")
emb_list = []
for i in range(0, len(chunks), BATCH_SIZE):
    emb_list.append(encode(chunks[i:i+BATCH_SIZE], False))
chunk_embeddings = np.vstack(emb_list)
print("Index shape:", chunk_embeddings.shape)

# ---------------- LLM ----------------

llm = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL).to(DEVICE)
llm.eval()

# ---------------- RETRIEVE ----------------

def retrieve(question):
    q_vec = encode([question], True)
    scores = cosine_similarity(q_vec, chunk_embeddings)[0]

    q_lower = question.lower()
    boosted = scores.copy()

    for i, title in enumerate(chunk_titles):
        if title and title in q_lower:
            boosted[i] += 0.25

    idx = np.argsort(boosted)[::-1][:TOP_K]
    return [chunks[i] for i in idx], scores[idx], round(float(max(scores)),2)

# ---------------- REFUSAL ----------------

def should_refuse(question, raw_scores, top_chunks):
    if max(raw_scores) < THRESHOLD:
        return True

    q_clean = re.sub(r"^(who|what|where|when|how|why|is|was|are|were)\s+",
                     "", question, flags=re.IGNORECASE).strip()
    proper = re.findall(r"\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)\b", q_clean)

    if proper:
        combined = " ".join(top_chunks)
        for name in proper:
            if re.search(re.escape(name), combined, re.IGNORECASE):
                return False
            surname = name.split()[-1]
            if re.search(r"\b" + re.escape(surname) + r"\b", combined, re.IGNORECASE):
                return False
        return True

    return False

# ---------------- SENTENCE FILTERING ----------------

FINITE_V = re.compile(
    r"\b(is|are|was|were|has|have|had|invented|discovered|born|died|created|"
    r"made|wrote|built|refers|known|called|converts|uses|became|worked|won|directed)\b",
    re.IGNORECASE)

def is_good(s):
    if len(s.split()) < 6: return False
    if not s[0].isupper(): return False
    if not FINITE_V.search(s): return False
    if re.search(r"formula_\d", s): return False
    return True

def select_sentences(question, top_chunks):
    sentences = []
    for chunk in top_chunks:
        for s in re.split(r"(?<=[.!?])\s+", chunk):
            s = s.strip()
            if s and is_good(s):
                sentences.append(s)

    if not sentences:
        return " ".join(top_chunks)

    q_vec = encode([question], True)
    s_vecs = encode(sentences, False)
    scores = cosine_similarity(q_vec, s_vecs)[0]

    idx = np.argsort(scores)[::-1][:TOP_SENTENCES]
    return " ".join([sentences[i] for i in idx])

# ---------------- GENERATE ----------------

def generate(question, context):
    prefix = (
        "Answer the question in 2 to 3 complete sentences.\n"
        "Use only the facts written in the context.\n"
        "Do not add any new information.\n"
        "Rewrite in simple English.\n\n"
        "Context:\n"
    )
    suffix = f"\nQuestion: {question}\nAnswer:"
    overhead = len(llm_tokenizer.encode(prefix + suffix))
    budget = MAX_PROMPT_TOKENS - overhead

    ctx_ids = llm_tokenizer.convert_tokens_to_ids(llm_tokenizer.tokenize(context))
    if len(ctx_ids) > budget:
        context = llm_tokenizer.decode(ctx_ids[:budget], skip_special_tokens=True)

    prompt = prefix + context + suffix

    inputs = llm_tokenizer(prompt, return_tensors="pt",
                           truncation=True,
                           max_length=MAX_PROMPT_TOKENS).to(DEVICE)

    with torch.no_grad():
        out = llm.generate(
            **inputs,
            min_new_tokens=80,
            max_new_tokens=120,
            do_sample=False,
            num_beams=4,
            no_repeat_ngram_size=3,
            length_penalty=1.5,
            early_stopping=True
        )

    return llm_tokenizer.decode(out[0], skip_special_tokens=True)

# ---------------- POST PROCESS ----------------

def post_process(text):
    text = re.sub(r"\([^)]*\)", "", text)
    text = re.sub(r"\[[^\]]*\]", "", text)
    text = re.sub(r"formula_\d+", "", text)
    text = re.sub(r"\s-\s", "-", text)
    text = re.sub(r"\s+,", ",", text)
    text = re.sub(r"\s+\.", ".", text)
    text = re.sub(r"\s+", " ", text).strip()

    sents = [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
    result = " ".join(sents[:3])

    if result and result[-1] not in ".!?":
        result += "."
    return result

# ---------------- FULL PIPELINE ----------------

def answer_question(question):
    top_chunks, raw_scores, confidence = retrieve(question)

    print("\nQuestion :", question)
    print("Confidence:", confidence)

    if should_refuse(question, raw_scores, top_chunks):
        print("Answer    : Not enough information in the Simple Wikipedia dataset.")
        return

    context = select_sentences(question, top_chunks)
    raw = generate(question, context)
    answer = post_process(raw)

    if not answer:
        answer = "Not enough information in the Simple Wikipedia dataset."

    print("Answer    :", answer)

# ---------------- TEST ----------------

print("="*60)
print(f"embed={EMBED_MODEL} | llm={LLM_MODEL}")
print("="*60)

for q in [
    "What is photosynthesis?",
    "Who was Albert Einstein?",
    "How does gravity work?",
    "Who invented the telephone?",
]:
    answer_question(q)


Device: cuda
Loading dataset...
Loaded 50,000,000 chars
Articles: 47817
Chunks: 32409


2026-02-17 06:47:15.911668: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1771310835.934225     191 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1771310835.940755     191 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1771310835.957234     191 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771310835.957251     191 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771310835.957254     191 computation_placer.cc:177] computation placer alr

Building index...
Index shape: (32409, 1024)


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

embed=intfloat/e5-large | llm=google/flan-t5-base


Token indices sequence length is longer than the specified maximum sequence length for this model (770 > 512). Running this sequence through the model will result in indexing errors



Question : What is photosynthesis?
Confidence: 0.9
Answer    : Photosynthesis is an endothermic chemical process which uses sunlight to turn carbon dioxide into sugars. The sugars are used by the cell as energy, and to build other kinds of molecules. fundamentally, photosynthesis converts light energy into chemical energy.

Question : Who was Albert Einstein?
Confidence: 0.88
Answer    : he developed the theory of relativity. he won the nobel prize in physics in 1921 for theoretical physics. his most famous equation is formula _ 1 in which e is for energy, m for mass, c is the speed of light is therefore " energy " equals " mass " multiplied by " the speed _ of light " squared.

Question : How does gravity work?
Confidence: 0.89
Answer    : gravity, or gravitation, is one of the fundamental forces of the universe. it is an attraction, or pull, between any two objects with mass. we discuss it in three parts : some physicists think gravity is caused by gravitons, but they are still unsu