-
Notifications
You must be signed in to change notification settings - Fork 14
/
agents.py
77 lines (58 loc) · 3.71 KB
/
agents.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from typing import List
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.agents import initialize_agent, AgentType
from langchain.schema import SystemMessage
from utils import convert_chat_history_to_normal_data_structure, format_messages_history, convert_entities_to_formatted_string
from memory import messages_history, summarizer, entities, summaries
from llm import llm_non_stream
class AICompanionAgent:
def __init__(self, prompt_template, messages_history_threshold: int = 15, verbose: bool = False):
self.prompt = PromptTemplate.from_template(template=prompt_template)
self.conversation_chain = LLMChain(
llm=llm_non_stream, verbose=True, prompt=self.prompt, output_key="response")
self.messages_history_threshold = messages_history_threshold
self.messages_history_counter = 0
self.verbose = verbose
def talk(self, user_input: str):
"""
The `talk` function takes user input, processes it through a conversation chain, adds the user
input and AI response to the message history, and returns the AI response.
:param user_input: The `user_input` parameter is a string that represents the input message from
the user. It is the text that the user wants to communicate or ask the AI
:type user_input: str
:return: the response generated by the AI model.
"""
self.messages_history_counter += 1
previous_messages_summary = summaries[-1] if summaries else ""
formatted_messages_history = format_messages_history(convert_chat_history_to_normal_data_structure(
messages_history.messages), k=self.messages_history_threshold)
formatted_entities = convert_entities_to_formatted_string(entities)
self.conversation_chain.verbose = self.verbose
ai_response = self.conversation_chain(
{"entities": formatted_entities, "input": user_input, "messages_history": formatted_messages_history, "summary": previous_messages_summary})
messages_history.add_user_message(user_input)
messages_history.add_ai_message(ai_response['response'])
if (self.messages_history_counter >= self.messages_history_threshold):
new_summary = summarizer.predict_new_summary(
messages_history.messages[-self.messages_history_threshold:], previous_messages_summary)
summaries.append(new_summary)
self.messages_history_counter = 0
return ai_response['response']
class EntitiesExtractionAgent:
def __init__(self, tools: List, is_agent_verbose: bool = False, max_iterations: int = 3, return_thought_process: bool = False):
entities_extraction_message = SystemMessage(
content=f"""Use your judgement and pick out the information you need from the human's new message, based on the following list of human's data:
{entities}
Use a tool to update the human's profile if new information is presented.
Don't update if there isn't new information from the human's message.
.""")
agent_kwargs = {
"system_message": entities_extraction_message,
}
self.agent = initialize_agent(tools, llm_non_stream, agent=AgentType.OPENAI_FUNCTIONS, verbose=is_agent_verbose,
max_iterations=max_iterations, return_intermediate_steps=return_thought_process, agent_kwargs=agent_kwargs)
def update_user_profile(self, user_input):
# previous_message_for_reference = messages_history.messages[-1] if messages_history.messages else ""
# self.agent.agent.prompt.messages.append(previous_message_for_reference)
return self.agent({"input": user_input})