# Part 3: Advanced Retrieval (Multi-Query + RAG-Fusion)

## Learning Objectives

By the end of this notebook, you will:
1. Understand limitations of single-query retrieval
2. Generate multiple query perspectives using LLMs
3. Implement multi-query retrieval for improved coverage
4. Apply Reciprocal Rank Fusion (RRF) to combine results
5. Build a RAG-Fusion pipeline
6. Compare retrieval strategies for security use cases

## The Problem with Basic Retrieval

In Part 2, we built a basic RAG system that retrieves documents based on similarity to a single query. However, this has limitations:

### Example: "How do attackers steal ML models?"

**Single query retrieval might miss:**
- Documents using different terminology ("model extraction" vs "model theft")
- Related concepts ("knowledge distillation", "API exploitation")
- Different perspectives (attacker techniques vs defender detection)
- Broader context (intellectual property protection)

### Solution: Multi-Query Retrieval

Instead of one query, generate multiple related queries:
1. "What are model extraction techniques?"
2. "How to detect model theft attempts?"
3. "What is knowledge distillation in model stealing?"
4. "How do attackers replicate proprietary models?"
5. "What defenses prevent unauthorized model access?"

Each query retrieves different relevant documents, improving overall coverage.

## Approaches We'll Implement

1. **Multi-Query Retrieval**: Generate query variations, retrieve for each, merge unique results
2. **RAG-Fusion**: Generate related queries, retrieve for each, apply RRF to rank results

---
## 1. Environment Setup

Let's start by loading our existing vector store from Part 2.

In [None]:
# Import required libraries
import os
from dotenv import load_dotenv
from typing import List
import numpy as np
from collections import defaultdict

# LangChain imports
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.vectorstores import Chroma
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.schema import Document
from langchain.load import dumps, loads

# Load environment variables
load_dotenv()

if not os.getenv("OPENAI_API_KEY"):
    print("⚠️  WARNING: OPENAI_API_KEY not found")
else:
    print("✅ OpenAI API key loaded")

In [None]:
# Initialize embeddings and LLM
embeddings = OpenAIEmbeddings(
    model="text-embedding-3-small",
    openai_api_key=os.getenv("OPENAI_API_KEY")
)

llm = ChatOpenAI(
    model="gpt-4",
    temperature=0,
    openai_api_key=os.getenv("OPENAI_API_KEY")
)

print("✅ Embeddings and LLM initialized")

In [None]:
# Load existing vector store from Part 2
vectorstore = Chroma(
    collection_name="owasp_llm_top10",
    embedding_function=embeddings,
    persist_directory="../data/chroma_db"
)

# Create basic retriever
retriever = vectorstore.as_retriever(
    search_type="similarity",
    search_kwargs={"k": 3}
)

print("✅ Vector store loaded")
print(f"   Collection: {vectorstore._collection.count()} documents")

---
## 2. Multi-Query Retrieval

### Concept

Multi-query retrieval generates multiple perspectives of the same question and retrieves documents for each. This helps capture:
- **Synonyms**: Different ways to express the same concept
- **Perspectives**: Attacker vs defender viewpoints
- **Granularity**: High-level vs technical details
- **Related Concepts**: Adjacent security topics

### Implementation Steps

1. **Generate Query Variations**: Use LLM to create 5 related queries
2. **Retrieve for Each**: Get top-k documents for each variation
3. **Merge Unique**: Combine results, removing duplicates
4. **Return Top Results**: Return most relevant unique documents

In [None]:
# Prompt template for generating query variations
multi_query_template = """You are an AI assistant specialized in cybersecurity and LLM security.
Your task is to generate 5 different versions of the given security question to help retrieve relevant documents from a security knowledge base.

Generate variations that:
1. Use different terminology (synonyms, technical terms, common phrases)
2. Explore different perspectives (attacker view, defender view, analyst view)
3. Vary the level of detail (high-level, technical, specific)
4. Include related concepts and adjacent topics
5. Cover different aspects of the security concern

Original question: {question}

Provide 5 alternative versions of this question, one per line:"""

multi_query_prompt = ChatPromptTemplate.from_template(multi_query_template)

# Create chain to generate queries
generate_queries_chain = (
    multi_query_prompt
    | llm
    | StrOutputParser()
    | (lambda x: x.split("\n"))
)

print("✅ Multi-query generator created")

