## Enhanced RAG with Reranking

This notebook demonstrates how to apply reranking methods to improve retrieval accuracy in RAG systems.

**Pipeline Overview:**
1. Document Loader 
2. Text Splitter
3. Embedding Model
4. Vector Store
5. **Two-Stage Retrieval**:
   - Initial Retrieval (Dense/Sparse)
   - **Reranking with Cross-Encoder**
6. Prompt Template
7. LLM 
8. Enhanced Chain with Reranking
9. Evaluation with Reranking Metrics

In [17]:
import os 
from dotenv import load_dotenv
import time
import numpy as np
from typing import List, Tuple

# Load environment variables
load_dotenv("/Users/reejungkim/Documents/Git/working-in-progress/.env")

True

# Install required packages if not already installed
!pip install sentence-transformers faiss-cpu langchain-community langchain-groq -qㅡ

### 1. Document Loader

In [3]:
from langchain_community.document_loaders import PyPDFLoader

# Load your PDF file
file_path = "/Users/reejungkim/Documents/Git/AI-Agent/Amazon-2024-Annual-Report.pdf"
loader = PyPDFLoader(file_path)

docs = loader.load()
print(f"Loaded {len(docs)} pages from PDF")

# Show sample content
if docs:
    print(f"\nSample content from page 1:")
    print(docs[0].page_content[:100] + "...")

Loaded 91 pages from PDF

Sample content from page 1:
ANNUAL REPORT
2 0 2 4...


### 2. Text Splitter

In [4]:
from langchain_text_splitters import RecursiveCharacterTextSplitter

# Text splitter configuration
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000, 
    chunk_overlap=200
)

texts = text_splitter.split_documents(docs)
print(f"Split into {len(texts)} chunks")

# Show sample chunk
if texts:
    print(f"\nSample chunk:")
    print(texts[0].page_content[:300] + "...")

Split into 425 chunks

Sample chunk:
ANNUAL REPORT
2 0 2 4...


### 3. Embedding Model

Using the user's preferred embedding model: `all-MiniLM-L6-v2` for speed.

In [5]:
from langchain_huggingface import HuggingFaceEmbeddings

# Initialize embedding model (user's preferred model for speed)
embedding_model = HuggingFaceEmbeddings(
    model_name='sentence-transformers/all-MiniLM-L6-v2',  # User's preference
    model_kwargs={'device': 'cpu'},
    encode_kwargs={'normalize_embeddings': True}
)

print("Embedding model loaded successfully")

Embedding model loaded successfully


### 4. Vector Store

In [6]:
from langchain_community.vectorstores import FAISS

# Create vector store
print("Creating vector store...")
start_time = time.time()

vectorstore = FAISS.from_documents(texts, embedding_model)

creation_time = time.time() - start_time
print(f"Vector store created in {creation_time:.2f} seconds")
print(f"Indexed {vectorstore.index.ntotal} vectors")

Creating vector store...
Vector store created in 31.16 seconds
Indexed 425 vectors


### 5. Reranking Setup

This is the key enhancement - adding a reranking layer using cross-encoder models.

In [7]:
# Import the reranking module we created
from rerank_module import (
    DocumentReranker, 
    RerankConfig, 
    RetrievalWithReranking,
    RERANKER_MODELS
)

# Display available reranker models
print("Available Reranker Models:")
for name, info in RERANKER_MODELS.items():
    print(f"- {name}: {info['description']} (Speed: {info['speed']}, Accuracy: {info['accuracy']})")

Available Reranker Models:
- ms-marco-miniLM-L6-v2: Fast and balanced model for general reranking (Speed: fast, Accuracy: good)
- ms-marco-miniLM-L12-v2: Better accuracy with moderate speed (Speed: medium, Accuracy: very good)
- ms-marco-TinyBERT-L6: Fastest model with decent accuracy (Speed: very fast, Accuracy: good)
- qnli-electra-base: Specialized for question-answering tasks (Speed: medium, Accuracy: very good)


In [8]:
# Initialize reranker with a fast model (user preference)
rerank_config = RerankConfig(
    model_name="cross-encoder/ms-marco-MiniLM-L6-v2",  # Fast and balanced
    threshold=0.0,  # Include all documents above this score
    max_length=512,
    batch_size=32
)

print("Loading reranker model...")
start_time = time.time()

reranker = DocumentReranker(rerank_config)

