Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions cmd/triagent-mcp/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,9 @@ func runProm(ctx context.Context, f serveFlags) error {

// runCloud wires the read-only cloud-context MCP. --provider selects the
// concrete backend; New plugs it in behind cloud.Provider. The launcher passes
// the allowlist override path and target scope through the subprocess env
// (cloud.EnvAllowlistPath, cloud.EnvScope), never argv.
// the allowlist override path, target scope, and pinned identity through the
// subprocess env (cloud.EnvAllowlistPath, cloud.EnvScope,
// cloud.EnvExpectedIdentity), never argv.
func runCloud(ctx context.Context, f serveFlags) error {
if f.cloudProvider == "" {
return fmt.Errorf("--provider is required (gcp or aws) (set --provider or $%s)", cloud.EnvProvider)
Expand All @@ -450,9 +451,10 @@ func runCloud(ctx context.Context, f serveFlags) error {
return err
}
srv, err := cloud.New(cloud.Options{
Provider: provider,
AllowlistPath: os.Getenv(cloud.EnvAllowlistPath),
Scope: parseCloudScope(os.Getenv(cloud.EnvScope)),
Provider: provider,
AllowlistPath: os.Getenv(cloud.EnvAllowlistPath),
Scope: parseCloudScope(os.Getenv(cloud.EnvScope)),
ExpectedIdentity: os.Getenv(cloud.EnvExpectedIdentity),
})
if err != nil {
return fmt.Errorf("build cloud mcp server: %w", err)
Expand Down
23 changes: 12 additions & 11 deletions internal/preflight/mcpconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,21 @@ func kubeEnv(in mcpConfigInputs) map[string]string {

// cloudSourceEnv builds the subprocess env for one triagent-cloud-<alias>
// server: the provider selector, the optional allowlist-override path, the
// JSON-encoded scope the cloud package decodes, and the per-provider
// pinned-identity env.
// JSON-encoded scope the cloud package decodes, the pinned identity the probe
// validates against, and the per-provider credential env the CLI authenticates
// with.
//
// The two clouds pin identity through different env, by mechanism. GCP
// impersonates the assumed identity directly, so a single env
// (CLOUDSDK_AUTH_IMPERSONATE_SERVICE_ACCOUNT) is both the impersonation target
// and the expected identity. AWS selects an assume-role profile (AWS_PROFILE)
// for credentials and checks the role ARN (TRIAGENT_CLOUD_AWS_EXPECTED_ROLE_ARN)
// for strict validity, so it needs both a profile selector and the expected ARN.
// The env-name constants come from the provider packages, never raw literals.
// The pinned identity is uniform: TRIAGENT_CLOUD_EXPECTED_IDENTITY carries it
// for both clouds, and the probe validates the resolved identity against it. The
// credential env differs by mechanism: GCP impersonates the assumed identity
// directly (CLOUDSDK_AUTH_IMPERSONATE_SERVICE_ACCOUNT), AWS selects an
// assume-role profile (AWS_PROFILE) whose role_arn is the deployment's read-only
// role. The env-name constants come from the provider packages, never raw
// literals.
func cloudSourceEnv(src profile.CloudSource) (map[string]string, error) {
env := map[string]string{
cloud.EnvProvider: src.Provider,
cloud.EnvProvider: src.Provider,
cloud.EnvExpectedIdentity: src.AssumedIdentity,
}
if src.CommandAllowlistPath != "" {
env[cloud.EnvAllowlistPath] = src.CommandAllowlistPath
Expand All @@ -209,7 +211,6 @@ func cloudSourceEnv(src profile.CloudSource) (map[string]string, error) {
env[gcp.EnvImpersonate] = src.AssumedIdentity
case "aws":
env[aws.EnvProfile] = src.Profile
env[aws.EnvExpectedRoleARN] = src.AssumedIdentity
}
return env, nil
}
Expand Down
11 changes: 6 additions & 5 deletions internal/preflight/mcpconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,12 +422,12 @@ func TestWriteMCPConfig_GCPCloudSource_RegistersServerWithImpersonationEnv(t *te
require.NotNil(t, env)
assert.Equal(t, "gcp", env[cloud.EnvProvider])
assert.Equal(t, "/etc/triagent/gcp-allow.json", env[cloud.EnvAllowlistPath])
// gcp impersonates the assumed identity directly; that one env is both
// the impersonation target and the expected identity.
// The pinned identity is uniform across providers.
assert.Equal(t, "triage-ro@prod.iam.gserviceaccount.com", env[cloud.EnvExpectedIdentity])
// gcp impersonates the assumed identity directly as its credential env.
assert.Equal(t, "triage-ro@prod.iam.gserviceaccount.com", env[gcp.EnvImpersonate])
// AWS-specific env must not leak onto a gcp source.
assert.NotContains(t, env, aws.EnvProfile)
assert.NotContains(t, env, aws.EnvExpectedRoleARN)

rawScope, _ := env[cloud.EnvScope].(string)
require.NotEmpty(t, rawScope, "scope must be JSON-encoded into the env")
Expand Down Expand Up @@ -460,9 +460,10 @@ func TestWriteMCPConfig_AWSCloudSource_RegistersServerWithProfileAndExpectedRole
env, _ := srv["env"].(map[string]any)
require.NotNil(t, env)
assert.Equal(t, "aws", env[cloud.EnvProvider])
// aws needs BOTH a profile selector and the expected role ARN.
// The pinned identity is uniform across providers.
assert.Equal(t, "arn:aws:iam::123456789012:role/triage-ro", env[cloud.EnvExpectedIdentity])
// aws selects an assume-role profile as its credential env.
assert.Equal(t, "triage-ro", env[aws.EnvProfile])
assert.Equal(t, "arn:aws:iam::123456789012:role/triage-ro", env[aws.EnvExpectedRoleARN])
// gcp impersonation env must not leak onto an aws source.
assert.NotContains(t, env, gcp.EnvImpersonate)
}
6 changes: 6 additions & 0 deletions pkg/mcp/cloud/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,10 @@ const (
// EnvScope carries the target scope allowlist the launcher froze for this
// session, as JSON the cloud package decodes into ScopeAllowlist.
EnvScope = "TRIAGENT_CLOUD_SCOPE"
// EnvExpectedIdentity carries the identity the launcher pinned for this
// session, uniform across providers: the impersonation target for gcp, the
// expected role ARN for aws. The serve subprocess reads it once at startup and
// threads it into the identity probe; the provider validates the resolved
// identity against it.
EnvExpectedIdentity = "TRIAGENT_CLOUD_EXPECTED_IDENTITY"
)
2 changes: 1 addition & 1 deletion pkg/mcp/cloud/fake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ func (f *fakeProvider) Inventory(context.Context, RunFunc) (Inventory, error) {
return f.inventory, nil
}

func (f *fakeProvider) Identity(context.Context, RunFunc) (IdentityStatus, error) {
func (f *fakeProvider) Identity(context.Context, RunFunc, string) (IdentityStatus, error) {
return f.identity, f.identityErr
}
18 changes: 14 additions & 4 deletions pkg/mcp/cloud/probe.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
package cloud

import "context"
import (
"context"
"fmt"
)

// Probe runs the read-only whoami for one provider: which pinned identity is
// active and whether it is valid. It is the single probe the launcher's
// connections panel, the session preflight gate, and the session_status tool
// all call, so those surfaces can never disagree.
//
// expected is the identity the launcher pinned for this session, threaded
// explicitly so the probe validates against it without reading process-global
// env; env is the exact subprocess environment the whoami exec runs under,
// passed in by the caller rather than read from os.Environ here.
//
// Probe never returns a Go error for an unreachable or invalid identity — that
// is a degrade, reported through IdentityStatus.Valid and Hint, so a stale cloud
// credential surfaces visibly instead of failing the caller. A Go error is
// reserved for a caller contract violation (a nil provider).
func Probe(ctx context.Context, p Provider) (IdentityStatus, error) {
env := minimalEnv(p.EnvPassthrough())
func Probe(ctx context.Context, p Provider, expected string, env []string) (IdentityStatus, error) {
if p == nil {
return IdentityStatus{}, fmt.Errorf("cloud: Probe requires a provider")
}
run := func(ctx context.Context, argv []string) (CLIResult, error) {
return execCLI(ctx, p.Binary(), argv, env, defaultOutputLimit)
}

st, err := p.Identity(ctx, run)
st, err := p.Identity(ctx, run, expected)
if err != nil {
return IdentityStatus{
Provider: p.Name(),
Expand Down
45 changes: 26 additions & 19 deletions pkg/mcp/cloud/probe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cloud
import (
"context"
"errors"
"os"
"strings"
"testing"

Expand All @@ -29,7 +30,7 @@ func (p *envProbeProvider) Inventory(context.Context, RunFunc) (Inventory, error
return Inventory{}, nil
}

func (p *envProbeProvider) Identity(ctx context.Context, run RunFunc) (IdentityStatus, error) {
func (p *envProbeProvider) Identity(ctx context.Context, run RunFunc, _ string) (IdentityStatus, error) {
res, err := run(ctx, nil)
if err != nil {
return IdentityStatus{}, err
Expand All @@ -51,56 +52,62 @@ func TestProbeReturnsProviderIdentity(t *testing.T) {
Valid: true,
},
}
st, err := Probe(context.Background(), p)
st, err := Probe(context.Background(), p, "", nil)
require.NoError(t, err)
assert.True(t, st.Valid)
assert.Equal(t, "ro-sa@proj.iam.gserviceaccount.com", st.AssumedIdentity)
}

func TestProbeErrorsOnNilProvider(t *testing.T) {
t.Parallel()
_, err := Probe(context.Background(), nil, "", nil)
require.Error(t, err, "a nil provider is a caller contract violation, not a degrade")
}

func TestProbeSurfacesProviderErrorAsInvalid(t *testing.T) {
t.Parallel()
p := &fakeProvider{name: "aws", identityErr: errors.New("token expired")}
st, err := Probe(context.Background(), p)
st, err := Probe(context.Background(), p, "", nil)
require.NoError(t, err, "Probe should degrade, not error")
assert.False(t, st.Valid, "expected Valid=false when the provider errors")
assert.Equal(t, "aws", st.Provider, "expected provider name carried through")
assert.NotEmpty(t, st.Hint, "expected the provider error surfaced as a hint")
}

// TestProbeUsesMinimalSubprocessEnv proves the probe path forwards only the
// base passthrough plus the provider's declared names to the whoami subprocess,
// dropping the launcher's ambient secrets. A parent canary must not cross the
// boundary while a declared passthrough var survives.
func TestProbeUsesMinimalSubprocessEnv(t *testing.T) {
// TestProbeExecsWithExactlyTheGivenEnv proves the probe execs the whoami
// subprocess under exactly the env the caller passed, with no read of
// os.Environ inside Probe: a parent canary set in the process env must not
// cross the boundary, while a var present only in the passed env survives.
func TestProbeExecsWithExactlyTheGivenEnv(t *testing.T) {
t.Setenv("TRIAGENT_CLOUD_LEAK_CANARY", "should-not-appear")
t.Setenv("CLOUDSDK_AUTH_IMPERSONATE_SERVICE_ACCOUNT", "ro-sa@proj.iam.gserviceaccount.com")
p := &envProbeProvider{
name: "gcp",
envPassthrough: []string{"CLOUDSDK_AUTH_IMPERSONATE_SERVICE_ACCOUNT"},
}
p := &envProbeProvider{name: "gcp"}

st, err := Probe(context.Background(), p)
env := []string{
"PATH=" + os.Getenv("PATH"),
"CLOUDSDK_AUTH_IMPERSONATE_SERVICE_ACCOUNT=ro-sa@proj.iam.gserviceaccount.com",
}
st, err := Probe(context.Background(), p, "", env)
require.NoError(t, err)

seen := st.AssumedIdentity
assert.NotContains(t, seen, "TRIAGENT_CLOUD_LEAK_CANARY",
"parent-env secret must not reach the probe subprocess")
"a var present only in the process env, not the passed env, must not reach the subprocess")
assert.Contains(t, seen, "CLOUDSDK_AUTH_IMPERSONATE_SERVICE_ACCOUNT=ro-sa@proj.iam.gserviceaccount.com",
"declared passthrough var must reach the probe subprocess")
"the passed env must reach the probe subprocess")
for _, line := range strings.Split(seen, "\n") {
if line == "" {
continue
}
name, _, _ := strings.Cut(line, "=")
assert.Contains(t, []string{"PATH", "HOME", "CLOUDSDK_AUTH_IMPERSONATE_SERVICE_ACCOUNT"}, name,
"only base + declared passthrough names may cross the boundary")
assert.Contains(t, []string{"PATH", "CLOUDSDK_AUTH_IMPERSONATE_SERVICE_ACCOUNT"}, name,
"only the names in the passed env may cross the boundary")
}
}

func TestProbeInvalidWhenIdentityEmpty(t *testing.T) {
t.Parallel()
p := &fakeProvider{name: "gcp", identity: IdentityStatus{Provider: "gcp", Valid: true}}
st, err := Probe(context.Background(), p)
st, err := Probe(context.Background(), p, "", nil)
require.NoError(t, err)
assert.False(t, st.Valid, "an empty resolved identity must not be reported valid")
}
7 changes: 5 additions & 2 deletions pkg/mcp/cloud/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ type Provider interface {
// accounts for aws). It execs only through run, never directly.
Inventory(ctx context.Context, run RunFunc) (Inventory, error)
// Identity is the read-only whoami: which pinned identity is active and
// whether it is valid. It execs only through run, never directly.
Identity(ctx context.Context, run RunFunc) (IdentityStatus, error)
// whether it is valid. expected is the identity the launcher pinned for this
// session (the impersonation target for gcp, the expected role ARN for aws,
// empty when none is pinned); the provider validates the resolved identity
// against it. It execs only through run, never directly.
Identity(ctx context.Context, run RunFunc, expected string) (IdentityStatus, error)
}

// RunFunc is the harness exec core, injected into providers so they never exec
Expand Down
23 changes: 7 additions & 16 deletions pkg/mcp/cloud/providers/aws/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,11 @@ import (
"context"
"encoding/json"
"fmt"
"os"
"strings"

"github.com/sourcehawk/triagent/pkg/mcp/cloud"
)

// EnvExpectedRoleARN optionally pins the IAM role ARN the assumed-role caller
// must resolve to. When set, Identity rejects any caller whose underlying role
// does not match it, the strict check. When unset, Identity falls back to the
// structural check (the caller must be an assumed-role ARN at all, proving the
// AWS_PROFILE assume-role pin took effect rather than the operator's plain base
// identity leaking through).
const EnvExpectedRoleARN = "TRIAGENT_CLOUD_AWS_EXPECTED_ROLE_ARN"

// callerIdentity is the projection of `aws sts get-caller-identity --output
// json`. Only the fields the probe and inventory fallback use are decoded.
type callerIdentity struct {
Expand All @@ -31,12 +22,12 @@ type callerIdentity struct {
// the command is also allowlisted so it works under the validated core), parses
// the caller ARN, and reports whether the pinned assume-role identity is active.
//
// Validity has two modes. With TRIAGENT_CLOUD_AWS_EXPECTED_ROLE_ARN set, the
// caller's underlying role must match it exactly. Without it, the structural
// check applies: the caller must be an assumed-role ARN, which proves the
// AWS_PROFILE pin took effect — a plain user/root ARN means base credentials
// leaked through unimpersonated, so the session is not valid.
func (p *Provider) Identity(ctx context.Context, run cloud.RunFunc) (cloud.IdentityStatus, error) {
// Validity has two modes. With expected set to a role ARN, the caller's
// underlying role must match it exactly. Without it, the structural check
// applies: the caller must be an assumed-role ARN, which proves the AWS_PROFILE
// pin took effect — a plain user/root ARN means base credentials leaked through
// unimpersonated, so the session is not valid.
func (p *Provider) Identity(ctx context.Context, run cloud.RunFunc, expected string) (cloud.IdentityStatus, error) {
res, err := run(ctx, []string{"sts", "get-caller-identity", "--output", "json"})
if err != nil {
return cloud.IdentityStatus{Provider: "aws", Valid: false, Hint: err.Error()}, nil
Expand All @@ -59,7 +50,7 @@ func (p *Provider) Identity(ctx context.Context, run cloud.RunFunc) (cloud.Ident
}

st := cloud.IdentityStatus{Provider: "aws", AssumedIdentity: caller.Arn}
st.Valid, st.Hint = evaluateIdentity(caller.Arn, os.Getenv(EnvExpectedRoleARN))
st.Valid, st.Hint = evaluateIdentity(caller.Arn, expected)
return st, nil
}

Expand Down
14 changes: 6 additions & 8 deletions pkg/mcp/cloud/providers/aws/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestIdentityBuildsCallerIdentityArgv(t *testing.T) {
p, err := newWithBinary("/usr/bin/aws")
require.NoError(t, err)

_, err = p.Identity(context.Background(), f.run)
_, err = p.Identity(context.Background(), f.run, "")
require.NoError(t, err)

require.Len(t, f.calls, 1)
Expand All @@ -42,7 +42,7 @@ func TestIdentityValidWhenAssumedRole(t *testing.T) {
p, err := newWithBinary("/usr/bin/aws")
require.NoError(t, err)

st, err := p.Identity(context.Background(), f.run)
st, err := p.Identity(context.Background(), f.run, "")
require.NoError(t, err)

assert.Equal(t, "aws", st.Provider)
Expand All @@ -57,7 +57,7 @@ func TestIdentityInvalidWhenNotAssumedRole(t *testing.T) {
p, err := newWithBinary("/usr/bin/aws")
require.NoError(t, err)

st, err := p.Identity(context.Background(), f.run)
st, err := p.Identity(context.Background(), f.run, "")
require.NoError(t, err)

assert.Equal(t, "arn:aws:iam::111122223333:user/operator", st.AssumedIdentity)
Expand All @@ -66,27 +66,25 @@ func TestIdentityInvalidWhenNotAssumedRole(t *testing.T) {
}

func TestIdentityMatchesExpectedRoleArnWhenPinned(t *testing.T) {
t.Setenv(EnvExpectedRoleARN, "arn:aws:iam::111122223333:role/triagent-readonly")
f := &fakeRun{results: map[string]cloud.CLIResult{
"sts get-caller-identity": {Stdout: callerIdentityAssumedRole},
}}
p, err := newWithBinary("/usr/bin/aws")
require.NoError(t, err)

st, err := p.Identity(context.Background(), f.run)
st, err := p.Identity(context.Background(), f.run, "arn:aws:iam::111122223333:role/triagent-readonly")
require.NoError(t, err)
assert.True(t, st.Valid, "assumed-role ARN whose role matches the pinned expectation is valid")
}

func TestIdentityRejectsMismatchedExpectedRoleArn(t *testing.T) {
t.Setenv(EnvExpectedRoleARN, "arn:aws:iam::111122223333:role/some-other-role")
f := &fakeRun{results: map[string]cloud.CLIResult{
"sts get-caller-identity": {Stdout: callerIdentityAssumedRole},
}}
p, err := newWithBinary("/usr/bin/aws")
require.NoError(t, err)

st, err := p.Identity(context.Background(), f.run)
st, err := p.Identity(context.Background(), f.run, "arn:aws:iam::111122223333:role/some-other-role")
require.NoError(t, err)
assert.False(t, st.Valid, "assumed role not matching the pinned expectation is invalid")
assert.NotEmpty(t, st.Hint)
Expand All @@ -99,7 +97,7 @@ func TestIdentityInvalidOnNonZeroExit(t *testing.T) {
p, err := newWithBinary("/usr/bin/aws")
require.NoError(t, err)

st, err := p.Identity(context.Background(), f.run)
st, err := p.Identity(context.Background(), f.run, "")
require.NoError(t, err)
assert.False(t, st.Valid)
assert.NotEmpty(t, st.Hint)
Expand Down
Loading
Loading