In [None]:
# Test query generation
test_question = "How do attackers steal ML models?"
print(f"🔍 Original Question: {test_question}")
print("\n📝 Generated Query Variations:")
print("=" * 80)

generated_queries = generate_queries_chain.invoke({"question": test_question})
for i, query in enumerate(generated_queries, 1):
    if query.strip():
        print(f"{i}. {query.strip()}")

In [None]:
def multi_query_retrieval(question: str, retriever, k: int = 3) -> List[Document]:
    """
    Retrieve documents using multiple query variations.
    
    Args:
        question: Original question
        retriever: Vector store retriever
        k: Number of documents to retrieve per query
        
    Returns:
        List of unique documents
    """
    # Generate query variations
    queries = generate_queries_chain.invoke({"question": question})
    queries = [q.strip() for q in queries if q.strip()]
    
    # Add original question
    all_queries = [question] + queries
    
    print(f"\n🔄 Retrieving for {len(all_queries)} query variations...")
    
    # Retrieve documents for each query
    all_docs = []
    seen_content = set()
    
    for i, query in enumerate(all_queries):
        print(f"   Query {i+1}: {query[:60]}...")
        docs = retriever.get_relevant_documents(query)
        
        # Add unique documents only
        for doc in docs:
            content_hash = hash(doc.page_content)
            if content_hash not in seen_content:
                seen_content.add(content_hash)
                all_docs.append(doc)
    
    print(f"\n✅ Retrieved {len(all_docs)} unique documents (from {len(all_queries) * k} total)")
    
    return all_docs[:k * 2]  # Return top 2k unique documents

print("✅ Multi-query retrieval function created")

In [None]:
# Test multi-query retrieval
print("=" * 80)
print("🧪 Testing Multi-Query Retrieval")
print("=" * 80)

question = "How do attackers steal ML models?"
print(f"\n❓ Question: {question}\n")

# Compare basic retrieval vs multi-query
print("\n1️⃣  BASIC RETRIEVAL (Single Query):")
print("-" * 80)
basic_docs = retriever.get_relevant_documents(question)
print(f"Retrieved {len(basic_docs)} documents:\n")
for i, doc in enumerate(basic_docs, 1):
    print(f"{i}. {doc.metadata['id']}: {doc.metadata['title']}")
    print(f"   Preview: {doc.page_content[:100]}...\n")

print("\n2️⃣  MULTI-QUERY RETRIEVAL (Multiple Perspectives):")
print("-" * 80)
multi_docs = multi_query_retrieval(question, retriever)
print(f"\nRetrieved {len(multi_docs)} unique documents:\n")
for i, doc in enumerate(multi_docs, 1):
    print(f"{i}. {doc.metadata['id']}: {doc.metadata['title']}")
    print(f"   Preview: {doc.page_content[:100]}...\n")

---
## 3. RAG-Fusion with Reciprocal Rank Fusion (RRF)

### Concept

RAG-Fusion improves on multi-query by intelligently ranking combined results using **Reciprocal Rank Fusion (RRF)**.

### Reciprocal Rank Fusion (RRF)

RRF is a simple but effective algorithm for combining ranked lists:

```
RRF_score(doc) = Σ [ 1 / (k + rank(doc, query_i)) ]
```

Where:
- `rank(doc, query_i)` = position of document in results for query i (1-indexed)
- `k` = constant (typically 60) to prevent division by very small numbers
- Higher score = more relevant across multiple queries

### Example

**Document A**:
- Rank 1 in Query 1: 1/(60+1) = 0.0164
- Rank 3 in Query 2: 1/(60+3) = 0.0159
- Not in Query 3: 0
- **Total RRF: 0.0323**

**Document B**:
- Rank 2 in Query 1: 1/(60+2) = 0.0161
- Rank 2 in Query 2: 1/(60+2) = 0.0161
- Rank 5 in Query 3: 1/(60+5) = 0.0154
- **Total RRF: 0.0476** ← Higher score, more consistently relevant

### Why RRF Works

- **Rank-based**: Uses position, not raw similarity scores (more robust)
- **Multiple evidence**: Documents appearing in multiple queries get boosted
- **Balanced**: No single query dominates the ranking
- **Simple**: No parameters to tune (except k)