if reranker.is_available():
    load_time = time.time() - start_time
    print(f"Reranker loaded successfully in {load_time:.2f} seconds")
else:
    print("Reranker not available. Please install sentence-transformers.")

Loading reranker model...


config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

Reranker loaded successfully in 23.08 seconds


### 6. Compare Regular Retrieval vs Reranked Retrieval

In [9]:
# Test query
test_query = "How much cash was used in operating activities in 2024?"

print(f"Test Query: {test_query}\n")

# Regular retrieval
print("=== REGULAR RETRIEVAL ===")
start_time = time.time()
regular_docs = vectorstore.similarity_search(test_query, k=5)
regular_time = time.time() - start_time

print(f"Retrieved {len(regular_docs)} documents in {regular_time:.3f}s")
for i, doc in enumerate(regular_docs[:3]):
    print(f"\n{i+1}. {doc.page_content[:200]}...")

print("\n" + "="*50 + "\n")

# Reranked retrieval
print("=== RERANKED RETRIEVAL ===")
if reranker.is_available():
    start_time = time.time()
    
    # Get more documents initially for reranking
    initial_docs = vectorstore.similarity_search(test_query, k=10)
    
    # Rerank and get top 5
    reranked_docs = reranker.rerank_documents_simple(
        test_query, 
        initial_docs, 
        top_k=5
    )
    
    rerank_time = time.time() - start_time
    
    print(f"Reranked {len(initial_docs)} documents, returned {len(reranked_docs)} in {rerank_time:.3f}s")
    for i, doc in enumerate(reranked_docs[:3]):
        score = doc.metadata.get('rerank_score', 'N/A')
        position = doc.metadata.get('rerank_position', 'N/A')
        orig_pos = doc.metadata.get('original_position', 'N/A')
        print(f"\n{i+1}. [Score: {score:.3f}, Rank: {position}, Orig: {orig_pos}]")
        print(f"   {doc.page_content[:200]}...")
else:
    print("Reranker not available")

Test Query: How much cash was used in operating activities in 2024?

=== REGULAR RETRIEVAL ===
Retrieved 5 documents in 0.076s

1. currency balances include British Pounds, Canadian Dollars, Euros, Indian Rupees, and Japanese Yen. 
Cash provided by (used in) operating activities was $84.9 billion and $115.9 billion in 2023 and 20...

2. Consolidated Statements of Cash Flows Reconciliation
The following table provides a reconciliation of the amount of cash, cash equivalents, and restricted cash reported within 
the consolidated balanc...

3. adequacy of our tax accruals. Although we believe our tax estimates are reasonable, the final outcome of audits, investigations, 
and any other tax controversies could be materially different from our...


=== RERANKED RETRIEVAL ===
Reranked 10 documents, returned 5 in 0.858s

1. [Score: 9.337, Rank: 1, Orig: 1]
   currency balances include British Pounds, Canadian Dollars, Euros, Indian Rupees, and Japanese Yen. 
Cash provided by (used in) operatin

### 7. LLM Setup

Using the user's preferred LLM model: `llama3-8b-8192` for speed.

In [None]:
from langchain_groq import ChatGroq

# Initialize LLM (user's preferred model for speed)
llm = ChatGroq(
    model="llama-3.1-8b-instant",  # Choose a model supported by Groq. llama3-8b-8192 Groq에서 가장 빠름 -> llama-3.1-8b-instant 모델로 교체됨 
    temperature=0,  # Deterministic responses
    max_tokens=512,
    groq_api_key=os.environ["groq_api"]
)

print("LLM initialized successfully")

# Test LLM
test_response = llm.invoke("Hello, how are you?")
print(f"LLM test response: {test_response.content[:100]}...") 

LLM initialized successfully
LLM test response: I'm functioning properly, thank you for asking. I'm a large language model, so I don't have emotions...


### 8. Enhanced QA Chain with Reranking

In [25]:
from langchain_core.retrievers import BaseRetriever 
from langchain_core.documents import Document
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from typing import List


