diff --git a/CHANGELOG.md b/CHANGELOG.md index e5499750a859d..cb4b7e6b8df3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## Unreleased ### New Features +- Added sources to agent/chat engine responses (#6854) - Added basic chat buffer memory to agents / chat engines (#6857) ## [v0.7.5] - 2023-07-11 diff --git a/docs/core_modules/query_modules/chat_engines/root.md b/docs/core_modules/query_modules/chat_engines/root.md index 4a621bf5000fc..26371396dff49 100644 --- a/docs/core_modules/query_modules/chat_engines/root.md +++ b/docs/core_modules/query_modules/chat_engines/root.md @@ -24,7 +24,8 @@ To stream response: ```python chat_engine = index.as_chat_engine() streaming_response = chat_engine.stream_chat("Tell me a joke.") -streaming_response.print_response_stream() +for token in streaming_response.response_gen: + print(token, end="") ``` diff --git a/docs/core_modules/query_modules/chat_engines/usage_pattern.md b/docs/core_modules/query_modules/chat_engines/usage_pattern.md index bc14799c7e849..a5124236ac8ff 100644 --- a/docs/core_modules/query_modules/chat_engines/usage_pattern.md +++ b/docs/core_modules/query_modules/chat_engines/usage_pattern.md @@ -99,7 +99,8 @@ This somewhat inconsistent with query engine (where you pass in a `streaming=Tru ```python chat_engine = index.as_chat_engine() streaming_response = chat_engine.stream_chat("Tell me a joke.") -streaming_response.print_response_stream() +for token in streaming_response.response_gen: + print(token, end="") ``` See an [end-to-end tutorial](/examples/customization/streaming/chat_engine_condense_question_stream_response.ipynb) diff --git a/docs/examples/agent/openai_agent.ipynb b/docs/examples/agent/openai_agent.ipynb index 6aaf7f224d9ea..43af4d45bd3a8 100644 --- a/docs/examples/agent/openai_agent.ipynb +++ b/docs/examples/agent/openai_agent.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "id": "99cea58c-48bc-4af6-8358-df9695659983", "metadata": { @@ -11,6 +12,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "673df1fe-eb6c-46ea-9a73-a96e7ae7942e", "metadata": { @@ -23,6 +25,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "54b7bc2e-606f-411a-9490-fcfab9236dfc", "metadata": { @@ -33,6 +36,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "23e80e5b-aaee-4f23-b338-7ae62b08141f", "metadata": {}, @@ -47,21 +51,12 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "9d47283b-025e-4874-88ed-76245b22f82e", "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/suo/miniconda3/envs/llama/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.7) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ "import json\n", "from typing import Sequence, List\n", @@ -75,6 +70,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "6fe08eb1-e638-4c00-9103-5c305bfacccf", "metadata": {}, @@ -84,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "3dd3c4a6-f3e0-46f9-ad3b-7ba57d1bc992", "metadata": { "tags": [] @@ -101,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "bfcfb78b-7d4f-48d9-8d4c-ffcded23e7ac", "metadata": { "tags": [] @@ -117,6 +113,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "fbcbd5ea-f377-44a0-a492-4568daa8b0b6", "metadata": { @@ -127,6 +124,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "5b737e6c-64eb-4ae6-a8f7-350b1953e612", "metadata": {}, @@ -195,6 +193,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "fbc2cec5-6cc0-4814-92a1-ca0bd237528f", "metadata": {}, @@ -261,6 +260,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "707d30b8-6405-4187-a9ed-6146dcc42167", "metadata": { @@ -271,6 +271,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "798ca3fd-6711-4c0c-a853-d868dd14b484", "metadata": {}, @@ -287,7 +288,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 5, "id": "38ab3938-1138-43ea-b085-f430b42f5377", "metadata": { "tags": [] @@ -300,7 +301,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 6, "id": "d852ece7-e5a1-4368-9d59-c7014e0b5b4d", "metadata": { "tags": [] @@ -312,6 +313,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "500cbee4", "metadata": {}, @@ -321,7 +323,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 7, "id": "9fd1cad5", "metadata": {}, "outputs": [ @@ -343,7 +345,7 @@ "}\n", "Got output: 405\n", "========================\n", - "(121 * 3) + 42 = 405\n" + "(121 * 3) + 42 is equal to 405.\n" ] } ], @@ -353,6 +355,26 @@ ] }, { + "cell_type": "code", + "execution_count": 8, + "id": "538bf32f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ToolOutput(content='363', tool_name='multiply', raw_input={'args': (), 'kwargs': {'a': 121, 'b': 3}}, raw_output=363), ToolOutput(content='405', tool_name='add', raw_input={'args': (), 'kwargs': {'a': 363, 'b': 42}}, raw_output=405)]\n" + ] + } + ], + "source": [ + "# inspect sources\n", + "print(response.sources)" + ] + }, + { + "attachments": {}, "cell_type": "markdown", "id": "fb33983c", "metadata": {}, @@ -362,7 +384,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "id": "1d1fc974", "metadata": {}, "outputs": [ @@ -370,14 +392,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "=== Calling Function ===\n", - "Calling function: multiply with args: {\n", - " \"a\": 121,\n", - " \"b\": 3\n", - "}\n", - "Got output: 363\n", - "========================\n", - "121 * 3 = 363\n" + "(121 * 3) + 42 is equal to 405.\n" ] } ], @@ -387,6 +402,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "aae035cb", "metadata": {}, @@ -397,7 +413,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "id": "14217fb2", "metadata": {}, "outputs": [ @@ -408,48 +424,41 @@ "=== Calling Function ===\n", "Calling function: multiply with args: {\n", " \"a\": 121,\n", - " \"b\": 2\n", + " \"b\": 3\n", "}\n", - "Got output: 242\n", + "Got output: 363\n", "========================\n", - "121 * 2 = 242\n", - "\n", - "Once upon a time, in a small village, there was a group of mice who lived happily in a cozy little burrow. The leader of the group was a wise and courageous mouse named Milo. Milo was known for his intelligence and his ability to solve problems.\n", + "121 * 2 is equal to 242.\n", "\n", - "One day, Milo gathered all the mice together and announced that they were facing a shortage of food. The mice were worried because winter was approaching, and they needed to find a way to gather enough food to survive the cold months ahead.\n", + "Once upon a time, in a small village, there was a group of mice who lived happily in a cozy little burrow. The leader of the group was a wise and courageous mouse named Max. Max was known for his intelligence and his ability to solve problems.\n", "\n", - "Milo came up with a brilliant plan. He divided the group into teams of two, pairing each mouse with another. Each pair was assigned the task of collecting as much food as possible. The mice were excited and motivated to contribute to the group's survival.\n", + "One day, Max gathered all the mice together and announced that they were facing a shortage of food. The mice had been relying on the village's grain storage, but it was running low. Max knew that they needed to find a solution quickly to ensure the survival of their group.\n", "\n", - "One pair of mice, named Bella and Max, set out on their mission. They scurried through fields and forests, searching for food. With their combined efforts, they were able to gather twice as much food as they had ever collected before. They found an abundance of nuts, seeds, and grains, enough to fill their tiny paws.\n", + "Inspired by the number 242, Max came up with a brilliant plan. He organized the mice into teams and assigned each team a specific task. Some mice were tasked with scouting for new food sources, while others were responsible for gathering and storing food. Max himself took charge of coordinating the efforts and ensuring that everything ran smoothly.\n", "\n", - "As Bella and Max returned to the burrow, they were greeted with cheers and applause from the other mice. The food they had collected would ensure that everyone in the group would have enough to eat during the winter.\n", + "The mice worked tirelessly, using their small size and agility to their advantage. They explored every nook and cranny of the village, searching for hidden food supplies. They discovered forgotten pantries, secret caches, and even managed to negotiate with the village bakery for leftover crumbs.\n", "\n", - "The success of Bella and Max inspired the other mice to work together and multiply their efforts. They formed more pairs and ventured out into different directions, exploring new territories and discovering hidden food sources. With each pair's contribution, the group's food supply grew exponentially.\n", + "With their combined efforts, the mice were able to gather enough food to sustain their group for the coming months. They celebrated their success with a grand feast, sharing stories of their adventures and expressing gratitude for their leader, Max.\n", "\n", - "Thanks to the multiplication of their efforts, the mice not only survived the winter but thrived. They had enough food to share with neighboring animals and even helped other creatures in need.\n", + "The story of the group of mice and their resourcefulness spread throughout the village, inspiring others to work together and find creative solutions to their own challenges. Max became a symbol of leadership and ingenuity, and the mice continued to thrive under his guidance.\n", "\n", - "The story of the group of mice taught everyone the power of collaboration and the importance of working together towards a common goal. It showed that by multiplying their efforts, they could overcome any challenge and achieve great things.\n", - "\n", - "And so, the mice lived happily ever after, always remembering the lesson they had learned about the power of multiplication and unity." + "And so, the mice of the village lived happily ever after, thanks to their determination, teamwork, and the number 242 that sparked their journey to find a new food source." ] } ], "source": [ - "agent_stream = agent.stream_chat(\n", + "response = agent.stream_chat(\n", " \"What is 121 * 2? Once you have the answer, use that number to write a story about a group of mice.\"\n", ")\n", - "for response in agent_stream:\n", - " response_gen = response.response_gen\n", - " # NOTE: here, we skip any intermediate steps and wait until the last response\n", - " # intermediate steps usually only contain function calls though\n", - " # for token in response_gen:\n", - " # print(token, end=\"\")\n", + "\n", + "response_gen = response.response_gen\n", "\n", "for token in response_gen:\n", " print(token, end=\"\")" ] }, { + "attachments": {}, "cell_type": "markdown", "id": "3fac119f", "metadata": {}, @@ -459,7 +468,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "33ea069f-819b-4ec1-a93c-fcbaacb362a1", "metadata": { "tags": [] @@ -469,52 +478,39 @@ "name": "stdout", "output_type": "stream", "text": [ - "=== Calling Function ===\n", - "Calling function: add with args: {\n", - " \"a\": 121,\n", - " \"b\": 8\n", - "}\n", - "Got output: 129\n", - "========================\n", - "121 + 8 = 129\n", - "\n", - "Once upon a time, in a lush green meadow, there was a group of mice who lived harmoniously in a cozy burrow. The leader of the group was a wise and compassionate mouse named Oliver. Oliver was known for his kindness and his ability to bring the mice together.\n", + "121 + 8 is equal to 129.\n", "\n", - "One sunny morning, as the mice were going about their daily activities, they noticed a group of birds flying overhead. The birds were migrating to a warmer climate, and they dropped a small bag of seeds as they passed by. The mice were excited to find this unexpected gift and quickly gathered around to examine it.\n", + "Once upon a time, in a small village nestled in the countryside, there was a group of mice who lived in a cozy little burrow beneath a towering oak tree. The mice were known for their unity and their ability to overcome challenges together.\n", "\n", - "Oliver, being the leader, suggested that they share the seeds among themselves. He believed that by working together, they could make the most of this fortunate event. The mice agreed and decided to divide the seeds equally among all the members of the group.\n", + "One sunny morning, as the mice were going about their daily activities, they stumbled upon a mysterious object hidden among the fallen leaves. It was a small, golden key with intricate engravings. Curiosity piqued, the mice decided to unlock the secret that lay before them.\n", "\n", - "As they distributed the seeds, they realized that there were 129 seeds in total. Each mouse received an equal share of 129 divided by the number of mice in the group. They were delighted to have this additional food source, which would help sustain them during the upcoming winter.\n", + "The leader of the mouse group, a wise and adventurous mouse named Mia, took the key in her tiny paws and inserted it into a hidden lock they discovered nearby. With a gentle turn, a hidden door creaked open, revealing a hidden treasure room filled with delicious treats and treasures.\n", "\n", - "The mice decided to plant the seeds in a nearby field, hoping to grow a bountiful harvest. They worked tirelessly, digging small holes and carefully placing the seeds in the ground. They watered the seeds and nurtured them with love and care.\n", + "The mice couldn't believe their eyes! The room was filled with an abundance of cheese, nuts, and grains. It was a feast fit for a king! Overwhelmed with joy, the mice wasted no time and began feasting on the delectable treats, filling their bellies with delight.\n", "\n", - "Days turned into weeks, and the mice watched with anticipation as tiny sprouts emerged from the soil. The sprouts grew into beautiful plants, bearing fruits and vegetables that the mice had never seen before. The field became a vibrant garden, providing an abundance of food for the entire group.\n", + "As they enjoyed their newfound treasure, the mice realized that they had stumbled upon a secret stash that had been left behind by a kind-hearted villager. Grateful for their good fortune, the mice decided to share their treasure with the other creatures of the village.\n", "\n", - "The mice celebrated their successful harvest with a grand feast. They enjoyed the fruits of their labor, grateful for the unity and cooperation that had led to their prosperity. Oliver, with tears of joy in his eyes, thanked each and every mouse for their contribution and reminded them of the power of togetherness.\n", + "Word quickly spread throughout the village about the generous mice and their treasure trove. Animals from far and wide flocked to the burrow beneath the oak tree, forming a harmonious community where everyone shared and cared for one another.\n", "\n", - "From that day forward, the mice continued to work together, sharing their resources and supporting one another. They thrived in their meadow, living in harmony and spreading their message of unity to other creatures they encountered.\n", + "The mice became known as the guardians of abundance, and their story was passed down through generations. The village prospered, and the spirit of unity and generosity lived on.\n", "\n", - "The story of the group of mice taught everyone the importance of coming together and sharing resources. It showed that by adding their efforts and working as a team, they could overcome challenges and achieve abundance. The mice became an inspiration to all, reminding everyone of the power of collaboration and the strength that lies in unity." + "And so, the mice and the other creatures of the village lived happily ever after, thanks to the key that unlocked a world of abundance and the number 129 that brought them together in harmony." ] } ], "source": [ - "chat_gen = agent.astream_chat(\n", + "response = await agent.astream_chat(\n", " \"What is 121 + 8? Once you have the answer, use that number to write a story about a group of mice.\"\n", ")\n", "\n", - "async for response in chat_gen:\n", - " response_gen = response.response_gen\n", - " # NOTE: here, we skip any intermediate steps and wait until the last response\n", - " # intermediate steps usually only contain function calls though\n", - " # for token in response_gen:\n", - " # print(token, end=\"\")\n", + "response_gen = response.response_gen\n", "\n", "for token in response_gen:\n", " print(token, end=\"\")" ] }, { + "attachments": {}, "cell_type": "markdown", "id": "2fe399c5-6d07-4926-b701-b612efd56b30", "metadata": {}, @@ -523,6 +519,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "8b47c034-f948-4604-a8d8-828b617ea245", "metadata": {}, @@ -532,7 +529,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 13, "id": "bef36d1e-c26e-4b07-b3d0-3b7f314a45f5", "metadata": { "tags": [] @@ -614,7 +611,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.9.6" } }, "nbformat": 4, diff --git a/llama_index/agent/context_retriever_agent.py b/llama_index/agent/context_retriever_agent.py index 07fdc95c0467b..3568b525a3ddb 100644 --- a/llama_index/agent/context_retriever_agent.py +++ b/llama_index/agent/context_retriever_agent.py @@ -9,13 +9,13 @@ BaseOpenAIAgent, ) from llama_index.bridge.langchain import print_text +from llama_index.chat_engine.types import AgentChatResponse from llama_index.callbacks.base import CallbackManager from llama_index.indices.base_retriever import BaseRetriever from llama_index.llms.base import ChatMessage from llama_index.llms.openai import OpenAI from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.prompts.prompts import QuestionAnswerPrompt -from llama_index.response.schema import RESPONSE_TYPE from llama_index.schema import NodeWithScore from llama_index.tools import BaseTool @@ -149,7 +149,7 @@ def _get_tools(self, message: str) -> List[BaseTool]: def chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> RESPONSE_TYPE: + ) -> AgentChatResponse: """Chat.""" # augment user message retrieved_nodes_w_scores: List[NodeWithScore] = self._retriever.retrieve( diff --git a/llama_index/agent/openai_agent.py b/llama_index/agent/openai_agent.py index a6d102b4c94e7..b24e32f12e8d3 100644 --- a/llama_index/agent/openai_agent.py +++ b/llama_index/agent/openai_agent.py @@ -4,21 +4,19 @@ from abc import abstractmethod from threading import Thread from typing import ( - AsyncGenerator, Callable, - Generator, List, Tuple, Type, Optional, ) -from llama_index.callbacks.base import CallbackManager from llama_index.chat_engine.types import ( BaseChatEngine, - StreamingChatResponse, - STREAMING_CHAT_RESPONSE_TYPE, + AgentChatResponse, + StreamingAgentChatResponse, ) +from llama_index.callbacks.base import CallbackManager from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.query.schema import QueryBundle @@ -30,7 +28,7 @@ from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.response.schema import RESPONSE_TYPE, Response from llama_index.schema import BaseNode, NodeWithScore -from llama_index.tools import BaseTool +from llama_index.tools import BaseTool, ToolOutput DEFAULT_MAX_FUNCTION_CALLS = 5 DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613" @@ -50,7 +48,7 @@ def get_function_by_name(tools: List[BaseTool], name: str) -> BaseTool: def call_function( tools: List[BaseTool], function_call: dict, verbose: bool = False -) -> ChatMessage: +) -> Tuple[ChatMessage, ToolOutput]: """Call a function and return the output as a string.""" name = function_call["name"] arguments_str = function_call["arguments"] @@ -61,14 +59,17 @@ def call_function( argument_dict = json.loads(arguments_str) output = tool(**argument_dict) if verbose: - print(f"Got output: {output}") + print(f"Got output: {str(output)}") print("========================") - return ChatMessage( - content=str(output), - role=MessageRole.FUNCTION, - additional_kwargs={ - "name": function_call["name"], - }, + return ( + ChatMessage( + content=str(output), + role=MessageRole.FUNCTION, + additional_kwargs={ + "name": function_call["name"], + }, + ), + output, ) @@ -117,11 +118,12 @@ def _init_chat(self, message: str) -> Tuple[List[BaseTool], List[dict]]: def chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> RESPONSE_TYPE: + ) -> AgentChatResponse: if chat_history is not None: self._memory.set(chat_history) tools, functions = self._init_chat(message) + sources = [] # TODO: Support forced function call all_messages = self._prefix_messages + self._memory.get() @@ -136,9 +138,10 @@ def chat( print(f"Exceeded max function calls: {self._max_function_calls}.") break - function_message = call_function( + function_message, tool_output = call_function( tools, function_call, verbose=self._verbose ) + sources.append(tool_output) self._memory.put(function_message) n_function_calls += 1 @@ -149,20 +152,56 @@ def chat( self._memory.put(ai_message) function_call = self._get_latest_function_call(self._memory.get_all()) - return Response(ai_message.content) + return AgentChatResponse(response=str(ai_message.content), sources=sources) def stream_chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> STREAMING_CHAT_RESPONSE_TYPE: + ) -> StreamingAgentChatResponse: if chat_history is not None: self._memory.set(chat_history) tools, functions = self._init_chat(message) + all_messages = self._prefix_messages + self._memory.get() + sources = [] + + # TODO: Support forced function call + chat_stream_response = StreamingAgentChatResponse( + chat_stream=self._llm.stream_chat(all_messages, functions=functions) + ) + + # Get the response in a separate thread so we can yield the response + thread = Thread( + target=chat_stream_response.write_response_to_history, + args=(self._memory,), + ) + thread.start() + + while chat_stream_response._is_function is None: + # Wait until we know if the response is a function call or not + time.sleep(0.05) + if chat_stream_response._is_function is False: + return chat_stream_response + + thread.join() + + n_function_calls = 0 + function_call = self._get_latest_function_call(self._memory.get_all()) + while function_call is not None: + if n_function_calls >= self._max_function_calls: + print(f"Exceeded max function calls: {self._max_function_calls}.") + break - def gen() -> Generator[StreamingChatResponse, None, None]: - # TODO: Support forced function call + function_message, tool_output = call_function( + tools, function_call, verbose=self._verbose + ) + sources.append(tool_output) + self._memory.put(function_message) + n_function_calls += 1 + + # send function call & output back to get another response all_messages = self._prefix_messages + self._memory.get() - chat_stream_response = StreamingChatResponse( - self._llm.stream_chat(all_messages, functions=functions) + chat_stream_response = StreamingAgentChatResponse( + chat_stream=self._llm.stream_chat(all_messages, functions=functions), + sources=sources, ) # Get the response in a separate thread so we can yield the response @@ -171,64 +210,27 @@ def gen() -> Generator[StreamingChatResponse, None, None]: args=(self._memory,), ) thread.start() - yield chat_stream_response - while chat_stream_response._is_function is None: # Wait until we know if the response is a function call or not time.sleep(0.05) if chat_stream_response._is_function is False: - return + return chat_stream_response thread.join() - - n_function_calls = 0 function_call = self._get_latest_function_call(self._memory.get_all()) - while function_call is not None: - if n_function_calls >= self._max_function_calls: - print(f"Exceeded max function calls: {self._max_function_calls}.") - break - function_message = call_function( - tools, function_call, verbose=self._verbose - ) - self._memory.put(function_message) - n_function_calls += 1 - - all_messages = self._prefix_messages + self._memory.get() - # send function call & output back to get another response - chat_stream_response = StreamingChatResponse( - self._llm.stream_chat(all_messages, functions=functions) - ) - - # Get the response in a separate thread so we can yield the response - thread = Thread( - target=chat_stream_response.write_response_to_history, - args=(self._memory,), - ) - thread.start() - yield chat_stream_response - - while chat_stream_response._is_function is None: - # Wait until we know if the response is a function call or not - time.sleep(0.05) - if chat_stream_response._is_function is False: - return - - thread.join() - function_call = self._get_latest_function_call(self._memory.get_all()) - - return gen() + return chat_stream_response async def achat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> RESPONSE_TYPE: + ) -> AgentChatResponse: if chat_history is not None: self._memory.set(chat_history) - + all_messages = self._prefix_messages + self._memory.get() tools, functions = self._init_chat(message) + sources = [] # TODO: Support forced function call - all_messages = self._prefix_messages + self._memory.get() chat_response = await self._llm.achat(all_messages, functions=functions) ai_message = chat_response.message self._memory.put(ai_message) @@ -240,9 +242,10 @@ async def achat( print(f"Exceeded max function calls: {self._max_function_calls}.") continue - function_message = call_function( + function_message, tool_output = call_function( tools, function_call, verbose=self._verbose ) + sources.append(tool_output) self._memory.put(function_message) n_function_calls += 1 @@ -254,21 +257,60 @@ async def achat( self._memory.put(ai_message) function_call = self._get_latest_function_call(self._memory.get_all()) - return Response(ai_message.content) + return AgentChatResponse(response=str(ai_message.content), sources=sources) async def astream_chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> STREAMING_CHAT_RESPONSE_TYPE: + ) -> StreamingAgentChatResponse: if chat_history is not None: self._memory.set(chat_history) - tools, functions = self._init_chat(message) + all_messages = self._prefix_messages + self._memory.get() + sources = [] + + # TODO: Support forced function call + chat_stream_response = StreamingAgentChatResponse( + achat_stream=await self._llm.astream_chat(all_messages, functions=functions) + ) + + # Get the response in a separate thread so we can yield the response + thread = Thread( + target=lambda x: asyncio.run( + chat_stream_response.awrite_response_to_history(x) + ), + args=(self._memory,), + ) + thread.start() + + while chat_stream_response._is_function is None: + # Wait until we know if the response is a function call or not + time.sleep(0.05) + if chat_stream_response._is_function is False: + return chat_stream_response + + thread.join() + + n_function_calls = 0 + function_call = self._get_latest_function_call(self._memory.get_all()) + while function_call is not None: + if n_function_calls >= self._max_function_calls: + print(f"Exceeded max function calls: {self._max_function_calls}.") + break - async def gen() -> AsyncGenerator[StreamingChatResponse, None]: + function_message, tool_output = call_function( + tools, function_call, verbose=self._verbose + ) + sources.append(tool_output) + self._memory.put(function_message) + n_function_calls += 1 + + # send function call & output back to get another response all_messages = self._prefix_messages + self._memory.get() - # TODO: Support forced function call - chat_stream_response = StreamingChatResponse( - await self._llm.astream_chat(all_messages, functions=functions) + chat_stream_response = StreamingAgentChatResponse( + achat_stream=await self._llm.astream_chat( + all_messages, functions=functions + ), + sources=sources, ) # Get the response in a separate thread so we can yield the response @@ -279,68 +321,32 @@ async def gen() -> AsyncGenerator[StreamingChatResponse, None]: args=(self._memory,), ) thread.start() - yield chat_stream_response while chat_stream_response._is_function is None: # Wait until we know if the response is a function call or not time.sleep(0.05) if chat_stream_response._is_function is False: - return + return chat_stream_response thread.join() - - n_function_calls = 0 function_call = self._get_latest_function_call(self._memory.get_all()) - while function_call is not None: - if n_function_calls >= self._max_function_calls: - print(f"Exceeded max function calls: {self._max_function_calls}.") - break - - function_message = call_function( - tools, function_call, verbose=self._verbose - ) - self._memory.put(function_message) - n_function_calls += 1 - - # send function call & output back to get another response - all_messages = self._prefix_messages + self._memory.get() - chat_stream_response = StreamingChatResponse( - await self._llm.astream_chat(all_messages, functions=functions) - ) - - # Get the response in a separate thread so we can yield the response - thread = Thread( - target=lambda x: asyncio.run( - chat_stream_response.awrite_response_to_history(x) - ), - args=(self._memory,), - ) - thread.start() - yield chat_stream_response - - while chat_stream_response._is_function is None: - # Wait until we know if the response is a function call or not - time.sleep(0.05) - if chat_stream_response._is_function is False: - return - - thread.join() - function_call = self._get_latest_function_call(self._memory.get_all()) - return gen() + return chat_stream_response # ===== Query Engine Interface ===== def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: - return self.chat( + agent_response = self.chat( query_bundle.query_str, chat_history=[], ) + return Response(response=str(agent_response)) async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: - return await self.achat( + agent_response = await self.achat( query_bundle.query_str, chat_history=[], ) + return Response(response=str(agent_response)) class OpenAIAgent(BaseOpenAIAgent): diff --git a/llama_index/chat_engine/condense_question.py b/llama_index/chat_engine/condense_question.py index 5118629407027..306daa4098935 100644 --- a/llama_index/chat_engine/condense_question.py +++ b/llama_index/chat_engine/condense_question.py @@ -1,7 +1,11 @@ import logging from typing import Any, List, Type, Optional -from llama_index.chat_engine.types import BaseChatEngine, STREAMING_CHAT_RESPONSE_TYPE +from llama_index.chat_engine.types import ( + BaseChatEngine, + AgentChatResponse, + StreamingAgentChatResponse, +) from llama_index.chat_engine.utils import response_gen_with_chat_history from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.service_context import ServiceContext @@ -9,7 +13,8 @@ from llama_index.llms.generic_utils import messages_to_history_str from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.prompts.base import Prompt -from llama_index.response.schema import RESPONSE_TYPE, StreamingResponse +from llama_index.response.schema import StreamingResponse, RESPONSE_TYPE +from llama_index.tools import ToolOutput logger = logging.getLogger(__name__) @@ -123,9 +128,19 @@ async def _acondense_question( ) return response + def _get_tool_output_from_response( + self, query: str, response: RESPONSE_TYPE + ) -> ToolOutput: + return ToolOutput( + content=str(response), + tool_name="query_engine", + raw_input={"query": query}, + raw_output=response, + ) + def chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> RESPONSE_TYPE: + ) -> AgentChatResponse: chat_history = chat_history or self._memory.get() # Generate standalone question from conversation context and last message @@ -137,17 +152,22 @@ def chat( print(log_str) # Query with standalone question - response = self._query_engine.query(condensed_question) + query_response = self._query_engine.query(condensed_question) + tool_output = self._get_tool_output_from_response( + condensed_question, query_response + ) # Record response self._memory.put(ChatMessage(role=MessageRole.USER, content=message)) - self._memory.put(ChatMessage(role=MessageRole.ASSISTANT, content=str(response))) + self._memory.put( + ChatMessage(role=MessageRole.ASSISTANT, content=str(query_response)) + ) - return response + return AgentChatResponse(response=str(query_response), sources=[tool_output]) def stream_chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> STREAMING_CHAT_RESPONSE_TYPE: + ) -> StreamingAgentChatResponse: chat_history = chat_history or self._memory.get() # Generate standalone question from conversation context and last message @@ -159,20 +179,22 @@ def stream_chat( print(log_str) # Query with standalone question - response = self._query_engine.query(condensed_question) + query_response = self._query_engine.query(condensed_question) + tool_output = self._get_tool_output_from_response( + condensed_question, query_response + ) # Record response if ( - isinstance(response, StreamingResponse) - and response.response_gen is not None + isinstance(query_response, StreamingResponse) + and query_response.response_gen is not None ): # override the generator to include writing to chat history - response = StreamingResponse( - response_gen_with_chat_history( - message, self._memory, response.response_gen + response = StreamingAgentChatResponse( + chat_stream=response_gen_with_chat_history( + message, self._memory, query_response.response_gen ), - source_nodes=response.source_nodes, - metadata=response.metadata, + sources=[tool_output], ) else: raise ValueError("Streaming is not enabled. Please use chat() instead.") @@ -180,7 +202,7 @@ def stream_chat( async def achat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> RESPONSE_TYPE: + ) -> AgentChatResponse: chat_history = chat_history or self._memory.get() # Generate standalone question from conversation context and last message @@ -192,17 +214,22 @@ async def achat( print(log_str) # Query with standalone question - response = await self._query_engine.aquery(condensed_question) + query_response = await self._query_engine.aquery(condensed_question) + tool_output = self._get_tool_output_from_response( + condensed_question, query_response + ) # Record response self._memory.put(ChatMessage(role=MessageRole.USER, content=message)) - self._memory.put(ChatMessage(role=MessageRole.ASSISTANT, content=str(response))) + self._memory.put( + ChatMessage(role=MessageRole.ASSISTANT, content=str(query_response)) + ) - return response + return AgentChatResponse(response=str(query_response), sources=[tool_output]) async def astream_chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> STREAMING_CHAT_RESPONSE_TYPE: + ) -> StreamingAgentChatResponse: chat_history = chat_history or self._memory.get() # Generate standalone question from conversation context and last message @@ -214,20 +241,23 @@ async def astream_chat( print(log_str) # Query with standalone question - response = await self._query_engine.aquery(condensed_question) + query_response = await self._query_engine.aquery(condensed_question) + tool_output = self._get_tool_output_from_response( + condensed_question, query_response + ) # Record response if ( - isinstance(response, StreamingResponse) - and response.response_gen is not None + isinstance(query_response, StreamingResponse) + and query_response.response_gen is not None ): # override the generator to include writing to chat history - response = StreamingResponse( - response_gen_with_chat_history( - message, self._memory, response.response_gen + # TODO: query engine does not support async generator yet + response = StreamingAgentChatResponse( + chat_stream=response_gen_with_chat_history( + message, self._memory, query_response.response_gen ), - source_nodes=response.source_nodes, - metadata=response.metadata, + sources=[tool_output], ) else: raise ValueError("Streaming is not enabled. Please use achat() instead.") diff --git a/llama_index/chat_engine/react.py b/llama_index/chat_engine/react.py index 5e48e421e5244..9ac8d57a6b311 100644 --- a/llama_index/chat_engine/react.py +++ b/llama_index/chat_engine/react.py @@ -5,7 +5,11 @@ ChatMessageHistory, ConversationBufferMemory, ) -from llama_index.chat_engine.types import BaseChatEngine, STREAMING_CHAT_RESPONSE_TYPE +from llama_index.chat_engine.types import ( + BaseChatEngine, + AgentChatResponse, + StreamingAgentChatResponse, +) from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.service_context import ServiceContext from llama_index.langchain_helpers.agents.agents import ( @@ -21,7 +25,6 @@ is_chat_model, to_lc_messages, ) -from llama_index.response.schema import RESPONSE_TYPE, Response from llama_index.tools.query_engine import QueryEngineTool @@ -148,34 +151,34 @@ def chat_history(self) -> List[ChatMessage]: def chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> RESPONSE_TYPE: + ) -> AgentChatResponse: if chat_history is not None: raise NotImplementedError( "chat_history argument is not supported for ReActChatEngine." ) response = self._agent.run(input=message) - return Response(response=response) + return AgentChatResponse(response=response) async def achat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> RESPONSE_TYPE: + ) -> AgentChatResponse: if chat_history is not None: raise NotImplementedError( "chat_history argument is not supported for ReActChatEngine." ) response = await self._agent.arun(input=message) - return Response(response=response) + return AgentChatResponse(response=response) def stream_chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> STREAMING_CHAT_RESPONSE_TYPE: + ) -> StreamingAgentChatResponse: raise NotImplementedError("stream_chat() is not supported for ReActChatEngine.") async def astream_chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> STREAMING_CHAT_RESPONSE_TYPE: + ) -> StreamingAgentChatResponse: raise NotImplementedError( "astream_chat() is not supported for ReActChatEngine." ) diff --git a/llama_index/chat_engine/simple.py b/llama_index/chat_engine/simple.py index 53da4d363898e..8e71db486e17d 100644 --- a/llama_index/chat_engine/simple.py +++ b/llama_index/chat_engine/simple.py @@ -4,14 +4,13 @@ from llama_index.chat_engine.types import ( BaseChatEngine, - StreamingChatResponse, - STREAMING_CHAT_RESPONSE_TYPE, + AgentChatResponse, + StreamingAgentChatResponse, ) from llama_index.indices.service_context import ServiceContext from llama_index.llm_predictor.base import LLMPredictor from llama_index.llms.base import LLM, ChatMessage from llama_index.memory import BaseMemory, ChatMemoryBuffer -from llama_index.response.schema import RESPONSE_TYPE, Response, StreamingResponse class SimpleChatEngine(BaseChatEngine): @@ -64,7 +63,7 @@ def from_defaults( def chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> RESPONSE_TYPE: + ) -> AgentChatResponse: if chat_history is not None: self._memory.set(chat_history) self._memory.put(ChatMessage(content=message, role="user")) @@ -74,27 +73,29 @@ def chat( ai_message = chat_response.message self._memory.put(ai_message) - return Response(response=chat_response.message.content) + return AgentChatResponse(response=str(chat_response.message.content)) def stream_chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> STREAMING_CHAT_RESPONSE_TYPE: + ) -> StreamingAgentChatResponse: if chat_history is not None: self._memory.set(chat_history) self._memory.put(ChatMessage(content=message, role="user")) all_messages = self._prefix_messages + self._memory.get() - chat_response = StreamingChatResponse(self._llm.stream_chat(all_messages)) + chat_response = StreamingAgentChatResponse( + chat_stream=self._llm.stream_chat(all_messages) + ) thread = Thread( target=chat_response.write_response_to_history, args=(self._memory,) ) thread.start() - return StreamingResponse(response_gen=chat_response.response_gen) + return chat_response async def achat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> RESPONSE_TYPE: + ) -> AgentChatResponse: if chat_history is not None: self._memory.set(chat_history) self._memory.put(ChatMessage(content=message, role="user")) @@ -104,24 +105,26 @@ async def achat( ai_message = chat_response.message self._memory.put(ai_message) - return Response(response=chat_response.message.content) + return AgentChatResponse(response=str(chat_response.message.content)) async def astream_chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> STREAMING_CHAT_RESPONSE_TYPE: + ) -> StreamingAgentChatResponse: if chat_history is not None: self._memory.set(chat_history) self._memory.put(ChatMessage(content=message, role="user")) all_messages = self._prefix_messages + self._memory.get() - chat_response = StreamingChatResponse(self._llm.stream_chat(all_messages)) + chat_response = StreamingAgentChatResponse( + chat_stream=self._llm.stream_chat(all_messages) + ) thread = Thread( target=lambda x: asyncio.run(chat_response.awrite_response_to_history(x)), args=(self._memory,), ) thread.start() - return StreamingResponse(response_gen=chat_response.response_gen) + return chat_response def reset(self) -> None: self._memory.reset() diff --git a/llama_index/chat_engine/types.py b/llama_index/chat_engine/types.py index 3d808ed9f0ee5..b29e4ed132ae9 100644 --- a/llama_index/chat_engine/types.py +++ b/llama_index/chat_engine/types.py @@ -1,42 +1,54 @@ import logging import queue from abc import ABC, abstractmethod +from dataclasses import dataclass, field from enum import Enum -from typing import AsyncGenerator, Generator, List, Optional, Union +from typing import Generator, List, Optional +from llama_index.tools import ToolOutput from llama_index.llms.base import ChatMessage, ChatResponseGen, ChatResponseAsyncGen from llama_index.memory import BaseMemory -from llama_index.response.schema import RESPONSE_TYPE, StreamingResponse logger = logging.getLogger(__name__) -class StreamingChatResponse: +@dataclass +class AgentChatResponse: + """Agent chat response.""" + + response: str = "" + sources: List[ToolOutput] = field(default_factory=list) + + def __str__(self) -> str: + return self.response + + +@dataclass +class StreamingAgentChatResponse: """Streaming chat response to user and writing to chat history.""" - def __init__( - self, chat_stream: Union[ChatResponseGen, ChatResponseAsyncGen] - ) -> None: - self._chat_stream = chat_stream - self._queue: queue.Queue = queue.Queue() - self._is_done = False - self._is_function: Optional[bool] = None - self.response_str = "" + response: str = "" + sources: List[ToolOutput] = field(default_factory=list) + chat_stream: Optional[ChatResponseGen] = None + achat_stream: Optional[ChatResponseAsyncGen] = None + _queue: queue.Queue = queue.Queue() + _is_done = False + _is_function: Optional[bool] = None def __str__(self) -> str: if self._is_done and not self._queue.empty() and not self._is_function: for delta in self._queue.queue: - self.response_str += delta - return self.response_str + self.response += delta + return self.response def write_response_to_history(self, memory: BaseMemory) -> None: - if isinstance(self._chat_stream, AsyncGenerator): + if self.chat_stream is None: raise ValueError( - "Cannot write to history with async generator in sync function." + "chat_stream is None. Cannot write to history without chat_stream." ) final_message = None - for chat in self._chat_stream: + for chat in self.chat_stream: final_message = chat.message self._is_function = ( final_message.additional_kwargs.get("function_call", None) is not None @@ -49,13 +61,14 @@ def write_response_to_history(self, memory: BaseMemory) -> None: self._is_done = True async def awrite_response_to_history(self, memory: BaseMemory) -> None: - if isinstance(self._chat_stream, Generator): + if self.achat_stream is None: raise ValueError( - "Cannot write to history with sync generator in async function." + "achat_stream is None. Cannot asynchronously write to " + "history without achat_stream." ) final_message = None - async for chat in self._chat_stream: + async for chat in self.achat_stream: final_message = chat.message self._is_function = ( final_message.additional_kwargs.get("function_call", None) is not None @@ -72,21 +85,13 @@ def response_gen(self) -> Generator[str, None, None]: while not self._is_done or not self._queue.empty(): try: delta = self._queue.get(block=False) - self.response_str += delta + self.response += delta yield delta except queue.Empty: # Queue is empty, but we're not done yet continue -STREAMING_CHAT_RESPONSE_TYPE = Union[ - StreamingResponse, - StreamingChatResponse, - Generator[StreamingChatResponse, None, None], - AsyncGenerator[StreamingChatResponse, None], -] - - class BaseChatEngine(ABC): """Base Chat Engine.""" @@ -98,28 +103,28 @@ def reset(self) -> None: @abstractmethod def chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> RESPONSE_TYPE: + ) -> AgentChatResponse: """Main chat interface.""" pass @abstractmethod def stream_chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> STREAMING_CHAT_RESPONSE_TYPE: + ) -> StreamingAgentChatResponse: """Stream chat interface.""" pass @abstractmethod async def achat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> RESPONSE_TYPE: + ) -> AgentChatResponse: """Async version of main chat interface.""" pass @abstractmethod async def astream_chat( self, message: str, chat_history: Optional[List[ChatMessage]] = None - ) -> STREAMING_CHAT_RESPONSE_TYPE: + ) -> StreamingAgentChatResponse: """Async version of main chat interface.""" pass diff --git a/llama_index/chat_engine/utils.py b/llama_index/chat_engine/utils.py index 3873a733b4ab8..df444c478a81d 100644 --- a/llama_index/chat_engine/utils.py +++ b/llama_index/chat_engine/utils.py @@ -1,15 +1,22 @@ -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.base import ( + ChatMessage, + MessageRole, + ChatResponse, + ChatResponseGen, +) from llama_index.memory import BaseMemory from llama_index.types import TokenGen def response_gen_with_chat_history( message: str, memory: BaseMemory, response_gen: TokenGen -) -> TokenGen: +) -> ChatResponseGen: response_str = "" for token in response_gen: response_str += token - yield token + yield ChatResponse( + role=MessageRole.ASSISTANT, content=response_str, delta=token + ) # Record response memory.put(ChatMessage(role=MessageRole.USER, content=message)) diff --git a/llama_index/tools/__init__.py b/llama_index/tools/__init__.py index 7401454b1556e..6a12c6b4a9421 100644 --- a/llama_index/tools/__init__.py +++ b/llama_index/tools/__init__.py @@ -1,7 +1,7 @@ """Tools.""" from llama_index.tools.query_engine import QueryEngineTool -from llama_index.tools.types import BaseTool, ToolMetadata +from llama_index.tools.types import BaseTool, ToolMetadata, ToolOutput from llama_index.tools.function_tool import FunctionTool from llama_index.tools.query_plan import QueryPlanTool @@ -9,6 +9,7 @@ "BaseTool", "QueryEngineTool", "ToolMetadata", + "ToolOutput", "FunctionTool", "QueryPlanTool", ] diff --git a/llama_index/tools/function_tool.py b/llama_index/tools/function_tool.py index 698f53a781af9..0640cffbb68e1 100644 --- a/llama_index/tools/function_tool.py +++ b/llama_index/tools/function_tool.py @@ -1,7 +1,7 @@ from typing import Any, Optional, Callable, Type from pydantic import BaseModel -from llama_index.tools.types import BaseTool, ToolMetadata +from llama_index.tools.types import BaseTool, ToolMetadata, ToolOutput from llama_index.bridge.langchain import Tool, StructuredTool from inspect import signature from llama_index.tools.utils import create_schema_from_function @@ -50,9 +50,15 @@ def fn(self) -> Callable[..., Any]: """Function.""" return self._fn - def __call__(self, *args: Any, **kwargs: Any) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> ToolOutput: """Call.""" - return self._fn(*args, **kwargs) + tool_output = self._fn(*args, **kwargs) + return ToolOutput( + content=str(tool_output), + tool_name=self.metadata.name, + raw_input={"args": args, "kwargs": kwargs}, + raw_output=tool_output, + ) def to_langchain_tool( self, diff --git a/llama_index/tools/ondemand_loader_tool.py b/llama_index/tools/ondemand_loader_tool.py index a82cc1448d515..1299f499678b0 100644 --- a/llama_index/tools/ondemand_loader_tool.py +++ b/llama_index/tools/ondemand_loader_tool.py @@ -5,7 +5,7 @@ """ -from llama_index.tools.types import BaseTool, ToolMetadata +from llama_index.tools.types import BaseTool, ToolMetadata, ToolOutput from llama_index.readers.base import BaseReader from typing import Any, Optional, Dict, Type, Callable, List from llama_index.readers.schema.base import Document @@ -107,7 +107,7 @@ def from_tool( ) metadata = ToolMetadata(name=name, description=description, fn_schema=fn_schema) return cls( - loader=tool, + loader=tool._fn, index_cls=index_cls, index_kwargs=index_kwargs, use_query_str_in_loader=use_query_str_in_loader, @@ -115,7 +115,7 @@ def from_tool( metadata=metadata, ) - def __call__(self, *args: Any, **kwargs: Any) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> ToolOutput: """Call.""" if self._query_str_kwargs_key not in kwargs: raise ValueError( @@ -133,4 +133,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: # TODO: add query kwargs query_engine = index.as_query_engine() response = query_engine.query(query_str) - return str(response) + return ToolOutput( + content=str(response), + tool_name=self.metadata.name, + raw_input={"query": query_str}, + raw_output=response, + ) diff --git a/llama_index/tools/query_engine.py b/llama_index/tools/query_engine.py index 90d4f4bbb8d22..ff1b1d6fdd5f5 100644 --- a/llama_index/tools/query_engine.py +++ b/llama_index/tools/query_engine.py @@ -2,7 +2,7 @@ from llama_index.indices.query.base import BaseQueryEngine from llama_index.langchain_helpers.agents.tools import IndexToolConfig, LlamaIndexTool -from llama_index.tools.types import BaseTool, ToolMetadata +from llama_index.tools.types import BaseTool, ToolMetadata, ToolOutput DEFAULT_NAME = "Query Engine Tool" DEFAULT_DESCRIPTION = """Useful for running a natural language query @@ -48,10 +48,15 @@ def query_engine(self) -> BaseQueryEngine: def metadata(self) -> ToolMetadata: return self._metadata - def __call__(self, input: Any) -> Any: + def __call__(self, input: Any) -> ToolOutput: query_str = cast(str, input) response = self._query_engine.query(query_str) - return str(response) + return ToolOutput( + content=str(response), + tool_name=self.metadata.name, + raw_input={"input": input}, + raw_output=response, + ) def as_langchain_tool(self) -> LlamaIndexTool: tool_config = IndexToolConfig( diff --git a/llama_index/tools/query_plan.py b/llama_index/tools/query_plan.py index c88a80260fcaa..69edbf089b5f3 100644 --- a/llama_index/tools/query_plan.py +++ b/llama_index/tools/query_plan.py @@ -6,8 +6,7 @@ from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer from llama_index.schema import NodeWithScore, TextNode -from llama_index.tools.types import BaseTool -from llama_index.tools.types import ToolMetadata +from llama_index.tools.types import BaseTool, ToolMetadata, ToolOutput DEFAULT_NAME = "query_plan_tool" @@ -137,7 +136,9 @@ def metadata(self) -> ToolMetadata: return metadata - def _execute_node(self, node: QueryNode, nodes_dict: Dict[int, QueryNode]) -> str: + def _execute_node( + self, node: QueryNode, nodes_dict: Dict[int, QueryNode] + ) -> ToolOutput: """Execute node.""" print_text(f"Executing node {node.json()}\n", color="blue") if len(node.dependencies) > 0: @@ -148,7 +149,7 @@ def _execute_node(self, node: QueryNode, nodes_dict: Dict[int, QueryNode]) -> st nodes_dict[dep] for dep in node.dependencies ] # execute the child nodes first - child_responses: List[str] = [ + child_responses: List[ToolOutput] = [ self._execute_node(child, nodes_dict) for child in child_query_nodes ] # form the child Node/NodeWithScore objects @@ -158,7 +159,7 @@ def _execute_node(self, node: QueryNode, nodes_dict: Dict[int, QueryNode]) -> st ): node_text = ( f"Query: {child_query_node.query_str}\n" - f"Response: {child_response}\n" + f"Response: {str(child_response)}\n" ) child_node = TextNode(text=node_text) child_nodes.append(child_node) @@ -170,7 +171,12 @@ def _execute_node(self, node: QueryNode, nodes_dict: Dict[int, QueryNode]) -> st query=node.query_str, nodes=child_nodes_with_scores, ) - response = str(response_obj) + response = ToolOutput( + content=str(response_obj), + tool_name=node.query_str, + raw_input={"query": node.query_str}, + raw_output=response_obj, + ) else: # this is a leaf request, execute the query string using the specified tool @@ -180,7 +186,7 @@ def _execute_node(self, node: QueryNode, nodes_dict: Dict[int, QueryNode]) -> st print_text( "Executed query, got response.\n" f"Query: {node.query_str}\n" - f"Response: {response}\n", + f"Response: {str(response)}\n", color="blue", ) return response @@ -197,7 +203,7 @@ def _find_root_nodes(self, nodes_dict: Dict[int, QueryNode]) -> List[QueryNode]: ] return [nodes_dict[node_id] for node_id in root_node_ids] - def __call__(self, *args: Any, **kwargs: Any) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> ToolOutput: """Call.""" # the kwargs represented as a JSON object # should be a QueryPlan object diff --git a/llama_index/tools/types.py b/llama_index/tools/types.py index db20598aedfbf..5722979aa52bc 100644 --- a/llama_index/tools/types.py +++ b/llama_index/tools/types.py @@ -32,6 +32,19 @@ def to_openai_function(self) -> Dict[str, Any]: } +class ToolOutput(BaseModel): + """Tool output.""" + + content: str + tool_name: str + raw_input: Dict[str, Any] + raw_output: Any + + def __str__(self) -> str: + """String.""" + return str(self.content) + + class BaseTool: @property @abstractmethod @@ -39,7 +52,7 @@ def metadata(self) -> ToolMetadata: pass @abstractmethod - def __call__(self, input: Any) -> Any: + def __call__(self, input: Any) -> ToolOutput: pass def _process_langchain_tool_kwargs( diff --git a/tests/llms/test_anthropic.py b/tests/llms/test_anthropic.py index ab3003dce6291..938468f279ca2 100644 --- a/tests/llms/test_anthropic.py +++ b/tests/llms/test_anthropic.py @@ -5,7 +5,7 @@ try: import anthropic except ImportError: - anthropic = None + anthropic = None # type: ignore @pytest.mark.skipif(anthropic is None, reason="anthropic not installed") diff --git a/tests/objects/test_node_mapping.py b/tests/objects/test_node_mapping.py index dd40fd7ef8b71..638295ccb700c 100644 --- a/tests/objects/test_node_mapping.py +++ b/tests/objects/test_node_mapping.py @@ -55,7 +55,7 @@ def test_tool_object_node_mapping() -> None: "Tool name: test_tool2\n" "Tool description: test\n" ) in node_mapping.to_node(tool2).get_text() recon_tool2 = node_mapping.from_node(node_mapping.to_node(tool2)) - assert recon_tool2(1, 2) == 3 + assert recon_tool2(1, 2).raw_output == 3 tool3 = FunctionTool.from_defaults( fn=lambda x, y: x * y, name="test_tool3", description="test3" diff --git a/tests/tools/test_base.py b/tests/tools/test_base.py index f91aa2fefe744..8dfd318f9f364 100644 --- a/tests/tools/test_base.py +++ b/tests/tools/test_base.py @@ -18,7 +18,7 @@ def test_function_tool() -> None: assert "x" in actual_schema["properties"] result = function_tool(1) - assert result == "1" + assert str(result) == "1" # test adding typing to function def tmp_function(x: int) -> str: @@ -48,7 +48,7 @@ class TestSchema(BaseModel): description="bar", fn_schema=TestSchema, ) - assert function_tool(1, 2) == "1,2" + assert str(function_tool(1, 2)) == "1,2" langchain_tool2 = function_tool.to_langchain_structured_tool() assert langchain_tool2.run({"x": 1, "y": 2}) == "1,2" assert langchain_tool2.args_schema == TestSchema diff --git a/tests/tools/test_ondemand_loader.py b/tests/tools/test_ondemand_loader.py index c2db6c5297208..b84e5df229b52 100644 --- a/tests/tools/test_ondemand_loader.py +++ b/tests/tools/test_ondemand_loader.py @@ -35,10 +35,10 @@ class TestSchemaSpec(BaseModel): fn_schema=TestSchemaSpec, ) response = tool(["Hello world."], query_str="What is?") - assert response == "What is?:Hello world." + assert str(response) == "What is?:Hello world." # convert tool to structured langchain tool lc_tool = tool.to_langchain_structured_tool() assert lc_tool.args_schema == TestSchemaSpec response = lc_tool.run({"texts": ["Hello world."], "query_str": "What is?"}) - assert response == "What is?:Hello world." + assert str(response) == "What is?:Hello world." diff --git a/tests/tools/tool_spec/test_base.py b/tests/tools/tool_spec/test_base.py index 8a6aad7302846..b8702161a719b 100644 --- a/tests/tools/tool_spec/test_base.py +++ b/tests/tools/tool_spec/test_base.py @@ -57,7 +57,7 @@ def test_tool_spec() -> None: assert tools[0].fn("hello", 1) == "foo hello 1" assert tools[1].metadata.name == "bar" assert tools[1].metadata.description == "bar(arg1: bool) -> str\nBar." - assert tools[1](True) == "bar True" + assert str(tools[1](True)) == "bar True" assert tools[2].metadata.name == "abc" assert tools[2].metadata.description == "abc(arg1: str) -> str\n" assert tools[2].metadata.fn_schema == AbcSchema