Skip to content

Commit

Permalink
Add Support for OpenAI Assistants in SAML (#755)
Browse files Browse the repository at this point in the history
* fix duplicate editor instances

* fix typo

* improve logging - adding stack trace

* refactor openai assistant

* uploading data files to openai assistants api

* fix list workflows endpoint

* added openai assistant

* saving tools and datasources in yaml

* add autocompelete for openai assistant tool names

* refactored code

* add deprecation alert in agents page

* fix metadata json errors

* revert agent chat ui page

* add agent icon to sidebar

* Small tweaks

* Minor tweaks

* Small tweaks

* fix openai assistant's not ending stream

* fix deleting all agents

* fix the build type error

* Small tweak

---------

Co-authored-by: Ismail Pelaseyed <homanp@gmail.com>
  • Loading branch information
elisalimli and homanp committed Feb 6, 2024
1 parent 18f00e5 commit 6e9df80
Show file tree
Hide file tree
Showing 31 changed files with 2,276 additions and 425 deletions.
33 changes: 25 additions & 8 deletions libs/superagent/app/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from app.models.request import LLMParams
from app.utils.streaming import CustomAsyncIteratorCallbackHandler
from prisma.models import Agent, AgentDatasource, AgentLLM, AgentTool
from prisma.enums import AgentType
from prisma.models import Agent

DEFAULT_PROMPT = (
"You are a helpful AI Assistant, answer the users questions to "
Expand Down Expand Up @@ -30,21 +31,38 @@ def __init__(
self.agent_config = agent_config

async def _get_tools(
self, agent_datasources: List[AgentDatasource], agent_tools: List[AgentTool]
self,
) -> List:
raise NotImplementedError

async def _get_llm(self, agent_llm: AgentLLM, model: str) -> Any:
async def _get_llm(
self,
) -> Any:
raise NotImplementedError

async def _get_prompt(self, agent: Agent) -> str:
async def _get_prompt(
self,
) -> str:
raise NotImplementedError

async def _get_memory(self) -> List:
raise NotImplementedError

async def get_agent(self):
if self.agent_config.type == "SUPERAGENT":
if self.agent_config.type == AgentType.OPENAI_ASSISTANT:
from app.agents.openai import OpenAiAssistant

agent = OpenAiAssistant(
agent_id=self.agent_id,
session_id=self.session_id,
enable_streaming=self.enable_streaming,
output_schema=self.output_schema,
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
)

else:
from app.agents.langchain import LangchainAgent

agent = LangchainAgent(
Expand All @@ -54,8 +72,7 @@ async def get_agent(self):
output_schema=self.output_schema,
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
)
else:
pass

return await agent.get_agent(config=self.agent_config)
return await agent.get_agent()
47 changes: 29 additions & 18 deletions libs/superagent/app/agents/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,16 @@ async def _get_tools(
metadata = (
{
"datasource_id": agent_datasource.datasource.id,
"options": agent_datasource.datasource.vectorDb.options
if agent_datasource.datasource.vectorDb
else {},
"provider": agent_datasource.datasource.vectorDb.provider
if agent_datasource.datasource.vectorDb
else None,
"options": (
agent_datasource.datasource.vectorDb.options
if agent_datasource.datasource.vectorDb
else {}
),
"provider": (
agent_datasource.datasource.vectorDb.provider
if agent_datasource.datasource.vectorDb
else None
),
"query_type": "document",
}
if tool_type == DatasourceTool
Expand Down Expand Up @@ -115,9 +119,11 @@ async def _get_tools(
description=agent_tool.tool.description,
metadata=agent_tool.tool.metadata,
args_schema=tool_info["schema"],
session_id=f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}",
session_id=(
f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}"
),
return_direct=agent_tool.tool.returnDirect,
)
tools.append(tool)
Expand Down Expand Up @@ -176,9 +182,11 @@ async def _get_prompt(self, agent: Agent) -> str:

async def _get_memory(self) -> List:
memory = MotorheadMemory(
session_id=f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}",
session_id=(
f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}"
),
memory_key="chat_history",
url=config("MEMORY_API_URL"),
return_messages=True,
Expand All @@ -187,12 +195,15 @@ async def _get_memory(self) -> List:
await memory.init()
return memory

async def get_agent(self, config: Agent):
llm = await self._get_llm(agent_llm=config.llms[0], model=config.llmModel)
async def get_agent(self):
llm = await self._get_llm(
agent_llm=self.agent_config.llms[0], model=self.agent_config.llmModel
)
tools = await self._get_tools(
agent_datasources=config.datasources, agent_tools=config.tools
agent_datasources=self.agent_config.datasources,
agent_tools=self.agent_config.tools,
)
prompt = await self._get_prompt(agent=config)
prompt = await self._get_prompt(agent=self.agent_config)
memory = await self._get_memory()

if len(tools) > 0:
Expand All @@ -213,8 +224,8 @@ async def get_agent(self, config: Agent):
return agent
else:
prompt_base = (
f"{config.prompt.replace('{', '{{').replace('}', '}}')}"
if config.prompt
f"{self.agent_config.prompt.replace('{', '{{').replace('}', '}}')}"
if self.agent_config.prompt
else None
)
prompt_base = prompt_base or DEFAULT_PROMPT
Expand Down
55 changes: 55 additions & 0 deletions libs/superagent/app/agents/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio

from langchain.agents import AgentExecutor
from langchain.agents.openai_assistant import OpenAIAssistantRunnable
from langchain.schema.messages import AIMessage
from langchain.schema.output import ChatGeneration, LLMResult

from app.agents.base import AgentBase


class OpenAiAssistant(AgentBase):
async def get_agent(self):
assistant_id = self.agent_config.metadata.get("id")
agent = OpenAIAssistantRunnable(assistant_id=assistant_id, as_agent=True)
enable_streaming = self.enable_streaming

class CustomAgentExecutor(AgentExecutor):
async def ainvoke(self, *args, **kwargs):
res = await super().ainvoke(*args, **kwargs)

if enable_streaming:
output = res.get("output").split(" ")
# TODO: find a better way to get the streaming callback
streaming = kwargs["config"]["callbacks"][0]
await streaming.on_llm_start()

# stream the tokens. after finishing, call the on_llm_end. (make sure you call it after all the tokens are streamed)
# make sure to call on_llm_end after all the tokens are streamed
tasks = []

for token in output:
task = streaming.on_llm_new_token(token + " ")
tasks.append(task)

await asyncio.gather(*tasks)

await streaming.on_llm_end(
response=LLMResult(
generations=[
[
ChatGeneration(
message=AIMessage(
content=res.get("output"),
)
)
]
],
)
)

return res

agent_executor = CustomAgentExecutor(agent=agent, tools=[])

return agent_executor
Loading

0 comments on commit 6e9df80

Please sign in to comment.