# Multimodal Report Generation Agent 

<a href="https://colab.research.google.com/github/run-llama/llama_parse/blob/main/examples/multimodal/multimodal_report_generation_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this cookbook we show you how to build a multimodal report generation agent from a bank of research reports. We use the a set of ICLR papers (which were also used as the dataset in our [DeepLearning.ai course](https://www.deeplearning.ai/short-courses/building-agentic-rag-with-llamaindex/?utm_campaign=llamaindexC2-launch&utm_medium=headband&utm_source=dlai-homepage).

We use our workflow abstraction to define an agentic system that first performs research to pull in the relevant files, and then surfaces it.

## Setup

In [None]:
import nest_asyncio

nest_asyncio.apply()

### Setup Observability

We setup an integration with LlamaTrace (integration with Arize).

If you haven't already done so, make sure to create an account here: https://llamatrace.com/login. Then create an API key and put it in the `PHOENIX_API_KEY` variable below.

In [None]:
!pip install -U llama-index-callbacks-arize-phoenix

In [None]:
# setup Arize Phoenix for logging/observability
import llama_index.core
import os

PHOENIX_API_KEY = "<PHOENIX_API_KEY>"
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"api_key={PHOENIX_API_KEY}"
llama_index.core.set_global_handler(
    "arize_phoenix", endpoint="https://llamatrace.com/v1/traces"
)

### Model Setup

Setup models that will be used for downstream orchestration.

In [None]:
from llama_index.core import Settings
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding

embed_model = OpenAIEmbedding(model="text-embedding-3-large")
llm = OpenAI(model="gpt-4o")

Settings.embed_model = embed_model
Settings.llm = llm

## Load, Parse, and Index Research Papers

Here we load 11 popular ICLR 2024 papers, and then we parse through LlamaParse.

**NOTE**: this may be slow. To save your results, run the cell to save all your outputs to JSON, so you can reload instead of having to re-parse.

In [None]:
urls = [
    "https://openreview.net/pdf?id=VtmBAGCN7o",
    "https://openreview.net/pdf?id=6PmJoRfdaK",
    "https://openreview.net/pdf?id=LzPWWPAdY4",
    "https://openreview.net/pdf?id=VTF8yNQM66",
    "https://openreview.net/pdf?id=hSyW5go0v8",
    "https://openreview.net/pdf?id=9WD9KwssyT",
    "https://openreview.net/pdf?id=yV6fD7LYkF",
    "https://openreview.net/pdf?id=hnrB5YHoYu",
    "https://openreview.net/pdf?id=WbWtOYIzIK",
    "https://openreview.net/pdf?id=c5pwL0Soay",
    "https://openreview.net/pdf?id=TpD2aG1h0D",
]

papers = [
    "metagpt.pdf",
    "longlora.pdf",
    "loftq.pdf",
    "swebench.pdf",
    "selfrag.pdf",
    "zipformer.pdf",
    "values.pdf",
    "finetune_fair_diffusion.pdf",
    "knowledge_card.pdf",
    "metra.pdf",
    "vr_mcl.pdf",
]

data_dir = "iclr_docs"

In [None]:
!mkdir "{data_dir}"
for url, paper in zip(urls, papers):
    !wget "{url}" -O "{data_dir}/{paper}"

In [None]:
from llama_parse import LlamaParse

parser = LlamaParse(
    result_type="markdown",
    use_vendor_multimodal_model=True,
    vendor_multimodal_model_name="anthropic-sonnet-3.5",
)

In [None]:
from pathlib import Path

# delete and recreate the image directory if not already created
out_image_dir = "out_iclr_images"
!rm -rf "{out_image_dir}"
!mkdir "{out_image_dir}"


paper_dicts = {}

for paper_path in papers:
    paper_base = Path(paper_path).basename
    full_paper_path = str(Path(data_dir) / paper_path)
    print(paper_base)
    raise Exception
    md_json_objs = parser.get_json_result(full_paper_path)
    json_dicts = md_json_objs[0]["pages"]

    image_path = str(Path(out_image_dir) / paper_base)
    image_dicts = parser.get_images(md_json_objs, download_path=image_path)
    paper_dicts[paper_path] = {
        "paper_path": full_paper_path,
        "json_dicts": json_dicts,
        "image_path": image_path,
    }

#### Get Text Nodes

Convert the dictionary above into TextNode objects that we can put into a vector store.

In [None]:
from llama_index.core.schema import TextNode
from typing import Optional

In [None]:
# NOTE: these are utility functions to sort the dumped images by the page number
# (they are formatted like "{uuid}-{page_num}.jpg"
import re


def get_page_number(file_name):
    match = re.search(r"-page-(\d+)\.jpg$", str(file_name))
    if match:
        return int(match.group(1))
    return 0


def _get_sorted_image_files(image_dir):
    """Get image files sorted by page."""
    raw_files = [f for f in list(Path(image_dir).iterdir()) if f.is_file()]
    sorted_files = sorted(raw_files, key=get_page_number)
    return sorted_files

