# Paper 29: Retrieval-Augmented Generation for Knowledge-Intensive Tasks
## Patrick Lewis, Ethan Perez, Aleksandra Piktus, et al., Meta AI (2020)

### RAG: Retrieval-Augmented Generation

Combine dense retrieval (DPR) with seq2seq generation (BART). Best of both worlds: external knowledge + powerful generation!

In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

## RAG Architecture

```
Input query (x)
    ↓
Retriever (DPR) → Top-k documents (z)
    ↓
Generator (BART) → P(y | x, z)
    ↓
Output (y)
```

**Two variants:**
- **RAG-Sequence**: Marginalize over documents for entire sequence
- **RAG-Token**: Marginalize over documents per token

In [None]:
def softmax(x):
    exp_x = np.exp(x - np.max(x))
    return exp_x / np.sum(exp_x)

class SimpleRetriever:
    """Simplified dense retriever (like DPR)"""
    def __init__(self, embedding_dim):
        self.embedding_dim = embedding_dim
        self.query_encoder_W = np.random.randn(embedding_dim, embedding_dim) * 0.01
    
    def encode_query(self, query_tokens):
        """Encode query to dense vector"""
        # Simplified: just use random projection
        query_vec = np.mean(query_tokens, axis=0)
        encoded = np.dot(self.query_encoder_W, query_vec)
        # L2 normalize
        return encoded / (np.linalg.norm(encoded) + 1e-8)
    
    def retrieve(self, query_embedding, document_embeddings, k=5):
        """
        Retrieve top-k documents
        Returns: indices and probabilities
        """
        # Compute similarities
        similarities = np.dot(document_embeddings, query_embedding)
        
        # Get top-k
        top_k_indices = np.argsort(similarities)[::-1][:k]
        top_k_scores = similarities[top_k_indices]
        
        # Convert to probabilities
        probs = softmax(top_k_scores)
        
        return top_k_indices, probs

# Test retriever
embedding_dim = 64
retriever = SimpleRetriever(embedding_dim)

# Dummy data
query_tokens = np.random.randn(10, embedding_dim)
document_embeddings = np.random.randn(20, embedding_dim)
# Normalize documents
document_embeddings = document_embeddings / (np.linalg.norm(document_embeddings, axis=1, keepdims=True) + 1e-8)

query_emb = retriever.encode_query(query_tokens)
top_indices, top_probs = retriever.retrieve(query_emb, document_embeddings, k=5)

print(f"Retrieved documents: {top_indices}")
print(f"Retrieval probabilities: {top_probs}")
print(f"Sum of probs: {np.sum(top_probs):.4f}")

## Generator (Seq2Seq)

In [None]:
class SimpleGenerator:
    """Simplified seq2seq generator (like BART)"""
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        
        # Encoder
        self.encoder_W = np.random.randn(hidden_dim, embedding_dim) * 0.01
        
        # Decoder
        self.decoder_W = np.random.randn(hidden_dim, embedding_dim) * 0.01
        self.output_W = np.random.randn(vocab_size, hidden_dim) * 0.01
    
    def generate_prob(self, query_tokens, doc_tokens, target_tokens):
        """
        Compute P(y | x, z) where:
        - x: query
        - z: document
        - y: target output
        """
        # Encode query + document
        combined = np.concatenate([query_tokens, doc_tokens], axis=0)
        encoder_hidden = np.tanh(np.dot(self.encoder_W, np.mean(combined, axis=0)))
        
        # Decode target
        log_prob = 0
        for target_token in target_tokens:
            decoder_hidden = np.tanh(np.dot(self.decoder_W, target_token))
            
            # Combine encoder and decoder
            combined_hidden = encoder_hidden + decoder_hidden
            
            # Output distribution
            logits = np.dot(self.output_W, combined_hidden)
            probs = softmax(logits)
            
            # Assume we know the target token index (simplified)
            # In reality, we'd compute cross-entropy
            target_idx = np.argmax(target_token)  # One-hot
            log_prob += np.log(probs[target_idx] + 1e-8)
        
        return log_prob

