# Advanced RAG Notebook

In [1]:
import os
import time
import json
from typing import List, Dict, Any, Tuple, Optional

import numpy as np
import pandas as pd
import torch

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import TextLoader, PyPDFLoader

from dotenv import load_dotenv

load_dotenv()

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
if not OPENAI_API_KEY:
    raise ValueError("Set OPENAI_API_KEY in environment or .env file.")

# Global config
DATA_MODE = "squad"  # "squad" or "uploaded"
EMBEDDING_MODEL = "text-embedding-3-small"
LLM_MODEL = "gpt-4o-mini"  # change if you like

# Top-level pipeline knobs
STAGE1_K = 30
TOP_K_RERANKED = 5
DEFAULT_CHUNK_SIZE = 400
DEFAULT_CHUNK_OVERLAP = 80

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE


  from .autonotebook import tqdm as notebook_tqdm


'cpu'

In [2]:
# Cross-encoder reranker using bge-reranker-base
# If this model is too heavy on your machine, you can swap it for another or use LLM-based reranking.

RERANKER_MODEL_NAME = "BAAI/bge-reranker-base"

tokenizer = AutoTokenizer.from_pretrained(RERANKER_MODEL_NAME)
reranker_model = AutoModelForSequenceClassification.from_pretrained(RERANKER_MODEL_NAME).to(DEVICE)

def cross_encoder_rerank(
    query: str,
    docs: List[Document],
    top_k: int = TOP_K_RERANKED
) -> List[Document]:
    if not docs:
        return []

    pairs = [(query, d.page_content) for d in docs]
    inputs = tokenizer(
        [p[0] for p in pairs],
        [p[1] for p in pairs],
        padding=True,
        truncation=True,
        return_tensors="pt",
        max_length=512
    ).to(DEVICE)

    with torch.no_grad():
        scores = reranker_model(**inputs).logits.squeeze(-1).cpu().numpy()

    ranked_idx = np.argsort(-scores)  # descending
    top_docs = [docs[i] for i in ranked_idx[:top_k]]
    return top_docs


In [3]:
def get_text_splitter(chunk_size: int = DEFAULT_CHUNK_SIZE,
                      chunk_overlap: int = DEFAULT_CHUNK_OVERLAP) -> RecursiveCharacterTextSplitter:
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        separators=["\n\n", "\n", ". ", " ", ""],
    )
    return splitter


def get_embedding_model() -> OpenAIEmbeddings:
    return OpenAIEmbeddings(model=EMBEDDING_MODEL, api_key=OPENAI_API_KEY)


def get_llm(model_name: str = LLM_MODEL, temperature: float = 0.0) -> ChatOpenAI:
    return ChatOpenAI(
        model=model_name,
        temperature=temperature,
        api_key=OPENAI_API_KEY,
    )


In [4]:
def load_squad_subset(max_examples: int = 2000) -> Tuple[List[Document], pd.DataFrame]:
    """
    Load a subset of SQuAD v1.1 and convert contexts to Documents.
    Also returns a small dataframe of QA pairs for evaluation.
    """
    ds = load_dataset("squad", split="train[:10%]")  # small subset
    ds = ds.shuffle(seed=42).select(range(min(max_examples, len(ds))))

    contexts = []
    qa_rows = []

    for ex in ds:
        context = ex["context"]
        q = ex["question"]
        ans_texts = ex["answers"]["text"]
        ans = ans_texts[0] if ans_texts else ""

        contexts.append(context)
        qa_rows.append({
            "id": ex["id"],
            "context": context,
            "question": q,
            "answer": ans
        })

    # dedupe contexts
    unique_contexts = list({c: True for c in contexts}.keys())

    docs = [Document(page_content=c, metadata={"source": f"squad_paragraph_{i}"})
            for i, c in enumerate(unique_contexts)]

    qa_df = pd.DataFrame(qa_rows)
    return docs, qa_df


In [5]:
from pathlib import Path

