# Name: RAG with Doc Agent

## Description: Retrieval Augmented Generation using a Doc agent. Based on <https://docs.ag2.ai/latest/docs/user-guide/reference-agents/docagent/#example>

## Tags: RAG, Doc Agent

###🧩 generated with ❤️ by Waldiez.

### Requirements

In [None]:
import sys  # pyright: ignore

# # !{sys.executable} -m pip install -q ag2[openai]==0.9.6 ag2[rag]==0.9.6 llama-index llama-index-core llama-index-llms-openai

### Imports

In [None]:
# pyright: reportUnusedImport=false,reportMissingTypeStubs=false
import csv
import importlib
import json
import os
import sqlite3
import sys
from dataclasses import asdict
from pprint import pprint
from types import ModuleType
from typing import (
    Annotated,
    Any,
    Callable,
    Coroutine,
    Dict,
    List,
    Optional,
    Set,
    Tuple,
    Union,
)

import autogen  # type: ignore
from autogen import (
    Agent,
    Cache,
    ChatResult,
    ConversableAgent,
    GroupChat,
    UserProxyAgent,
    runtime_logging,
)
from autogen.agents.experimental import DocAgent
from autogen.agents.experimental.document_agent.chroma_query_engine import (
    VectorChromaQueryEngine,
)
from autogen.events import BaseEvent
from autogen.io.run_response import AsyncRunResponseProtocol, RunResponseProtocol
import numpy as np
from llama_index.llms.openai import OpenAI

#
# let's try to avoid:
# module 'numpy' has no attribute '_no_nep50_warning'"
# ref: https://github.com/numpy/numpy/blob/v2.2.2/doc/source/release/2.2.0-notes.rst#nep-50-promotion-state-option-removed
os.environ["NEP50_DEPRECATION_WARNING"] = "0"
os.environ["NEP50_DISABLE_WARNING"] = "1"
os.environ["NPY_PROMOTION_STATE"] = "weak"
if not hasattr(np, "_no_pep50_warning"):

    import contextlib
    from typing import Generator

    @contextlib.contextmanager
    def _np_no_nep50_warning() -> Generator[None, None, None]:
        """Dummy function to avoid the warning.

        Yields
        ------
        None
            Nothing.
        """
        yield

    setattr(np, "_no_pep50_warning", _np_no_nep50_warning)  # noqa

### Load model API keys

In [None]:
# NOTE:
# This section assumes that a file named "rag_with_doc_agent_api_keys"
# exists in the same directory as this file.
# This file contains the API keys for the models used in this flow.
# It should be .gitignored and not shared publicly.
# If this file is not present, you can either create it manually
# or change the way API keys are loaded in the flow.


def load_api_key_module(flow_name: str) -> ModuleType:
    """Load the api key module.

    Parameters
    ----------
    flow_name : str
        The flow name.

    Returns
    -------
    ModuleType
        The api keys loading module.
    """
    module_name = f"{flow_name}_api_keys"
    if module_name in sys.modules:
        return importlib.reload(sys.modules[module_name])
    return importlib.import_module(module_name)


__MODELS_MODULE__ = load_api_key_module("rag_with_doc_agent")


def get_rag_with_doc_agent_model_api_key(model_name: str) -> str:
    """Get the model api key.
    Parameters
    ----------
    model_name : str
        The model name.

    Returns
    -------
    str
        The model api key.
    """
    return __MODELS_MODULE__.get_rag_with_doc_agent_model_api_key(model_name)

### Models

In [None]:
gpt_4o_llm_config: dict[str, Any] = {
    "model": "gpt-4o",
    "api_type": "openai",
    "api_key": get_rag_with_doc_agent_model_api_key("gpt_4o"),
}

### Agents

In [None]:
# pyright: reportUnnecessaryIsInstance=false

doc_agent_query_engine = VectorChromaQueryEngine(
    llm=OpenAI(model="gpt-4o", temperature=0.0),
    db_path="chroma",
    collection_name="financial_report",
)

doc_agent = DocAgent(
    name="doc_agent",
    parsed_docs_path="parsed_docs",
    query_engine=doc_agent_query_engine,
    llm_config=autogen.LLMConfig(
        config_list=[
            gpt_4o_llm_config,
        ],
        cache_seed=42,
    ),
)

user = UserProxyAgent(
    name="user",
    description="A new User agent",
    human_input_mode="ALWAYS",
    max_consecutive_auto_reply=None,
    default_auto_reply="",
    code_execution_config=False,
    is_termination_msg=None,  # pyright: ignore
    llm_config=False,  # pyright: ignore
)


def get_sqlite_out(dbname: str, table: str, csv_file: str) -> None:
    """Convert a sqlite table to csv and json files.

    Parameters
    ----------
    dbname : str
        The sqlite database name.
    table : str
        The table name.
    csv_file : str
        The csv file name.
    """
    conn = sqlite3.connect(dbname)
    query = f"SELECT * FROM {table}"  # nosec
    try:
        cursor = conn.execute(query)
    except sqlite3.OperationalError:
        conn.close()
        return
    rows = cursor.fetchall()
    column_names = [description[0] for description in cursor.description]
    data = [dict(zip(column_names, row)) for row in rows]
    conn.close()
    with open(csv_file, "w", newline="", encoding="utf-8") as file:
        csv_writer = csv.DictWriter(file, fieldnames=column_names)
        csv_writer.writeheader()
        csv_writer.writerows(data)
    json_file = csv_file.replace(".csv", ".json")
    with open(json_file, "w", encoding="utf-8") as file:
        json.dump(data, file, indent=4, ensure_ascii=False)

### Start chatting

In [None]:
def main(on_event: Optional[Callable[[BaseEvent], bool]] = None) -> RunResponseProtocol:
    """Start chatting.

    Returns
    -------
    RunResponseProtocol
        The result of the chat session, which can be a single ChatResult,
        a list of ChatResults, or a dictionary mapping integers to ChatResults.

    Raises
    ------
    RuntimeError
        If the chat session fails.
    """
    with Cache.disk(cache_seed=42) as cache:  # pyright: ignore
        results = user.run(
            doc_agent,
            cache=cache,
            summary_method="last_msg",
            max_turns=2,
            clear_history=True,
            message='Can you ingest "https://raw.githubusercontent.com/ag2ai/ag2/refs/heads/main/test/agentchat/contrib/graph_rag/Toast_financial_report.pdf"  and tell me the fiscal year 2024 financial summary?" ',
        )
        if on_event:
            if not isinstance(results, list):
                results = [results]
            for index, result in enumerate(results):
                for event in result.events:
                    try:
                        should_continue = on_event(event)
                    except Exception as e:
                        raise RuntimeError("Error in event handler: " + str(e)) from e
                    if event.type == "run_completion":
                        should_continue = False
                    if not should_continue:
                        break
        else:
            if not isinstance(results, list):
                results = [results]
            for result in results:
                result.process()

    return results

In [None]:
main()