In [None]:
def reciprocal_rank_fusion(results: List[List[Document]], k: int = 60) -> List[tuple]:
    """
    Apply Reciprocal Rank Fusion to combine multiple ranked lists.
    
    Args:
        results: List of ranked document lists (one per query)
        k: Constant for RRF formula (default: 60)
        
    Returns:
        List of (document, rrf_score) tuples, sorted by score
    """
    # Dictionary to accumulate RRF scores
    rrf_scores = defaultdict(float)
    
    # For each query's results
    for doc_list in results:
        # For each document in the ranked list
        for rank, doc in enumerate(doc_list, 1):
            # Use content hash as unique identifier
            doc_id = hash(doc.page_content)
            
            # Add RRF score: 1 / (k + rank)
            rrf_scores[doc_id] += 1.0 / (k + rank)
    
    # Create list of (document, score) tuples
    # Need to map back from hash to actual document
    doc_map = {}
    for doc_list in results:
        for doc in doc_list:
            doc_id = hash(doc.page_content)
            if doc_id not in doc_map:
                doc_map[doc_id] = doc
    
    # Create scored document list
    scored_docs = [(doc_map[doc_id], score) for doc_id, score in rrf_scores.items()]
    
    # Sort by RRF score (descending)
    scored_docs.sort(key=lambda x: x[1], reverse=True)
    
    return scored_docs

print("✅ RRF function created")

In [None]:
def rag_fusion_retrieval(question: str, retriever, k: int = 3) -> List[tuple]:
    """
    RAG-Fusion: Generate multiple queries, retrieve, and apply RRF.
    
    Args:
        question: Original question
        retriever: Vector store retriever
        k: Number of documents to retrieve per query
        
    Returns:
        List of (document, rrf_score) tuples
    """
    # Generate query variations
    queries = generate_queries_chain.invoke({"question": question})
    queries = [q.strip() for q in queries if q.strip()]
    
    # Add original question
    all_queries = [question] + queries
    
    print(f"\n🔄 RAG-Fusion: Processing {len(all_queries)} queries...")
    
    # Retrieve documents for each query
    all_results = []
    for i, query in enumerate(all_queries):
        print(f"   Query {i+1}: {query[:60]}...")
        docs = retriever.get_relevant_documents(query)
        all_results.append(docs)
    
    # Apply Reciprocal Rank Fusion
    print(f"\n🔀 Applying Reciprocal Rank Fusion...")
    fused_results = reciprocal_rank_fusion(all_results)
    
    print(f"✅ RAG-Fusion complete: {len(fused_results)} unique documents ranked\n")
    
    return fused_results

print("✅ RAG-Fusion function created")

In [None]:
# Test RAG-Fusion
print("=" * 80)
print("🧪 Testing RAG-Fusion")
print("=" * 80)

question = "How do attackers steal ML models?"
print(f"\n❓ Question: {question}\n")

# RAG-Fusion retrieval
fusion_results = rag_fusion_retrieval(question, retriever)

print("📊 RAG-Fusion Results (Ranked by RRF Score):\n")
for i, (doc, score) in enumerate(fusion_results[:5], 1):
    print(f"{i}. {doc.metadata['id']}: {doc.metadata['title']}")
    print(f"   RRF Score: {score:.4f}")
    print(f"   Preview: {doc.page_content[:100]}...\n")

---
## 4. Building Complete RAG Chains

Now let's create complete RAG chains using each retrieval method.

In [None]:
# Prompt template for answer generation
rag_template = """You are an AI security expert assistant helping users understand LLM vulnerabilities and security best practices.

Use the following security documentation context to answer the question. Be specific, accurate, and cite relevant details from the context.

Guidelines:
1. If the answer is in the context, provide a comprehensive explanation with examples.
2. Always mention which OWASP LLM vulnerability (LLM01-LLM10) is relevant.
3. Include prevention measures and best practices when applicable.
4. If the question cannot be answered from the context, say "I don't have enough information in the documentation to answer that question accurately."
5. Be concise but thorough - aim for 3-5 paragraphs.

Context:
{context}

Question: {question}

Answer:"""

rag_prompt = ChatPromptTemplate.from_template(rag_template)

print("✅ RAG prompt template created")

In [None]:
# Helper function to format documents
def format_docs(docs):
    """Format documents for context."""
    return "\n\n".join([f"Document {i+1} ({doc.metadata['id']} - {doc.metadata['title']}):\n{doc.page_content}" 
                        for i, doc in enumerate(docs)])

# Helper function to format fusion results
def format_fusion_docs(fusion_results):
    """Format fusion results for context."""
    docs = [doc for doc, score in fusion_results]
    return format_docs(docs)

print("✅ Formatting helpers created")

