diff --git a/splunklib/ai/engines/langchain.py b/splunklib/ai/engines/langchain.py index bd103724..d5d305ab 100644 --- a/splunklib/ai/engines/langchain.py +++ b/splunklib/ai/engines/langchain.py @@ -15,6 +15,7 @@ import json import logging import os +import string import uuid from collections.abc import Awaitable, Callable, Sequence from dataclasses import asdict, dataclass @@ -1416,11 +1417,22 @@ def _denormalize_tool_name(name: str) -> str: return name +def _is_agent_name_valid(name: str) -> bool: + AGENT_NAME_ALLOWED_CHARS = string.ascii_letters + string.digits + "_-" + if not (1 <= len(name) <= 128): + return False + + return set(name).issubset(AGENT_NAME_ALLOWED_CHARS) + + def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool: if not agent.name: raise AssertionError("Agent must have a name to be used by other Agents") - # TODO: restrict subagent names + if not _is_agent_name_valid(agent.name): + raise AssertionError( + "Agent name is invalid, must contain only letters, numbers, '_' or '-' and have max 128 characters" + ) async def invoke_agent( message: HumanMessage, thread_id: str | None diff --git a/tests/integration/ai/test_agent.py b/tests/integration/ai/test_agent.py index f587c927..1f9ea591 100644 --- a/tests/integration/ai/test_agent.py +++ b/tests/integration/ai/test_agent.py @@ -440,6 +440,66 @@ async def test_duplicated_subagent_name(self) -> None: ): pass + @pytest.mark.asyncio + async def test_subagent_with_invalid_name(self) -> None: + pytest.importorskip("langchain_openai") + + async with ( + Agent( + model=(await self.model()), + system_prompt="", + service=self.service, + name="invalid name", + ) as subagent_invalid, + Agent( + model=(await self.model()), + system_prompt="", + service=self.service, + name="invalid@name", + ) as subagent_invalid2, + Agent( + model=(await self.model()), + system_prompt="", + service=self.service, + name="a" * 129, + ) as subagent_too_long, + ): + with pytest.raises( + AssertionError, + match="Agent name is invalid", + ): + async with Agent( + model=(await self.model()), + system_prompt="", + service=self.service, + agents=[subagent_invalid], + ): + pass + + with pytest.raises( + AssertionError, + match="Agent name is invalid", + ): + async with Agent( + model=(await self.model()), + system_prompt="", + service=self.service, + agents=[subagent_invalid2], + ): + pass + + with pytest.raises( + AssertionError, + match="Agent name is invalid", + ): + async with Agent( + model=(await self.model()), + system_prompt="", + service=self.service, + agents=[subagent_too_long], + ): + pass + @pytest.mark.asyncio async def test_subagent_soft_failure_with_invalid_args(self) -> None: pytest.importorskip("langchain_openai")