Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,6 @@
}
],
"./splunklib/ai/tools.py": [
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 15,
"endColumn": 31,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 48,
"endColumn": 56,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
Expand Down
56 changes: 53 additions & 3 deletions splunklib/ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.

import asyncio
import os
from collections.abc import AsyncGenerator, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
Expand Down Expand Up @@ -46,6 +47,7 @@
_testing_app_id: str | None = None

DEFAULT_TOOL_SETTINGS = ToolSettings(local=False, remote=None)
_SPLUNK_SYSTEM_USER = "splunk-system-user"


@final
Expand Down Expand Up @@ -181,9 +183,14 @@ async def _start_agent(self) -> AsyncGenerator[Self]:
"internal error: _impl was not set to None after agent invocation"
)

splunk_username = await asyncio.to_thread(
lambda: _get_splunk_username(self._service)
)
_validate_agent_privileges(splunk_username)

self.logger.debug(f"Creating agent {self.name=}; {self.trace_id=}")

self._tools = await self._load_tools(stack)
self._tools = await self._load_tools(stack, splunk_username)

backend = get_backend()
self._impl = await backend.create_agent(self)
Expand All @@ -194,7 +201,9 @@ async def _start_agent(self) -> AsyncGenerator[Self]:

self._impl = None

async def _load_tools(self, stack: AsyncExitStack) -> list[Tool]:
async def _load_tools(
self, stack: AsyncExitStack, splunk_username: str
) -> list[Tool]:
tools: list[Tool] = []
if not self.tool_settings.local and not self.tool_settings.remote:
return tools
Expand Down Expand Up @@ -225,7 +234,9 @@ async def _load_tools(self, stack: AsyncExitStack) -> list[Tool]:
if self.tool_settings.remote:
self.logger.debug("Probing MCP Server App availability")
remote_session = await stack.enter_async_context(
connect_remote_mcp(self._service, app_id, self.trace_id)
connect_remote_mcp(
self._service, app_id, self.trace_id, splunk_username
)
)

if remote_session:
Expand Down Expand Up @@ -301,6 +312,10 @@ async def invoke_with_data(
)


class PrivilegedExecutionError(Exception):
pass


def _local_tools_path() -> tuple[str | None, str]:
local_tools_path = _testing_local_tools_path
app_id = _testing_app_id
Expand All @@ -317,3 +332,38 @@ def _local_tools_path() -> tuple[str | None, str]:
local_tools_path = None

return local_tools_path, app_id


def _get_splunk_username(service: Service) -> str:
class Content(BaseModel):
username: str

class Entry(BaseModel):
content: Content

class ResponseBody(BaseModel):
entry: list[Entry]

# Query Splunk API for the username.
res = service.get(
path_segment="authentication/current-context",
output_mode="json",
)

body = ResponseBody.model_validate_json(str(res.body)) # pyright: ignore[reportUnknownArgumentType]
if len(body.entry) == 0:
return ""
return body.entry[0].content.username


def _validate_agent_privileges(username: str) -> None:
"""Enforces that the agent is not executed under a system account.

Raises:
PrivilegedExecutionError: If the current execution context corresponds
to a disallowed system account.
"""
if username == _SPLUNK_SYSTEM_USER:
raise PrivilegedExecutionError(
f"Agent must not be executed by the system user: {_SPLUNK_SYSTEM_USER}"
)
35 changes: 6 additions & 29 deletions splunklib/ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,37 +247,11 @@ def _convert_tool_result(
)


def _get_splunk_username(service: Service) -> str:
if service.username:
return service.username

class Content(BaseModel):
username: str

class Entry(BaseModel):
content: Content

class ResponseBody(BaseModel):
entry: list[Entry]

# In case service.username is unavailable, query Splunk API for the username.
# This can happen when a service is created with a token, without username/password.
res = service.get(
path_segment="authentication/current-context",
output_mode="json",
)

body = ResponseBody.model_validate_json(str(res.body))
if len(body.entry) == 0:
return ""
return body.entry[0].content.username