# Test generator
vocab_size = 1000
generator = SimpleGenerator(vocab_size, embedding_dim, hidden_dim=128)

# Dummy tokens (embeddings)
query = np.random.randn(5, embedding_dim)
doc = np.random.randn(20, embedding_dim)
target = np.random.randn(8, embedding_dim)

log_prob = generator.generate_prob(query, doc, target)
print(f"\nLog P(y | x, z): {log_prob:.4f}")

## RAG-Sequence: Marginalize Over Documents

$$
P_{RAG-Seq}(y | x) = \sum_{z \in \text{top-k}} P(z | x) \cdot P(y | x, z)
$$

Generate entire sequence with each document, then combine.

In [None]:
class RAGSequence:
    """RAG-Sequence model"""
    def __init__(self, retriever, generator):
        self.retriever = retriever
        self.generator = generator
    
    def forward(self, query_tokens, target_tokens, document_embeddings, documents_tokens, k=5):
        """
        RAG-Sequence forward pass
        
        P(y|x) = Σ_z P(z|x) * P(y|x,z)
        """
        # Retrieve documents
        query_emb = self.retriever.encode_query(query_tokens)
        doc_indices, doc_probs = self.retriever.retrieve(query_emb, document_embeddings, k=k)
        
        # Marginalize over documents
        total_prob = 0
        
        for doc_idx, p_z_given_x in zip(doc_indices, doc_probs):
            # Get document tokens
            doc_tokens = documents_tokens[doc_idx]
            
            # P(y | x, z)
            log_p_y_given_xz = self.generator.generate_prob(query_tokens, doc_tokens, target_tokens)
            p_y_given_xz = np.exp(log_p_y_given_xz)
            
            # P(z|x) * P(y|x,z)
            total_prob += p_z_given_x * p_y_given_xz
        
        return np.log(total_prob + 1e-8), doc_indices, doc_probs

# Create RAG-Sequence model
rag_seq = RAGSequence(retriever, generator)

# Generate dummy documents
num_docs = 20
documents_tokens = [np.random.randn(15, embedding_dim) for _ in range(num_docs)]

# Test
log_prob, used_docs, used_probs = rag_seq.forward(
    query_tokens=query,
    target_tokens=target,
    document_embeddings=document_embeddings,
    documents_tokens=documents_tokens,
    k=5
)

print("\nRAG-Sequence:")
print(f"Log P(y|x): {log_prob:.4f}")
print(f"Used documents: {used_docs}")
print(f"Document weights: {used_probs}")

## RAG-Token: Marginalize Per Token

$$
P_{RAG-Token}(y | x) = \prod_{i=1}^{|y|} \sum_{z \in \text{top-k}} P(z | x) \cdot P(y_i | x, z, y_{<i})
$$

Can use different documents for different tokens!

In [None]:
class RAGToken:
    """RAG-Token model (simplified)"""
    def __init__(self, retriever, generator):
        self.retriever = retriever
        self.generator = generator
    
    def forward_token(self, query_tokens, target_token, document_embeddings, documents_tokens, k=5):
        """
        Compute P(y_i | x) for single token
        
        P(y_i | x) = Σ_z P(z|x) * P(y_i|x,z)
        """
        # Retrieve documents
        query_emb = self.retriever.encode_query(query_tokens)
        doc_indices, doc_probs = self.retriever.retrieve(query_emb, document_embeddings, k=k)
        
        # Marginalize for this token
        token_prob = 0
        
        for doc_idx, p_z_given_x in zip(doc_indices, doc_probs):
            doc_tokens = documents_tokens[doc_idx]
            
            # P(y_i | x, z) - simplified
            log_p = self.generator.generate_prob(query_tokens, doc_tokens, [target_token])
            p_yi_given_xz = np.exp(log_p)
            
            token_prob += p_z_given_x * p_yi_given_xz
        
        return token_prob, doc_indices, doc_probs
    
    def forward(self, query_tokens, target_tokens, document_embeddings, documents_tokens, k=5):
        """
        Full sequence probability
        
        P(y|x) = ∏_i P(y_i|x)
        """
        log_prob_total = 0
        
        for target_token in target_tokens:
            token_prob, _, _ = self.forward_token(
                query_tokens, target_token, document_embeddings, documents_tokens, k
            )
            log_prob_total += np.log(token_prob + 1e-8)
        
        return log_prob_total