In [None]:
# 1. Basic RAG Chain (from Part 2)
basic_rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | rag_prompt
    | llm
    | StrOutputParser()
)

print("✅ Basic RAG chain created")

In [None]:
# 2. Multi-Query RAG Chain
def multi_query_rag_chain_func(question: str) -> str:
    """RAG chain using multi-query retrieval."""
    docs = multi_query_retrieval(question, retriever)
    context = format_docs(docs[:5])  # Use top 5 documents
    
    prompt_value = rag_prompt.invoke({"context": context, "question": question})
    response = llm.invoke(prompt_value)
    return response.content

print("✅ Multi-Query RAG chain created")

In [None]:
# 3. RAG-Fusion Chain
def rag_fusion_chain_func(question: str) -> str:
    """RAG chain using RAG-Fusion."""
    fusion_results = rag_fusion_retrieval(question, retriever)
    context = format_fusion_docs(fusion_results[:5])  # Use top 5 by RRF score
    
    prompt_value = rag_prompt.invoke({"context": context, "question": question})
    response = llm.invoke(prompt_value)
    return response.content

print("✅ RAG-Fusion chain created")

---
## 5. Comprehensive Comparison

Let's compare all three approaches on various security questions.

In [None]:
def compare_retrieval_methods(question: str):
    """
    Compare basic, multi-query, and RAG-Fusion retrieval for a given question.
    """
    print("=" * 80)
    print(f"❓ QUESTION: {question}")
    print("=" * 80)
    
    # 1. Basic Retrieval
    print("\n1️⃣  BASIC RETRIEVAL")
    print("-" * 80)
    basic_docs = retriever.get_relevant_documents(question)
    print(f"Retrieved {len(basic_docs)} documents:\n")
    for i, doc in enumerate(basic_docs, 1):
        print(f"  {i}. {doc.metadata['id']}: {doc.metadata['title']}")
    
    print("\n🤖 Answer:")
    basic_answer = basic_rag_chain.invoke(question)
    print(basic_answer[:300] + "...\n")
    
    # 2. Multi-Query Retrieval
    print("\n2️⃣  MULTI-QUERY RETRIEVAL")
    print("-" * 80)
    multi_answer = multi_query_rag_chain_func(question)
    print("\n🤖 Answer:")
    print(multi_answer[:300] + "...\n")
    
    # 3. RAG-Fusion
    print("\n3️⃣  RAG-FUSION")
    print("-" * 80)
    fusion_answer = rag_fusion_chain_func(question)
    print("\n🤖 Answer:")
    print(fusion_answer[:300] + "...\n")
    
    print("=" * 80)
    print()

print("✅ Comparison function created")

### Test Case 1: Model Theft

In [None]:
compare_retrieval_methods("How do attackers steal ML models?")

### Test Case 2: Prompt Injection Defenses

In [None]:
compare_retrieval_methods("What defenses exist against prompt injection attacks?")

### Test Case 3: Information Leakage

In [None]:
compare_retrieval_methods("How can LLMs accidentally leak sensitive information?")

---
## 6. Analysis and Recommendations

### When to Use Each Approach

#### 🎯 Basic Retrieval

**Best for:**
- Specific, well-defined questions
- When query matches document terminology exactly
- Low-latency requirements (single query)
- Cost-sensitive applications

**Example:** "What is LLM01 Prompt Injection?"

**Pros:**
- Fastest (single retrieval)
- Lowest cost (one embedding)
- Simple to implement

**Cons:**
- May miss relevant documents with different wording
- Single perspective only
- Lower recall

---

#### 🎯 Multi-Query Retrieval

**Best for:**
- Ambiguous or broad questions
- When you want comprehensive coverage
- Exploratory queries
- When terminology varies across documents

**Example:** "How do I secure my ML deployment?"

**Pros:**
- Better recall (finds more relevant documents)
- Covers multiple perspectives
- Handles synonyms and paraphrasing

**Cons:**
- Higher latency (multiple LLM calls + retrievals)
- Higher cost (generating queries + multiple embeddings)
- May retrieve duplicates

---

#### 🎯 RAG-Fusion

**Best for:**
- Complex, multi-faceted questions
- When ranking quality matters most
- Research and analysis tasks
- When you need documents consistently relevant across perspectives

**Example:** "What are the most effective defenses against adversarial attacks on LLMs?"

**Pros:**
- Best ranking quality (RRF scoring)
- Prioritizes documents relevant across multiple queries
- Robust to single-query biases
- Better precision at top-k

