<a href="https://colab.research.google.com/github/jerryjliu/llama_index/blob/main/docs/docs/examples/workflow/multi_step_query_engine.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MultiStep Query Engine

Implementation of [MultiStepQueryEngine](https://docs.llamaindex.ai/en/stable/examples/query_transformations/SimpleIndexDemo-multistep/) using workflows.

In [None]:
!pip install -U llama-index

In [None]:
import os

os.environ["OPENAI_API_KEY"] = "sk-..."

Since workflows are async first, this all runs fine in a notebook. If you were running in your own code, you would want to use `asyncio.run()` to start an async event loop if one isn't already running.

```python
async def main():
    <async code>

if __name__ == "__main__":
    import asyncio
    asyncio.run(main())
```

## The Workflow

The steps will use the built-in `StartEvent` and `StopEvent` events.

## Define Event

In [None]:
from llama_index.core.workflow import Event
from typing import Dict, List, Any
from llama_index.core.schema import NodeWithScore


class QueryMultiStepEvent(Event):
    """
    Event containing results of JSON analysis.

    Attributes:
        sql_query (str): The generated SQL query.
        table_schema (Dict[str, Any]): Schema of the analyzed table.
        results (List[Dict[str, Any]]): Query execution results.
    """

    nodes: List[NodeWithScore]
    source_nodes: List[NodeWithScore]
    final_response_metadata: Dict[str, Any]

In [None]:
from llama_index.core.indices.query.query_transform.base import (
    StepDecomposeQueryTransform,
)
from llama_index.core.response_synthesizers import (
    get_response_synthesizer,
)

from llama_index.core.schema import QueryBundle, TextNode

from llama_index.core.workflow import (
    Context,
    Workflow,
    StartEvent,
    StopEvent,
    step,
)

from llama_index.llms.openai import OpenAI
from llama_index.core.llms import LLM

from IPython.display import Markdown, display


class MultiStepQueryEngineWorkflow(Workflow):
    def combine_queries(
        self,
        query_bundle: QueryBundle,
        prev_reasoning: str,
        index_summary: str,
        llm: LLM,
    ) -> QueryBundle:
        """Combine queries."""
        transform_metadata = {
            "prev_reasoning": prev_reasoning,
            "index_summary": index_summary,
        }
        return StepDecomposeQueryTransform(llm=llm)(
            query_bundle, metadata=transform_metadata
        )

    @step(pass_context=True)
    async def _query_multistep(
        self, ctx: Context, ev: StartEvent
    ) -> QueryMultiStepEvent:
        """Run query combiner."""
        prev_reasoning = ""
        cur_response = None
        should_stop = False
        cur_steps = 0

        # use response
        final_response_metadata: Dict[str, Any] = {"sub_qa": []}

        text_chunks = []
        source_nodes = []

        query = ev.get("query")
        ctx.data["query"] = ev.get("query")

        num_steps = ev.get("num_steps")
        llm = ev.get("llm")
        stop_fn = ev.get("stop_fn")
        query_engine = ev.get("query_engine")
        index_summary = ev.get("index_summary")

        while not should_stop:
            if num_steps is not None and cur_steps >= num_steps:
                should_stop = True
                break
            elif should_stop:
                break

            updated_query_bundle = self.combine_queries(
                QueryBundle(query_str=query),
                prev_reasoning,
                index_summary,
                llm,
            )

            # TODO: make stop logic better
            stop_dict = {"query_bundle": updated_query_bundle}
            if stop_fn(stop_dict):
                should_stop = True
                break

            cur_response = query_engine.query(updated_query_bundle)

            # append to response builder
            cur_qa_text = (
                f"\nQuestion: {updated_query_bundle.query_str}\n"
                f"Answer: {cur_response!s}"
            )
            text_chunks.append(cur_qa_text)
            for source_node in cur_response.source_nodes:
                source_nodes.append(source_node)
            # update metadata
            final_response_metadata["sub_qa"].append(
                (updated_query_bundle.query_str, cur_response)
            )

            prev_reasoning += (
                f"- {updated_query_bundle.query_str}\n" f"- {cur_response!s}\n"
            )
            cur_steps += 1

        nodes = [
            NodeWithScore(node=TextNode(text=text_chunk))
            for text_chunk in text_chunks
        ]
        return QueryMultiStepEvent(
            nodes=nodes,
            source_nodes=source_nodes,
            final_response_metadata=final_response_metadata,
        )

    @step(pass_context=True)
    async def synthesize(
        self, ctx: Context, ev: QueryMultiStepEvent
    ) -> StopEvent:
        """Synthesize the response."""
        response_synthesizer = get_response_synthesizer()
        final_response = await response_synthesizer.asynthesize(
            query=ctx.data.get("query"),
            nodes=ev.nodes,
            additional_source_nodes=ev.source_nodes,
        )
        final_response.metadata = ev.final_response_metadata

        return StopEvent(result=final_response)

## Download Data

In [None]:
!mkdir -p 'data/paul_graham/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'

--2024-08-16 19:51:26--  https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8002::154, 2606:50c0:8000::154, 2606:50c0:8003::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8002::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 75042 (73K) [text/plain]
Saving to: ‘data/paul_graham/paul_graham_essay.txt’


2024-08-16 19:51:26 (329 KB/s) - ‘data/paul_graham/paul_graham_essay.txt’ saved [75042/75042]



In [None]:
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, Settings
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding

llm = OpenAI(model="gpt-4")

Settings.llm = llm

documents = SimpleDirectoryReader("data/paul_graham").load_data()
index = VectorStoreIndex.from_documents(
    documents=documents,
)

### Run the Workflow!

In [None]:
w = MultiStepQueryEngineWorkflow(timeout=200)

In [None]:
from typing import cast


def default_stop_fn(stop_dict: Dict) -> bool:
    """Stop function for multi-step query combiner."""
    query_bundle = cast(QueryBundle, stop_dict.get("query_bundle"))
    if query_bundle is None:
        raise ValueError("Response must be provided to stop function.")

    return "none" in query_bundle.query_str.lower()

In [None]:
query_engine = index.as_query_engine(llm=llm)
index_summary = "Used to answer questions about the author"
query = "In which city did the author found his first company, Viaweb?"
num_steps = 3
stop_fn = default_stop_fn

In [None]:
# Run a query

result = await w.run(
    query=query,
    query_engine=query_engine,
    llm=llm,
    index_summary=index_summary,
    num_steps=num_steps,
    stop_fn=stop_fn,
)

display(
    Markdown("> Question: {}".format(query)),
    Markdown("Answer: {}".format(result)),
)

> Question: In which city did the author found his first company, Viaweb?

Answer: The author founded his first company, Viaweb, in Cambridge.

In [None]:
sub_qa = result.metadata["sub_qa"]
tuples = [(t[0], t[1].response) for t in sub_qa]
display(Markdown(f"{tuples}"))

[('Who is the author who founded Viaweb?', 'The author who founded Viaweb is Paul Graham.'), ('In which city did Paul Graham found his first company, Viaweb?', 'Paul Graham founded his first company, Viaweb, in Cambridge.')]