# Create RAG-Token model
rag_token = RAGToken(retriever, generator)

# Test
log_prob_token = rag_token.forward(
    query_tokens=query,
    target_tokens=target,
    document_embeddings=document_embeddings,
    documents_tokens=documents_tokens,
    k=5
)

print("\nRAG-Token:")
print(f"Log P(y|x): {log_prob_token:.4f}")
print("\nDifference: RAG-Token can use different docs per token!")

## Synthetic QA Example

In [None]:
# Create more realistic example
knowledge_base = [
    "The Eiffel Tower was built in 1889 by Gustave Eiffel.",
    "Paris is the capital of France and has a population of 2.2 million.",
    "The Statue of Liberty was a gift from France to the United States.",
    "Mount Everest is 8,849 meters tall and located in the Himalayas.",
    "The Amazon River flows through South America for 6,400 kilometers.",
]

qa_pairs = [
    ("When was the Eiffel Tower built?", "1889", 0),
    ("What is the height of Mount Everest?", "8,849 meters", 3),
    ("How long is the Amazon River?", "6,400 kilometers", 4),
]

print("Knowledge Base:")
for i, doc in enumerate(knowledge_base):
    print(f"  {i}. {doc}")

print("\nQA Pairs:")
for q, a, doc_idx in qa_pairs:
    print(f"  Q: {q}")
    print(f"  A: {a}")
    print(f"  Relevant doc: #{doc_idx}")
    print()

## Visualize RAG Architecture

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(16, 8))

def draw_rag_variant(ax, title, is_token=False):
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 12)
    ax.axis('off')
    ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
    
    # Query
    ax.add_patch(plt.Rectangle((4, 10.5), 2, 0.8, fill=True, 
                               color='lightblue', ec='black', linewidth=2))
    ax.text(5, 10.9, 'Query (x)', ha='center', va='center', fontsize=11, fontweight='bold')
    
    # Retriever
    ax.add_patch(plt.Rectangle((3.5, 9), 3, 1, fill=True, 
                               color='lightgreen', ec='black', linewidth=2))
    ax.text(5, 9.5, 'Retriever\n(DPR)', ha='center', va='center', fontsize=10, fontweight='bold')
    ax.arrow(5, 10.5, 0, -0.3, head_width=0.2, head_length=0.1, fc='black', ec='black', linewidth=2)
    
    # Retrieved documents
    doc_positions = [2, 4, 6, 8]
    ax.text(5, 7.8, 'Top-k Documents', ha='center', fontsize=10, fontweight='bold')
    for i, x in enumerate(doc_positions[:3]):
        ax.add_patch(plt.Rectangle((x-0.4, 6.5), 0.8, 1, fill=True, 
                                   color='lightyellow', ec='black', linewidth=1.5))
        ax.text(x, 7, f'z{i+1}', ha='center', va='center', fontsize=9)
        # Arrow from retriever
        ax.plot([5, x], [9, 7.5], 'k--', alpha=0.5, linewidth=1)
    
    if not is_token:
        # RAG-Sequence: each doc generates full sequence
        y_positions = [2, 4, 6]
        for i, (dx, dy) in enumerate(zip(doc_positions[:3], y_positions)):
            # Generator per document
            ax.add_patch(plt.Rectangle((dy-0.5, 4.5), 1, 0.8, fill=True, 
                                       color='lightcoral', ec='black', linewidth=1.5))
            ax.text(dy, 4.9, f'Gen', ha='center', va='center', fontsize=8)
            ax.arrow(dx, 6.5, dy-dx, -1.5, head_width=0.15, head_length=0.1, 
                    fc='gray', ec='gray', linewidth=1, alpha=0.6)
            
            # Output sequence
            ax.add_patch(plt.Rectangle((dy-0.6, 3), 1.2, 0.6, fill=True, 
                                       color='wheat', ec='black', linewidth=1))
            ax.text(dy, 3.3, f'y', ha='center', va='center', fontsize=8)
            ax.arrow(dy, 4.5, 0, -0.8, head_width=0.12, head_length=0.08, 
                    fc='black', ec='black', linewidth=1)
        
        # Combine
        ax.add_patch(plt.Rectangle((4, 1.2), 2, 0.8, fill=True, 
                                   color='plum', ec='black', linewidth=2))
        ax.text(5, 1.6, 'Σ P(z|x)P(y|x,z)', ha='center', va='center', fontsize=9, fontweight='bold')
        for dy in y_positions:
            ax.plot([dy, 5], [3, 2], 'k-', alpha=0.5, linewidth=1.5)
    else:
        # RAG-Token: combine docs for each token
        token_y = 4.5
        for t in range(3):
            tx = 2 + t * 2.5
            
            # Token position
            ax.add_patch(plt.Rectangle((tx-0.4, token_y), 0.8, 0.6, fill=True, 
                                       color='lightcoral', ec='black', linewidth=1.5))
            ax.text(tx, token_y+0.3, f'y{t+1}', ha='center', va='center', fontsize=9)
            
            # Arrows from all docs
            for dx in doc_positions[:3]:
                ax.plot([dx, tx], [6.5, token_y+0.6], 'k--', alpha=0.3, linewidth=0.8)
        
        # Final output
        ax.add_patch(plt.Rectangle((3.5, 2.5), 3, 0.8, fill=True, 
                                   color='plum', ec='black', linewidth=2))
        ax.text(5, 2.9, '∏ Σ P(z|x)P(yi|x,z)', ha='center', va='center', 
               fontsize=9, fontweight='bold')
        ax.arrow(4, token_y, 0.8, -1.3, head_width=0.15, head_length=0.1, 
                fc='black', ec='black', linewidth=1.5, alpha=0.5)
    
    # Final answer
    ax.add_patch(plt.Rectangle((4, 0.3), 2, 0.6, fill=True, 
                               color='lightgreen', ec='black', linewidth=2))
    ax.text(5, 0.6, 'Answer', ha='center', va='center', fontsize=11, fontweight='bold')
    ax.arrow(5, 1.2 if not is_token else 2.5, 0, 
            -0.2 if not is_token else -1.5, 
            head_width=0.2, head_length=0.1, fc='green', ec='green', linewidth=2)

