diff --git a/pkg/telemetry/integration_test.go b/pkg/telemetry/integration_test.go index e65167f19..4d4f47291 100644 --- a/pkg/telemetry/integration_test.go +++ b/pkg/telemetry/integration_test.go @@ -483,3 +483,44 @@ func TestTelemetryIntegration_MultipleRequests(t *testing.T) { assert.Contains(t, metricsBody, "toolhive_mcp_requests") assert.Contains(t, metricsBody, `server="multi-test"`) } + +func TestTelemetryIntegration_ToolErrorDetection(t *testing.T) { + t.Parallel() + // Setup test providers + exporter := tracetest.NewInMemoryExporter() + tracerProvider := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + meterProvider := sdkmetric.NewMeterProvider() + + config := Config{ServiceName: "test", ServiceVersion: "1.0.0"} + middleware := NewHTTPMiddleware(config, tracerProvider, meterProvider, "test", "stdio") + + // Test tool call with error + testHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte(`{"result":{"isError":true}}`)) + }) + + mcpRequest := &mcp.ParsedMCPRequest{Method: "tools/call", ID: "test", IsRequest: true} + req := httptest.NewRequest("POST", "/messages", nil) + ctx := context.WithValue(req.Context(), mcp.MCPRequestContextKey, mcpRequest) + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + middleware(testHandler).ServeHTTP(rec, req) + + // Verify span has error attribute + tracerProvider.ForceFlush(ctx) + spans := exporter.GetSpans() + require.Len(t, spans, 1) + + span := spans[0] + assert.Equal(t, "mcp.tools/call", span.Name) + + // Check for tool error attribute + for _, attr := range span.Attributes { + if attr.Key == "mcp.tool.error" { + assert.True(t, attr.Value.AsBool()) + return + } + } + t.Error("Expected mcp.tool.error attribute not found") +} diff --git a/pkg/telemetry/middleware.go b/pkg/telemetry/middleware.go index 19e943bc3..8dae47085 100644 --- a/pkg/telemetry/middleware.go +++ b/pkg/telemetry/middleware.go @@ -130,6 +130,7 @@ func (m *HTTPMiddleware) Handler(next http.Handler) http.Handler { ResponseWriter: w, statusCode: http.StatusOK, bytesWritten: 0, + isToolCall: mcpparser.GetMCPMethod(ctx) == string(mcp.MethodToolsCall), } // Add HTTP attributes @@ -147,6 +148,9 @@ func (m *HTTPMiddleware) Handler(next http.Handler) http.Handler { // Call the next handler with the instrumented context next.ServeHTTP(rw, r.WithContext(ctx)) + // Finalize tool error detection now that response is complete + rw.finalizeToolErrorDetection() + // Record completion metrics and finalize span duration := time.Since(startTime) m.finalizeSpan(span, rw, duration) @@ -390,19 +394,44 @@ func (*HTTPMiddleware) finalizeSpan(span trace.Span, rw *responseWriter, duratio attribute.Float64("http.duration_ms", float64(duration.Nanoseconds())/1e6), ) - // Set span status based on HTTP status code + // Add MCP tool error indicator if detected + if rw.isToolCall { + span.SetAttributes(attribute.Bool("mcp.tool.error", rw.hasToolError)) + } + + // Set span status based on HTTP status code AND MCP tool errors if rw.statusCode >= 400 { span.SetStatus(codes.Error, fmt.Sprintf("HTTP %d", rw.statusCode)) + } else if rw.hasToolError { + span.SetStatus(codes.Error, "MCP tool execution error") } else { span.SetStatus(codes.Ok, "") } } +// detectMCPToolError performs lightweight detection of MCP tool execution errors +// Returns true if the response likely contains a tool execution error +func detectMCPToolError(data []byte) bool { + // Attempt to parse JSON and check for isError field + var resp struct { + Result struct { + IsError bool `json:"isError"` + } `json:"result"` + } + if err := json.Unmarshal(data, &resp); err != nil { + return false + } + return resp.Result.IsError +} + // responseWriter wraps http.ResponseWriter to capture response details. type responseWriter struct { http.ResponseWriter - statusCode int - bytesWritten int64 + statusCode int + bytesWritten int64 + hasToolError bool // tracks if MCP tool execution error is detected + isToolCall bool // tracks if this is a tools/call request + responseBuffer []byte // buffer to collect response data for tool calls } // WriteHeader captures the status code. @@ -411,13 +440,29 @@ func (rw *responseWriter) WriteHeader(statusCode int) { rw.ResponseWriter.WriteHeader(statusCode) } -// Write captures the number of bytes written. +// Write captures the number of bytes written and buffers data for tool calls. func (rw *responseWriter) Write(data []byte) (int, error) { n, err := rw.ResponseWriter.Write(data) rw.bytesWritten += int64(n) + + // Buffer response data for tool calls to enable proper error detection + if rw.isToolCall && !rw.hasToolError { + rw.responseBuffer = append(rw.responseBuffer, data...) + } + return n, err } +// finalizeToolErrorDetection performs error detection on the complete buffered response. +// This should be called after the response is completely written. +func (rw *responseWriter) finalizeToolErrorDetection() { + if rw.isToolCall && !rw.hasToolError && len(rw.responseBuffer) > 0 { + rw.hasToolError = detectMCPToolError(rw.responseBuffer) + // Clear buffer to free memory + rw.responseBuffer = nil + } +} + // recordMetrics records request metrics. func (m *HTTPMiddleware) recordMetrics(ctx context.Context, r *http.Request, rw *responseWriter, duration time.Duration) { // Get MCP method from context if available @@ -426,10 +471,12 @@ func (m *HTTPMiddleware) recordMetrics(ctx context.Context, r *http.Request, rw mcpMethod = "unknown" } - // Determine status (success/error) + // Determine status (success/error/tool_error) status := "success" if rw.statusCode >= 400 { status = "error" + } else if rw.hasToolError { + status = "tool_error" } // Common attributes for all metrics diff --git a/pkg/telemetry/middleware_test.go b/pkg/telemetry/middleware_test.go index 5a7a36822..1da7dc745 100644 --- a/pkg/telemetry/middleware_test.go +++ b/pkg/telemetry/middleware_test.go @@ -1491,3 +1491,44 @@ func TestFactoryMiddleware_Integration(t *testing.T) { assert.NoError(t, err) }) } + +func TestDetectMCPToolError(t *testing.T) { + t.Parallel() + assert.False(t, detectMCPToolError([]byte(`{"result":{"isError":false}}`))) + assert.True(t, detectMCPToolError([]byte(`{"result":{"isError":true}}`))) + assert.False(t, detectMCPToolError([]byte(`{"result":{"content":"test"}}`))) + + // Test invalid JSON - should return false, not panic + assert.False(t, detectMCPToolError([]byte(`invalid json`))) + assert.False(t, detectMCPToolError([]byte(`{"malformed": json}`))) +} + +func TestResponseWriter_ToolErrorDetection(t *testing.T) { + t.Parallel() + rec := httptest.NewRecorder() + + // Tool call with error + rw := &responseWriter{ResponseWriter: rec, isToolCall: true} + rw.Write([]byte(`{"result":{"isError":true}}`)) + rw.finalizeToolErrorDetection() // Now we need to explicitly finalize + assert.True(t, rw.hasToolError) + + // Tool call without error + rw = &responseWriter{ResponseWriter: rec, isToolCall: true} + rw.Write([]byte(`{"result":{"isError":false}}`)) + rw.finalizeToolErrorDetection() + assert.False(t, rw.hasToolError) + + // Non-tool call should not detect errors + rw = &responseWriter{ResponseWriter: rec, isToolCall: false} + rw.Write([]byte(`{"result":{"isError":true}}`)) + rw.finalizeToolErrorDetection() + assert.False(t, rw.hasToolError) + + // Test chunked writes (multiple Write calls) + rw = &responseWriter{ResponseWriter: rec, isToolCall: true} + rw.Write([]byte(`{"result":{"isError":`)) + rw.Write([]byte(`true}}`)) + rw.finalizeToolErrorDetection() + assert.True(t, rw.hasToolError, "Should detect error in chunked response") +}