def load_uploaded_docs(data_dir: str = "./data/uploaded") -> List[Document]:
    """
    Load user docs from a folder. Supports .txt, .md, .pdf out of the box.
    You can add more loaders as needed.
    """
    base = Path(data_dir)
    if not base.exists():
        raise ValueError(f"{data_dir} does not exist. Create it and drop files there.")

    docs: List[Document] = []

    for path in base.rglob("*"):
        if path.is_dir():
            continue

        ext = path.suffix.lower()

        if ext in [".txt", ".md"]:
            loader = TextLoader(str(path), encoding="utf-8")
        elif ext == ".pdf":
            loader = PyPDFLoader(str(path))
        else:
            # ignore unsupported types for now
            continue

        docs.extend(loader.load())

    if not docs:
        raise ValueError(f"No supported files found in {data_dir}. Add txt/md/pdf files.")

    # We don't have ground truth answers for uploaded docs; evaluation will be question-only.
    return docs


In [6]:
def build_vectorstore(
    docs: List[Document],
    chunk_size: int = DEFAULT_CHUNK_SIZE,
    chunk_overlap: int = DEFAULT_CHUNK_OVERLAP
) -> Tuple[FAISS, List[Document]]:
    splitter = get_text_splitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    chunks = splitter.split_documents(docs)

    embeddings = get_embedding_model()
    vectordb = FAISS.from_documents(chunks, embedding=embeddings)
    return vectordb, chunks


In [7]:
def rewrite_query(query: str) -> str:
    llm = get_llm()
    system_msg = (
        "You are a helpful assistant that rewrites user queries for better retrieval. "
        "Expand abbreviations and add key synonyms, but keep the core information need unchanged. "
        "Return ONLY the rewritten query text, no explanations."
    )
    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": query}
    ]
    resp = llm.invoke(messages)
    return resp.content.strip()


In [8]:
def stage1_retrieve(
    vectordb: FAISS,
    query: str,
    k: int = STAGE1_K
) -> List[Document]:
    docs = vectordb.similarity_search(query, k=k)
    return docs


In [9]:
def rerank(
    query: str,
    candidates: List[Document],
    top_k: int = TOP_K_RERANKED
) -> List[Document]:
    if not candidates:
        return []
    return cross_encoder_rerank(query, candidates, top_k=top_k)


In [10]:
def compress_context(
    query: str,
    docs: List[Document],
    max_tokens_hint: int = 800
) -> str:
    """
    Abstractive compression: produce a concise, focused brief
    that keeps only information needed to answer the query.
    """
    if not docs:
        return ""

    llm = get_llm()

    joined = "\n\n".join(
        [f"[DOC {i} | source={d.metadata.get('source','unknown')}]\n{d.page_content}"
         for i, d in enumerate(docs)]
    )

    system_msg = (
        "You are a careful information compressor. Given a user query and multiple text chunks, "
        "you produce a concise but information-rich summary that retains ONLY the details relevant "
        "to answering the query. Keep factual details and numbers. Avoid general commentary."
    )

    prompt = f"""User query:
{query}

Raw retrieved text:
{joined}

Compress the above into a single focused brief (<= {max_tokens_hint} tokens)
containing only information directly relevant to answering the query.
"""

    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": prompt}
    ]

    resp = llm.invoke(messages)
    return resp.content.strip()


In [11]:
def generate_answer(
    query: str,
    compressed_context: str
) -> str:
    llm = get_llm()

    system_msg = (
        "You are a question answering system that MUST stay grounded in the provided context. "
        "Follow these rules:\n"
        "- Use ONLY the context to answer.\n"
        "- If the answer is not in the context, say you don't know.\n"
        "- When possible, quote short phrases from the context and mention which part you used.\n"
        "- Do not invent facts.\n"
    )

    prompt = f"""Context to use (may be compressed):
\"\"\"{compressed_context}\"\"\"

User question:
{query}

Instructions:
- Base your answer ONLY on the context above.
- If the answer is unknown or not clearly supported, explicitly say so.
- Keep the answer concise but clear.
"""

    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": prompt}
    ]

    resp = llm.invoke(messages)
    return resp.content.strip()


