# RAG System Implementation with LangChain

This notebook implements a complete Retrieval-Augmented Generation (RAG) system using LangChain.

## Components:
1. **Document Loading & Indexing**: Load RAG-Instruct dataset
2. **Vector Store**: FAISS with Sentence-BERT embeddings
3. **Retrievers**:
   - Dense retrieval (semantic search)
   - Sparse retrieval (BM25)
   - Hybrid retrieval (ensemble)
4. **LLM Integration**: Base and fine-tuned models
5. **RAG Chain**: Complete question-answering pipeline
6. **Evaluation**: Compare different configurations

## 1. Setup and Installation

In [None]:
%pip install -q langchain langchain-community langchain-huggingface
%pip install -q datasets sentence-transformers faiss-cpu rank-bm25
%pip install -q transformers torch peft accelerate
%pip install -q rouge-score bert-score nltk

## 2. Imports


In [None]:
import os
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from typing import List, Dict, Any


# LangChain imports
from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain_community.llms import HuggingFacePipeline
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_classic.chains.retrieval import create_retrieval_chain
from langchain_classic.chains.combine_documents import create_stuff_documents_chain

# HuggingFace and ML imports
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel

# Set random seed
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device configuration
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


## 2b. Custom Ensemble Retriever

Create a custom ensemble retriever that combines multiple retrievers with weighted scores.


In [None]:
from typing import List, Optional, Dict, Any
from pydantic import Field
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from langchain_core.callbacks import CallbackManagerForRetrieverRun

class CustomEnsembleRetriever(BaseRetriever):
    retrievers: List[Any] = Field(description="Retrievers/runnables to ensemble")
    weights: Optional[List[float]] = Field(default=None, description="Weights for each retriever")
    k: int = Field(default=3, description="Number of docs to return")
    rrf_k: int = Field(default=60, description="RRF constant")

    model_config = {"arbitrary_types_allowed": True}

    @staticmethod
    def _doc_key(doc: Document) -> str:
        # chunk-unique identifiers
        if "chunk_uid" in doc.metadata:
            return str(doc.metadata["chunk_uid"])

        parent_id = doc.metadata.get("parent_id")
        chunk_id = doc.metadata.get("chunk_id")
        if parent_id is not None and chunk_id is not None:
            return f"{parent_id}_{chunk_id}"

        # Fall back to any explicit id, else content
        if "id" in doc.metadata:
            return str(doc.metadata["id"])

        return doc.page_content

    def _get_relevant_documents(
        self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None
    ) -> List[Document]:
        weights = self.weights or [1.0] * len(self.retrievers)
        if len(self.retrievers) != len(weights):
            raise ValueError("Number of retrievers must match number of weights")

        doc_scores: Dict[str, float] = {}
        docs_by_key: Dict[str, Document] = {}

        for retriever, w in zip(self.retrievers, weights):
            docs = retriever.invoke(query)

            for rank, doc in enumerate(docs, start=1):
                key = self._doc_key(doc)
                docs_by_key[key] = doc
                doc_scores[key] = doc_scores.get(key, 0.0) + w * (1.0 / (self.rrf_k + rank))

        top = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)[: self.k]
        return [docs_by_key[key] for key, _ in top]

    async def _aget_relevant_documents(
        self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None
    ) -> List[Document]:
        weights = self.weights or [1.0] * len(self.retrievers)
        if len(self.retrievers) != len(weights):
            raise ValueError("Number of retrievers must match number of weights")

        doc_scores: Dict[str, float] = {}
        docs_by_key: Dict[str, Document] = {}

        for retriever, w in zip(self.retrievers, weights):
            docs = await retriever.ainvoke(query)

            for rank, doc in enumerate(docs, start=1):
                key = self._doc_key(doc)
                docs_by_key[key] = doc
                doc_scores[key] = doc_scores.get(key, 0.0) + w * (1.0 / (self.rrf_k + rank))

        top = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)[: self.k]
        return [docs_by_key[key] for key, _ in top]


## 3. Configuration


In [None]:
# Model Configuration
BASE_MODEL_NAME = "unsloth/gemma-3-4b-it"

# Fine-tuned adapter path (from your fine-tuning notebook)
FINETUNED_ADAPTER_PATH = "./rag-instruct-gemma-3-finetuned"

