## RAG with Re-Ranking (Better Relevance)

How it works:
Instead of blindly sending the top search results, a re-ranker (often another small ML model) scores results for relevance.
Only the most relevant documents go to the LLM.
Example Use Case:In healthcare, a clinical assistant retrieves only the most relevant medical guidelines for a doctor’s query instead of dumping too much text.

User Query → Retriever → Re-Ranker (sort/filter results) → Top Docs → LLM → Answer

High-level idea (what the app does)

1. Load a PDF (default: Amazon Bedrock – User Guide.pdf).
2. Split it into chunks → embed chunks → store in FAISS (a vector index).
3. For a user question, retrieve a broad pool of chunks (top-N).
4. Re-rank that pool with a Cross-Encoder (query+text scorer) → keep best K.
5. Send those best K chunks as context to an LLM to generate a grounded answer.
6. Show the answer + sources in a simple Gradio UI.


In [None]:
import os
import tempfile
from dataclasses import dataclass
from typing import List, Tuple, Any, Dict

import gradio as gr
from dotenv import load_dotenv

# LangChain
from langchain.schema import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough, RunnableMap
from langchain_core.output_parsers import StrOutputParser

# re-ranking model.
from sentence_transformers import CrossEncoder
import torch

In [None]:

!uv add sentence-transformers

In [None]:
# Load environment variables in a file called .env

load_dotenv(override=True)
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')


In [None]:
DEFAULT_LOCAL_PDF = "Amazon SageMaker AI-Developer Guide.pdf"

In [None]:
# Purpose: Control how the LLM behaves.
# What they say:
# Be factual.
# Use only the provided context.
# If not present, say “I don’t know”.
# Keep it concise and optionally show short citations like [source].
# Why it matters: Keeps the LLM grounded and reduces hallucinations

SYSTEM_PROMPT = """You are a factual assistant. Answer ONLY using the provided context.
If the answer is not present, say you don't know. Keep answers concise.
Include short inline citations like [source] when helpful."""

PROMPT = ChatPromptTemplate.from_messages(
    [
        ("system", SYSTEM_PROMPT),
        ("human",
         "Question:\n{question}\n\n"
         "Context:\n{context}\n\n"
         "Answer using ONLY the context.")
    ]
)


In [None]:
# Purpose: If Gradio returns raw bytes (instead of a file path), this makes a temporary .pdf file and returns its path.
# How it works:
# Creates a temp folder, writes the bytes to uploaded.pdf.
# Returns that file path.
# Why it matters: Later steps (PDF loader) need a filesystem path, not raw bytes.

def _bytes_to_temp_pdf(file_bytes: bytes, suggested_name: str = "uploaded.pdf") -> str:
    """Write bytes to a temp .pdf and return the path."""
    tmp_dir = tempfile.mkdtemp()
    pdf_path = os.path.join(tmp_dir, suggested_name)
    with open(pdf_path, "wb") as f:
        f.write(file_bytes)
    return pdf_path

In [None]:
# Purpose: Robustly figure out the actual PDF path to load.
# How it works:
# If use_default=True, ensure Amazon Bedrock - User Guide.pdf exists and return it.
# Else, if user uploaded:
# If it’s a string path (Gradio type="filepath"), verify and return it.
# If it’s bytes, call _bytes_to_temp_pdf(...) and return that path.
# Otherwise, raise a helpful error.
# Why it matters: No matter how the file arrives, the rest of the code always gets a valid path.

def _resolve_pdf_path(pdf_input, use_default: bool) -> str:
    """
    Returns a filesystem path to a readable PDF.
    - use_default=True -> uses DEFAULT_LOCAL_PDF from current folder.
    - pdf_input can be a filepath (str) or bytes (if input type changed).
    """
    if use_default:
        if not os.path.exists(DEFAULT_LOCAL_PDF):
            raise FileNotFoundError(
                f"Default PDF not found: {DEFAULT_LOCAL_PDF} (place it in this folder or upload another file)."
            )
        return DEFAULT_LOCAL_PDF

    if pdf_input is None:
        raise ValueError("Please upload a PDF or check 'Use default PDF'.")

    if isinstance(pdf_input, str):
        if not os.path.exists(pdf_input):
            raise FileNotFoundError(f"Uploaded path not found: {pdf_input}")
        return pdf_input

    if isinstance(pdf_input, (bytes, bytearray)):
        return _bytes_to_temp_pdf(pdf_input)

    raise TypeError(f"Unsupported file input type: {type(pdf_input)}")