In [12]:
def evaluate_answer(
    query: str,
    answer: str,
    context: str,
    ground_truth: Optional[str] = None
) -> Dict[str, Any]:
    """
    Use an LLM-judge to score:
    - correctness (1-5)
    - groundedness (1-5)
    - relevance (1-5)
    Also return a short explanation.
    """
    llm = get_llm()

    system_msg = (
        "You are an evaluation assistant. "
        "You will be given a question, a context, a model answer, and optionally a ground truth answer. "
        "You must score the model's answer on:\n"
        "- correctness: is it factually correct vs ground truth/context? (1-5)\n"
        "- groundedness: is it clearly supported by the context? (1-5)\n"
        "- relevance: is it focused on the question? (1-5)\n"
        "Return a JSON object with keys: correctness, groundedness, relevance, explanation."
    )

    gt_text = ground_truth if ground_truth is not None else "(no ground truth provided)"

    user_prompt = f"""Question: {query}

Ground truth answer (may be empty): {gt_text}

Context:
\"\"\"{context}\"\"\"

Model answer:
\"\"\"{answer}\"\"\"

Now output a JSON object like:
{{
  "correctness": 1-5 integer,
  "groundedness": 1-5 integer,
  "relevance": 1-5 integer,
  "explanation": "short explanation here"
}}
"""

    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": user_prompt}
    ]

    resp = llm.invoke(messages)
    raw = resp.content.strip()

    # basic JSON parsing with fallback
    try:
        data = json.loads(raw)
    except Exception:
        # Try to extract JSON substring
        start = raw.find("{")
        end = raw.rfind("}")
        if start != -1 and end != -1:
            try:
                data = json.loads(raw[start:end+1])
            except Exception:
                data = {}
        else:
            data = {}

    def safe_int(key, default=3):
        try:
            return int(data.get(key, default))
        except Exception:
            return default

    return {
        "correctness": safe_int("correctness"),
        "groundedness": safe_int("groundedness"),
        "relevance": safe_int("relevance"),
        "explanation": data.get("explanation", raw[:300])
    }


In [13]:
class FeedbackMemory:
    def __init__(self):
        self.docs: List[Document] = []
        self.vectordb: Optional[FAISS] = None

    def rebuild_index(self):
        if not self.docs:
            self.vectordb = None
            return
        embeddings = get_embedding_model()
        self.vectordb = FAISS.from_documents(self.docs, embeddings)

    def add_feedback(self, query: str, corrected_fact: str):
        doc = Document(
            page_content=corrected_fact,
            metadata={"source": "user_feedback", "query": query}
        )
        self.docs.append(doc)
        self.rebuild_index()

    def retrieve(self, query: str, k: int = 3) -> List[Document]:
        if self.vectordb is None:
            return []
        return self.vectordb.similarity_search(query, k=k)


memory_store = FeedbackMemory()


In [14]:
def retrieve_with_memory(
    vectordb: FAISS,
    query: str,
    stage1_k: int = STAGE1_K,
    top_k_reranked: int = TOP_K_RERANKED,
) -> List[Document]:
    main_docs = stage1_retrieve(vectordb, query, k=stage1_k)
    mem_docs = memory_store.retrieve(query, k=5)

    combined = main_docs + mem_docs
    reranked = rerank(query, combined, top_k=top_k_reranked)
    return reranked


In [16]:
def run_single_query_pipeline(
    vectordb: FAISS,
    query: str,
    ground_truth: Optional[str] = None,
    use_memory: bool = True
) -> Dict[str, Any]:
    timings = {}
    t0 = time.time()

    # Rewrite
    t_start = time.time()
    rewritten = rewrite_query(query)
    timings["rewrite_ms"] = int((time.time() - t_start) * 1000)

    # Retrieval
    t_start = time.time()
    if use_memory:
        top_docs = retrieve_with_memory(vectordb, rewritten)
    else:
        candidates = stage1_retrieve(vectordb, rewritten, k=STAGE1_K)
        top_docs = rerank(rewritten, candidates, top_k=TOP_K_RERANKED)
    timings["retrieval_plus_rerank_ms"] = int((time.time() - t_start) * 1000)

    # Compression
    t_start = time.time()
    compressed = compress_context(rewritten, top_docs)
    timings["compression_ms"] = int((time.time() - t_start) * 1000)

    # Generation
    t_start = time.time()
    answer = generate_answer(query, compressed)
    timings["generation_ms"] = int((time.time() - t_start) * 1000)

    # Evaluation
    t_start = time.time()
    eval_result = evaluate_answer(
        query=query,
        answer=answer,
        context=compressed,
        ground_truth=ground_truth
    )
    timings["evaluation_ms"] = int((time.time() - t_start) * 1000)

    timings["total_ms"] = int((time.time() - t0) * 1000)

    # Pretty print
    print("=" * 80)
    print("Original query:", query)
    print("Rewritten query:", rewritten)
    print("-" * 80)
    print("Top documents (sources):")
    for i, d in enumerate(top_docs):
        print(f"[{i}] source={d.metadata.get('source','unknown')}")
    print("-" * 80)
    print("Compressed context (preview):")
    print(compressed[:1000] + ("..." if len(compressed) > 1000 else ""))
    print("-" * 80)
    if ground_truth is not None:
        print("Ground truth answer:", ground_truth)
        print("-" * 80)
    print("Model answer:")
    print(answer)
    print("-" * 80)
    print("Evaluation scores:")
    print(f"  correctness:  {eval_result['correctness']}")
    print(f"  groundedness: {eval_result['groundedness']}")
    print(f"  relevance:    {eval_result['relevance']}")
    print(f"  explanation:  {eval_result['explanation']}")
    print("-" * 80)
    print("Latency (ms):")
    for k, v in timings.items():
        print(f"  {k}: {v}")
    print("=" * 80)

    return {
        "query": query,
        "rewritten_query": rewritten,
        "answer": answer,
        "ground_truth": ground_truth,
        "compressed_context": compressed,
        "top_docs": top_docs,
        "eval": eval_result,
        "timings": timings,
    }


