Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions pkg/authz/response_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,35 +187,42 @@ func (rfw *ResponseFilteringWriter) processSSEResponse(rawResponse []byte) error
var written bool
if data, ok := bytes.CutPrefix(line, []byte("data:")); ok {
message, err := jsonrpc2.DecodeMessage(data)
if err != nil {
rfw.ResponseWriter.WriteHeader(rfw.statusCode)
_, err := rfw.ResponseWriter.Write(rawResponse)
return err
}

response, ok := message.(*jsonrpc2.Response)
if !ok {
rfw.ResponseWriter.WriteHeader(rfw.statusCode)
_, err := rfw.ResponseWriter.Write(rawResponse)
return err
}

filteredResponse, err := rfw.filterListResponse(response)
if err != nil {
return rfw.writeErrorResponse(response.ID, err)
}

filteredData, err := jsonrpc2.EncodeMessage(filteredResponse)
if err != nil {
return rfw.writeErrorResponse(response.ID, err)
switch {
case err != nil:
// Pass this line through unfiltered. Earlier revisions wrote
// rawResponse and returned here, which leaked every subsequent
// data line on the stream past the filter (issue #5257).
if rfw.method == string(mcp.MethodToolsList) {
slog.Warn("SSE data line could not be decoded as JSON-RPC; passing through unfiltered",
"method", rfw.method, "error", err)
}
default:
if response, ok := message.(*jsonrpc2.Response); ok {
filteredResponse, err := rfw.filterListResponse(response)
if err != nil {
return rfw.writeErrorResponse(response.ID, err)
}

filteredData, err := jsonrpc2.EncodeMessage(filteredResponse)
if err != nil {
return rfw.writeErrorResponse(response.ID, err)
}

_, err = rfw.ResponseWriter.Write([]byte("data: " + string(filteredData) + "\n"))
if err != nil {
return fmt.Errorf("%w: %w", errBug, err)
}

written = true
} else if rfw.method == string(mcp.MethodToolsList) {
// Non-Response message (e.g. a notifications/* frame
// interleaved on the stream). Pass through unfiltered for
// this line only; the next data line may still be the real
// tools/list response and must reach the filter.
slog.Warn("SSE data line was not a JSON-RPC Response; passing through unfiltered",
"method", rfw.method)
}
}

_, err = rfw.ResponseWriter.Write([]byte("data: " + string(filteredData) + "\n"))
if err != nil {
return fmt.Errorf("%w: %w", errBug, err)
}

written = true
}

if !written {
Expand Down
123 changes: 123 additions & 0 deletions pkg/authz/response_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -863,3 +863,126 @@ func TestOptimizerPassThroughToolsInResponseFilter(t *testing.T) {
"admin_tool has no permit policy and is not a pass-through tool")
})
}

// TestResponseFilteringWriter_SSE_PerLineFallthrough is a regression test for
// issue #5257: when an SSE upstream interleaves a non-Response data line (e.g.
// an MCP notification) or an undecodable data line with a real tools/list
// response, the filter previously wrote the entire raw upstream payload and
// returned, leaking the unfiltered tools/list past Cedar. It must instead pass
// only the offending line through and continue filtering the rest of the
// stream.
func TestResponseFilteringWriter_SSE_PerLineFallthrough(t *testing.T) {
t.Parallel()

authorizer, err := cedar.NewCedarAuthorizer(cedar.ConfigOptions{
Policies: []string{
`permit(principal, action == Action::"call_tool", resource == Tool::"weather");`,
},
EntitiesJSON: `[]`,
}, "")
require.NoError(t, err)

identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{
Subject: "user1",
Claims: map[string]interface{}{"sub": "user1"},
}}

// Real tools/list response. The caller is only authorized for "weather",
// so a working filter must drop "admin_tool" from the response.
toolsResult := mcp.ListToolsResult{
Tools: []mcp.Tool{
{Name: "weather", Description: "Get weather information"},
{Name: "admin_tool", Description: "Sensitive admin operations"},
},
}
resultJSON, err := json.Marshal(toolsResult)
require.NoError(t, err)
encodedResp, err := jsonrpc2.EncodeMessage(&jsonrpc2.Response{
ID: jsonrpc2.Int64ID(1),
Result: json.RawMessage(resultJSON),
})
require.NoError(t, err)
respLine := "data: " + string(encodedResp)

testCases := []struct {
name string
precedingLines []string
}{
{
name: "non-response data line before tools/list",
precedingLines: []string{
// A notifications/* frame is a valid JSON-RPC notification
// (no id), so jsonrpc2.DecodeMessage returns a non-Response
// message. The buggy path treated this as a signal to dump
// rawResponse and return.
`data: {"jsonrpc":"2.0","method":"notifications/message","params":{"level":"info","data":"warming up"}}`,
},
},
{
name: "undecodable data line before tools/list",
precedingLines: []string{
`data: this is not json at all`,
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

req, err := http.NewRequest(http.MethodPost, "/messages", nil)
require.NoError(t, err)
req = req.WithContext(auth.WithIdentity(req.Context(), identity))

rr := httptest.NewRecorder()
rfw := NewResponseFilteringWriter(rr, authorizer, req, string(mcp.MethodToolsList), nil, nil)
rfw.ResponseWriter.Header().Set("Content-Type", "text/event-stream")

body := strings.Join(append(tc.precedingLines, respLine, ""), "\n")
_, err = rfw.Write([]byte(body))
require.NoError(t, err)

require.NoError(t, rfw.FlushAndFilter())

out := rr.Body.String()

// Each preceding line must still appear verbatim; pass-through
// is the whole point of the fix.
for _, pl := range tc.precedingLines {
assert.Contains(t, out, pl, "non-response/undecodable preceding line must pass through unchanged")
}

// The real tools/list response must have been filtered. Pull the
// last data line out and decode it.
var filteredLine string
for _, line := range strings.Split(out, "\n") {
if strings.HasPrefix(line, "data: {\"jsonrpc\"") && strings.Contains(line, `"result"`) {
filteredLine = line
}
}
require.NotEmpty(t, filteredLine, "no JSON-RPC Response data line found in output")

payload := strings.TrimPrefix(filteredLine, "data: ")
msg, err := jsonrpc2.DecodeMessage([]byte(payload))
require.NoError(t, err)
resp, ok := msg.(*jsonrpc2.Response)
require.True(t, ok)

var filtered mcp.ListToolsResult
require.NoError(t, json.Unmarshal(resp.Result, &filtered))

names := make([]string, len(filtered.Tools))
for i, tool := range filtered.Tools {
names[i] = tool.Name
}
assert.Contains(t, names, "weather", "authorized tool must be retained")
assert.NotContains(t, names, "admin_tool",
"unauthorized tool must be filtered; presence indicates the cedar bypass from #5257 is back")

// And the raw unfiltered payload (the bug used to dump it)
// must not appear in the wire output.
assert.NotContains(t, out, `"admin_tool"`,
"unfiltered tools/list payload leaked into SSE output")
})
}
}