In [None]:
# Purpose: Takes a list of retrieved chunks and turns them into a single readable context string with simple citations.
# How it works:
# Loops each Document d.
# Builds labels like [Amazon Bedrock - User Guide.pdf (page 3)].
# Concatenates their page_content into one string the LLM can read.
# Why it matters: LLMs need a single text block of context; this function creates it and preserves where each chunk came from.

def format_docs(docs: List[Document]) -> str:
    """Formats retrieved docs into a single string with simple citations."""
    lines = []
    for i, d in enumerate(docs, start=1):
        src = d.metadata.get("source", f"doc_{i}")
        page = d.metadata.get("page", None)
        page_str = f" (page {page+1})" if isinstance(page, int) else ""
        lines.append(f"[{src}{page_str}] {d.page_content}")
        print (src)
        print (page)
        print (page_str)
        print (lines)
    return "\n\n".join(lines)
  

In [None]:
# Purpose: Pretty-prints which chunks/pages were used in the answer.
# How it works:
# For each Document, prints source (filename), page, and start_index (the character offset set by the splitter).
# Returns a readable bullet list.
# Why it matters: Learners see transparency—exactly which pages powered the answer.

def show_sources(docs: List[Document]) -> str:
    if not docs:
        return "No sources."
    lines = []
    for d in docs:
        src = d.metadata.get("source", "unknown")
        page = d.metadata.get("page", None)
        start_idx = d.metadata.get("start_index", "?")
        page_str = f"page {page+1}" if isinstance(page, int) else "page ?"
        lines.append(f"- {src} ({page_str}), start_char={start_idx}")
        print (src)
        print (page)
        print (start_idx)
        print (page_str)
        print (lines)
    return "\n".join(lines)

In [None]:
# =========================
# Indexing
# Purpose: Read the PDF and convert it to a list of LangChain Documents (one per page).
# How it works:
# Uses PyPDFLoader(pdf_path).load() → returns one Document per page, with metadata["page"] set.
# Adds metadata["source"] = filename to each doc (used for citations).
# Why it matters: Converts raw PDF into a structure that LangChain’s splitters and retrievers understand.

def load_pdf_as_docs(pdf_path: str) -> List[Document]:
    loader = PyPDFLoader(pdf_path)
    docs = loader.load()  # each page is a Document with page metadata
    for d in docs:
        d.metadata["source"] = os.path.basename(pdf_path)
    return docs

# Purpose: Build the vector search part of RAG.
# How it works:
# Split pages into overlapping chunks (RecursiveCharacterTextSplitter) so each chunk is a manageable size for embeddings/LLM.
# Embed chunks with OpenAIEmbeddings.
# Index them in FAISS: FAISS.from_documents(...).
# Create a retriever view with k=initial_k—this is the broad candidate pool (e.g., top-15) returned for each question.
# Return the retriever and some stats (pages, chunks, k, etc.) for the UI.
# Why it matters: This is classic “RAG retrieval”: given a query, quickly find similar chunks by vector similarity.

def build_vector_index(
    docs: List[Document],
    embed_model: str = "text-embedding-3-small",
    chunk_size: int = 500,
    chunk_overlap: int = 80,
    initial_k: int = 15,
) -> Tuple[Any, Dict[str, Any]]:
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size, chunk_overlap=chunk_overlap, add_start_index=True
    )
    chunks = splitter.split_documents(docs)

    embeddings = OpenAIEmbeddings(model=embed_model)
    vectorstore = FAISS.from_documents(chunks, embedding=embeddings)
    retriever = vectorstore.as_retriever(search_kwargs={"k": initial_k})

    stats = {
        "pages": len(docs),
        "chunks": len(chunks),
        "initial_k": initial_k,
        "chunk_size": chunk_size,
        "chunk_overlap": chunk_overlap,
        "embed_model": embed_model,
    }
    return retriever, stats

In [None]:
# =========================
# Re-Ranker
# Holds which Cross-Encoder model to use and how many chunks to keep (top_k) after re-ranking.
# =========================
@dataclass
class RerankerConfig:
    model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
    top_k: int = 5