class RetrievalWithReranking(BaseRetriever):
    """Custom retriever that combines vector search with reranking."""
    
    vectorstore: any  # Your vectorstore
    reranker: any     # Your reranker
    k_initial: int = 10
    k_final: int = 4
    
    class Config:
        arbitrary_types_allowed = True
    
    def _get_relevant_documents(
        self, 
        query: str, 
        *, 
        run_manager: CallbackManagerForRetrieverRun = None
    ) -> List[Document]:
        """Retrieve and rerank documents."""
        # Step 1: Initial retrieval
        initial_docs = self.vectorstore.similarity_search(
            query, 
            k=self.k_initial
        )
        
        # Step 2: Rerank
        if not initial_docs:
            return []
        
        # Prepare documents for reranking
        doc_texts = [doc.page_content for doc in initial_docs]
        
        # Get reranking scores
        rerank_results = self.reranker.rerank(
            query=query,
            documents=doc_texts,
            top_n=self.k_final
        )
        
        # Reorder documents based on reranking scores
        reranked_docs = []
        for result in rerank_results:
            reranked_docs.append(initial_docs[result['index']])
        
        return reranked_docs


  warn(


In [26]:
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA

# Enhanced prompt template
template = """Use the following context to answer the question accurately and concisely.
If the answer is not in the context, say "I cannot find this information in the provided context."

Context: {context}

Question: {question}

Answer:"""

prompt = PromptTemplate.from_template(template)

# Create enhanced retriever with reranking
enhanced_retriever = RetrievalWithReranking(
    vectorstore=vectorstore,
    reranker=reranker,
    k_initial=10,  # Retrieve more documents initially
    k_final=4      # Use top 4 after reranking
)

# Create QA chain with reranking
qa_chain_reranked = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=enhanced_retriever,
    chain_type_kwargs={"prompt": prompt},
    return_source_documents=True
)

# Also create a regular chain for comparison
regular_retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
qa_chain_regular = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=regular_retriever,
    chain_type_kwargs={"prompt": prompt},
    return_source_documents=True
)

print("QA chains created successfully")

QA chains created successfully


### 9. Comparison: Regular vs Reranked RAG

In [27]:
# Test questions
test_questions = [
    "How much cash was used in operating activities in 2024?",
    "Who is Brad D. Smith and what is his role?",
    "What is Amazon's revenue for 2024?",
    "What are the main business segments mentioned?"
]

def compare_rag_systems(question):
    print(f"\n{'='*60}")
    print(f"QUESTION: {question}")
    print(f"{'='*60}")
    
    # Regular RAG
    print("\n🔍 REGULAR RAG:")
    start_time = time.time()
    regular_result = qa_chain_regular({"query": question})
    regular_time = time.time() - start_time
    
    print(f"Answer: {regular_result['result']}")
    print(f"Time: {regular_time:.3f}s")
    print(f"Sources: {len(regular_result['source_documents'])} documents")
    
    # Reranked RAG
    print("\n🔄 RERANKED RAG:")
    start_time = time.time()
    reranked_result = qa_chain_reranked({"query": question})
    reranked_time = time.time() - start_time
    
    print(f"Answer: {reranked_result['result']}")
    print(f"Time: {reranked_time:.3f}s")
    print(f"Sources: {len(reranked_result['source_documents'])} documents")
    
    # Show reranking scores if available
    if reranked_result['source_documents']:
        print("\nReranking Details:")
        for i, doc in enumerate(reranked_result['source_documents']):
            score = doc.metadata.get('rerank_score', 'N/A')
            orig_pos = doc.metadata.get('original_position', 'N/A')
            print(f"  {i+1}. Score: {score:.3f}, Original position: {orig_pos}")
    
    return regular_result, reranked_result

# Run comparisons
for question in test_questions[:2]:  # Test first 2 questions
    compare_rag_systems(question)


QUESTION: How much cash was used in operating activities in 2024?

🔍 REGULAR RAG:


  regular_result = qa_chain_regular({"query": question})


Answer: $115,877 million.
Time: 0.605s
Sources: 4 documents

🔄 RERANKED RAG:


AttributeError: 'DocumentReranker' object has no attribute 'rerank'

### 10. Interactive RAG with Reranking

In [None]:
def interactive_rag_with_reranking():
    """Interactive RAG system with reranking."""
    print("🔄 Interactive RAG with Reranking")
    print("Type 'quit' to exit\n")
    
    while True:
        question = input("❓ Your question: ").strip()
        
        if question.lower() in ['quit', 'exit', 'q']:
            print("Goodbye!")
            break
        
        if not question:
            continue
        
        print("\n🤖 Processing...")
        start_time = time.time()
        
        try:
            result = qa_chain_reranked({"query": question})
            response_time = time.time() - start_time
            
            print(f"\n✅ Answer: {result['result']}")
            print(f"⏱️  Response time: {response_time:.2f}s")
            
            # Show source information
            if result['source_documents']:
                print(f"\n📚 Sources ({len(result['source_documents'])} documents):")
                for i, doc in enumerate(result['source_documents']):
                    score = doc.metadata.get('rerank_score', 'N/A')
                    page = doc.metadata.get('page', 'N/A')
                    print(f"  {i+1}. Page {page}, Rerank Score: {score:.3f}")
                    print(f"     {doc.page_content[:100]}...")
        
        except Exception as e:
            print(f"❌ Error: {e}")
        
        print("\n" + "-"*50)