def _get_mcp_token(service: Service) -> str | None:
def _get_mcp_token(splunk_username: str, service: Service) -> str | None:
try:
res = service.get(
path_segment="mcp_token",
username=_get_splunk_username(service),
username=splunk_username,
output_mode="json",
)
except HTTPError as e:
Expand Down Expand Up @@ -324,10 +298,13 @@ async def connect_remote_mcp(
service: Service,
app_id: str,
trace_id: str,
splunk_username: str,
) -> AsyncGenerator[ClientSession | None]:
management_url = f"{service.scheme}://{service.host}:{service.port}"
mcp_url = f"{management_url}/services/mcp"
mcp_token = await asyncio.to_thread(lambda: _get_mcp_token(service))
mcp_token = await asyncio.to_thread(
lambda: _get_mcp_token(splunk_username, service)
)
if mcp_token is not None:
async with streamable_http_client(
url=mcp_url,
Expand Down
48 changes: 40 additions & 8 deletions tests/integration/ai/test_agent_mcp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from starlette.routing import Mount, Route

from splunklib.ai import Agent
from splunklib.ai.agent import (
_get_splunk_username, # pyright: ignore[reportPrivateUsage]
)
from splunklib.ai.engines.langchain import LOCAL_TOOL_PREFIX
from splunklib.ai.messages import (
AIMessage,
Expand Down Expand Up @@ -50,7 +53,6 @@
)
from splunklib.ai.tools import (
ToolType,
_get_splunk_username, # pyright: ignore[reportPrivateUsage]
locate_app,
)
from splunklib.client import connect
Expand Down Expand Up @@ -296,6 +298,12 @@ async def mcp_token_handler(_: Request) -> Response:
return JSONResponse(content={"token": AUTH_TOKEN}, status_code=200)


async def current_context_handler(_: Request) -> Response:
return JSONResponse(
content={"entry": [{"content": {"username": "admin"}}]}, status_code=200
)


class TestRemoteTools(AITestCase):
@patch(
"splunklib.ai.agent._testing_local_tools_path",
Expand Down Expand Up @@ -364,6 +372,11 @@ async def dispatch(
routes=[
Mount("/services/mcp", app=mcp.streamable_http_app()),
Route("/services/mcp_token", mcp_token_handler, methods=["GET"]),
Route(
"/services/authentication/current-context",
current_context_handler,
methods=["GET"],
),
],
lifespan=lifespan,
middleware=[Middleware(MCPMiddleware)],
Expand All @@ -376,7 +389,6 @@ async def dispatch(
port=port,
splunkToken=AUTH_TOKEN,
autologin=True,
username="admin", # not required, but set to avoid mocking the authentication/current-context endpoint
),
)

Expand Down Expand Up @@ -427,15 +439,24 @@ async def dispatch(
async def test_remote_tools_mcp_app_unavailable(self) -> None:
pytest.importorskip("langchain_openai")

async with run_http_server(Starlette(routes=[])) as (host, port):
async with run_http_server(
Starlette(
routes=[
Route(
"/services/authentication/current-context",
current_context_handler,
methods=["GET"],
),
]
)
) as (host, port):
service = await asyncio.to_thread(
lambda: connect(
scheme="http",
host=host,
port=port,
splunkToken=AUTH_TOKEN,
autologin=True,
username="admin", # not required, but set to avoid mocking the authentication/current-context endpoint
),
)

Expand Down Expand Up @@ -489,6 +510,11 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, Any]:
routes=[
Mount("/services/mcp", app=mcp.streamable_http_app()),
Route("/services/mcp_token", mcp_token_handler, methods=["GET"]),
Route(
"/services/authentication/current-context",
current_context_handler,
methods=["GET"],
),
],
lifespan=lifespan,
)
Expand All @@ -500,7 +526,6 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, Any]:
port=port,
splunkToken=AUTH_TOKEN,
autologin=True,
username="admin", # not required, but set to avoid mocking the authentication/current-context endpoint
),
)

Expand Down Expand Up @@ -579,6 +604,11 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, Any]:
routes=[
Mount("/services/mcp", app=mcp.streamable_http_app()),
Route("/services/mcp_token", mcp_token_handler, methods=["GET"]),
Route(
"/services/authentication/current-context",
current_context_handler,
methods=["GET"],
),
],
lifespan=lifespan,
)
Expand All @@ -590,7 +620,6 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, Any]:
port=port,
splunkToken=AUTH_TOKEN,
autologin=True,
username="admin", # not required, but set to avoid mocking the authentication/current-context endpoint
),
)

Expand Down Expand Up @@ -732,6 +761,11 @@ async def lifespan(_app: Starlette) -> AsyncGenerator[None, Any]:
routes=[
Mount("/services/mcp", app=mcp.streamable_http_app()),
Route("/services/mcp_token", mcp_token_handler, methods=["GET"]),
Route(
"/services/authentication/current-context",
current_context_handler,
methods=["GET"],
),
],
lifespan=lifespan,
)
Expand All @@ -743,8 +777,6 @@ async def lifespan(_app: Starlette) -> AsyncGenerator[None, Any]:
port=port,
splunkToken=AUTH_TOKEN,
autologin=True,
# To avoid mocking `authentication/current-context` endpoint
username="admin",
),
)

Expand Down
35 changes: 35 additions & 0 deletions tests/unit/ai/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import pytest

from splunklib.ai import Agent, OpenAIModel
from splunklib.ai.agent import PrivilegedExecutionError
from splunklib.ai.messages import AgentResponse, AIMessage, HumanMessage
from splunklib.ai.middleware import (
AgentMiddlewareHandler,
Expand All @@ -28,6 +30,8 @@
detect_injection,
truncate_input,
)
from splunklib.client import Service
from splunklib.data import Record


class TestDetectInjection(unittest.TestCase):
Expand Down Expand Up @@ -168,3 +172,34 @@ async def handler(_request: AgentRequest) -> AgentResponse[Any]:
)
await middleware.agent_middleware(request, handler)
assert called


class TestPrivilegedExecution(unittest.IsolatedAsyncioTestCase):
@pytest.mark.asyncio
async def test_agent_with_system_user(self) -> None:
model = OpenAIModel(
model="test-model", base_url="test-url", api_key="test-api-key"
)

def handler(url: str, _message: dict[str, Any], **_kwargs: dict[str, Any]):
assert (
url
== "https://localhost:8089/services/authentication/current-context?output_mode=json"
)
return Record(
{
"status": 200,
"headers": [],
"body": '{"entry": [{"content": {"username": "splunk-system-user"}}]}',
}
)

service = Service(token="test-token", handler=handler)

with pytest.raises(PrivilegedExecutionError, match="splunk-system-user"):
async with Agent(
model=model,
system_prompt="Your name is stefan",
service=service,
):
...