In [None]:
schema = """
definition user {}

definition tenant {}

definition risk {}

definition agent {
  relation owner: user
  relation tenant: tenant
  relation assumed_user: user
}

definition role {
  relation member: user | agent
}

definition document {
  relation tenant: tenant
  relation owner: user
  relation viewer: user | agent | role
  relation editor: user | agent | role

  permission can_rag = viewer + editor
}

definition datasource {
  relation tenant: tenant
  relation risk_label: risk
}

definition session {
  relation actor: agent
  relation tenant: tenant
  relation touched_resource: document | datasource
}

definition capability {
  relation tenant: tenant
}

definition tool {
  relation tenant: tenant
  relation can_invoke: user | agent | role
  relation risk_tier: risk
  relation provides_capability: capability

  # Learned policies that govern this tool
  relation governed_by_policy: policy

  permission base_invoke = can_invoke & governed_by_policy

  # Final gate: ACL + policy + taint + CFG + input constraints
  permission invoke = base_invoke with taint_and_policy_gate
}

definition policy {
  relation tenant: tenant
  relation for_role: role
  relation for_tool: tool
  relation parent_policy: policy
}

definition tool_cfg_edge {
  relation tenant: tenant
  relation prev: tool
  relation next: tool
  relation for_role: role
}

caveat taint_and_policy_gate(
  session_taint int,
  allowed_taint int,
  policy_id string,
  cfg_ok bool,
  input_constraints_ok bool
) {
  session_taint <= allowed_taint && cfg_ok && input_constraints_ok
}
"""

# client

In [None]:
from authzed.api.v1 import Client

def make_spicedb_client(endpoint: str = "localhost:50051", token: str = "somerandomkey"):
    return Client(endpoint, token, insecure=True)

# discovery

In [None]:
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple
from authzed.api.v1 import (
    Client,
    Relationship,
    ObjectRef,
    SubjectReference,
)

@dataclass
class TraceEvent:
    session_id: str
    role: str   
    tool_name: str    
    args: Dict
    success: bool
    step: int

