# **Modular RAG** with Haystack and Hypster

This project implements the concepts described in the paper: [Modular RAG: Transforming RAG Systems into LEGO-like Reconfigurable Frameworks](https://arxiv.org/abs/2407.21059) by Yunfan Gao et al. 

For a detailed walkthrough, refer to this [Medium article](https://medium.com/p/d2f0ecc88b8f).

## Key Objectives

- Decompose RAG (Retrieval-Augmented Generation) into its fundamental components using **Haystack**.
- Utilize **Hypster** to manage a "hyper-space" of potential RAG configurations.
- Facilitate easy swapping and experimentation with various implementations.

# Setup

## Install Dependencies

This project supports multiple package managers. Choose the method that best suits your environment:

### uv

```bash
uv run --with jupyter jupyter lab
```

### Conda

```bash
conda env create -n modular-rag python=3.10 -y
conda activate modular-rag
pip install -r requirements.txt
```

### pip

```bash
pip install -r requirements.txt
```

## Define Environment Variables

In [1]:
import getpass
import os

for var in ["ANTHROPIC_API_KEY", "OPENAI_API_KEY", "JINA_API_KEY"]:
    if var not in os.environ:
        os.environ[var] = getpass.getpass(f"Enter {var}: ")

## Set working directory to the project root

In [2]:
import os


# Find the directory containing pyproject.toml by searching upwards from the current directory
def find_project_root(start_dir=os.getcwd()):
    current_dir = start_dir
    while True:
        if "pyproject.toml" in os.listdir(current_dir):
            return current_dir
        parent_dir = os.path.dirname(current_dir)
        if parent_dir == current_dir:  # Reached the root directory
            raise FileNotFoundError("Could not find pyproject.toml in any parent directory.")
        current_dir = parent_dir


# Set the working directory to the project root
PROJECT_ROOT = find_project_root()
os.chdir(PROJECT_ROOT)

# Print the current working directory for confirmation
print("Current working directory:", os.getcwd())

Current working directory: /Users/giladrubin/python_workspace/modular-rag


## Load Environment Variables

In [3]:
from dotenv import load_dotenv

load_dotenv()

True

# Motivation

<img src="../assets/graphs/full_pipelines.webp" width="80%" style="display: block; margin: 0 auto;">

# LLM

In [4]:
from hypster import HP, config

In [5]:
@config
def llm_config(hp: HP):
    import os

    anthropic_models = {"haiku": "claude-3-haiku-20240307", "sonnet": "claude-3-5-sonnet-20240620"}
    openai_models = {"gpt-4o-mini": "gpt-4o-mini", "gpt-4o": "gpt-4o", "gpt-4o-latest": "gpt-4o-2024-08-06"}
    ollama_models = {"llama3": "llama3", "llama3.1": "llama3.1", "mistral": "mistral"}
    model_options = {**anthropic_models, **openai_models, **ollama_models}

    model = hp.select(model_options, default="gpt-4o-mini")
    temperature = hp.number(0.0, min=0.0, max=1.0)

    if model in openai_models.values():
        from haystack.components.generators import OpenAIGenerator

        llm = OpenAIGenerator(model=model, generation_kwargs={"temperature": temperature})
    elif model in anthropic_models.values():
        from haystack_integrations.components.generators.anthropic import AnthropicGenerator

        llm = AnthropicGenerator(model=model, generation_kwargs={"temperature": temperature})
    else:
        from haystack_integrations.components.generators.ollama import OllamaGenerator

        base_url = hp.text(os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434"), name="ollama_base_url")
        llm = OllamaGenerator(model=model, url=base_url, generation_kwargs={"temperature": temperature})


llm_config.save("configs/llm.py")


In [29]:
from hypster.ui import apply_vscode_theme, interactive_config

apply_vscode_theme()
results = interactive_config(llm_config)

VBox(children=(Dropdown(description='model', index=2, layout=Layout(min_width='300px', width='300px'), options…

# Indexing Pipeline

<img src="../assets/graphs/full_index.webp" width="50%" style="display: block; margin: 0 auto;">

In [7]:
from hypster import HP, config


@config
def indexing_config(hp: HP):
    from haystack import Pipeline
    from haystack.components.converters import PyPDFToDocument

    pipeline = Pipeline()
    pipeline.add_component("loader", PyPDFToDocument())

    enrich_doc_w_llm = hp.select([True, False], default=True)
    if enrich_doc_w_llm:
        from textwrap import dedent

        from haystack.components.builders import PromptBuilder

        from src.haystack_utils import AddLLMMetadata

        template = dedent("""
            Summarize the document's main topic in one sentence (15 words max). 
            Then list 3-5 keywords or acronyms that best \
            represent its content for search purposes.

            Context:
            {{ documents[0].content[:1000] }}
            
            ============================
            
            Output format:

            Summary:
            Keywords:
        """)

        llm = hp.nest("configs/llm.py")
        pipeline.add_component("prompt_builder", PromptBuilder(template=template))
        pipeline.add_component("llm", llm["llm"])
        pipeline.add_component("document_enricher", AddLLMMetadata())
        pipeline.connect("loader", "prompt_builder")
        pipeline.connect("prompt_builder", "llm")
        pipeline.connect("llm", "document_enricher")
        pipeline.connect("loader", "document_enricher")

        splitter_source = "document_enricher"
    else:
        splitter_source = "loader"

    from haystack.components.preprocessors import DocumentSplitter

    split_by = hp.select(["sentence", "word", "passage", "page"], default="sentence")
    splitter = DocumentSplitter(split_by=split_by, split_length=hp.int(10), split_overlap=hp.int(2))

    pipeline.add_component("splitter", splitter)
    pipeline.connect(splitter_source, "splitter")


indexing_config.save("configs/indexing.py")

In [8]:
results = interactive_config(indexing_config)

VBox(children=(Dropdown(description='enrich_doc_w_llm', layout=Layout(min_width='300px', width='300px'), optio…

In [9]:
results["pipeline"]

<haystack.core.pipeline.pipeline.Pipeline object at 0x31195c4f0>
🚅 Components
  - loader: PyPDFToDocument
  - prompt_builder: PromptBuilder
  - llm: OpenAIGenerator
  - document_enricher: AddLLMMetadata
  - splitter: DocumentSplitter
🛤️ Connections
  - loader.documents -> prompt_builder.documents (List[Document])
  - loader.documents -> document_enricher.documents (List[Document])
  - prompt_builder.prompt -> llm.prompt (str)
  - llm.replies -> document_enricher.replies (List[str])
  - document_enricher.documents -> splitter.documents (List[Document])

# Embedding

In [10]:
@config
def fast_embed(hp: HP):
    from typing import Any, Dict, List

    from fastembed import TextEmbedding

    def get_model_dim(chosen_model: str, model_list: List[Dict[str, Any]]) -> int:
        for model in model_list:
            if model["model"] == chosen_model:
                return model["dim"]
        raise ValueError(f"Model {chosen_model} not found in the list of supported models.")

    from haystack_integrations.components.embedders.fastembed import (
        FastembedDocumentEmbedder,
        FastembedTextEmbedder,
    )

    meta_fileds_to_embed = ["parent_doc_summary"]

    model = hp.select(
        {"bge-small": "BAAI/bge-small-en-v1.5", "mini-lm": "sentence-transformers/all-MiniLM-L6-v2"},
        default="mini-lm",
    )
    import os

    cpu_count = os.cpu_count() or 1
    doc_embedder = FastembedDocumentEmbedder(
        model=model,
        parallel=hp.int(cpu_count),
        meta_fields_to_embed=meta_fileds_to_embed,
    )
    text_embedder = FastembedTextEmbedder(model=model)
    embedding_dim = get_model_dim(model, TextEmbedding.list_supported_models())


fast_embed.save("configs/fast_embed.py")

In [11]:
results = interactive_config(fast_embed)

VBox(children=(Dropdown(description='model', index=1, layout=Layout(min_width='300px', width='300px'), options…

In [12]:
@config
def jina_embed(hp: HP):
    from haystack_integrations.components.embedders.jina import JinaDocumentEmbedder, JinaTextEmbedder

    meta_fileds_to_embed = ["parent_doc_summary"]

    model = hp.select({"v3": "jina-embeddings-v3", "v2": "jina-embeddings-v2"}, default="v3")
    if model == "jina-embeddings-v3":
        late_chunking = hp.select([True, False], default=True, name="late_chunking")
    else:
        late_chunking = False

    doc_embedder = JinaDocumentEmbedder(
        model=model,
        batch_size=hp.int(16),
        dimensions=hp.int(256),
        task="retrieval.passage",
        late_chunking=late_chunking,
        meta_fields_to_embed=meta_fileds_to_embed,
    )
    text_embedder = JinaTextEmbedder(model=model, dimensions=doc_embedder.dimensions, task="retrieval.query")
    embedding_dim = doc_embedder.dimensions


jina_embed.save("configs/jina_embed.py")

In [13]:
results = interactive_config(jina_embed)

VBox(children=(Dropdown(description='model', layout=Layout(min_width='300px', width='300px'), options=('v3', '…

# Retrieval

# <img src="../assets/graphs/fusion_retriever.webp" width="60%">

In [14]:
@config
def in_memory_retrieval(hp: HP):
    from haystack import Pipeline
    from haystack.document_stores.in_memory import InMemoryDocumentStore

    from src.haystack_utils import PassThroughDocuments, PassThroughText

    pipeline = Pipeline()
    # utility components for the first and last parts of the pipline
    pipeline.add_component("query", PassThroughText())
    pipeline.add_component("retrieved_documents", PassThroughDocuments())

    document_store = InMemoryDocumentStore()
    from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever

    embedding_similarity_function = hp.select(["cosine", "dot_product"], default="cosine")
    document_store.embedding_similarity_function = embedding_similarity_function
    pipeline.add_component("embedding_retriever", InMemoryEmbeddingRetriever(document_store=document_store))

    use_bm25 = hp.bool(True)
    if use_bm25:
        from haystack.components.joiners.document_joiner import DocumentJoiner
        from haystack.components.retrievers.in_memory import InMemoryBM25Retriever

        bm25_algorithm = hp.select(["BM25Okapi", "BM25L", "BM25Plus"], default="BM25L")
        document_store.bm25_algorithm = bm25_algorithm
        pipeline.add_component("bm25_retriever", InMemoryBM25Retriever(document_store=document_store))
        pipeline.connect("query", "bm25_retriever")

        join_mode = hp.select(
            ["distribution_based_rank_fusion", "concatenate", "merge", "reciprocal_rank_fusion"],
            default="distribution_based_rank_fusion",
        )
        bm25_weight = hp.number(0.5)
        joiner = DocumentJoiner(join_mode=join_mode, top_k=hp.int(10), weights=[bm25_weight, 1 - bm25_weight])

        pipeline.add_component("document_joiner", joiner)
        pipeline.connect("bm25_retriever", "document_joiner")
        pipeline.connect("embedding_retriever", "document_joiner")
        pipeline.connect("document_joiner", "retrieved_documents")
    else:
        pipeline.connect("embedding_retriever", "retrieved_documents")


in_memory_retrieval.save("configs/in_memory_retrieval.py")

In [15]:
results = interactive_config(in_memory_retrieval)

VBox(children=(Dropdown(description='embedding_similarity_function', layout=Layout(min_width='300px', width='3…

In [16]:
@config
def qdrant_retrieval(hp: HP):
    from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever
    from haystack_integrations.document_stores.qdrant import QdrantDocumentStore

    # location = hp.text(":memory:")
    embedding_similarity_function = hp.select(["cosine", "dot_product", "l2"], default="cosine")

    document_store = QdrantDocumentStore(
        # location=location,
        recreate_index=True,
        similarity=embedding_similarity_function,
        embedding_dim=hp.int(256, name="embedding_dim"),
        on_disk=True,
        path="qdrant/data",
    )

    embedding_retriever = QdrantEmbeddingRetriever(document_store=document_store, top_k=hp.int(20))

    from haystack import Pipeline

    from src.haystack_utils import PassThroughDocuments, PassThroughText

    pipeline = Pipeline()
    pipeline.add_component("query", PassThroughText())
    pipeline.add_component("embedding_retriever", embedding_retriever)
    pipeline.add_component("retrieved_documents", PassThroughDocuments())
    pipeline.connect("embedding_retriever", "retrieved_documents")


qdrant_retrieval.save("configs/qdrant_retrieval.py")

In [17]:
results = interactive_config(qdrant_retrieval)

VBox(children=(Dropdown(description='embedding_similarity_function', layout=Layout(min_width='300px', width='3…

# Reranker

In [18]:
from hypster import HP, config


@config
def reranker(hp: HP):
    jina_models = {
        "jina-reranker-v2": "jina-reranker-v2-base-multilingual",
        "jina-colbert-v2": "jina-colbert-v2",
        "jina-reranker-v1": "jina-reranker-v1-base-en",
    }

    transformers_models = {
        "tiny-bert-v2": "cross-encoder/ms-marco-TinyBERT-L-2-v2",
        "minilm-v2": "cross-encoder/ms-marco-MiniLM-L-2-v2",
    }

    model = hp.select({**jina_models, **transformers_models}, default="jina-reranker-v2")
    top_k = hp.int(3)
    if model in jina_models.values():
        from haystack_integrations.components.rankers.jina import JinaRanker

        reranker = JinaRanker(model=model, top_k=top_k)
    else:
        from haystack.components.rankers import TransformersSimilarityRanker

        reranker = TransformersSimilarityRanker(model=model, top_k=top_k)


reranker.save("configs/reranker.py")

In [19]:
results = interactive_config(reranker)

VBox(children=(Dropdown(description='model', layout=Layout(min_width='300px', width='300px'), options=('jina-r…

In [20]:
@config
def response_config(hp: HP):
    from textwrap import dedent

    llm = hp.nest("configs/llm.py")
    llm = llm["llm"]

    template = dedent("""\
    Given the following information,
    answer the question concisely in one to two sentences,
    using only the relevant details provided in the documents.
    Support your answer with a brief, word-for-word quote from the most pertinent document. 
    Note that some documents may not be relevant to the question.
    ========================================
    Context:
    {% for document in documents %}
    Document {{loop.index}}:
    {{ document.meta.llm_extracted_info }}
    {{ document.content }}
    ---
    {% endfor %}
    ========================================
    Question: {{query}}

    Answer:

    Supporting Quote:
    """)


response_config.save("configs/response.py")

In [21]:
results = interactive_config(response_config)

VBox(children=(VBox(children=(HTML(value="<span style='font-size: 1.2em;                 font-weight: bold; ma…

# Modular RAG

In [22]:
@config
def modular_rag(hp: HP):
    indexing = hp.nest("configs/indexing.py")
    indexing_pipeline = indexing["pipeline"]

    embedder_type = hp.select(["fastembed", "jina"], default="fastembed")
    match embedder_type:
        case "fastembed":
            embedder = hp.nest("configs/fast_embed.py")
        case "jina":
            embedder = hp.nest("configs/jina_embed.py")

    indexing_pipeline.add_component("doc_embedder", embedder["doc_embedder"])

    document_store_type = hp.select(["in_memory", "qdrant"], default="in_memory")
    match document_store_type:
        case "in_memory":
            retrieval = hp.nest("configs/in_memory_retrieval.py")
        case "qdrant":
            retrieval = hp.nest("configs/qdrant_retrieval.py", values={"embedding_dim": embedder["embedding_dim"]})

    from haystack.components.writers import DocumentWriter
    from haystack.document_stores.types import DuplicatePolicy

    document_writer = DocumentWriter(retrieval["document_store"], policy=DuplicatePolicy.OVERWRITE)
    indexing_pipeline.add_component("document_writer", document_writer)

    indexing_pipeline.connect("splitter", "doc_embedder")
    indexing_pipeline.connect("doc_embedder", "document_writer")

    pipeline = retrieval["pipeline"]

    pipeline.add_component("text_embedder", embedder["text_embedder"])
    pipeline.connect("query", "text_embedder")
    pipeline.connect("text_embedder", "embedding_retriever.query_embedding")

    from src.haystack_utils import PassThroughDocuments

    pipeline.add_component("docs_for_generation", PassThroughDocuments())
    use_reranker = hp.select([True, False], default=True)
    if use_reranker:
        reranker = hp.nest("configs/reranker.py")
        pipeline.add_component("reranker", reranker["reranker"])
        pipeline.connect("retrieved_documents", "reranker")
        pipeline.connect("reranker", "docs_for_generation")
        pipeline.connect("query", "reranker")
    else:
        pipeline.connect("retrieved_documents", "docs_for_generation")

    response = hp.nest("configs/response.py")
    from haystack.components.builders import PromptBuilder

    pipeline.add_component("prompt_builder", PromptBuilder(template=response["template"]))
    pipeline.add_component("llm", response["llm"])
    pipeline.connect("prompt_builder", "llm")
    pipeline.connect("query.text", "prompt_builder.query")
    pipeline.connect("docs_for_generation", "prompt_builder")


modular_rag.save("configs/modular_rag.py")

In [23]:
results = modular_rag(
    values={
        "indexing.enrich_doc_w_llm": True,
        "indexing.llm.model": "gpt-4o-mini",
        "document_store_type": "in_memory",
        "retrieval.bm25_weight": 0.8,
        "embedder_type": "fastembed",
        "reranker.model": "tiny-bert-v2",
        "response.llm.model": "gpt-4o-mini",
        "indexing.splitter.split_length": 6,
        "reranker.top_k": 3,
    },
)

In [24]:
from hypster.ui import apply_vscode_theme, interactive_config

apply_vscode_theme()
results = interactive_config(modular_rag)

VBox(children=(VBox(children=(HTML(value="<span style='font-size: 1.2em;                 font-weight: bold; ma…

In [25]:
indexing_pipeline = results["indexing_pipeline"]
indexing_pipeline.warm_up()

file_paths = ["data/raw/modular_rag.pdf"]
for file_path in file_paths:  # this can be parallelized
    indexing_pipeline.run({"loader": {"sources": [file_path]}})

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

Fetching 5 files: 100%|██████████| 5/5 [00:00<00:00, 67432.54it/s]
Fetching 5 files: 100%|██████████| 5/5 [00:00<00:00, 138884.24it/s]
Fetching 5 files: 100%|██████████| 5/5 [00:00<00:00, 121927.44it/s]
Fetching 5 files: 100%|██████████| 5/5 [00:00<00:00, 153076.79it/s]
Fetching 5 files: 100%|██████████| 5/5 [00:00<00:00, 151967.54it/s]
Fetching 5 files: 100%|██████████| 5/5 [00:00<00:00, 14706.54it/s]
Fetching 5 files: 100%|██████████| 5/5 [00:00<00:00, 16181.73it/s]
Fetching 5 files: 100%|██████████| 5/5 [00:00<00:00, 14285.78it/s]
Fetching 5 files: 100%|██████████| 5/5 [00:00<00:00, 35187.11it/s]
Fetching 5 files: 100%|██████████| 5/5 [00:00<00:00, 33989.50it/s]
Calculating embeddings: 100%|██████████| 312/312 [00:14<00:00, 21.02it/s]


In [26]:
query = "What are the 6 main modules of the modular RAG framework?"

pipeline = results["pipeline"]
pipeline.warm_up()
response = pipeline.run({"query": {"text": query}}, include_outputs_from=["prompt_builder", "docs_for_generation"])

Calculating embeddings: 100%|██████████| 1/1 [00:00<00:00, 45.74it/s]


In [27]:
print(response["llm"]["replies"][0])

The six main modules of the modular RAG framework are Indexing, Pre-retrieval, Retrieval, Post-retrieval, Generation, and Orchestration. 

Supporting Quote: "we have established six main modules: Indexing, Pre-retrieval, Retrieval, Post-retrieval, Generation, and Orchestration."
