Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion splunklib/ai/engines/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import os
import string
import uuid
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -1416,11 +1417,22 @@ def _denormalize_tool_name(name: str) -> str:
return name


def _is_agent_name_valid(name: str) -> bool:
AGENT_NAME_ALLOWED_CHARS = string.ascii_letters + string.digits + "_-"
if not (1 <= len(name) <= 128):
return False

return set(name).issubset(AGENT_NAME_ALLOWED_CHARS)


def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool:
if not agent.name:
raise AssertionError("Agent must have a name to be used by other Agents")

# TODO: restrict subagent names
if not _is_agent_name_valid(agent.name):
raise AssertionError(
"Agent name is invalid, must contain only letters, numbers, '_' or '-' and have max 128 characters"
)

async def invoke_agent(
message: HumanMessage, thread_id: str | None
Expand Down
60 changes: 60 additions & 0 deletions tests/integration/ai/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,66 @@ async def test_duplicated_subagent_name(self) -> None:
):
pass

@pytest.mark.asyncio
async def test_subagent_with_invalid_name(self) -> None:
pytest.importorskip("langchain_openai")

async with (
Agent(
model=(await self.model()),
system_prompt="",
service=self.service,
name="invalid name",
) as subagent_invalid,
Agent(
model=(await self.model()),
system_prompt="",
service=self.service,
name="invalid@name",
) as subagent_invalid2,
Agent(
model=(await self.model()),
system_prompt="",
service=self.service,
name="a" * 129,
) as subagent_too_long,
):
with pytest.raises(
AssertionError,
match="Agent name is invalid",
):
async with Agent(
model=(await self.model()),
system_prompt="",
service=self.service,
agents=[subagent_invalid],
):
pass

with pytest.raises(
AssertionError,
match="Agent name is invalid",
):
async with Agent(
model=(await self.model()),
system_prompt="",
service=self.service,
agents=[subagent_invalid2],
):
pass

with pytest.raises(
AssertionError,
match="Agent name is invalid",
):
async with Agent(
model=(await self.model()),
system_prompt="",
service=self.service,
agents=[subagent_too_long],
):
pass

@pytest.mark.asyncio
async def test_subagent_soft_failure_with_invalid_args(self) -> None:
pytest.importorskip("langchain_openai")
Expand Down
Loading