# Purpose: Ranks the candidate chunks using a cross-encoder (a model that reads the query and the chunk together and outputs a relevance score).
# How it works:
# __init__: loads the sentence-transformers cross-encoder once (CPU or GPU).
# rerank(query, docs):
# 1. Build (query, doc_text) pairs,
# 2. self.model.predict(...) → get a score per pair,
# 3. Sort by score descending,
# 4. Return the top_k (Document, score) pairs (actually the code returns just Documents for the chain).
# Why it matters: Plain vector similarity is fast but can be fuzzy. Cross-encoders are more precise because they read the query and full chunk together, giving a better final shortlist.

class CrossEncoderReranker:
    """Re-ranks retrieved chunks using CrossEncoder scoring of (query, chunk)."""
    def __init__(self, cfg: RerankerConfig):
        self.cfg = cfg
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = CrossEncoder(cfg.model_name, device=device)

    def rerank(self, query: str, docs: List[Document]) -> List[Tuple[Document, float]]:
        if not docs:
            return []
        pairs = [(query, d.page_content) for d in docs]
        scores = self.model.predict(pairs)  # numpy array
        ranked = sorted(zip(docs, scores), key=lambda x: float(x[1]), reverse=True)
        print (scores)
        print (ranked)
        return ranked[: self.cfg.top_k]

In [None]:
# ===================================================
# RAG Chains
# ===================================================

# Purpose: A vanilla RAG chain without re-ranking.
# Steps (Runnable pipeline):
# Inputs: "question" (passthrough) + "docs" from retriever.
# Build the prompt inputs:
# "context" = format_docs(docs)
# "sources" = the same docs
# PROMPT → LLM → StrOutputParser to get plain text.
# Output "answer" and "sources".
# Why it matters: Lets learners compare with/without re-ranking.
# ===========================================================
def make_chain_no_rerank(retriever, model_name="gpt-4o-mini", temperature=0.0):
    llm = ChatOpenAI(model=model_name, temperature=temperature)
    chain = (
        RunnableMap({"question": RunnablePassthrough(), "docs": retriever})
        | RunnableMap({
            "question": lambda x: x["question"],
            "context": lambda x: format_docs(x["docs"]),
            "sources": lambda x: x["docs"],
        })
        | RunnableMap({
            "answer": ChatPromptTemplate.from_messages(
                [("system", SYSTEM_PROMPT),
                 ("human", "Question:\n{question}\n\nContext:\n{context}\n\nAnswer using ONLY the context.")]
            ) | llm | StrOutputParser(),
            "sources": lambda x: x["sources"]
        })
    )
    return chain

# Purpose: The RAG + Re-Ranking chain.
# Steps:
# Inputs: "question" and a broad pool of "candidates" from retriever (e.g., 15 chunks).
# Re-rank: reranker.rerank(question, candidates) → keep top-K most relevant chunks.
# Build the prompt inputs with those reranked chunks:
# "context" = format_docs(reranked)
# "sources" = reranked
# PROMPT → LLM → StrOutputParser to get the final answer.
# Output "answer" and "sources".
# Why it matters: This is the “Better Relevance” upgrade—improves precision by giving the LLM fewer but better chunks.

def make_chain_with_rerank(retriever, reranker: CrossEncoderReranker,
                           model_name="gpt-4o-mini", temperature=0.0):
    llm = ChatOpenAI(model=model_name, temperature=temperature)
    chain = (
        RunnableMap({"question": RunnablePassthrough(), "candidates": retriever})
        | RunnableMap({
            "question": lambda x: x["question"],
            "reranked": lambda x: [d for d, s in reranker.rerank(x["question"], x["candidates"])],
        })
        | RunnableMap({
            "question": lambda x: x["question"],
            "context": lambda x: format_docs(x["reranked"]),
            "sources": lambda x: x["reranked"],
        })
        | RunnableMap({
            "answer": PROMPT | llm | StrOutputParser(),
            "sources": lambda x: x["sources"]
        })
    )
    return chain

In [None]:
# =========================
# Gradio Callbacks
# =========================

# What it does:
# 1.API key: use the textbox value if provided; else rely on env var.
# 2.Resolve PDF path: _resolve_pdf_path(pdf_file, use_default).
# 3.Load & index:
# pages = load_pdf_as_docs(pdf_path)
# retriever, stats = build_vector_index(...) using chosen chunk sizes, embeddings, and pool size (pool_k).
# 4.Choose chain:
# If use_rerank=True, construct a CrossEncoderReranker with rerank_model and rerank_top_k, then call make_chain_with_rerank(...).
# Else, call make_chain_no_rerank(...).
# 5.Build a short summary string (pages, chunks, k, models) for the UI.
# 6.Return:
# chain (stored in a hidden Gradio State)
# index_summary (Markdown)
# status message.
# Why it matters: One-time setup per PDF; afterwards users can ask many questions.

