<a href="https://colab.research.google.com/github/zHazyl/ml-from-scratch/blob/main/custsom-agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install llama-index-agent-openai llama-index-llms-openai llama-index -q

In [3]:
import json
from typing import Sequence, List

from llama_index.llms.openai import OpenAI
from llama_index.core.llms import ChatMessage
from llama_index.core.tools import BaseTool, FunctionTool
from llama_index.agent.openai import OpenAIAgent

In [4]:
from google.colab import userdata
import os

os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')

In [5]:
def add(a: int, b:int) -> int:
  """Add two integers and returns the result integer"""
  return a + b

add_tool = FunctionTool.from_defaults(fn=add)

def useless_tool() -> int:
  """This is a useless tool."""
  return "This is a useless output"

useless_tool = FunctionTool.from_defaults(fn=useless_tool)

In [6]:
llm = OpenAI(model='gpt-3.5-turbo-0613')
agent = OpenAIAgent.from_tools([useless_tool, add_tool], llm=llm, verbose=True)

In [7]:
response = agent.chat(
    "What is 5 + 2", tool_choice="auto"
) # note function_call param is deprecated use tool_choice instead

Added user message to memory: What is 5 + 2
=== Calling Function ===
Calling function: add with args: {
  "a": 5,
  "b": 2
}
Got output: 7



In [8]:
response = agent.chat("What is 5 * 2?", tool_choice="useless_tool")

Added user message to memory: What is 5 * 2?
=== Calling Function ===
Calling function: useless_tool with args: {}
Got output: This is a useless output

=== Calling Function ===
Calling function: add with args: {
  "a": 5,
  "b": 2
}
Got output: 7



In [9]:
print(response)

The product of 5 and 2 is 10.


In [10]:
response = agent.chat("What is 5 * 2", tool_choice="none")

Added user message to memory: What is 5 * 2


In [11]:
print(response)

The product of 5 and 2 is 10.


In [None]:
%pip install llama-index-readers-wikipedia

In [18]:
from llama_index.core.agent import (
    CustomSimpleAgentWorker,
    Task,
    AgentChatResponse
)
from typing import Dict, Any, List, Tuple, Optional
from llama_index.core.tools import BaseTool, QueryEngineTool
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.output_parsers import PydanticOutputParser
from llama_index.core.query_engine import RouterQueryEngine
from llama_index.core import ChatPromptTemplate, PromptTemplate
from llama_index.core.selectors import PydanticSingleSelector
from llama_index.core.bridge.pydantic import Field, BaseModel

In [17]:
from llama_index.core.llms import ChatMessage, MessageRole

DEFAULT_PROMPT_STR = """
Given previous question/response pairs, please determine if an error has occured in the response, and suggest\
 a modified question that will not trigger the error.

Examples of modified questions:
- The question itself is modified to elicit a non-erroneous response
- The question is augmented with context that will help the downstream system better answer the queston.
- The question is augmented with examples of negative response, or other negative questions.

An error means that either an exception has triggered, or the response is completely irrelevant to the question.

Please return the evaluation of the response in the following JSON format
"""

def get_chat_prompt_template(
    system_prompt: str, current_reasoning: Tuple[str, str]
) -> ChatPromptTemplate:
  system_msg = ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)
  messages = [system_msg]
  for raw_msg in current_reasoning:
    if raw_msg[0] == "user":
      messages.append(
          ChatMessage(role=MessageRole.USER, content=raw_msg[1])
      )
    else:
      messages.append(
          ChatMessage(role=MessageRole.ASSISTANT, content=raw_msg[1])
      )
    return ChatPromptTemplate(message_templates=messages)

class ResponseEval(BaseModel):
  """Evaluation of whether the response has an error."""

  has_error: bool = Field(
      ..., description="Whether the response has an error."
  )
  new_question: str = Field(..., description="The suggested new question")
  explanation: str = Field(
      ...,
      description=(
          "The explanation for the error as well as for the question."
          "Can include the direct stack trace as well"
      )
  )

In [67]:
from llama_index.core.bridge.pydantic import PrivateAttr