# Dataset configuration
DATASET_NAME = "FreedomIntelligence/RAG-Instruct"
SUBSET_SIZE = 200  # Set to None for full dataset

# Retrieval configuration
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
TOP_K = 3  # Number of documents to retrieve

# Generation configuration
MAX_NEW_TOKENS = 64
TEMPERATURE = 0.7
TOP_P = 0.9

# TEST PARAMETERS

# BASE_MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
# MAX_NEW_TOKENS = 64
# TOP_K = 3
# SUBSET_SIZE = 200  # optional for speed

# ACTUAL RUN PARAMETERS

# BASE_MODEL_NAME = "unsloth/gemma-3-4b-it"
# FINETUNED_ADAPTER_PATH = "./rag-instruct-gemma-3-finetuned"
# MAX_NEW_TOKENS = 128
# TOP_K = 3
# SUBSET_SIZE = 1000


print(f"Base Model: {BASE_MODEL_NAME}")
print(f"Embedding Model: {EMBEDDING_MODEL}")
print(f"Device: {device}")
print(f"Top-K Retrieval: {TOP_K}")


## 4. Load Dataset


In [None]:
# Load RAG-Instruct dataset
print("Loading RAG-Instruct dataset...")
dataset = load_dataset(DATASET_NAME, split="train")

if SUBSET_SIZE:
    dataset = dataset.select(range(min(SUBSET_SIZE, len(dataset))))
    print(f"Working with subset of {len(dataset)} examples")

# Split into train/test using train_test_split
split_dataset = dataset.train_test_split(test_size=0.1, seed=42, shuffle=True)
train_dataset = split_dataset['train']
test_dataset = split_dataset['test']

print(f"Total examples: {len(dataset)}")
print(f"Train examples: {len(train_dataset)}")
print(f"Test examples: {len(test_dataset)}")
print(f"\nExample structure:")
print(f"Question: {dataset[0]['question'][:150]}...")
print(f"Answer: {dataset[0]['answer'][:150]}...")
print(f"Documents: {len(dataset[0]['documents'])} documents")


## 5. Prepare Documents for LangChain

Convert documents to LangChain `Document` format with metadata.


In [None]:
# Extract unique documents from all examples (with provenance + deterministic IDs)
print("Extracting unique documents (dedup + provenance + deterministic IDs)...")

from collections import defaultdict

# Map normalized doc text -> provenance info
doc_to_examples = defaultdict(list)

for ex_idx, example in enumerate(tqdm(dataset)):
    for doc in example.get("documents", []):
        if not doc:
            continue
        text = doc.strip()
        if not text:
            continue
        # Keep provenance: which dataset examples contained this doc
        doc_to_examples[text].append(ex_idx)

# Deterministic ordering (stable IDs across runs)
unique_docs_sorted = sorted(doc_to_examples.keys())

# Convert to LangChain Document objects
langchain_documents = [
    Document(
        page_content=text,
        metadata={
            "id": i,
            "example_indices": doc_to_examples[text],   # provenance
            "n_examples": len(doc_to_examples[text]),   # how often it appears
            "source": DATASET_NAME                      # dataset identifier
        },
    )
    for i, text in enumerate(unique_docs_sorted)
]

# Reporting some statsts
lengths = [len(d.page_content.split()) for d in langchain_documents]
print(f"\nTotal unique documents: {len(langchain_documents)}")
print(f"Average document length: {np.mean(lengths):.1f} words")
print(f"Median document length: {np.median(lengths):.1f} words")
print(f"\nExample document (id={langchain_documents[0].metadata['id']}):")
print(langchain_documents[0].page_content[:300] + "...")
print(f"Provenance: appears in {langchain_documents[0].metadata['n_examples']} examples")


In [None]:
# Chunk documents with LangChain
print("Chunking documents with LangChain RecursiveCharacterTextSplitter...")

from langchain_text_splitters import RecursiveCharacterTextSplitter

# Tunable chunking params:
# - chunk_size: target characters per chunk
# - chunk_overlap: overlapping characters to preserve context across chunks
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=150,
    separators=["\n\n", "\n", ". ", " ", ""],
)