draw_rag_variant(axes[0], 'RAG-Sequence', is_token=False)
draw_rag_variant(axes[1], 'RAG-Token', is_token=True)

plt.tight_layout()
plt.show()

## Compare RAG Variants

In [None]:
# Simulate probabilities for visualization
n_docs = 5
n_tokens = 8

# RAG-Sequence: same doc weights for all tokens
doc_weights_seq = softmax(np.random.randn(n_docs))
weights_seq_matrix = np.tile(doc_weights_seq, (n_tokens, 1))

# RAG-Token: different doc weights per token
weights_token_matrix = np.array([softmax(np.random.randn(n_docs)) for _ in range(n_tokens)])

# Visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

im1 = ax1.imshow(weights_seq_matrix.T, cmap='YlOrRd', aspect='auto', vmin=0, vmax=1)
ax1.set_xlabel('Output Token Position', fontsize=12)
ax1.set_ylabel('Document', fontsize=12)
ax1.set_title('RAG-Sequence\n(Same docs for all tokens)', fontsize=13, fontweight='bold')
plt.colorbar(im1, ax=ax1, label='P(z|x)')

im2 = ax2.imshow(weights_token_matrix.T, cmap='YlOrRd', aspect='auto', vmin=0, vmax=1)
ax2.set_xlabel('Output Token Position', fontsize=12)
ax2.set_ylabel('Document', fontsize=12)
ax2.set_title('RAG-Token\n(Different docs per token)', fontsize=13, fontweight='bold')
plt.colorbar(im2, ax=ax2, label='P(z|x)')

plt.tight_layout()
plt.show()

print("\nRAG-Sequence: More consistent (uses same knowledge)")
print("RAG-Token: More flexible (can mix knowledge sources)")

## Key Takeaways

### RAG Architecture:

