# **Modular RAG** with Haystack and Hypster

This project implements the concepts described in the paper: [Modular RAG: Transforming RAG Systems into LEGO-like Reconfigurable Frameworks](https://arxiv.org/abs/2407.21059) by Yunfan Gao et al. 

For a detailed walkthrough, refer to this [Medium article](https://medium.com/p/d2f0ecc88b8f).

## Key Objectives

- Decompose RAG (Retrieval-Augmented Generation) into its fundamental components using **Haystack**.
- Utilize **Hypster** to manage a "hyper-space" of potential RAG configurations.
- Facilitate easy swapping and experimentation with various implementations.

# Setup

## Install Dependencies

This project supports multiple package managers. Choose the method that best suits your environment:

### Poetry

```bash
poetry install
```

### pip

```bash
pip install -r requirements.txt
```

### Conda

```bash
conda env create -n modular-rag python=3.10 -y
conda activate modular-rag
pip install -r requirements.txt
```

## Set Current Working Directory

In [1]:
import os


# Find the directory containing pyproject.toml by searching upwards from the current directory
def find_project_root(start_dir=os.getcwd()):
    current_dir = start_dir
    while True:
        if "pyproject.toml" in os.listdir(current_dir):
            return current_dir
        parent_dir = os.path.dirname(current_dir)
        if parent_dir == current_dir:  # Reached the root directory
            raise FileNotFoundError("Could not find pyproject.toml in any parent directory.")
        current_dir = parent_dir


# Set the working directory to the project root
PROJECT_ROOT = find_project_root()
os.chdir(PROJECT_ROOT)

# Print the current working directory for confirmation
print("Current working directory:", os.getcwd())

Current working directory: /Users/giladrubin/python_workspace/modular-rag


# Config

In [2]:
from hypster import HP, config

In [3]:
@config
def llm_config(hp: HP):
    anthropic_models = {"haiku": "claude-3-haiku-20240307", "sonnet": "claude-3-5-sonnet-20240620"}
    openai_models = {"gpt-4o-mini": "gpt-4o-mini", "gpt-4o": "gpt-4o", "gpt-4o-latest": "gpt-4o-2024-08-06"}
    model_options = {**anthropic_models, **openai_models}

    model = hp.select(model_options, default="gpt-4o-mini")
    temperature = hp.number_input(0.0)

    if model in openai_models.values():
        from haystack.components.generators import OpenAIGenerator

        llm = OpenAIGenerator(model=model, generation_kwargs={"temperature": temperature})
    else:
        from haystack_integrations.components.generators.anthropic import AnthropicGenerator

        llm = AnthropicGenerator(model=model, generation_kwargs={"temperature": temperature})


llm_config.save("configs/llm.py")

In [4]:
from hypster import HP, config


@config
def indexing_config(hp: HP):
    from haystack import Pipeline
    from haystack.components.converters import PyPDFToDocument

    pipeline = Pipeline()
    pipeline.add_component("loader", PyPDFToDocument())

    enrich_doc_w_llm = hp.select([True, False], default=True)
    if enrich_doc_w_llm:
        from textwrap import dedent

        from haystack.components.builders import PromptBuilder

        from src.haystack_utils import AddLLMMetadata

        template = dedent("""
            Summarize the document's main topic in one sentence (15 words max). 
            Then list 3-5 keywords or acronyms that best \
            represent its content for search purposes.

            Context:
            {{ documents[0].content[:1000] }}
            
            ============================
            
            Output format:

            Summary:
            Keywords:
        """)

        llm = hp.propagate("configs/llm.py")
        pipeline.add_component("prompt_builder", PromptBuilder(template=template))
        pipeline.add_component("llm", llm["llm"])
        pipeline.add_component("document_enricher", AddLLMMetadata())
        pipeline.connect("loader", "prompt_builder")
        pipeline.connect("prompt_builder", "llm")
        pipeline.connect("llm", "document_enricher")
        pipeline.connect("loader", "document_enricher")

        splitter_source = "document_enricher"
    else:
        splitter_source = "loader"

    from haystack.components.preprocessors import DocumentSplitter

    split_by = hp.select(["sentence", "word", "passage", "page"], default="sentence")
    splitter = DocumentSplitter(split_by=split_by, split_length=hp.int_input(10), split_overlap=hp.int_input(2))

    pipeline.add_component("splitter", splitter)
    pipeline.connect(splitter_source, "splitter")


indexing_config.save("configs/indexing.py")

In [5]:
@config
def fast_embed(hp: HP):
    from typing import Any, Dict, List

    from fastembed import TextEmbedding

    def get_model_dim(chosen_model: str, model_list: List[Dict[str, Any]]) -> int:
        for model in model_list:
            if model["model"] == chosen_model:
                return model["dim"]
        raise ValueError(f"Model {chosen_model} not found in the list of supported models.")

    from haystack_integrations.components.embedders.fastembed import (
        FastembedDocumentEmbedder,
        FastembedTextEmbedder,
    )

    meta_fileds_to_embed = ["parent_doc_summary"]

    model = hp.select(
        {"bge-small": "BAAI/bge-small-en-v1.5", "mini-lm": "sentence-transformers/all-MiniLM-L6-v2"},
        default="mini-lm",
    )
    import os

    cpu_count = os.cpu_count() or 1
    doc_embedder = FastembedDocumentEmbedder(
        model=model,
        parallel=hp.int_input(cpu_count),
        meta_fields_to_embed=meta_fileds_to_embed,
    )
    text_embedder = FastembedTextEmbedder(model=model)
    embedding_dim = get_model_dim(model, TextEmbedding.list_supported_models())


fast_embed.save("configs/fast_embed.py")

In [6]:
@config
def jina_embed(hp: HP):
    from haystack_integrations.components.embedders.jina import JinaDocumentEmbedder, JinaTextEmbedder

    meta_fileds_to_embed = ["parent_doc_summary"]

    model = hp.select({"v3": "jina-embeddings-v3", "v2": "jina-embeddings-v2"}, default="v3")
    late_chunking = hp.select([True, False], default=True, name="late_chunking") if model == "v3" else False
    doc_embedder = JinaDocumentEmbedder(
        model=model,
        batch_size=hp.int_input(16),
        dimensions=hp.int_input(256),
        task="retrieval.passage",
        late_chunking=late_chunking,
        meta_fields_to_embed=meta_fileds_to_embed,
    )
    text_embedder = JinaTextEmbedder(model=model, dimensions=doc_embedder.dimensions, task="retrieval.query")
    embedding_dim = doc_embedder.dimensions


jina_embed.save("configs/jina_embed.py")

In [7]:
@config
def in_memory_retrieval(hp: HP):
    from haystack import Pipeline
    from haystack.document_stores.in_memory import InMemoryDocumentStore

    from src.haystack_utils import PassThroughDocuments, PassThroughText

    pipeline = Pipeline()
    # utility components for the first and last parts of the pipline
    pipeline.add_component("query", PassThroughText())
    pipeline.add_component("retrieved_documents", PassThroughDocuments())

    retrieval_types = hp.multi_select(["bm25", "embeddings"], default=["bm25", "embeddings"])
    if len(retrieval_types) == 0:
        raise ValueError("At least one retrieval type must be selected.")

    document_store = InMemoryDocumentStore()
    if "embeddings" in retrieval_types:
        from haystack.components.retrievers.in_memory import InMemoryEmbeddingRetriever

        embedding_similarity_function = hp.select(["cosine", "dot_product"], default="cosine")
        document_store.embedding_similarity_function = embedding_similarity_function
        pipeline.add_component("embedding_retriever", InMemoryEmbeddingRetriever(document_store=document_store))

    if "bm25" in retrieval_types:
        from haystack.components.retrievers.in_memory import InMemoryBM25Retriever

        bm25_algorithm = hp.select(["BM25Okapi", "BM25L", "BM25Plus"], default="BM25L")
        document_store.bm25_algorithm = bm25_algorithm
        pipeline.add_component("bm25_retriever", InMemoryBM25Retriever(document_store=document_store))
        pipeline.connect("query", "bm25_retriever")

    if len(retrieval_types) == 2:  # both bm25 and embeddings
        from haystack.components.joiners.document_joiner import DocumentJoiner

        join_mode = hp.select(
            ["distribution_based_rank_fusion", "concatenate", "merge", "reciprocal_rank_fusion"],
            default="distribution_based_rank_fusion",
        )
        bm25_weight = hp.number_input(0.5)
        joiner = DocumentJoiner(join_mode=join_mode, top_k=hp.int_input(10), weights=[bm25_weight, 1 - bm25_weight])

        pipeline.add_component("document_joiner", joiner)
        pipeline.connect("bm25_retriever", "document_joiner")
        pipeline.connect("embedding_retriever", "document_joiner")
        pipeline.connect("document_joiner", "retrieved_documents")
    elif "embeddings" in retrieval_types:
        pipeline.connect("embedding_retriever", "retrieved_documents")
    else:  # only bm25
        pipeline.connect("bm25_retriever", "retrieved_documents")


in_memory_retrieval.save("configs/in_memory_retrieval.py")

In [8]:
@config
def qdrant_retrieval(hp: HP):
    from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever
    from haystack_integrations.document_stores.qdrant import QdrantDocumentStore

    # location = hp.text_input(":memory:")
    embedding_similarity_function = hp.select(["cosine", "dot_product", "l2"], default="cosine")

    document_store = QdrantDocumentStore(
        # location=location,
        recreate_index=True,
        similarity=embedding_similarity_function,
        embedding_dim=hp.int_input(256, name="embedding_dim"),
        on_disk=True,
        path="qdrant/data",
    )

    embedding_retriever = QdrantEmbeddingRetriever(document_store=document_store, top_k=hp.int_input(20))

    from haystack import Pipeline

    from src.haystack_utils import PassThroughDocuments, PassThroughText

    pipeline = Pipeline()
    pipeline.add_component("query", PassThroughText())
    pipeline.add_component("embedding_retriever", embedding_retriever)
    pipeline.add_component("retrieved_documents", PassThroughDocuments())
    pipeline.connect("embedding_retriever", "retrieved_documents")


qdrant_retrieval.save("configs/qdrant_retrieval.py")

In [9]:
@config
def reranker(hp: HP):
    jina_models = {
        "reranker-v2": "jina-reranker-v2-base-multilingual",
        "colbert-v2": "jina-colbert-v2",
        "reranker-v1": "jina-reranker-v1-base-en",
    }

    transformers_models = {
        "tiny-bert-v2": "cross-encoder/ms-marco-TinyBERT-L-2-v2",
        "minilm-v2": "cross-encoder/ms-marco-MiniLM-L-2-v2",
    }

    model = hp.select({**jina_models, **transformers_models}, default="reranker-v2")
    top_k = hp.int_input(3)
    if model in jina_models.values():
        from haystack_integrations.components.rankers.jina import JinaRanker

        reranker = JinaRanker(model=model, top_k=top_k)
    else:
        from haystack.components.rankers import TransformersSimilarityRanker

        reranker = TransformersSimilarityRanker(model=model, top_k=top_k)


reranker.save("configs/reranker.py")

In [10]:
@config
def response_config(hp: HP):
    from textwrap import dedent

    llm = hp.propagate("configs/llm.py")
    llm = llm["llm"]

    template = dedent("""\
    Given the following information,
    answer the question concisely in one to two sentences,
    using only the relevant details provided in the documents.
    Support your answer with a brief, word-for-word quote from the most pertinent document. 
    Note that some documents may not be relevant to the question.
    ========================================
    Context:
    {% for document in documents %}
    Document {{loop.index}}:
    {{ document.meta.llm_extracted_info }}
    {{ document.content }}
    ---
    {% endfor %}
    ========================================
    Question: {{query}}

    Answer:

    Supporting Quote:
    """)


response_config.save("configs/response.py")

In [11]:
@config
def modular_rag(hp: HP):
    indexing = hp.propagate("configs/indexing.py")
    indexing_pipeline = indexing["pipeline"]

    embedder_type = hp.select(["fastembed", "jina"], default="fastembed")
    match embedder_type:
        case "fastembed":
            embedder = hp.propagate("configs/fast_embed.py")
        case "jina":
            embedder = hp.propagate("configs/jina_embed.py")

    indexing_pipeline.add_component("doc_embedder", embedder["doc_embedder"])

    document_store_type = hp.select(["in_memory", "qdrant"], default="in_memory")
    match document_store_type:
        case "in_memory":
            retrieval = hp.propagate("configs/in_memory_retrieval.py")
        case "qdrant":
            retrieval = hp.propagate(
                "configs/qdrant_retrieval.py", overrides={"embedding_dim": embedder["embedding_dim"]}
            )

    from haystack.components.writers import DocumentWriter
    from haystack.document_stores.types import DuplicatePolicy

    document_writer = DocumentWriter(retrieval["document_store"], policy=DuplicatePolicy.OVERWRITE)
    indexing_pipeline.add_component("document_writer", document_writer)

    indexing_pipeline.connect("splitter", "doc_embedder")
    indexing_pipeline.connect("doc_embedder", "document_writer")

    pipeline = retrieval["pipeline"]
    pipeline.add_component("text_embedder", embedder["text_embedder"])
    pipeline.connect("query", "text_embedder")
    pipeline.connect("text_embedder", "embedding_retriever.query_embedding")

    from src.haystack_utils import PassThroughDocuments

    pipeline.add_component("docs_for_generation", PassThroughDocuments())
    use_reranker = hp.select([True, False], default=True)
    if use_reranker:
        reranker = hp.propagate("configs/reranker.py")
        pipeline.add_component("reranker", reranker["reranker"])
        pipeline.connect("retrieved_documents", "reranker")
        pipeline.connect("reranker", "docs_for_generation")
        pipeline.connect("query", "reranker")
    else:
        pipeline.connect("retrieved_documents", "docs_for_generation")

    response = hp.propagate("configs/response.py")
    from haystack.components.builders import PromptBuilder

    pipeline.add_component("prompt_builder", PromptBuilder(template=response["template"]))
    pipeline.add_component("llm", response["llm"])
    pipeline.connect("prompt_builder", "llm")
    pipeline.connect("query.text", "prompt_builder.query")
    pipeline.connect("docs_for_generation", "prompt_builder")

modular_rag.save("configs/modular_rag.py")

In [13]:
results = modular_rag(
    selections={
        "indexing.enrich_doc_w_llm": True,
        "indexing.llm.model": "gpt-4o-mini",
        "document_store_type": "qdrant",
        "retrieval.bm25_weight": 0.8,
        "embedder_type": "fastembed",
        "reranker.model": "tiny-bert-v2",
        "response.llm.model": "haiku",
    },
    overrides={"indexing.splitter.split_length": 6, "reranker.top_k": 3},
)

In [14]:
indexing_pipeline = results["indexing_pipeline"]
indexing_pipeline.warm_up()

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

In [16]:
file_paths = ["data/raw/modular_rag.pdf"]
for file_path in file_paths:  # this can be parallelized
    indexing_pipeline.run({"loader": {"sources": [file_path]}})

In [17]:
query = "What are the 6 main modules of the modular RAG framework?"

pipeline = results["pipeline"]
pipeline.warm_up()
response = pipeline.run({"query": {"text": query}}, include_outputs_from=["prompt_builder", "docs_for_generation"])

Calculating embeddings: 100%|██████████| 1/1 [00:00<00:00, 49.26it/s]


In [18]:
print(response["llm"]["replies"][0])

According to Document 2, the six main modules of the modular RAG framework are: "Indexing, Pre-retrieval, Retrieval, Post-retrieval, Generation, and Orchestration."

The relevant quote is: "Based on the current stage of RAG development, we have established six main modules: Indexing, Pre-retrieval, Retrieval, Post-retrieval, Generation, and Orchestration."