In [17]:
# Choose which mode to run
DATA_MODE = "squad"      # or "uploaded"

if DATA_MODE == "squad":
    base_docs, qa_df = load_squad_subset(max_examples=600)
    print(f"Loaded {len(base_docs)} SQuAD base documents and {len(qa_df)} QA pairs.")
    vectordb, all_chunks = build_vectorstore(base_docs, chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP)
    print(f"Vectorstore built with {len(all_chunks)} chunks.")
else:
    base_docs = load_uploaded_docs("./data/uploaded")
    print(f"Loaded {len(base_docs)} uploaded documents.")
    qa_df = None  # no ground truth
    vectordb, all_chunks = build_vectorstore(base_docs, chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP)
    print(f"Vectorstore built with {len(all_chunks)} chunks.")


Generating train split: 100%|██████████| 87599/87599 [00:00<00:00, 261895.11 examples/s]
Generating validation split: 100%|██████████| 10570/10570 [00:00<00:00, 416311.02 examples/s]


Loaded 496 SQuAD base documents and 600 QA pairs.
Vectorstore built with 1272 chunks.


In [18]:
results = []

if DATA_MODE == "squad":
    sample_df = qa_df.sample(3, random_state=42)  # pick 3 for quick test
    for _, row in sample_df.iterrows():
        q = row["question"]
        gt = row["answer"]
        out = run_single_query_pipeline(vectordb, q, ground_truth=gt, use_memory=True)
        row_result = {
            "query": q,
            "correct": out["eval"]["correctness"],
            "grounded": out["eval"]["groundedness"],
            "relevance": out["eval"]["relevance"],
            "total_ms": out["timings"]["total_ms"],
        }
        results.append(row_result)
else:
    # Example queries for your own docs
    example_queries = [
        "What is the main purpose of this document?",
        "Summarize the key steps in the process described.",
    ]
    for q in example_queries:
        out = run_single_query_pipeline(vectordb, q, ground_truth=None, use_memory=True)
        row_result = {
            "query": q,
            "correct": out["eval"]["correctness"],
            "grounded": out["eval"]["groundedness"],
            "relevance": out["eval"]["relevance"],
            "total_ms": out["timings"]["total_ms"],
        }
        results.append(row_result)

summary_df = pd.DataFrame(results)
summary_df


Original query: Who was the superior of prince Kublai?
Rewritten query: Who was the superior or leader of Prince Kublai Khan?
--------------------------------------------------------------------------------
Top documents (sources):
[0] source=squad_paragraph_105
[1] source=squad_paragraph_105
[2] source=squad_paragraph_105
[3] source=squad_paragraph_129
[4] source=squad_paragraph_170
--------------------------------------------------------------------------------
Compressed context (preview):
Prince Kublai Khan's superior was Ögedei Khan, who granted Kublai a large appanage in North China starting in 1236. Kublai later ruled as Khagan from 1260 to 1294.
--------------------------------------------------------------------------------
Ground truth answer: Ögedei Khan
--------------------------------------------------------------------------------
Model answer:
Prince Kublai Khan's superior was Ögedei Khan.
--------------------------------------------------------------------------------
E