**Components**:
1. **Retriever**: Dense retrieval (DPR-style)
   - Query encoder: $q_{emb} = E_Q(x)$
   - Document encoder: $d_{emb} = E_D(z)$
   - Retrieval: $P(z|x) \propto \exp(q_{emb} \cdot d_{emb})$

2. **Generator**: Seq2seq model (BART)
   - Input: query $x$ + document $z$
   - Output: $P(y | x, z)$

### RAG-Sequence:

$$
P_{RAG-Seq}(y | x) = \sum_{z \in \text{top-k}} P(z | x) \cdot P_{seq2seq}(y | x, z)
$$

**Process**:
1. Retrieve top-k documents
2. Generate full sequence with each document
3. Weighted sum of sequences

**Characteristics**:
- Each document generates complete answer
- More consistent (single knowledge source per sequence)
- Better for factoid QA

### RAG-Token:

$$
P_{RAG-Token}(y | x) = \prod_{i=1}^{|y|} \left( \sum_{z \in \text{top-k}} P(z | x) \cdot P(y_i | x, z, y_{<i}) \right)
$$

**Process**:
1. Retrieve top-k documents (same for all tokens)
2. For each token: marginalize over documents
3. Different documents can contribute to different tokens

**Characteristics**:
- Can mix information from multiple documents
- More flexible generation
- Better for long-form generation

### Training:

**End-to-end**:
```
Loss = -log P(y* | x)
```

Gradients flow through:
- Generator (BART parameters)
- Query encoder (retriever parameters)

**Document encoder**: Usually frozen (pre-indexed)

### Implementation Details:

**From paper**:
- Retriever: DPR with BERT-base
- Generator: BART-large (400M params)
- Knowledge: Wikipedia (21M passages)
- Top-k: k=5 or k=10
- Index: FAISS for fast retrieval

### Results:

**Natural Questions (Open)**:
- BART (no retrieval): 27.0% EM
- RAG-Sequence: 44.5% EM
- RAG-Token: 44.1% EM

**TriviaQA**:
- BART: 50.1%
- RAG: 56.8%

**WebQuestions**:
- BART: 27.6%
- RAG: 45.2%

### RAG vs Baselines:

| Model | Knowledge | Parametric | Performance |
|-------|-----------|------------|-------------|
| T5-11B | Memorized | ✓ | Good |
| REALM | Retrieved | Mixed | Better |
| **RAG** | **Retrieved** | **✓** | **Best** |

### Advantages:

- ✅ **Factual accuracy**: Access to external knowledge
- ✅ **Scalability**: Add knowledge without retraining
- ✅ **Interpretability**: Can inspect retrieved documents
- ✅ **Efficiency**: Smaller models than pure parametric
- ✅ **Up-to-date**: Update index, not model weights

### Limitations:

- ❌ **Retrieval errors**: Wrong docs → wrong answers
- ❌ **Latency**: Retrieval adds overhead
- ❌ **Index maintenance**: Need to re-encode for updates
- ❌ **Memory**: Full document index required

### When to Use:

**RAG-Sequence**:
- Factoid QA
- Short answers
- When single source is enough

**RAG-Token**:
- Long-form generation
- Multi-hop reasoning
- Combining multiple sources

### Modern Extensions:

- **RETRO** (DeepMind): Retrieve at every layer
- **Atlas** (Meta): Improved training
- **Toolformer**: Retrieve via API calls
- **WebGPT**: Interactive retrieval
- **Self-RAG**: Self-reflective retrieval

### Production Tips:

1. **Hybrid ranking**: Combine retrieval + reranking
2. **Cache**: Pre-retrieve for common queries
3. **Async**: Retrieve while generating
4. **Fallback**: Parametric generation if retrieval fails
5. **Monitor**: Track retrieval quality

### Applications:

- Open-domain QA (Google, Bing)
- Chatbots with knowledge bases
- Document QA
- Fact-checking
- Research assistants
- Customer support

### Key Insight:

**RAG = Best of both worlds**
- Parametric knowledge (generation capability)
- Non-parametric knowledge (external retrieval)
- End-to-end differentiable
- Practical and effective!