chunked_documents = []
for parent_doc in tqdm(langchain_documents):
    chunks = text_splitter.split_text(parent_doc.page_content)

    for j, chunk in enumerate(chunks):
        chunked_documents.append(
            Document(
                page_content=chunk,
                metadata={
                    **parent_doc.metadata,          # keep provenance + source + id
                    "parent_id": parent_doc.metadata["id"],
                    "chunk_id": j,
                },
            )
        )

# Reporting
chunk_lengths = [len(d.page_content.split()) for d in chunked_documents]
print(f"\nParent docs: {len(langchain_documents)}")
print(f"Chunks: {len(chunked_documents)}")
print(f"Avg chunk length: {np.mean(chunk_lengths):.1f} words")
print(f"Median chunk length: {np.median(chunk_lengths):.1f} words")
print("\nExample chunk metadata:", chunked_documents[0].metadata)
print("Example chunk preview:\n", chunked_documents[0].page_content[:300] + "...")


## 6. Create Embeddings Model

Initialize HuggingFace embeddings for dense retrieval.


In [None]:
# Initialize HuggingFace embeddings
print(f"Loading embedding model: {EMBEDDING_MODEL}")

embeddings = HuggingFaceEmbeddings(
    model_name=EMBEDDING_MODEL,
    model_kwargs={'device': device},
    encode_kwargs={'normalize_embeddings': True}  # For cosine similarity
)

print(f"✓ Embedding model loaded")
print(f"  Embedding dimension: {len(embeddings.embed_query('test'))}")


## 7. Create FAISS Vector Store (Dense Retrieval)

Build vector store for semantic search.


In [None]:
# Create FAISS vector store from CHUNKED documents
print("Creating FAISS vector store from chunked documents...")
print("This may take a few minutes...")

vectorstore = FAISS.from_documents(
    documents=chunked_documents,
    embedding=embeddings
)

print(f"✓ Vector store created with {vectorstore.index.ntotal} chunks")

# Create dense retriever
dense_retriever = vectorstore.as_retriever(
    search_type="similarity",
    search_kwargs={"k": TOP_K}
)

print(f"✓ Dense retriever ready (top-{TOP_K})")

# Quick sanity checkzzz
sample_docs = dense_retriever.invoke("sanity check query")
print("Sample retrieved metadata:", sample_docs[0].metadata)
print("Sample retrieved preview:", sample_docs[0].page_content[:200], "...")


## 8. Create BM25 Retriever (Sparse Retrieval)

Keyword-based retrieval using BM25 algorithm.


In [None]:
# Create BM25 retriever for keyword-based search
print("Creating BM25 retriever...")

bm25_retriever = BM25Retriever.from_documents(
    chunked_documents
)
bm25_retriever.k = TOP_K

print(f"✓ BM25 retriever ready (top-{TOP_K})")


## 9. Create Hybrid Retriever (Ensemble)

Combine dense and sparse retrievers for best results.


In [None]:
# Make each retriever return more candidates for better fusion
dense_retriever.search_kwargs["k"] = TOP_K * 2
bm25_retriever.k = TOP_K * 2

## Update to TOP_K * 4 for real runs

print("Creating hybrid retriever...")

hybrid_retriever = CustomEnsembleRetriever(
    retrievers=[dense_retriever, bm25_retriever],
    weights=[0.5, 0.5],
    k=TOP_K,
    rrf_k=60
)

print("✓ Hybrid retriever ready (ensemble of dense + BM25)")


## 10. Test Retrievers

Test all three retrieval methods with a sample query.


In [None]:
# Test all three retrievers
test_query = test_dataset[0]['question']
print(f"Test Query: {test_query}\n")

print("=" * 80)
print("DENSE RETRIEVAL (Semantic Search)")
print("=" * 80)
dense_docs = dense_retriever.invoke(test_query)
for i, doc in enumerate(dense_docs, 1):
    print(f"\nDoc {i}:")
    print(doc.page_content[:200] + "...")

print("\n" + "=" * 80)
print("SPARSE RETRIEVAL (BM25 Keyword Search)")
print("=" * 80)
sparse_docs = bm25_retriever.invoke(test_query)
for i, doc in enumerate(sparse_docs, 1):
    print(f"\nDoc {i}:")
    print(doc.page_content[:200] + "...")