In [None]:
from copy import deepcopy
from pathlib import Path


# attach image metadata to the text nodes
def get_text_nodes(json_dicts, paper_path, image_dir=None):
    """Split docs into nodes, by separator."""
    nodes = []

    image_files = _get_sorted_image_files(image_dir) if image_dir is not None else None
    md_texts = [d["md"] for d in json_dicts]

    for idx, md_text in enumerate(md_texts):
        chunk_metadata = {
            "page_num": idx + 1,
            "parsed_text_markdown": md_text,
            "paper_path": paper_path,
        }
        if image_files is not None:
            image_file = image_files[idx]
            chunk_metadata["image_path"] = str(image_file)
        chunk_metadata["parsed_text_markdown"] = md_text
        node = TextNode(
            text="",
            metadata=chunk_metadata,
        )
        nodes.append(node)

    return nodes

In [None]:
# this will combine all nodes from all papers into a single list
all_text_nodes = []
text_nodes_dict = {}
for paper_path, paper_dict in paper_dicts.items():
    json_dicts = paper_dict["json_dicts"]
    text_nodes = get_text_nodes(
        json_dicts, paper_dict["paper_path"], image_dir=paper_dict["image_path"]
    )
    all_text_nodes.extend(text_nodes)
    text_nodes_dict[paper_path] = text_nodes

In [None]:
print(all_text_nodes[10].get_content(metadata_mode="all"))

page_num: 11
image_path: data_images/412ac275-abe2-4585-be43-5680e7754740-page-10.jpg
parsed_text_markdown: # Commitment to Disciplined Reinvestment Rate

Disciplined Reinvestment Rate is the Foundation for Superior Returns on and of Capital, while Driving Durable CFO Growth

| Metric | Value |
|--------|-------|
| 10-Year Reinvestment Rate | ~50% |
| CFO CAGR 2024-2032 | ~6% |
| Mid-Cycle Planning Price | at $60/BBL WTI |

| Period | Industry Growth Focus | ConocoPhillips Strategy Reset | Reinvestment Rate |
|--------|------------------------|-------------------------------|-------------------|
| 2012-2016 | >100% Reinvestment Rate | - | ~$75/BBL WTI Average |
| 2017-2022 | - | <60% Reinvestment Rate | ~$63/BBL WTI Average |
| 2023E | - | - | at $80/BBL WTI |
| 2024-2028 | - | - | at $60/BBL WTI (with $80/BBL WTI option shown) |
| 2029-2032 | - | - | at $60/BBL WTI (with $80/BBL WTI option shown) |

*Chart shows ConocoPhillips Average Annual Reinvestment Rate (%) over time, with histo

### Build Indexes

Once the text nodes are ready, we feed into our vector store index abstraction, which will index these nodes into a simple in-memory vector store (of course, you should definitely check out our 40+ vector store integrations!)

Besides vector indexing, we **also** store a mapping of paper path to the summary index 

In [None]:
import os
from llama_index.core import (
    StorageContext,
    SummaryIndex,
    VectorStoreIndex,
    load_index_from_storage,
)

# Vector Indexing
if not os.path.exists("storage_nodes_papers"):
    index = VectorStoreIndex(text_nodes)
    # save index to disk
    index.set_index_id("vector_index")
    index.storage_context.persist("./storage_nodes_papers")
else:
    # rebuild storage context
    storage_context = StorageContext.from_defaults(persist_dir="storage_nodes_papers")
    # load index
    index = load_index_from_storage(storage_context, index_id="vector_index")

    
