Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/openhuman/agent/harness/session/turn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ use crate::openhuman::agent::harness;
use crate::openhuman::agent::hooks::{self, ToolCallRecord, TurnContext};
use crate::openhuman::agent::memory_loader::collect_recall_citations;
use crate::openhuman::agent::progress::AgentProgress;
use crate::openhuman::agent::tool_policy::{ToolPolicyDecision, ToolPolicyRequest};
use crate::openhuman::agent::tool_policy::{
ToolCallContext, ToolPolicyDecision, ToolPolicyRequest,
};
use crate::openhuman::agent_experience::{
prepend_experience_block, render_experience_hits, AgentExperienceStore, ExperienceQuery,
};
Expand Down Expand Up @@ -1163,13 +1165,15 @@ impl Agent {
false,
)
} else {
let policy_request = ToolPolicyRequest {
tool_name: call.name.clone(),
arguments: call.arguments.clone(),
session_id: self.event_session_id().to_string(),
channel: self.event_channel().to_string(),
agent_definition_id: self.agent_definition_id.to_string(),
};
let context = ToolCallContext::session(
self.event_session_id(),
self.event_channel(),
self.agent_definition_id.to_string(),
call_id.clone(),
(iteration + 1) as u32,
);
let policy_request =
ToolPolicyRequest::new(call.name.clone(), call.arguments.clone(), context);
if let ToolPolicyDecision::Deny { reason } =
self.tool_policy.check(&policy_request).await
{
Expand Down
8 changes: 5 additions & 3 deletions src/openhuman/agent/harness/session/turn_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,11 @@ impl ToolPolicy for DenyCountingPolicy {

async fn check(&self, request: &ToolPolicyRequest) -> ToolPolicyDecision {
assert_eq!(request.tool_name, "counting");
assert_eq!(request.session_id, "turn-test-session");
assert_eq!(request.channel, "turn-test-channel");
assert_eq!(request.agent_definition_id, "main");
assert_eq!(request.context.session_id, "turn-test-session");
assert_eq!(request.context.channel, "turn-test-channel");
assert_eq!(request.context.agent_definition_id, "main");
assert_eq!(request.context.call_id, "policy-1");
assert_eq!(request.context.iteration, 1);
ToolPolicyDecision::deny("locked by test policy")
}
}
Expand Down
156 changes: 148 additions & 8 deletions src/openhuman/agent/tool_policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,126 @@
//! deny a tool before any side effect reaches the tool implementation.

use async_trait::async_trait;
use std::fmt;

/// Structured context for a tool call before it reaches the tool
/// implementation.
#[derive(Clone, PartialEq, Eq)]
pub struct ToolCallContext {
Comment thread
vaddisrinivas marked this conversation as resolved.
pub session_id: String,
pub channel: String,
pub agent_definition_id: String,
pub call_id: String,
pub iteration: u32,
pub source: ToolCallSource,
}

impl ToolCallContext {
pub fn session(
session_id: impl Into<String>,
channel: impl Into<String>,
agent_definition_id: impl Into<String>,
call_id: impl Into<String>,
iteration: u32,
) -> Self {
Self {
session_id: session_id.into(),
channel: channel.into(),
agent_definition_id: agent_definition_id.into(),
call_id: call_id.into(),
iteration,
source: ToolCallSource::Session,
}
}
}

impl fmt::Debug for ToolCallContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ToolCallContext")
.field("session_id", &redact_for_debug(&self.session_id))
.field("channel", &redact_for_debug(&self.channel))
.field("agent_definition_id", &self.agent_definition_id)
.field("call_id", &self.call_id)
.field("iteration", &self.iteration)
.field("source", &self.source)
.finish()
}
}

/// Entry point that produced a tool call.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)] // Reserved for non-session tool ingress paths wired in follow-up PRs.
pub enum ToolCallSource {
Session,
Bus,
Channel,
Comment thread
vaddisrinivas marked this conversation as resolved.
Cron,
Webhook,
Unknown,
}

/// Snapshot of the tool call and session context a policy can inspect.
#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct ToolPolicyRequest {
pub tool_name: String,
pub arguments: serde_json::Value,
pub context: ToolCallContext,
/// Backward-compatible mirror of `context.session_id`.
#[deprecated(note = "use context.session_id")]
pub session_id: String,
/// Backward-compatible mirror of `context.channel`.
Comment thread
vaddisrinivas marked this conversation as resolved.
#[deprecated(note = "use context.channel")]
pub channel: String,
/// Backward-compatible mirror of `context.agent_definition_id`.
#[deprecated(note = "use context.agent_definition_id")]
pub agent_definition_id: String,
}

impl fmt::Debug for ToolPolicyRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
#[allow(deprecated)]
{
f.debug_struct("ToolPolicyRequest")
.field("tool_name", &self.tool_name)
.field("arguments", &"<redacted>")
.field("context", &self.context)
.field("session_id", &redact_for_debug(&self.session_id))
.field("channel", &redact_for_debug(&self.channel))
.field("agent_definition_id", &self.agent_definition_id)
.finish()
}
}
}

impl ToolPolicyRequest {
pub fn new(
tool_name: impl Into<String>,
arguments: serde_json::Value,
context: ToolCallContext,
) -> Self {
#[allow(deprecated)]
{
Self {
tool_name: tool_name.into(),
arguments,
session_id: context.session_id.clone(),
channel: context.channel.clone(),
agent_definition_id: context.agent_definition_id.clone(),
context,
}
}
}
}

fn redact_for_debug(value: &str) -> String {
let trimmed = value.trim();
if trimmed.is_empty() {
return "<empty>".to_string();
}
let prefix: String = trimmed.chars().take(4).collect();
format!("{prefix}...")
}

/// Decision returned by a [`ToolPolicy`].
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ToolPolicyDecision {
Expand Down Expand Up @@ -63,14 +172,45 @@ mod tests {
#[tokio::test]
async fn allow_all_policy_allows_every_call() {
let policy = AllowAllToolPolicy;
let request = ToolPolicyRequest {
tool_name: "echo".into(),
arguments: serde_json::json!({ "value": 1 }),
session_id: "session".into(),
channel: "chat".into(),
agent_definition_id: "orchestrator".into(),
};
let request = ToolPolicyRequest::new(
"echo",
serde_json::json!({ "value": 1 }),
ToolCallContext::session("session", "chat", "orchestrator", "call-1", 1),
);

assert_eq!(policy.check(&request).await, ToolPolicyDecision::Allow);
#[allow(deprecated)]
{
assert_eq!(request.session_id, request.context.session_id);
assert_eq!(request.channel, request.context.channel);
assert_eq!(
request.agent_definition_id,
request.context.agent_definition_id
);
}
assert_eq!(request.context.source, ToolCallSource::Session);
assert_eq!(request.context.call_id, "call-1");
}

#[test]
fn debug_redacts_sensitive_context_fields() {
let request = ToolPolicyRequest::new(
"secrets.lookup",
serde_json::json!({ "secret": "super-secret-token" }),
ToolCallContext::session(
"session-secret-123",
"private-channel",
"orchestrator",
"call-1",
1,
),
);

let rendered = format!("{request:?}");
assert!(rendered.contains("sess..."));
assert!(rendered.contains("priv..."));
assert!(!rendered.contains("session-secret-123"));
assert!(!rendered.contains("private-channel"));
assert!(!rendered.contains("super-secret-token"));
}
}
Loading