print("\n" + "=" * 80)
print("HYBRID RETRIEVAL (Ensemble)")
print("=" * 80)
hybrid_docs = hybrid_retriever.invoke(test_query)
for i, doc in enumerate(hybrid_docs[:TOP_K], 1):
    print(f"\nDoc {i}:")
    print(doc.page_content[:200] + "...")


In [None]:
# Sanity check: does it actually return TOP_K docs?
docs = hybrid_retriever.invoke("test query")
print("Returned:", len(docs))

# Sanity check, are chunks unique
keys = []
for d in docs:
    parent_id = d.metadata.get("parent_id")
    chunk_id = d.metadata.get("chunk_id")
    keys.append((parent_id, chunk_id))

print("Unique chunk keys:", len(set(keys)), "out of", len(keys))
print("Example metadata:", docs[0].metadata)


## 11. Load Base Model

Load the base LLM for generation.


In [None]:
# Load base model
print(f"Loading base model: {BASE_MODEL_NAME}")
print("This may take a few minutes...")

base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)

# Ensure pad token exists
if base_tokenizer.pad_token is None:
    base_tokenizer.pad_token = base_tokenizer.eos_token
    base_tokenizer.pad_token_id = base_tokenizer.eos_token_id

# Gemma: left padding is typically expected for batched generation
if "gemma" in BASE_MODEL_NAME.lower():
    base_tokenizer.padding_side = "left"

# Load model with appropriate device settings
if device == "mps":
    print("  Using float32 on MPS for numerical stability")
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_NAME,
        torch_dtype=torch.float32,
        low_cpu_mem_usage=True,
    ).to("mps")

elif device == "cuda":
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_NAME,
        torch_dtype=dtype,
        device_map="auto",
        low_cpu_mem_usage=True,
    )

else:  # CPU
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_NAME,
        torch_dtype=torch.float32,
        low_cpu_mem_usage=True,
    )

base_model.config.pad_token_id = base_tokenizer.pad_token_id
base_model.config.eos_token_id = base_tokenizer.eos_token_id
base_model.eval()

print(f"✓ Base model loaded: {BASE_MODEL_NAME}")


## 12. Create HuggingFace Pipeline for LangChain

Wrap the model in a LangChain-compatible pipeline.


In [None]:
print("Creating text generation pipeline...")

generation_config = {
    "max_new_tokens": MAX_NEW_TOKENS,
    "min_new_tokens": 1,
    "pad_token_id": base_tokenizer.pad_token_id,
    "eos_token_id": base_tokenizer.eos_token_id,
    "return_full_text": False,
    "num_return_sequences": 1,
    "use_cache": True,
    "repetition_penalty": 1.1,
}

if device == "mps":
    print("  Using greedy decoding for MPS stability (no sampling)")
    generation_config.update({
        "do_sample": False,
    })
else:
    generation_config.update({
        "do_sample": True,
        "temperature": TEMPERATURE,
        "top_p": TOP_P,
        "top_k": 50,
    })

text_generation_pipeline = pipeline(
    "text-generation",
    model=base_model,
    tokenizer=base_tokenizer,
    **generation_config
)

base_llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
print("✓ Base LLM ready for LangChain!")


## 13. Create RAG Prompt Template

Define how questions and retrieved documents are formatted.


In [None]:
# Define RAG prompt template
rag_prompt_template = """You are a helpful assistant.
Use ONLY the provided context to answer the question.
If the answer cannot be found in the context, say: "I cannot find the answer in the provided documents."
If the context contains conflicting information, say so and use the most relevant part.

Context:
{context}

Question: {input}

Answer (be concise and, when possible, quote short phrases from the context):"""

RAG_PROMPT = PromptTemplate(
    template=rag_prompt_template,
    input_variables=["context", "input"]
)

print("✓ RAG prompt template created (compatible with create_retrieval_chain)")


## 13b. Helper Function for Modern RAG Chains

Create a helper to build RAG chains using modern LangChain (LCEL).


In [None]:
def create_rag_chain(llm, retriever, prompt_template):
    document_chain = create_stuff_documents_chain(llm, prompt_template)
    retrieval_chain = create_retrieval_chain(retriever, document_chain)
    return retrieval_chain