# Summary Index dictionary - store map from paper path to a summary index around it
paper_summary_indexes = {
    paper_path: SummaryIndex(text_nodes_dict[paper_path])
    for paper_path in papers
}`

## Define Tools

We define two tools for the downstream agent: a chunk-level retriever tool and a document-retrieval tool.

In [None]:
from llama_index.core.tools import FunctionTool
from llama_index.core.schema import NodeWithScore

# function tools
def chunk_retriever_fn(query: str) -> List[NodeWithScore]:
    """Retrieves a small set of relevant document chunks from the corpus.
    
    ONLY use for research questions that want to look up specific facts from the knowledge corpus,
    and don't need entire documents.
    
    """
    retriever = index.as_retriever(similarity_top_k=5)
    nodes = retriever.retrieve(query)
    return nodes



def _get_document_nodes(nodes: List[NodeWithScore], top_n: int = 2) -> List[NodeWithScore]:
    """Get document nodes from a set of chunk nodes.
    
    Given chunk nodes, "de-reference" into a set of documents, with a simple weighting function (cumulative total) to determine ordering.
    
    Cutoff by top_n.
    
    """
    paper_paths = {n.metadata["paper_path"] for n in nodes}
    paper_path_scores = {f: 0 for f in file_paths}
    for n in nodes:
        paper_path_scores[n.metadata["paper_path"]] += n.score
        
    # Sort paper_path_scores by score in descending order
    sorted_paper_paths = sorted(paper_path_scores.items(), key=itemgetter(1), reverse=True)
    # Take top_n paper paths
    top_paper_paths = [path for path, score in sorted_paper_paths[:top_n]]
    
    # use summary index to get nodes from all paper paths
    all_nodes = []
    for paper_path in top_paper_paths:
        # NOTE: input to retriever can be blank 
        all_nodes.extend(paper_summary_indexes[paper_path].as_retriever().retrieve(""))
        
    return all_nodes

def doc_retriever_fn(query: str) -> float:
    """Document retriever that retrieves entire documents from the corpus.
    
    ONLY use for research questions that may require searching over entire research reports.
    
    Will be slower and more expensive than chunk-level retrieval but may be necessary.
    """
    retriever = index.as_retriever(similarity_top_k=5)
    nodes = retriever.retrieve(query)
    return _get_document_nodes(nodes)


chunk_retriever_tool = FunctionTool.from_defaults(fn=chunk_retriever_fn)
doc_retriever_tool = FunctionTool.from_defaults(fn=doc_retriever_fn)

## Build Workflow 

Now that we've built the index, we're ready to build the report generation workflow. 

The workflow contains roughly the following steps: 

1. **Research Gathering**: Perform a function calling loop where the agent tries to reason about what tool to call (chunk-level or document-level retrieval) in order to gather more information. All information is shared to a dictionary that is propagated throughout each step. The tools return an indication of the type of information returned to the agent. After the agent feels like it's gathered enough information, move on to the next phase.


2. **Report Generation**: Generate a research report given the pooled research. For now, try to stuff as much information into the context window through the summary index.


This implementation is inspired by our [Function Calling Agent workflow](https://docs.llamaindex.ai/en/latest/examples/workflow/function_calling_agent/) implementation.

In [None]:
from llama_index.core.workflow import Workflow

from typing import Any, List

from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools.types import BaseTool
from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import CompactAndRefine
from llama_index.core.workflow import Event


class ChunkRetrievalEvent(Event):
    query: str
    
    
class DocRetrievalEvent(Event):
    query: str
    

class ReportGenerationEvent(Event):
    input: str


class ReportGenerationAgent(Workflow):
    """Report generation agent."""

    def __init__(
        self,
        *args: Any,
        llm: FunctionCallingLLM | None = None,
        tools: List[BaseTool] | None = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.tools = tools or []

        self.llm = llm or OpenAI()
        self.summarizer = CompactAndRefine(llm=self.llm)
        assert self.llm.metadata.is_function_calling_model

        self.memory = ChatMemoryBuffer.from_defaults(llm=llm)
        self.sources = []

    @step()
    async def prepare_chat_history(self, ev: StartEvent) -> InputEvent:
        # clear sources
        self.sources = []

        # get user input
        user_input = ev.input
        user_msg = ChatMessage(role="user", content=user_input)
        self.memory.put(user_msg)

        # get chat history
        chat_history = self.memory.get()
        return InputEvent(input=chat_history)

    @step(pass_context=True)
    async def handle_llm_input(
        self, ctx: Context, ev: InputEvent
    ) -> ChunkRetrievalEvent | DocRetrievalEvent | StopEvent:
        chat_history = ev.input

        response = await self.llm.achat_with_tools(
            self.tools, chat_history=chat_history
        )
        self.memory.put(response.message)

        tool_calls = self.llm.get_tool_calls_from_response(
            response, error_on_no_tool_call=False
        )
        if not tool_calls:
            # all the content should be stored in the context, so just pass along input
            return ReportGenerationEvent(input=ev.input)

        for tool_call in tool_calls:
            if tool_call.tool_name == "chunk_retrieval":
                return ChunkRetrievalEvent(query=tool_call.tool_kwargs["query_str"])
            elif tool_call.tool_name == "doc_retrieval":
                return DocRetrievalEvent(query=tool_call.tool_kwargs["query_str"])
            else:
                return StopEvent(result={"response": "Invalid tool."})

    @step(pass_context=True)
    async def handle_retrieval(
        self, ctx: Context, ev: ChunkRetrievalEvent | DocRetrievalEvent
    ) -> InputEvent:
        """Handle retrieval.

        Store retrieved chunks, and go back to agent reasoning loop.

        """
        if isinstance(ev, ChunkRetrievalEvent):
            retrieved_chunks = self.chunk_retriever_tool(ev.query).raw_output
        else:
            retrieved_chunks = self.doc_retriever_tool(ev.query).raw_output
        ctx.data["stored_chunks"].extend(retrieved_chunks)

        # synthesize an answer given the query to return to the LLM.
        response = self.summarizer.synthesize(ev.query, nodes=retrieved_chunks)
        self.memory.put(str(response))

        # send input event back with updated chat history
        return InputEvent(input=self.memory.get())

    @step(pass_context=True)
    async def generate_report(
        self, ctx: Context, ev: ReportGenerationEvent
    ) -> StopEvent:
        """Generate report."""
        # given all the context, generate query
        self.summarizer.synthesize(ev.query, nodes=ctx["stored_chunks"])