# Retrieval-Augmented Transformer Demo

This notebook demonstrates the RAG system with interactive examples.

## Setup


In [None]:
import sys
import os
sys.path.append('..')

import config
from model.base_model import BaseQAModel
from model.rag_model import RAGModel
from retrieval.hybrid_retriever import HybridRetriever
import torch

print("Setup complete!")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")


## Load Models


In [None]:
# Load retriever
print("Loading retriever...")
retriever = HybridRetriever()
index_path = os.path.join(config.RETRIEVAL_DIR, 'hybrid_index')
retriever.load_index(index_path)

# Load baseline model
print("\nLoading baseline model...")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
baseline_model = BaseQAModel(device=device)
baseline_model.load_from_checkpoint(config.BASELINE_OUTPUT_DIR)

# Load RAG model
print("\nLoading RAG model...")
rag_model = RAGModel(retriever=retriever, device=device)
rag_model.load_from_checkpoint(config.RAG_OUTPUT_DIR)

print("\nAll models loaded!")


## Interactive Demo

Try asking questions and compare the models!


In [None]:
def ask_question(question):
    """Ask a question to both models and compare."""
    print(f"Question: {question}\n")
    print("="*80)
    
    # Baseline model
    print("\n🔵 BASELINE MODEL (No Retrieval):")
    baseline_answer = baseline_model.generate_answer(question)
    print(f"Answer: {baseline_answer}")
    
    # RAG model
    print("\n🟢 RAG MODEL (With Retrieval):")
    rag_answer, retrieved_passages = rag_model.generate_answer(question, use_retrieval=True)
    print(f"Answer: {rag_answer}")
    
    if retrieved_passages:
        print("\n📚 Retrieved Passages:")
        for i, (passage, score) in enumerate(retrieved_passages, 1):
            print(f"\n  Passage {i} (score={score:.3f}):")
            print(f"  {passage['text'][:200]}...")
    
    print("\n" + "="*80)

# Example questions
questions = [
    "Who was the first person to walk on the moon?",
    "What is the capital of France?",
    "When did World War II end?",
]

for q in questions:
    ask_question(q)
    print("\n")


## Custom Question

Try your own question!


In [None]:
# Try your own question here
custom_question = "What is quantum computing?"
ask_question(custom_question)