def invoke_rag_chain(chain, query):
    result = chain.invoke({"input": query})

    answer = (
        result.get("answer")
        or result.get("output_text")
        or result.get("result")
        or ""
    )

    context_docs = result.get("context", [])  # usually List[Document]

    return {
        "result": answer,
        "source_documents": context_docs,
        "raw": result,  # just for debugging
    }


## 14. Create RAG Chains

Create complete RAG pipelines for each retrieval method.


In [None]:
# Create RAG chains for each retrieval method using modern LCEL
print("Creating RAG chains...")

# Dense RAG Chain
dense_rag_chain = create_rag_chain(base_llm, dense_retriever, RAG_PROMPT)
dense_rag_retriever = dense_retriever  # Store for later use

# Sparse RAG Chain
sparse_rag_chain = create_rag_chain(base_llm, bm25_retriever, RAG_PROMPT)
sparse_rag_retriever = bm25_retriever

# Hybrid RAG Chain
hybrid_rag_chain = create_rag_chain(base_llm, hybrid_retriever, RAG_PROMPT)
hybrid_rag_retriever = hybrid_retriever

print("✓ Dense RAG Chain created (LCEL)")
print("✓ Sparse RAG Chain created (LCEL)")
print("✓ Hybrid RAG Chain created (LCEL)")


## 14b. Load Fine-Tuned Model

Load the fine-tuned adapter to compare with base model performance.


In [None]:
!pip -q uninstall -y bitsandbytes
!pip -q install -U bitsandbytes accelerate transformers peft

In [None]:
import bitsandbytes as bnb
print("bitsandbytes:", bnb.__version__)

import torch
print("cuda:", torch.version.cuda, "gpu:", torch.cuda.get_device_name(0))


In [None]:
import os

finetuned_llm = None
use_finetuned = False

if os.path.exists(FINETUNED_ADAPTER_PATH):
    print(f"Loading fine-tuned adapter from: {FINETUNED_ADAPTER_PATH}")

    try:
        ADAPTER_BASE_MODEL = "unsloth/gemma-3-4b-it"
        print(f"Loading base model for adapter: {ADAPTER_BASE_MODEL}")

        finetuned_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_BASE_MODEL)

        if finetuned_tokenizer.pad_token is None:
            finetuned_tokenizer.pad_token = finetuned_tokenizer.eos_token
            finetuned_tokenizer.pad_token_id = finetuned_tokenizer.eos_token_id

        if "gemma" in ADAPTER_BASE_MODEL.lower():
            finetuned_tokenizer.padding_side = "left"

        # Load base model for adapter
        if device == "mps":
            print("  Using float32 on MPS for numerical stability")
            base_for_adapter = AutoModelForCausalLM.from_pretrained(
                ADAPTER_BASE_MODEL,
                torch_dtype=torch.float32,
                low_cpu_mem_usage=True,
            ).to("mps")

        elif device == "cuda":
            dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
            base_for_adapter = AutoModelForCausalLM.from_pretrained(
                ADAPTER_BASE_MODEL,
                torch_dtype=dtype,
                device_map="auto",
                low_cpu_mem_usage=True,
            )
            # from transformers import BitsAndBytesConfig

            # dtype = torch.float16

            # bnb_config = BitsAndBytesConfig(
            #   load_in_4bit=True,
            #   bnb_4bit_quant_type="nf4",
            #   bnb_4bit_use_double_quant=True,
            #   bnb_4bit_compute_dtype=dtype,
            # )
            # base_for_adapter = AutoModelForCausalLM.from_pretrained(
            #     ADAPTER_BASE_MODEL,
            #     quantization_config=bnb_config,
            #     torch_dtype=dtype,
            #     device_map="auto",
            #     low_cpu_mem_usage=True,
            # )
        else:
            base_for_adapter = AutoModelForCausalLM.from_pretrained(
                ADAPTER_BASE_MODEL,
                torch_dtype=torch.float32,
                low_cpu_mem_usage=True,
            )

        print("Loading PEFT adapter...")
        finetuned_model = PeftModel.from_pretrained(base_for_adapter, FINETUNED_ADAPTER_PATH)

        # Align config and eval mode
        finetuned_model.config.pad_token_id = finetuned_tokenizer.pad_token_id
        finetuned_model.config.eos_token_id = finetuned_tokenizer.eos_token_id
        finetuned_model.eval()

        finetuned_gen_config = {
            "max_new_tokens": MAX_NEW_TOKENS,
            "min_new_tokens": 1,
            "pad_token_id": finetuned_tokenizer.pad_token_id,
            "eos_token_id": finetuned_tokenizer.eos_token_id,
            "return_full_text": False,
            "num_return_sequences": 1,
            "use_cache": True,
            "repetition_penalty": 1.1,
        }

        if device == "mps":
            finetuned_gen_config.update({"do_sample": False})
        else:
            finetuned_gen_config.update({
                "do_sample": False, # changed to false
                "temperature": 0.7,
                "top_p": TOP_P,
                "top_k": 50,
                "renormalize_logits": True,
                "repetition_penalty": 1.05,
            })

        finetuned_pipeline = pipeline(
            "text-generation",
            model=finetuned_model,
            tokenizer=finetuned_tokenizer,
            **finetuned_gen_config
        )

        finetuned_llm = HuggingFacePipeline(pipeline=finetuned_pipeline)
        use_finetuned = True
        print("✓ Fine-tuned model loaded successfully!")

    except Exception as e:
        print(f"⚠ Could not load fine-tuned model: {e}")
        print("  Continuing with base model only...")
        finetuned_llm = None
        use_finetuned = False