class RetryAgentWorker(CustomSimpleAgentWorker):
  """Agent worker that adds a retry layer on top of a router.
  Continues iterating until there's no errors / task is done.
  """

  prompt_str: str = Field(default=DEFAULT_PROMPT_STR)
  max_iterations: int = Field(default=10)
  _router_query_engine: RouterQueryEngine = PrivateAttr()

  def __init__(self, tools: List[BaseTool], **kwargs: Any) -> None:
    """Init params."""
    # validate that all tools are query engine tools
    for tool in tools:
      if not isinstance(tool, QueryEngineTool):
        raise ValueError(
            f"Tool {tool.metadata.name} is not a query engine tool."
        )
    self._router_query_engine = RouterQueryEngine(
        selector=PydanticSingleSelector.from_defaults(),
        query_engine_tools=tools,
        verbose=kwargs.get("verbose", False),
    )
    super().__init__(
        tools=tools,
        **kwargs,
    )

  def _initialize_state(self, task: Task, **kwargs: Any) -> Dict[str, Any]:
    """Initialize state"""
    return {"count": 0, "current_reasoning": []}

  def _run_step(
      self, state: Dict[str, Any], task: Task, input: Optional[str] = None
  ) -> Tuple[AgentChatResponse, bool]:
    """Run step.
    Returns:
      Tuple of (agent_response, is_done)

    """
    if "new_input" not in state:
      new_input = task.input
    else:
      new_input = state["new_input"]

    # first run router query engine
    response = self._router_query_engine.query(new_input)

    # append to current reasoning
    state['current_reasoning'].extend(
        [("user", new_input), ("assistant_reasoning")]
    )

    # Then, check for errors
    # dynamically create pydantic program for structured output extraction based on template
    chat_prompt_tmpl = get_chat_prompt_template(
        self.prompt_str, state['current_reasoning']
    )
    llm_program = LLMTextCompletionProgram.from_defaults(
        output_parser=PydanticOutputParser(output_cls=ResponseEval),
        prompt=chat_prompt_tmpl,
        llm=self.llm,
    )
    # run program, look at the result
    response_eval = llm_program(
        query_str=new_input, response_str=str(response)
    )
    if not response_eval.has_error:
      is_done = True
    else:
      is_done = False
    state["new_input"] = response_eval.new_question

    if self.verbose:
      print(f"> Question: {new_input}")
      print(f"> Response: {response}")
      print(f"> Response eval: {response_eval.dict()}")

    # return response
    return AgentChatResponse(response=str(response)), is_done

  def _finalize_task(self, state: Dict[str, Any], **kwargs) -> None:
    """Finalize task."""
    # nothing to finalize here
    # this is usually if you want to modify any sort of
    # internal state beyond what is set in '_initialize_state'
    pass

In [23]:
from llama_index.core.tools import QueryEngineTool

In [51]:
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
    column
)
from llama_index.core import SQLDatabase

engine = create_engine("sqlite:///:memory", future=True)
metadata_obj = MetaData()
#crate city SQL table
table_name = 'city_stats'
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)

metadata_obj.create_all(engine)

In [None]:
from sqlalchemy import insert

rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {"city_name": "Berlin", "population": 3645000, "country": "Germany"},
]

for row in rows:
  stmt = insert(city_stats_table).values(**row)
  with engine.begin() as connection:
    cursor = connection.execute(stmt)

In [62]:
from llama_index.core.query_engine import NLSQLTableQueryEngine

sql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database, tables=["city_stats"], verbose=True
)
sql_tool = QueryEngineTool.from_defaults(
    query_engine=sql_query_engine,
    description=(
        "Useful for translating a natural language query into a SQL query over"
        " a table containing: city_stats, containing the population/country of"
        " each city"
    ),
)

In [63]:
from llama_index.readers.wikipedia import WikipediaReader
from llama_index.core import VectorStoreIndex

In [None]:
%pip install wikipedia

In [56]:
cities = ["Toronto", "Berlin", "Tokyo"]
wiki_docs = WikipediaReader().load_data(pages=cities)

In [64]:
# build a separate vector index per city
# You could also choose to define a single vector index across all docs, and annotate each chunk by metadata
vector_tools = []
for city, wiki_doc in zip(cities, wiki_docs):
  vector_index = VectorStoreIndex.from_documents([wiki_doc])
  vector_query_engine = vector_index.as_query_engine()
  vector_tool = QueryEngineTool.from_defaults(
      query_engine=vector_query_engine,
      description=f"Useful for answer semantic questions about {city}",
  )
  vector_tools.append(vector_tool)

In [58]:
from llama_index.core.agent import AgentRunner
from llama_index.llms.openai import OpenAI

In [68]:
llm = OpenAI(model="gpt-3.5-turbo")
callback_manager = llm.callback_manager

query_engine_tools = [sql_tool] + vector_tools
agent_worker = RetryAgentWorker.from_tools(
    query_engine_tools,
    llm=llm,
    verbose=True,
    callback_manager=callback_manager
)
agent = AgentRunner(agent_worker, callback_manager=callback_manager)

In [None]:
response = agent.chat("Which countries are each city from?")
print(str(response))


[1;3;38;5;200mSelecting query engine 0: The choice is relevant because it involves translating a natural language query into a SQL query over a table containing city_stats, which likely includes information about the country of each city..
[0m> Question: Which countries are each city from?
> Response: Toronto is from Canada, Berlin is from Germany, and Tokyo is from Japan.
> Response eval: {'has_error': True, 'new_question': 'Which countries are Paris, London, and New York City from?', 'explanation': 'The response is irrelevant because the question was asking about the countries for each city, but the response did not provide any information.'}
[1;3;38;5;200mSelecting query engine 0: The choice is useful for translating a natural language query into a SQL query over a table containing city_stats, which includes the population/country of each city. This choice can help identify the countries of Paris, London, and New York City based on the city_stats table..
[0m> Question: Which cou