Unnamed: 0,query,correct,grounded,relevance,total_ms
0,Who was the superior of prince Kublai?,5,5,5,12230
1,"In turn, Chinese supporters have accused Weste...",5,5,5,13041
2,Which record label was the soundtrack album re...,5,5,5,9306


In [19]:
# Example: pick one query and simulate a user correction

if results:
    example_query = results[0]["query"]
    simulated_correction = "User correction: the correct answer is that the passage describes XYZ, not ABC."

    print("Adding feedback for query:")
    print(example_query)
    print("Correction:", simulated_correction)

    memory_store.add_feedback(example_query, simulated_correction)

    # Run again; memory is now part of retrieval
    print("\n--- Re-running pipeline with feedback memory ---\n")
    out_after = run_single_query_pipeline(vectordb, example_query, ground_truth=None, use_memory=True)


Adding feedback for query:
Who was the superior of prince Kublai?
Correction: User correction: the correct answer is that the passage describes XYZ, not ABC.

--- Re-running pipeline with feedback memory ---

Original query: Who was the superior of prince Kublai?
Rewritten query: Who was the superior or higher authority of Prince Kublai Khan?
--------------------------------------------------------------------------------
Top documents (sources):
[0] source=squad_paragraph_105
[1] source=squad_paragraph_105
[2] source=squad_paragraph_105
[3] source=squad_paragraph_170
[4] source=squad_paragraph_129
--------------------------------------------------------------------------------
Compressed context (preview):
Prince Kublai Khan's superior authority was Ögedei Khan, who granted him a large appanage in North China starting in 1236. Kublai later ruled as Khagan from 1260 to 1294. Additionally, Kublai recognized the Phagpa lama as a senior instructor in religious affairs, establishing a uniq

In [20]:
def build_vectorstore_with_chunk_size(chunk_size: int) -> FAISS:
    vectordb_tmp, _ = build_vectorstore(base_docs, chunk_size=chunk_size, chunk_overlap=int(chunk_size * 0.2))
    return vectordb_tmp

def compare_chunk_sizes(
    query: str,
    gt_answer: Optional[str] = None,
    chunk_sizes: Tuple[int, int] = (300, 800)
) -> pd.DataFrame:
    size_a, size_b = chunk_sizes

    db_a = build_vectorstore_with_chunk_size(size_a)
    db_b = build_vectorstore_with_chunk_size(size_b)

    out_a = run_single_query_pipeline(db_a, query, ground_truth=gt_answer, use_memory=False)
    out_b = run_single_query_pipeline(db_b, query, ground_truth=gt_answer, use_memory=False)

    rows = []
    for label, cs, out in [("A", size_a, out_a), ("B", size_b, out_b)]:
        rows.append({
            "variant": label,
            "chunk_size": cs,
            "correct": out["eval"]["correctness"],
            "grounded": out["eval"]["groundedness"],
            "relevance": out["eval"]["relevance"],
            "total_ms": out["timings"]["total_ms"],
        })

    return pd.DataFrame(rows), out_a, out_b


# Example for SQuAD mode
if DATA_MODE == "squad":
    row = qa_df.sample(1, random_state=123).iloc[0]
    q = row["question"]
    gt = row["answer"]

    cs_df, out_a, out_b = compare_chunk_sizes(q, gt_answer=gt, chunk_sizes=(300, 800))
    display(cs_df)


Original query: During which season of American Idol did Fox beat the other networks in ratings for the first time? 
Rewritten query: In which season of American Idol did the Fox network achieve higher television ratings than its competitors for the first time?
--------------------------------------------------------------------------------
Top documents (sources):
[0] source=squad_paragraph_149
[1] source=squad_paragraph_405
[2] source=squad_paragraph_322
[3] source=squad_paragraph_149
[4] source=squad_paragraph_322
--------------------------------------------------------------------------------
Compressed context (preview):
Fox achieved higher television ratings than its competitors for the first time during the seventh season of American Idol.
--------------------------------------------------------------------------------
Ground truth answer: season seven
--------------------------------------------------------------------------------
Model answer:
Fox beat the other networks in ra

Unnamed: 0,variant,chunk_size,correct,grounded,relevance,total_ms
0,A,300,5,5,5,9918
1,B,800,5,5,5,14979