# Uncomment to run interactive mode
# interactive_rag_with_reranking()

### 11. Performance Analysis

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

def analyze_reranking_performance(questions):
    """Analyze performance differences between regular and reranked RAG."""
    results = []
    
    for question in questions:
        print(f"Testing: {question[:50]}...")
        
        # Regular RAG
        start_time = time.time()
        regular_result = qa_chain_regular({"query": question})
        regular_time = time.time() - start_time
        
        # Reranked RAG
        start_time = time.time()
        reranked_result = qa_chain_reranked({"query": question})
        reranked_time = time.time() - start_time
        
        results.append({
            'question': question[:30] + '...',
            'regular_time': regular_time,
            'reranked_time': reranked_time,
            'time_overhead': reranked_time - regular_time,
            'regular_answer_length': len(regular_result['result']),
            'reranked_answer_length': len(reranked_result['result'])
        })
    
    return pd.DataFrame(results)

# Analyze performance
if len(test_questions) > 0:
    print("Analyzing performance...")
    perf_df = analyze_reranking_performance(test_questions)
    
    print("\n📊 Performance Analysis:")
    print(perf_df[['question', 'regular_time', 'reranked_time', 'time_overhead']].round(3))
    
    # Summary statistics
    avg_regular_time = perf_df['regular_time'].mean()
    avg_reranked_time = perf_df['reranked_time'].mean()
    avg_overhead = perf_df['time_overhead'].mean()
    
    print(f"\n📈 Summary:")
    print(f"   Average Regular Time: {avg_regular_time:.3f}s")
    print(f"   Average Reranked Time: {avg_reranked_time:.3f}s")
    print(f"   Average Overhead: {avg_overhead:.3f}s ({(avg_overhead/avg_regular_time)*100:.1f}%)")

### 12. Summary and Best Practices

#### What We Implemented:
1. **Two-Stage Retrieval**: Initial retrieval + Cross-encoder reranking
2. **Multiple Reranker Models**: From fast TinyBERT to accurate MiniLM-L12
3. **Performance Optimization**: Batch processing and caching
4. **Comprehensive Comparison**: Regular vs Reranked RAG

#### Key Benefits of Reranking:
- 🎯 **Better Accuracy**: Cross-encoders understand query-document relationships better
- 🔍 **Context Awareness**: Models can evaluate semantic relevance more precisely
- ⚡ **Efficiency**: Only rerank top candidates, not entire corpus
- 📊 **Measurable Improvement**: Clear ranking scores for transparency

#### Best Practices:
1. **Model Selection**: Balance speed vs accuracy based on your needs
2. **Threshold Tuning**: Filter low-relevance documents
3. **Batch Size**: Optimize for your hardware
4. **Initial Retrieval**: Retrieve more documents (10-20) for better reranking candidates
5. **Final Selection**: Keep 3-5 top documents for LLM processing

### 13. Export Enhanced RAG Function

In [None]:
def enhanced_rag_with_reranking(query, verbose=True):
    """Enhanced RAG function with reranking."""
    start_time = time.time()
    
    try:
        result = qa_chain_reranked({"query": query})
        response_time = time.time() - start_time
        
        if verbose:
            print(f"Question: {query}")
            print(f"Answer: {result['result']}")
            print(f"Response time: {response_time:.3f}s")
            
            if result['source_documents']:
                print(f"\nTop sources:")
                for i, doc in enumerate(result['source_documents'][:2]):
                    score = doc.metadata.get('rerank_score', 'N/A')
                    print(f"  {i+1}. Score: {score:.3f} - {doc.page_content[:100]}...")
        
        return result
    
    except Exception as e:
        print(f"Error: {e}")
        return None

# Test the enhanced function
print("Testing enhanced RAG function:")
enhanced_rag_with_reranking("What was Amazon's operating cash flow in 2024?")