**Cons:**
- Highest latency (query generation + multiple retrievals + fusion)
- Highest cost
- Most complex implementation

### Performance Comparison

Let's analyze the trade-offs:

In [None]:
import time

def benchmark_retrieval(question: str):
    """Benchmark retrieval methods."""
    print(f"\n⏱️  Benchmarking retrieval methods for: '{question}'")
    print("=" * 80)
    
    # Basic
    start = time.time()
    basic_docs = retriever.get_relevant_documents(question)
    basic_time = time.time() - start
    print(f"\n1. Basic Retrieval:")
    print(f"   Time: {basic_time:.2f}s")
    print(f"   Documents: {len(basic_docs)}")
    print(f"   Unique: {len(set([doc.page_content for doc in basic_docs]))}")
    
    # Multi-Query
    start = time.time()
    multi_docs = multi_query_retrieval(question, retriever)
    multi_time = time.time() - start
    print(f"\n2. Multi-Query Retrieval:")
    print(f"   Time: {multi_time:.2f}s ({multi_time/basic_time:.1f}x slower)")
    print(f"   Documents: {len(multi_docs)}")
    print(f"   Unique: {len(set([doc.page_content for doc in multi_docs]))}")
    
    # RAG-Fusion
    start = time.time()
    fusion_results = rag_fusion_retrieval(question, retriever)
    fusion_time = time.time() - start
    print(f"\n3. RAG-Fusion:")
    print(f"   Time: {fusion_time:.2f}s ({fusion_time/basic_time:.1f}x slower)")
    print(f"   Documents: {len(fusion_results)}")
    print(f"   Top-3 RRF scores: {[f'{score:.4f}' for _, score in fusion_results[:3]]}")
    
    print("\n" + "=" * 80)

print("✅ Benchmark function created")

In [None]:
# Run benchmark
benchmark_retrieval("How do attackers steal ML models?")

---
## 7. Summary and Key Takeaways

### What We Built

✅ Three advanced retrieval strategies:
1. **Multi-Query Retrieval**: Generate query variations for broader coverage
2. **RAG-Fusion**: Apply Reciprocal Rank Fusion for intelligent ranking
3. **Comparison Framework**: Evaluate approaches systematically

### Core Concepts Learned

1. **Query Expansion**: Using LLMs to generate related queries
2. **Reciprocal Rank Fusion**: Combining ranked lists effectively
3. **Trade-offs**: Latency vs quality vs cost
4. **Use Cases**: When to apply each approach

### Key Insights

**Multi-Query Retrieval:**
- ↑ Better recall (finds more relevant documents)
- ↑ Handles terminology variations
- ↓ Higher latency and cost
- ✅ Great for exploratory queries

**RAG-Fusion:**
- ↑↑ Best precision at top-k
- ↑↑ Intelligent ranking across perspectives
- ↓↓ Highest latency and cost
- ✅ Best for complex, important queries

**Production Recommendation:**
- Use **Basic** for 80% of simple queries (fast, cheap)
- Use **Multi-Query** for ambiguous/exploratory queries (medium cost)
- Use **RAG-Fusion** for critical/complex queries (highest quality)
- Consider **hybrid routing**: classify query complexity, route to appropriate method

### Next Steps

In **Part 4**, we'll tackle **Query Decomposition**:
- Break complex questions into sub-questions
- Answer sequentially or in parallel
- Synthesize comprehensive answers
- Handle multi-step reasoning

Example: "How do I secure my entire ML pipeline?" →
1. "What are security risks in ML training?"
2. "How to secure ML model deployment?"
3. "What are inference-time security considerations?"
4. "How to monitor ML systems for security?"

---

### 🎯 Practice Exercises

1. **Experiment with k**: Try different values for k (constant in RRF)
2. **Query Templates**: Design different prompt templates for query generation
3. **Weighted Fusion**: Modify RRF to weight queries differently
4. **Hybrid Retrieval**: Combine dense embeddings with keyword search (BM25)
5. **Query Classification**: Build a classifier to route queries to appropriate method

### 📚 Further Reading

- [RAG-Fusion Paper](https://arxiv.org/abs/2402.03367)
- [Reciprocal Rank Fusion](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf)
- [LangChain Multi-Query Retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/MultiQueryRetriever)
- [Query Expansion Techniques](https://en.wikipedia.org/wiki/Query_expansion)