else:
    print(f"⚠ Fine-tuned adapter not found at: {FINETUNED_ADAPTER_PATH}")
    print("  Continuing with base model only...")

print(f"\nUsing fine-tuned model: {use_finetuned}")


## 14c. Create Fine-Tuned RAG Chains

Create RAG chains using the fine-tuned model for comparison.


In [None]:
# Create RAG chains with fine-tuned model (if available)
if use_finetuned:
    print("Creating RAG chains with fine-tuned model...")

    # Dense RAG Chain with Fine-tuned Model
    finetuned_dense_rag_chain = create_rag_chain(finetuned_llm, dense_retriever, RAG_PROMPT)
    finetuned_dense_rag_retriever = dense_retriever

    # Sparse RAG Chain with Fine-tuned Model
    finetuned_sparse_rag_chain = create_rag_chain(finetuned_llm, bm25_retriever, RAG_PROMPT)
    finetuned_sparse_rag_retriever = bm25_retriever

    # Hybrid RAG Chain with Fine-tuned Model
    finetuned_hybrid_rag_chain = create_rag_chain(finetuned_llm, hybrid_retriever, RAG_PROMPT)
    finetuned_hybrid_rag_retriever = hybrid_retriever

    print("✓ Fine-tuned Dense RAG Chain created (LCEL)")
    print("✓ Fine-tuned Sparse RAG Chain created (LCEL)")
    print("✓ Fine-tuned Hybrid RAG Chain created (LCEL)")
else:
    print("Skipping fine-tuned RAG chains (model not available)")


In [None]:
if use_finetuned:
    out = finetuned_hybrid_rag_chain.invoke({"input": "sanity check question"})
    print(out.keys())

In [None]:
import time

t0 = time.time()
docs = hybrid_retriever.invoke("sanity check question")
print("Retriever time:", time.time() - t0, "seconds")
print("Docs:", len(docs))

In [None]:
import time
t0 = time.time()
resp = finetuned_llm.invoke("Hello, answer briefly.")
print("LLM time:", time.time() - t0, "seconds")
print(resp)

## 15. Test RAG System

Test the complete RAG pipeline with a sample question.


In [None]:
# Test the RAG system
test_question = test_dataset[0]["question"]
print(f"Test Question: {test_question}\n")
print(f"Ground Truth: {test_dataset[0]['answer'][:200]}...\n")

print("=" * 80)
print("TESTING HYBRID RAG CHAIN")
print("=" * 80)

# Use the updated invoke helper 
result = invoke_rag_chain(hybrid_rag_chain, test_question)

print("\nRetrieved Documents:")
for i, doc in enumerate(result["source_documents"], 1):
    print(f"\nDoc {i}:")
    print("Metadata:", doc.metadata)
    print(doc.page_content[:200] + "...")

print("\n" + "=" * 80)
print("GENERATED ANSWER:")
print("=" * 80)
print(result["result"])


