diff --git a/main.py b/main.py index 3b05882..76e3655 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ import uuid import uvicorn +from functools import partial from fastapi import FastAPI, UploadFile from fastapi.encoders import jsonable_encoder from fastapi.responses import PlainTextResponse @@ -17,6 +18,9 @@ parser.add_argument('--langchain', action='store_true') parser.add_argument('--towhee', action='store_true') parser.add_argument('--moniter', action='store_true') +parser.add_argument('--agent', action='store_true', + help='The default is False, which only works when `--langchain` is enabled.' + ' It means using the agent in langchain to dynamically select tools.') parser.add_argument('--max_observation', default=1000) parser.add_argument('--name', default=str(uuid.uuid4())) args = parser.parse_args() @@ -29,6 +33,7 @@ USE_TOWHEE = args.towhee MAX_OBSERVATION = args.max_observation ENABLE_MONITER = args.moniter +ENABLE_AGENT = args.agent NAME = args.name assert (USE_LANGCHAIN and not USE_TOWHEE ) or (USE_TOWHEE and not USE_LANGCHAIN), \ @@ -36,6 +41,7 @@ if USE_LANGCHAIN: from src_langchain.operations import chat, insert, drop, check, get_history, clear_history, count # pylint: disable=C0413 + chat = partial(chat, enable_agent=ENABLE_AGENT) if USE_TOWHEE: from src_towhee.operations import chat, insert, drop, check, get_history, clear_history, count # pylint: disable=C0413 if ENABLE_MONITER: diff --git a/requirements.txt b/requirements.txt index 35c55cf..f73fc10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -langchain==0.0.230 +langchain==0.0.322 unstructured pexpect pdf2image diff --git a/src_langchain/agent/chat_agent.py b/src_langchain/agent/chat_agent.py index 2b85ce7..e8a2c0d 100644 --- a/src_langchain/agent/chat_agent.py +++ b/src_langchain/agent/chat_agent.py @@ -1,11 +1,11 @@ from typing import Any, List, Optional, Sequence -from pydantic import Field -from langchain.agents.conversational_chat.prompt import PREFIX, SUFFIX +from pydantic import Field +from langchain.schema.language_model import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.agents import ConversationalChatAgent, AgentOutputParser, Agent from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseOutputParser, BaseLanguageModel +from langchain.schema import BaseOutputParser from langchain.tools.base import BaseTool from .output_parser import OutputParser diff --git a/src_langchain/data_loader/data_parser.py b/src_langchain/data_loader/data_parser.py index 5efa94c..132cade 100644 --- a/src_langchain/data_loader/data_parser.py +++ b/src_langchain/data_loader/data_parser.py @@ -42,7 +42,7 @@ def __call__(self, data_src, source_type: str = 'file') -> List[str]: token_count += len(self.enc.encode(doc)) return docs, token_count - def from_files(self, files: list, encoding: Optional[str] = None) -> List[Document]: + def from_files(self, files: list, encoding: Optional[str] = 'utf-8') -> List[Document]: '''Load documents from path or file-like object, return a list of unsplit LangChain Documents''' docs = [] for file in files: diff --git a/src_langchain/operations.py b/src_langchain/operations.py index ff6cce7..8d6ab61 100644 --- a/src_langchain/operations.py +++ b/src_langchain/operations.py @@ -4,6 +4,7 @@ from typing import List from langchain.agents import Tool, AgentExecutor +from langchain.chains import ConversationalRetrievalChain sys.path.append(os.path.dirname(__file__)) @@ -13,7 +14,6 @@ from store import MemoryStore, DocStore # pylint: disable=C0413 from data_loader import DataParser # pylint: disable=C0413 - logger = logging.getLogger(__name__) encoder = TextEncoder() @@ -21,7 +21,7 @@ load_data = DataParser() -def chat(session_id, project, question): +def chat(session_id, project, question, enable_agent=False): '''Chat API''' doc_db = DocStore( table_name=project, @@ -29,25 +29,37 @@ def chat(session_id, project, question): ) memory_db = MemoryStore(table_name=project, session_id=session_id) - tools = [ - Tool( - name='Search', - func=doc_db.search, - description='Search through Milvus.' + if enable_agent: # use agent + memory_db.memory.output_key = None + tools = [ + Tool( + name='Search', + func=doc_db.search, + description='useful for search professional knowledge and information' + ) + ] + agent = ChatAgent.from_llm_and_tools(llm=chat_llm, tools=tools) + agent_chain = AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + memory=memory_db.memory, + verbose=False ) - ] - agent = ChatAgent.from_llm_and_tools(llm=chat_llm, tools=tools) - agent_chain = AgentExecutor.from_agent_and_tools( - agent=agent, - tools=tools, - memory=memory_db.memory, - verbose=False - ) - try: - final_answer = agent_chain.run(input=question) - return question, final_answer - except Exception as e: # pylint: disable=W0703 - return question, f'Something went wrong:\n{e}' + try: + final_answer = agent_chain.run(input=question) + return question, final_answer + except Exception as e: # pylint: disable=W0703 + return question, f'Something went wrong:\n{e}' + else: # use chain + memory_db.memory.output_key = 'answer' + qa = ConversationalRetrievalChain.from_llm( + llm=chat_llm, + retriever=doc_db.vector_db.as_retriever(), + memory=memory_db.memory, + return_generated_question=True + ) + qa_result = qa(question) + return qa_result['generated_question'], qa_result['answer'] def insert(data_src, project, source_type: str = 'file'): @@ -93,6 +105,7 @@ def check(project): raise RuntimeError from e return {'store': doc_check, 'memory': memory_check} + def count(project): '''Count entities.''' try: @@ -130,7 +143,6 @@ def load(document_strs: List[str], project: str): num = doc_db.insert(document_strs) return num - # if __name__ == '__main__': # project = 'akcio' # data_src = 'https://docs.towhee.io/'