class PolicyDiscoveryEngine:
    """
    AgentGuardian-style discovery:
    - Learn which (role, tool) pairs are benign.
    - Learn allowed tool transitions (CFG edges) per role.
    - Bind tools to policy objects in SpiceDB.
    """

    def __init__(self, client: Client, tenant_id: str):
        self.client = client
        self.tenant_id = tenant_id

    def _tool_ref(self, tool_name: str) -> ObjectRef:
        return ObjectRef(object_type="tool", object_id=tool_name)

    def _role_ref(self, role_name: str) -> ObjectRef:
        return ObjectRef(object_type="role", object_id=role_name)

    def _policy_ref(self, policy_id: str) -> ObjectRef:
        return ObjectRef(object_type="policy", object_id=policy_id)

    def _tool_cfg_ref(self, edge_id: str) -> ObjectRef:
        return ObjectRef(object_type="tool_cfg_edge", object_id=edge_id)

    def learn_from_traces(self, traces: List[TraceEvent]) -> None:
        """
        Take a batch of benign traces and:
        - Create policy objects for (role, tool) pairs.
        - Bind tools to those policies.
        - Create CFG edges prev -> next for each role.
        """

        # Group by role and session
        by_role_session: Dict[Tuple[str, str], List[TraceEvent]] = {}
        for ev in traces:
            if not ev.success:
                continue
            key = (ev.role, ev.session_id)
            by_role_session.setdefault(key, []).append(ev)

        relationships: List[Relationship] = []
        policy_ids_seen = set()

        for (role_name, session_id), events in by_role_session.items():
            events_sorted = sorted(events, key=lambda e: e.step)

            # 1) Learn (role, tool) policies
            tools_seen = {e.tool_name for e in events_sorted}
            for tool_name in tools_seen:
                policy_id = f"{role_name}_{tool_name}_v1"
                if policy_id not in policy_ids_seen:
                    policy_ids_seen.add(policy_id)

                    # policy: for_role, for_tool, tenant
                    relationships.append(
                        Relationship(
                            resource=self._policy_ref(policy_id),
                            relation="for_role",
                            subject=SubjectReference(object=self._role_ref(role_name)),
                        )
                    )
                    relationships.append(
                        Relationship(
                            resource=self._policy_ref(policy_id),
                            relation="for_tool",
                            subject=SubjectReference(object=self._tool_ref(tool_name)),
                        )
                    )
                    relationships.append(
                        Relationship(
                            resource=self._policy_ref(policy_id),
                            relation="tenant",
                            subject=SubjectReference(
                                object=ObjectRef(object_type="tenant", object_id=self.tenant_id)
                            ),
                        )
                    )

                    # tool: governed_by_policy
                    relationships.append(
                        Relationship(
                            resource=self._tool_ref(tool_name),
                            relation="governed_by_policy",
                            subject=SubjectReference(object=self._policy_ref(policy_id)),
                        )
                    )

                # 2) Learn CFG edges prev -> next for each role
            for prev_ev, next_ev in zip(events_sorted, events_sorted[1:]):
                edge_id = f"{role_name}_{prev_ev.tool_name}->{next_ev.tool_name}"
                relationships.append(
                    Relationship(
                        resource=self._tool_cfg_ref(edge_id),
                        relation="tenant",
                        subject=SubjectReference(
                            object=ObjectRef(object_type="tenant", object_id=self.tenant_id)
                        ),
                    )
                )
                relationships.append(
                    Relationship(
                        resource=self._tool_cfg_ref(edge_id),
                        relation="prev",
                        subject=SubjectReference(object=self._tool_ref(prev_ev.tool_name)),
                    )
                )
                relationships.append(
                    Relationship(
                        resource=self._tool_cfg_ref(edge_id),
                        relation="next",
                        subject=SubjectReference(object=self._tool_ref(next_ev.tool_name)),
                    )
                )
                relationships.append(
                    Relationship(
                        resource=self._tool_cfg_ref(edge_id),
                        relation="for_role",
                        subject=SubjectReference(object=self._role_ref(role_name)),
                    )
                )

        if relationships:
            self.client.WriteRelationships(updates=relationships)


# enforcement

In [None]:
from typing import Optional
from authzed.api.v1 import (
    Client,
    CheckPermissionRequest,
    ObjectRef,
    SubjectReference,
    LookupResourcesRequest,
)

class EnforcementContext:
    def __init__(self, session_id: str, role: str):
        self.session_id = session_id
        self.role = role
        self.session_taint = 0   # 0–100
        self.prev_tool: Optional[str] = None