In [None]:
if use_finetuned:
    test_question = test_dataset[0]["question"]
    print(f"Test Question: {test_question}\n")
    print(f"Ground Truth: {test_dataset[0]['answer'][:200]}...\n")

    print("=" * 80)
    print("TESTING HYBRID RAG CHAIN")
    print("=" * 80)

    # Use the updated invoke helper (no retriever arg)
    result = invoke_rag_chain(finetuned_hybrid_rag_chain, test_question)

    print("\nRetrieved Documents:")
    for i, doc in enumerate(result["source_documents"], 1):
        print(f"\nDoc {i}:")
        print("Metadata:", doc.metadata)
        print(doc.page_content[:200] + "...")

    print("\n" + "=" * 80)
    print("GENERATED ANSWER:")
    print("=" * 80)
    print(result["result"])


## 16. Evaluate RAG System

Systematic evaluation across all configurations.


In [None]:
NUM_EVAL_EXAMPLES = 5
print(f"Evaluating RAG system on {NUM_EVAL_EXAMPLES} examples...\n")

configs = [
    {"name": "Base - Dense Retrieval", "chain": dense_rag_chain},
    {"name": "Base - Sparse Retrieval (BM25)", "chain": sparse_rag_chain},
    {"name": "Base - Hybrid Retrieval", "chain": hybrid_rag_chain},
]

if use_finetuned:
    configs.extend([
        {"name": "Fine-tuned - Dense Retrieval", "chain": finetuned_dense_rag_chain},
        {"name": "Fine-tuned - Sparse Retrieval (BM25)", "chain": finetuned_sparse_rag_chain},
        {"name": "Fine-tuned - Hybrid Retrieval", "chain": finetuned_hybrid_rag_chain},
    ])
    print(f"Evaluating {len(configs)} configurations (including fine-tuned model)")
else:
    print(f"Evaluating {len(configs)} configurations (base model only)")

results_by_config = {}

for config in configs:
    print(f"\n{'='*80}")
    print(f"Evaluating: {config['name']}")
    print(f"{'='*80}")

    config_results = []

    for i in tqdm(range(min(NUM_EVAL_EXAMPLES, len(test_dataset)))):
        example = test_dataset[i]

        try:
            result = invoke_rag_chain(config["chain"], example["question"])

            config_results.append({
                "question": example["question"],
                "answer": result["result"],
                "ground_truth": example["answer"],
                "retrieved_docs": [doc.page_content for doc in result["source_documents"]],
                "true_documents": example["documents"],
            })
        except Exception as e:
            print(f"\nError on example {i}: {e}")
            continue

    results_by_config[config["name"]] = config_results
    print(f"Completed {config['name']}: {len(config_results)} examples")

print("\n✓ Evaluation complete!")


## 17. Calculate Metrics

Calculate ROUGE and BERTScore for evaluation.


In [None]:
from rouge_score import rouge_scorer
from bert_score import score as bert_score
import nltk
import torch

nltk.download("punkt", quiet=True)

scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)

def calculate_metrics(predictions: List[str], references: List[str]) -> Dict[str, float]:
    metrics = {"rouge1": [], "rouge2": [], "rougeL": []}

    # ROUGE
    for pred, ref in zip(predictions, references):
        scores = scorer.score(ref, pred) 
        metrics["rouge1"].append(scores["rouge1"].fmeasure)
        metrics["rouge2"].append(scores["rouge2"].fmeasure)
        metrics["rougeL"].append(scores["rougeL"].fmeasure)

    # BERTScore heavy shit, think of a substitute
    
    print("  Calculating BERTScore...")
    bs_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

    P, R, F1 = bert_score(
        predictions,
        references,
        lang="en",
        model_type="distilroberta-base",
        device=bs_device,
        verbose=False
    )
    metrics["bertscore"] = F1.detach().cpu().numpy().tolist()

    return {k: float(np.mean(v)) for k, v in metrics.items()}

print("\nCalculating metrics...\n")
metrics_by_config = {}