def ui_build_index(api_key: str,
                   pdf_file,
                   use_default: bool,
                   model_name: str,
                   temperature: float,
                   embed_model: str,
                   chunk_size: int,
                   chunk_overlap: int,
                   pool_k: int,
                   use_rerank: bool,
                   rerank_model: str,
                   rerank_top_k: int):
    try:
        # API key
        if api_key:
            os.environ["OPENAI_API_KEY"] = api_key
        if not os.getenv("OPENAI_API_KEY"):
            raise RuntimeError("Missing OPENAI_API_KEY. Provide it in the textbox or environment.")

        # Resolve PDF path
        pdf_path = _resolve_pdf_path(pdf_file, use_default)

        # Load docs & build index
        pages = load_pdf_as_docs(pdf_path)
        retriever, stats = build_vector_index(
            pages, embed_model=embed_model,
            chunk_size=chunk_size, chunk_overlap=chunk_overlap,
            initial_k=pool_k
        )

        # Choose chain (with or without reranking)
        if use_rerank:
            reranker = CrossEncoderReranker(RerankerConfig(
                model_name=rerank_model, top_k=rerank_top_k
            ))
            chain = make_chain_with_rerank(retriever, reranker,
                                           model_name=model_name, temperature=temperature)
        else:
            chain = make_chain_no_rerank(retriever,
                                         model_name=model_name, temperature=temperature)

        summary = (
            f"✅ Index ready for **{os.path.basename(pdf_path)}**\n"
            f"- Pages: {stats['pages']} | Chunks: {stats['chunks']}\n"
            f"- Retriever pool k: {stats['initial_k']} | "
            f"Re-rank: {'ON' if use_rerank else 'OFF'} (top_k={rerank_top_k if use_rerank else '-'})\n"
            f"- Chunk size/overlap: {stats['chunk_size']}/{stats['chunk_overlap']}\n"
            f"- Embeddings: {stats['embed_model']} | LLM: {model_name} (T={temperature})"
        )
        return chain, gr.update(value=summary, visible=True), "Index built successfully."
    except Exception as e:
        return None, gr.update(value="", visible=True), f"❌ Error: {e}"


# Purpose: Runs when you click “Ask”.
# What it does:
# 1.Checks that a chain exists (i.e., you built the index first) and that a question is provided.
# 2.Calls chain.invoke(question).
# 3.Extracts:
# result["answer"] → the final text
# result["sources"] → which chunks were used
# 4.Formats sources with show_sources(...).
# 5.Returns both to the UI.
# Why it matters: This is the live Q&A endpoint.

def ui_ask(chain, question: str):
    if chain is None:
        return "Please build the index first.", ""
    if not question or not question.strip():
        return "Please enter a question.", ""
    try:
        result = chain.invoke(question.strip())
        answer = result["answer"]
        docs = result["sources"]
        return answer.strip(), show_sources(docs)
    except Exception as e:
        return f"❌ Error while answering: {e}", ""

In [None]:
# =========================
# Build Gradio UI
# =========================

# 1.API Key textbox (optional if you set env var).
# 2.Data source:
# File input (type=filepath) to upload a different PDF
# Checkbox to use default Amazon Bedrock - User Guide.pdf
# 3.RAG settings: LLM, temperature, embeddings model, chunk size/overlap, retriever pool (top-N).
# 4.Re-ranking settings: toggle ON/OFF, choose cross-encoder model, set rerank top-K.
# 5.Build Index button → runs ui_build_index(...).
# 6.Question textbox + Ask button → runs ui_ask(...).
# 7.Answer (Markdown) + Sources (Textbox) display.
# Why it matters: Learners can see how changing knobs (pool size, rerank top-K) affects quality.