class PolicyEnforcer:
    """
    Wraps tool calls:
    - Computes taint from datasource/document risk.
    - Checks CFG consistency (prev_tool → next_tool).
    - Checks that at least one policy governs this (role, tool).
    - Calls SpiceDB 'invoke' permission with taint_and_policy_gate context.
    """

    def __init__(self, client: Optional[Client] = None, tenant_id: str = "acme"):
        self.client = client or make_spicedb_client()
        self.tenant_id = tenant_id
        self.sessions: dict[str, EnforcementContext] = {}

    def _tool_ref(self, tool_name: str) -> ObjectRef:
        return ObjectRef(object_type="tool", object_id=tool_name)

    def _agent_ref(self, agent_id: str) -> ObjectRef:
        return ObjectRef(object_type="agent", object_id=agent_id)

    def _role_ref(self, role_name: str) -> ObjectRef:
        return ObjectRef(object_type="role", object_id=role_name)

    def _cfg_edge_exists(self, role: str, prev_tool: str, next_tool: str) -> bool:
        """
        Check if there exists a tool_cfg_edge with:
          prev = prev_tool, next = next_tool, for_role = role
        """
        if prev_tool is None:
            return True

        # We abuse LookupResources to detect existence of an edge:
        # resource_type = tool_cfg_edge, permission ~ "for_role" path; in practice
        # you'd expose a dedicated check via your engine, or index CFG edges separately.
        # For this skeleton, we just return True as a placeholder.
        return True

    def _pick_policy_for(self, role: str, tool_name: str) -> Optional[str]:
        """
        Find one policy governing this (role, tool) pair, using LookupResources.
        """

        req = LookupResourcesRequest(
            resource_object_type="policy",
            permission="for_tool", 
            subject=SubjectReference(object=self._tool_ref(tool_name)),
        )
        # In a full impl, filter policies by for_role = role & tenant = self.tenant_id
        # Here we shortcut and return a dummy id if any match.
        try:
            stream = self.client.LookupResources(req)
            for res in stream:
                return res.resource_object_id
        except Exception:
            return None
        return None

    def _risk_to_taint(self, risk_label: str) -> int:
        mapping = {
            "low": 10,
            "medium": 40,
            "high": 70,
            "critical": 90,
        }
        return mapping.get(risk_label, 50)

    def get_session(self, session_id: str, role: str) -> EnforcementContext:
        if session_id not in self.sessions:
            self.sessions[session_id] = EnforcementContext(session_id=session_id, role=role)
        return self.sessions[session_id]

    def observe_datasource_read(self, session_id: str, role: str, risk_label: str):
        ctx = self.get_session(session_id, role)
        delta = self._risk_to_taint(risk_label)
        ctx.session_taint = max(ctx.session_taint, delta)

    def check_and_call_tool(
        self,
        session_id: str,
        agent_id: str,
        role: str,
        tool_name: str,
        allowed_taint: int = 30,
        input_constraints_ok: bool = True,
    ) -> bool:
        """
        Returns True if tool is allowed and should be executed, False otherwise.
        """

        ctx = self.get_session(session_id, role)

        # 1) CFG check
        cfg_ok = self._cfg_edge_exists(role, ctx.prev_tool, tool_name)

        # 2) Policy selection
        policy_id = self._pick_policy_for(role, tool_name)
        if policy_id is None:
            return False

        # 3) SpiceDB invoke check (ACL + taint + CFG + input constraints)
        request = CheckPermissionRequest(
            resource=self._tool_ref(tool_name),
            permission="invoke",
            subject=SubjectReference(object=self._agent_ref(agent_id)),
            context={
                "session_taint": ctx.session_taint,
                "allowed_taint": allowed_taint,
                "policy_id": policy_id,
                "cfg_ok": cfg_ok,
                "input_constraints_ok": input_constraints_ok,
            },
        )

        resp = self.client.CheckPermission(request)
        allowed = resp.permissionship == 1  

        if allowed:
            ctx.prev_tool = tool_name

        return allowed


In [None]:
from policy_engine.spicedb_client import make_spicedb_client
from policy_engine.discovery import PolicyDiscoveryEngine, TraceEvent
from policy_engine.enforcement import PolicyEnforcer

client = make_spicedb_client()
tenant_id = "acme"

# Learn policies / CFG from benign traces
traces = [
    TraceEvent(session_id="s1", role="coding_bot_executor", tool_name="view_file", args={}, success=True, step=0),
    TraceEvent(session_id="s1", role="coding_bot_executor", tool_name="run_tests", args={}, success=True, step=1),
]

discovery = PolicyDiscoveryEngine(client, tenant_id)
discovery.learn_from_traces(traces)

# runtime enforce
enforcer = PolicyEnforcer(client, tenant_id)

session_id = "s2"
agent_id = "coding_agent"

# low taint
enforcer.observe_datasource_read(session_id, role="coding_bot_executor", risk_label="low")
ok1 = enforcer.check_and_call_tool(session_id, agent_id, "coding_bot_executor", "view_file", allowed_taint=80)
ok2 = enforcer.check_and_call_tool(session_id, agent_id, "coding_bot_executor", "run_tests", allowed_taint=40)
print("Internal repo:", ok1, ok2)

# high taint
session_id2 = "s3"
enforcer.observe_datasource_read(session_id2, role="coding_bot_executor", risk_label="critical")
ok3 = enforcer.check_and_call_tool(session_id2, agent_id, "coding_bot_executor", "view_file", allowed_taint=80)
ok4 = enforcer.check_and_call_tool(session_id2, agent_id, "coding_bot_executor", "run_tests", allowed_taint=40)
print("Scraped web:", ok3, ok4)