for config_name, results in results_by_config.items():
    if not results:
        continue

    print(f"Metrics for: {config_name}")

    predictions = [r["answer"] for r in results]
    references  = [r["ground_truth"] for r in results]

    avg_metrics = calculate_metrics(predictions, references)
    metrics_by_config[config_name] = avg_metrics

    print(f"  ROUGE-1:   {avg_metrics['rouge1']:.4f}")
    print(f"  ROUGE-2:   {avg_metrics['rouge2']:.4f}")
    print(f"  ROUGE-L:   {avg_metrics['rougeL']:.4f}")
    print(f"  BERTScore: {avg_metrics['bertscore']:.4f}\n")


## 18. Results Comparison and Visualization

Compare and visualize the performance of different retrieval methods.


In [None]:
import matplotlib.pyplot as plt

# Sort by BERTScore for clearer comparison
comparison_df = comparison_df.sort_values("bertscore", ascending=True)

print("\n" + "="*80)
print("RAG SYSTEM COMPARISON (LangChain)")
print("="*80 + "\n")
print(comparison_df.to_string(index=False))

fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle("RAG System Performance with LangChain", fontsize=16, fontweight="bold")

metrics_to_plot = ["rouge1", "rouge2", "rougeL", "bertscore"]
metric_names = ["ROUGE-1", "ROUGE-2", "ROUGE-L", "BERTScore"]

for idx, (metric, name) in enumerate(zip(metrics_to_plot, metric_names)):
    if metric not in comparison_df:
        continue

    ax = axes[idx // 2, idx % 2]
    values = comparison_df[metric].values
    configs = comparison_df["Configuration"].values

    bars = ax.barh(configs, values, alpha=0.7)
    ax.set_xlabel(name)
    ax.set_title(f"{name} Scores", fontweight="bold")
    ax.set_xlim(0, 1.0)
    ax.grid(axis="x", alpha=0.3)

    for bar in bars:
        width = bar.get_width()
        ax.text(width + 0.01, bar.get_y() + bar.get_height() / 2,
                f"{width:.3f}", va="center")

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.savefig("rag_langchain_comparison.png", dpi=300, bbox_inches="tight")
plt.show()


## 19. Save Results

Save evaluation results for later analysis.


In [None]:
import pickle

# Save results
with open('rag_langchain_results.pkl', 'wb') as f:
    pickle.dump({
        'results_by_config': results_by_config,
        'metrics_by_config': metrics_by_config,
        'comparison_df': comparison_df
    }, f)

comparison_df.to_csv('rag_langchain_comparison.csv', index=False)

print("✓ Results saved:")
print("  - rag_langchain_results.pkl (full results)")
print("  - rag_langchain_comparison.csv (metrics table)")
print("  - rag_langchain_comparison.png (visualization)")


## 20. Interactive Demo

Try the RAG system with your own questions!


In [None]:
def ask_question(question: str, retrieval_type="hybrid", use_finetuned_model=False):
    if use_finetuned_model and use_finetuned:
        chain_map = {
            "dense": finetuned_dense_rag_chain,
            "sparse": finetuned_sparse_rag_chain,
            "hybrid": finetuned_hybrid_rag_chain
        }
        model_name = "Fine-tuned Model"
    else:
        chain_map = {
            "dense": dense_rag_chain,
            "sparse": sparse_rag_chain,
            "hybrid": hybrid_rag_chain
        }
        model_name = "Base Model"

    chain = chain_map.get(retrieval_type, hybrid_rag_chain)

    print(f"\n{'='*80}")
    print(f"Question: {question}")
    print(f"Model: {model_name}")
    print(f"Retrieval: {retrieval_type.upper()}")
    print(f"{'='*80}\n")

    # ✅ Use the helper
    result = invoke_rag_chain(chain, question)

    print("Retrieved Documents:")
    for i, doc in enumerate(result["source_documents"], 1):
        print(f"\n  {i}. {doc.page_content[:150]}...")

    print(f"\n{'='*80}")
    print("ANSWER:")
    print(f"{'='*80}")
    print(result["result"])

    return result


# Example usage:
ask_question("Who is Lebron James?", retrieval_type="hybrid")


### References:

- [LangChain Documentation](https://python.langchain.com/)
- [FAISS](https://github.com/facebookresearch/faiss)
- [Sentence-Transformers](https://www.sbert.net/)
- [BM25 Algorithm](https://en.wikipedia.org/wiki/Okapi_BM25)