with gr.Blocks(title="RAG with Re-Ranking — PDF QA") as demo:
    gr.Markdown(
        """
        # 🔎 RAG with Re-Ranking — PDF Question Answering
        1) Provide your **OpenAI API key** (or set `OPENAI_API_KEY` in env).  
        2) Use the default **Amazon Bedrock - User Guide.pdf** (place it in this folder) or upload a PDF.  
        3) Choose retrieval pool size and re-ranking settings, then **Build Index**.  
        4) Ask questions and compare re-rank ON vs OFF for relevance.
        """
    )

    with gr.Row():
        api_key = gr.Textbox(
            label="OpenAI API Key (optional if set in environment)",
            type="password",
            placeholder="sk-...",
        )

    with gr.Accordion("Data source", open=True):
        with gr.Row():
            pdf = gr.File(label="Upload PDF", file_types=[".pdf"], type="filepath")
            use_default = gr.Checkbox(value=True, label=f"Use default: {DEFAULT_LOCAL_PDF}")

    with gr.Accordion("RAG settings", open=True):
        with gr.Row():
            model_name = gr.Dropdown(choices=["gpt-4o-mini"], value="gpt-4o-mini", label="LLM")
            temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Temperature")
        with gr.Row():
            embed_model = gr.Dropdown(
                choices=["text-embedding-3-small", "text-embedding-3-large"],
                value="text-embedding-3-small",
                label="Embeddings model"
            )
            chunk_size = gr.Slider(200, 1200, value=500, step=50, label="Chunk size")
            chunk_overlap = gr.Slider(0, 400, value=80, step=10, label="Chunk overlap")
            pool_k = gr.Slider(5, 50, value=15, step=1, label="Retriever pool size (top-N)")
    with gr.Accordion("Re-ranking settings", open=True):
        with gr.Row():
            use_rerank = gr.Checkbox(value=True, label="Enable Cross-Encoder Re-Ranking")
            rerank_model = gr.Dropdown(
                choices=[
                    "cross-encoder/ms-marco-MiniLM-L-6-v2",
                    "cross-encoder/ms-marco-MiniLM-L-12-v2",
                    "cross-encoder/ms-marco-TinyBERT-L-2-v2"
                ],
                value="cross-encoder/ms-marco-MiniLM-L-6-v2",
                label="Cross-Encoder model"
            )
            rerank_top_k = gr.Slider(1, 15, value=5, step=1, label="Re-rank top-K")

    build_btn = gr.Button("🔧 Build Index", variant="primary")

    index_summary = gr.Markdown(visible=False)
    status = gr.Markdown("")
    state_chain = gr.State()

    gr.Markdown("---")
    question = gr.Textbox(
        label="Ask a question about the PDF",
        lines=2,
        placeholder="e.g., What does Amazon Bedrock provide to developers?"
    )
    ask_btn = gr.Button("💬 Ask")
    answer = gr.Markdown(label="Answer")
    sources = gr.Textbox(label="Sources (retrieved chunks/pages)", lines=6)

    build_btn.click(
        ui_build_index,
        inputs=[
            api_key, pdf, use_default,
            model_name, temperature, embed_model, chunk_size, chunk_overlap,
            pool_k, use_rerank, rerank_model, rerank_top_k
        ],
        outputs=[state_chain, index_summary, status],
        api_name="build_index",
    )

    ask_btn.click(
        ui_ask,
        inputs=[state_chain, question],
        outputs=[answer, sources],
        api_name="ask",
    )

In [None]:
demo.launch()

<u><b>End-to-end flow (quick recap)</u></b>

1. Click Build Index → PDF is loaded, chunked, embedded, indexed; chain is prepared (with or without re-ranking).

2. Type a question → The chain retrieves a broad pool → optionally re-ranks → sends top-K to the LLM → shows answer + sources.


<u><b>What to tweak (for more learning/home work)</u></b>

1.Retriever pool size (top-N): larger means better recall but more noise (and slower re-rank).

2.Re-rank top-K: smaller K = tighter context (more precise), but risk missing something.

3.Chunk size/overlap: bigger chunks carry more context but may dilute relevance; overlap helps preserve sentence continuity.

4.Cross-encoder model: MiniLM-L-6-v2 is fast; larger ones can be more accurate but slower.

5.Temperature: keep at 0.0 for factual answers; increase for more creative wording (not recommended for docs QA).


<u><b>Common pitfalls (and fixes)</u></b>

1.“Default PDF not found” → place Amazon Bedrock - User Guide.pdf next to the script or uncheck “Use default” and upload.

2.“Missing OPENAI_API_KEY” → set env var or paste it into the UI field.

3.Slow re-ranking → reduce pool_k or rerank_top_k, or choose a smaller cross-encoder.

4.GPU not used → the cross-encoder auto-detects CUDA; if none, it uses CPU (slower but works).

5.Large PDFs → consider persisting FAISS to disk and reusing the index; or pre-chunk offline.