diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index 01642d6641..5ef95ef03c 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -15,6 +15,7 @@ import ( "log/slog" "net" "net/http" + "sync" "time" "github.com/mark3labs/mcp-go/client" @@ -60,6 +61,14 @@ type httpBackendClient struct { // registry manages authentication strategies for outgoing requests to backend MCP servers. // Must not be nil - use UnauthenticatedStrategy for no authentication. registry vmcpauth.OutgoingAuthRegistry + + // transportMu protects transportCache. + transportMu sync.RWMutex + + // transportCache holds one *http.Transport per backend ID so connections are + // reused across calls to the same backend. Call FlushIdleConnections to evict + // stale keep-alive connections (e.g., after a health check failure or backend replacement). + transportCache map[string]*http.Transport } // NewHTTPBackendClient creates a new HTTP-based backend client. @@ -76,12 +85,75 @@ func NewHTTPBackendClient(registry vmcpauth.OutgoingAuthRegistry) (vmcp.BackendC } c := &httpBackendClient{ - registry: registry, + registry: registry, + transportCache: make(map[string]*http.Transport), } c.clientFactory = c.defaultClientFactory return c, nil } +// newBackendTransport creates a *http.Transport with the same defaults as http.DefaultTransport. +// If http.DefaultTransport is a *http.Transport, it is cloned directly (preserving any +// environment-specific settings like TLS config or proxy overrides). Otherwise a transport +// with the standard Go defaults is constructed, preserving proxy, dial timeout, HTTP/2, and +// idle-connection settings that a zero-value &http.Transport{} would drop. +func newBackendTransport() *http.Transport { + if dt, ok := http.DefaultTransport.(*http.Transport); ok { + return dt.Clone() + } + // http.DefaultTransport has been replaced (e.g. in tests or by a third-party library). + // Construct a transport with the same defaults as the Go standard library uses for + // http.DefaultTransport so we don't silently drop proxy, timeout, or HTTP/2 settings. + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } +} + +// getOrCreateTransport returns the cached *http.Transport for a backend, creating one if needed. +// Each backend gets its own transport so connection pools are isolated per backend. +func (h *httpBackendClient) getOrCreateTransport(backendID string) *http.Transport { + h.transportMu.RLock() + if t, ok := h.transportCache[backendID]; ok { + h.transportMu.RUnlock() + return t + } + h.transportMu.RUnlock() + + h.transportMu.Lock() + defer h.transportMu.Unlock() + if t, ok := h.transportCache[backendID]; ok { + return t + } + t := newBackendTransport() + h.transportCache[backendID] = t + return t +} + +// FlushIdleConnections closes all idle keep-alive connections for the given backend +// and removes its transport from the cache so the next request gets a fresh connection. +// This implements vmcp.ConnectionFlusher. +func (h *httpBackendClient) FlushIdleConnections(backendID string) { + h.transportMu.Lock() + t, ok := h.transportCache[backendID] + if ok { + delete(h.transportCache, backendID) + } + h.transportMu.Unlock() + if ok { + t.CloseIdleConnections() + slog.Debug("flushed idle connections for backend", "backend", backendID) + } +} + // roundTripperFunc is a function adapter for http.RoundTripper. type roundTripperFunc func(*http.Request) (*http.Response, error) @@ -173,7 +245,12 @@ func (h *httpBackendClient) resolveAuthStrategy(target *vmcp.BackendTarget) (vmc func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) { // Build transport chain (outermost to innermost, request execution order): // size limit (response body) → trace propagation → identity propagation → authentication → HTTP - var baseTransport = http.DefaultTransport + // + // Use the per-backend cached transport so connections are reused across calls + // to the same backend. Each backend has an isolated pool, preventing stale + // keep-alive connections to one backend from affecting others. + // Call FlushIdleConnections after a failure to evict stale connections. + var baseTransport http.RoundTripper = h.getOrCreateTransport(target.WorkloadID) // Resolve authentication strategy ONCE at client creation time authStrategy, err := h.resolveAuthStrategy(target) @@ -280,8 +357,9 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm // This enables type-safe error checking with errors.Is() instead of string matching. // // Error detection strategy (in order of preference): -// 1. Check for standard Go error types (context errors, net.Error, url.Error) -// 2. Fall back to string pattern matching for library-specific errors (MCP SDK, HTTP libs) +// Check for standard Go error types (context errors, io.EOF, net.Error), +// then mcp-go transport sentinels, then fall back to string pattern matching +// for library-specific errors (MCP SDK, HTTP libs). // // Error chain preservation: // The returned error wraps the sentinel error (ErrTimeout, ErrBackendUnavailable, etc.) with %w @@ -296,7 +374,7 @@ func wrapBackendError(err error, backendID string, operation string) error { return nil } - // 1. Type-based detection: Check for context deadline/cancellation + // Type-based detection: context deadline/cancellation if errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("%w: failed to %s for backend %s (timeout): %v", vmcp.ErrTimeout, operation, backendID, err) @@ -306,23 +384,37 @@ func wrapBackendError(err error, backendID string, operation string) error { vmcp.ErrCancelled, operation, backendID, err) } - // 2. Type-based detection: Check for io.EOF errors - // These indicate the connection was closed unexpectedly + // Type-based detection: io.EOF errors indicate the connection was closed unexpectedly if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { return fmt.Errorf("%w: failed to %s for backend %s (connection closed): %v", vmcp.ErrBackendUnavailable, operation, backendID, err) } - // 3. Type-based detection: Check for net.Error with Timeout() method - // This handles network timeouts from the standard library + // Type-based detection: net.Error with Timeout() handles network timeouts from the standard library var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { return fmt.Errorf("%w: failed to %s for backend %s (timeout): %v", vmcp.ErrTimeout, operation, backendID, err) } - // 4. String-based detection: Fall back to pattern matching for cases where - // we don't have structured error types (MCP SDK, HTTP libraries with embedded status codes) + // mcp-go transport sentinel errors: check these before string-based fallbacks + // to ensure accurate classification of protocol-level errors. + if errors.Is(err, transport.ErrUnauthorized) { + return fmt.Errorf("%w: failed to %s for backend %s: %v", + vmcp.ErrAuthenticationFailed, operation, backendID, err) + } + // ErrLegacySSEServer is returned for any 4xx (except 401) on initialize POST. + // This includes 403 (auth rejection) and 404/405 (endpoint not found/method not allowed). + // We cannot distinguish auth failures from routing errors without the raw status code, + // so we surface a clear message and classify as backend unavailable to allow recovery. + if errors.Is(err, transport.ErrLegacySSEServer) { + const legacyMsg = "server rejected MCP initialize — possible auth rejection or legacy SSE-only server" + return fmt.Errorf("%w: failed to %s for backend %s (%s): %v", + vmcp.ErrBackendUnavailable, operation, backendID, legacyMsg, err) + } + + // String-based detection: fall back to pattern matching for cases where + // we don't have structured error types (MCP SDK, HTTP libraries with embedded status codes). // Authentication errors (401, 403, auth failures) if vmcp.IsAuthenticationError(err) { return fmt.Errorf("%w: failed to %s for backend %s: %v", diff --git a/pkg/vmcp/client/client_test.go b/pkg/vmcp/client/client_test.go index 5fae752c91..cb6aada488 100644 --- a/pkg/vmcp/client/client_test.go +++ b/pkg/vmcp/client/client_test.go @@ -8,11 +8,13 @@ package client import ( "context" "errors" + "fmt" "net/http" "net/http/httptest" "testing" "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/propagation" @@ -659,6 +661,73 @@ func TestTracePropagatingRoundTripper_ParentChildSpan(t *testing.T) { "traceparent should contain child span ID, not parent") } +// TestWrapBackendError verifies that wrapBackendError maps mcp-go transport sentinel +// errors to the correct vmcp sentinel errors for downstream health monitoring. +func TestWrapBackendError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + wantSentinel error + wantMsgContains string + }{ + { + name: "nil error returns nil", + err: nil, + wantSentinel: nil, + }, + { + // mcp-go returns ErrUnauthorized for 401 on initialize POST. + // Must map to ErrAuthenticationFailed so health monitors classify + // the backend as BackendUnauthenticated, not BackendUnhealthy. + name: "ErrUnauthorized maps to ErrAuthenticationFailed", + err: transport.ErrUnauthorized, + wantSentinel: vmcp.ErrAuthenticationFailed, + }, + { + // errors.Is traverses the error chain, so wrapping ErrUnauthorized + // in another error must still produce ErrAuthenticationFailed. + name: "wrapped ErrUnauthorized maps to ErrAuthenticationFailed", + err: fmt.Errorf("transport layer: %w", transport.ErrUnauthorized), + wantSentinel: vmcp.ErrAuthenticationFailed, + }, + { + // mcp-go returns ErrLegacySSEServer for non-401 4xx on initialize POST + // (e.g. 403, 404, 405). Classified as backend unavailable so the health + // monitor can recover if the backend is later corrected. + name: "ErrLegacySSEServer maps to ErrBackendUnavailable", + err: transport.ErrLegacySSEServer, + wantSentinel: vmcp.ErrBackendUnavailable, + wantMsgContains: "legacy SSE", + }, + { + name: "context.DeadlineExceeded maps to ErrTimeout", + err: context.DeadlineExceeded, + wantSentinel: vmcp.ErrTimeout, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := wrapBackendError(tt.err, "test-backend", "initialize") + + if tt.err == nil { + assert.NoError(t, result) + return + } + + require.Error(t, result) + assert.ErrorIs(t, result, tt.wantSentinel) + if tt.wantMsgContains != "" { + assert.Contains(t, result.Error(), tt.wantMsgContains) + } + }) + } +} + func TestResolveAuthStrategy(t *testing.T) { t.Parallel() diff --git a/pkg/vmcp/errors.go b/pkg/vmcp/errors.go index 1bc038ca1f..f2b8e58868 100644 --- a/pkg/vmcp/errors.go +++ b/pkg/vmcp/errors.go @@ -98,9 +98,12 @@ func IsAuthenticationError(err error) bool { return true } - // Check for HTTP 401/403 status codes with context - // Match patterns like "401 Unauthorized", "HTTP 401", "status code 401" + // Check for HTTP 401/403 status codes with context. + // Match patterns like "401 Unauthorized", "HTTP 401", "status code 401". + // Also match mcp-go's ErrUnauthorized = "unauthorized (401)" which uses + // reversed order compared to the "401 unauthorized" pattern above. if strings.Contains(errLower, "401 unauthorized") || + strings.Contains(errLower, "unauthorized (401)") || strings.Contains(errLower, "403 forbidden") || strings.Contains(errLower, "http 401") || strings.Contains(errLower, "http 403") || diff --git a/pkg/vmcp/health/checker.go b/pkg/vmcp/health/checker.go index 8bec4ffc0f..303a69cc16 100644 --- a/pkg/vmcp/health/checker.go +++ b/pkg/vmcp/health/checker.go @@ -115,6 +115,15 @@ func (h *healthChecker) CheckHealth(ctx context.Context, target *vmcp.BackendTar return vmcp.BackendHealthy, nil } +// FlushIdleConnections closes idle connections for the given backend. +// Implements vmcp.ConnectionFlusher by delegating to the underlying client if it supports it. +// Called by the Monitor after a health check failure to evict stale keep-alive connections. +func (h *healthChecker) FlushIdleConnections(backendID string) { + if flusher, ok := h.client.(vmcp.ConnectionFlusher); ok { + flusher.FlushIdleConnections(backendID) + } +} + // categorizeError determines the appropriate health status based on the error type. // This uses sentinel error checking with errors.Is() for type-safe error categorization. // Falls back to string-based detection for backwards compatibility with non-wrapped errors. diff --git a/pkg/vmcp/health/checker_test.go b/pkg/vmcp/health/checker_test.go index dc39162c88..754940d6a5 100644 --- a/pkg/vmcp/health/checker_test.go +++ b/pkg/vmcp/health/checker_test.go @@ -334,6 +334,9 @@ func TestIsAuthenticationError(t *testing.T) { {name: "request unauthorized", err: errors.New("request unauthorized"), expectErr: true}, {name: "access denied", err: errors.New("access denied"), expectErr: true}, + // mcp-go ErrUnauthorized format: "unauthorized (401)" (reversed order vs "401 unauthorized") + {name: "unauthorized (401) - mcp-go ErrUnauthorized format", err: errors.New("unauthorized (401)"), expectErr: true}, + // Negative cases - should NOT be detected as auth errors {name: "connection refused", err: errors.New("connection refused"), expectErr: false}, {name: "timeout", err: errors.New("request timeout"), expectErr: false}, @@ -515,6 +518,13 @@ func TestHealthChecker_CheckHealth_AuthErrorsCategorizesAsUnauthenticated(t *tes name: "wrapped sentinel auth error", err: fmt.Errorf("client credentials grant failed: %w", vmcp.ErrAuthenticationFailed), }, + { + // transport.ErrUnauthorized is wrapped with ErrAuthenticationFailed in wrapBackendError, + // so a 401 from the mcp-go transport layer reaches health monitoring as + // BackendUnauthenticated instead of BackendUnhealthy. + name: "mcp-go ErrUnauthorized wrapped as ErrAuthenticationFailed by wrapBackendError", + err: fmt.Errorf("%w: failed to initialize for backend my-backend: unauthorized (401)", vmcp.ErrAuthenticationFailed), + }, } for _, tt := range tests { diff --git a/pkg/vmcp/health/monitor.go b/pkg/vmcp/health/monitor.go index 6d7588a7cb..8879209291 100644 --- a/pkg/vmcp/health/monitor.go +++ b/pkg/vmcp/health/monitor.go @@ -351,6 +351,10 @@ func (m *Monitor) UpdateBackends(newBackends []vmcp.Backend) { } // Remove backend from status tracker so it no longer appears in status reports m.statusTracker.RemoveBackend(id) + // Evict the cached transport for this backend to free connections and memory + if flusher, ok := m.checker.(vmcp.ConnectionFlusher); ok { + flusher.FlushIdleConnections(id) + } } } } @@ -424,6 +428,11 @@ func (m *Monitor) performHealthCheck(ctx context.Context, backend *vmcp.Backend) if err != nil { slog.Debug("health check failed for backend", "backend", backend.Name, "error", err, "status", status) m.statusTracker.RecordFailure(backend.ID, backend.Name, status, err) + // Flush idle connections so the next attempt gets a fresh TCP connection. + // This recovers from stale keep-alive connections to pods that have been replaced. + if flusher, ok := m.checker.(vmcp.ConnectionFlusher); ok { + flusher.FlushIdleConnections(backend.ID) + } } else { // Pass status to RecordSuccess - it may be healthy or degraded (from slow response) // RecordSuccess will further check for recovering state (had recent failures) diff --git a/pkg/vmcp/health/monitor_test.go b/pkg/vmcp/health/monitor_test.go index 4e741bb996..bcfd552110 100644 --- a/pkg/vmcp/health/monitor_test.go +++ b/pkg/vmcp/health/monitor_test.go @@ -6,6 +6,7 @@ package health import ( "context" "errors" + "sync" "testing" "time" @@ -1165,3 +1166,88 @@ func TestMonitor_CircuitBreakerStatusReporting(t *testing.T) { err = monitor.Stop() require.NoError(t, err) } + +// flushingChecker is a test HealthChecker that also implements ConnectionFlusher. +// It returns configurable results and records FlushIdleConnections calls. +type flushingChecker struct { + mu sync.Mutex + shouldFail bool + flushCalls []string +} + +func (f *flushingChecker) CheckHealth(_ context.Context, _ *vmcp.BackendTarget) (vmcp.BackendHealthStatus, error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.shouldFail { + return vmcp.BackendUnhealthy, errors.New("connection refused") + } + return vmcp.BackendHealthy, nil +} + +func (f *flushingChecker) FlushIdleConnections(backendID string) { + f.mu.Lock() + defer f.mu.Unlock() + f.flushCalls = append(f.flushCalls, backendID) +} + +func (f *flushingChecker) getFlushCalls() []string { + f.mu.Lock() + defer f.mu.Unlock() + out := make([]string, len(f.flushCalls)) + copy(out, f.flushCalls) + return out +} + +// TestMonitor_FlushIdleConnections verifies that FlushIdleConnections is called on +// the checker after a health check failure and is not called after a success. +func TestMonitor_FlushIdleConnections(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Use a stub BackendClient so NewMonitor is happy; we replace checker below. + stubClient := mocks.NewMockBackendClient(ctrl) + backend := vmcp.Backend{ID: "backend-1", Name: "Backend 1", BaseURL: "http://localhost:8080", TransportType: "sse"} + + config := MonitorConfig{ + CheckInterval: 50 * time.Millisecond, + UnhealthyThreshold: 1, + Timeout: 10 * time.Millisecond, + } + + monitor, err := NewMonitor(stubClient, []vmcp.Backend{backend}, config) + require.NoError(t, err) + + // Replace the internal checker with one that records flush calls. + checker := &flushingChecker{shouldFail: true} + monitor.checker = checker + + ctx := context.Background() + require.NoError(t, monitor.Start(ctx)) + defer func() { _ = monitor.Stop() }() + + // Wait for at least one failure to be recorded and the flush to be triggered. + require.Eventually(t, func() bool { + return len(checker.getFlushCalls()) >= 1 + }, 500*time.Millisecond, 10*time.Millisecond, "FlushIdleConnections should be called after a health check failure") + + assert.Equal(t, "backend-1", checker.getFlushCalls()[0], + "FlushIdleConnections should be called with the correct backend ID") + + // Switch to success and record the flush count at that point. + checker.mu.Lock() + checker.shouldFail = false + flushCountAtSwitch := len(checker.flushCalls) + checker.mu.Unlock() + + // Wait for a successful health check to be recorded. + require.Eventually(t, func() bool { + s, statusErr := monitor.GetBackendStatus("backend-1") + return statusErr == nil && s == vmcp.BackendHealthy + }, 500*time.Millisecond, 10*time.Millisecond, "backend should recover to healthy") + + // No additional flush calls should have been made during successful checks. + assert.Equal(t, flushCountAtSwitch, len(checker.getFlushCalls()), + "FlushIdleConnections should not be called after a successful health check") +} diff --git a/pkg/vmcp/mocks/mock_backend_client.go b/pkg/vmcp/mocks/mock_backend_client.go index 1a187df3f8..20de0c2ec4 100644 --- a/pkg/vmcp/mocks/mock_backend_client.go +++ b/pkg/vmcp/mocks/mock_backend_client.go @@ -3,7 +3,7 @@ // // Generated by this command: // -// mockgen -destination=mocks/mock_backend_client.go -package=mocks -source=types.go BackendClient HealthChecker +// mockgen -destination=mocks/mock_backend_client.go -package=mocks -source=types.go BackendClient HealthChecker ConnectionFlusher // // Package mocks is a generated GoMock package. @@ -139,3 +139,39 @@ func (mr *MockBackendClientMockRecorder) ReadResource(ctx, target, uri any) *gom mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadResource", reflect.TypeOf((*MockBackendClient)(nil).ReadResource), ctx, target, uri) } + +// MockConnectionFlusher is a mock of ConnectionFlusher interface. +type MockConnectionFlusher struct { + ctrl *gomock.Controller + recorder *MockConnectionFlusherMockRecorder + isgomock struct{} +} + +// MockConnectionFlusherMockRecorder is the mock recorder for MockConnectionFlusher. +type MockConnectionFlusherMockRecorder struct { + mock *MockConnectionFlusher +} + +// NewMockConnectionFlusher creates a new mock instance. +func NewMockConnectionFlusher(ctrl *gomock.Controller) *MockConnectionFlusher { + mock := &MockConnectionFlusher{ctrl: ctrl} + mock.recorder = &MockConnectionFlusherMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnectionFlusher) EXPECT() *MockConnectionFlusherMockRecorder { + return m.recorder +} + +// FlushIdleConnections mocks base method. +func (m *MockConnectionFlusher) FlushIdleConnections(backendID string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "FlushIdleConnections", backendID) +} + +// FlushIdleConnections indicates an expected call of FlushIdleConnections. +func (mr *MockConnectionFlusherMockRecorder) FlushIdleConnections(backendID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FlushIdleConnections", reflect.TypeOf((*MockConnectionFlusher)(nil).FlushIdleConnections), backendID) +} diff --git a/pkg/vmcp/server/status_reporting.go b/pkg/vmcp/server/status_reporting.go index 89f428d57c..53e45046c1 100644 --- a/pkg/vmcp/server/status_reporting.go +++ b/pkg/vmcp/server/status_reporting.go @@ -67,7 +67,25 @@ func (s *Server) periodicStatusReporting(ctx context.Context, config StatusRepor ticker := time.NewTicker(interval) defer ticker.Stop() - // Report status immediately after initial health checks complete + // Only start the version-polling ticker when the registry supports dynamic + // discovery. For static registries the ticker would fire every 2s only to + // type-assert and continue, wasting wakeups in the steady state. + dynamicReg, isDynamic := s.backendRegistry.(vmcp.DynamicRegistry) + var versionTickerC <-chan time.Time + var lastRegistryVersion uint64 + if isDynamic { + const versionPollInterval = 2 * time.Second + versionTicker := time.NewTicker(versionPollInterval) + defer versionTicker.Stop() + versionTickerC = versionTicker.C + } + + // Snapshot the version before reporting so that any mutation that races with + // reportStatus is visible to the version ticker on the next poll cycle, rather + // than being silently absorbed by a post-report version update. + if isDynamic { + lastRegistryVersion = dynamicReg.Version() + } s.reportStatus(ctx, config.Reporter) for { @@ -77,7 +95,18 @@ func (s *Server) periodicStatusReporting(ctx context.Context, config StatusRepor return case <-ticker.C: + if isDynamic { + lastRegistryVersion = dynamicReg.Version() + } s.reportStatus(ctx, config.Reporter) + + case <-versionTickerC: + if v := dynamicReg.Version(); v != lastRegistryVersion { + slog.Debug("backend registry changed, triggering immediate status report", + "old_version", lastRegistryVersion, "new_version", v) + lastRegistryVersion = v + s.reportStatus(ctx, config.Reporter) + } } } } diff --git a/pkg/vmcp/server/status_reporting_test.go b/pkg/vmcp/server/status_reporting_test.go index c25dae39ae..b2bd4a641c 100644 --- a/pkg/vmcp/server/status_reporting_test.go +++ b/pkg/vmcp/server/status_reporting_test.go @@ -132,6 +132,83 @@ func TestDefaultStatusReportingConfig(t *testing.T) { assert.Nil(t, config.Reporter, "Default reporter should be nil") } +// testDynamicRegistry is a minimal vmcp.DynamicRegistry for testing version-change detection. +type testDynamicRegistry struct { + mu sync.Mutex + version uint64 +} + +func (r *testDynamicRegistry) Version() uint64 { + r.mu.Lock() + defer r.mu.Unlock() + return r.version +} + +func (*testDynamicRegistry) List(_ context.Context) []vmcp.Backend { return nil } +func (*testDynamicRegistry) Get(_ context.Context, _ string) *vmcp.Backend { return nil } +func (*testDynamicRegistry) Count() int { return 0 } + +func (r *testDynamicRegistry) Upsert(_ vmcp.Backend) error { + r.mu.Lock() + defer r.mu.Unlock() + r.version++ + return nil +} + +func (r *testDynamicRegistry) Remove(_ string) error { + r.mu.Lock() + defer r.mu.Unlock() + r.version++ + return nil +} + +// TestPeriodicStatusReporting_ReactsToVersionChange verifies that when the backend +// registry version changes, an immediate status report is triggered via the version-polling +// ticker rather than waiting for the full reporting interval. +func TestPeriodicStatusReporting_ReactsToVersionChange(t *testing.T) { + t.Parallel() + + reporter := &mockReporter{} + reg := &testDynamicRegistry{} + server := &Server{ + backendRegistry: reg, + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Use a long interval so the periodic tick never fires during the test. + config := StatusReportingConfig{ + Interval: 30 * time.Second, + Reporter: reporter, + } + + done := make(chan struct{}) + go func() { + defer close(done) + server.periodicStatusReporting(ctx, config) + }() + + // Wait for the initial immediate report before triggering a version change. + require.Eventually(t, func() bool { + return reporter.getCallCount() >= 1 + }, 3*time.Second, 10*time.Millisecond, "expected initial immediate status report") + + countAfterInit := reporter.getCallCount() + + // Trigger a version bump to simulate a backend being removed from the registry. + require.NoError(t, reg.Remove("some-backend")) + + // The version-polling ticker fires every 2 seconds; allow up to 5 seconds. + require.Eventually(t, func() bool { + return reporter.getCallCount() > countAfterInit + }, 5*time.Second, 10*time.Millisecond, + "version change should trigger an immediate status report without waiting for the 30s interval") + + cancel() + <-done +} + // TestReportStatus tests the reportStatus method. func TestReportStatus(t *testing.T) { t.Parallel() diff --git a/pkg/vmcp/types.go b/pkg/vmcp/types.go index 2a525691d9..5db4b52332 100644 --- a/pkg/vmcp/types.go +++ b/pkg/vmcp/types.go @@ -516,7 +516,7 @@ type HealthChecker interface { // Note: Resource _meta forwarding is not currently supported due to MCP SDK handler // signature limitations; the Meta field is preserved for future SDK improvements. // -//go:generate mockgen -destination=mocks/mock_backend_client.go -package=mocks -source=types.go BackendClient HealthChecker +//go:generate mockgen -destination=mocks/mock_backend_client.go -package=mocks -source=types.go BackendClient HealthChecker ConnectionFlusher type BackendClient interface { // CallTool invokes a tool on the backend MCP server. // The meta parameter contains _meta fields from the client request that should be forwarded to the backend. @@ -538,6 +538,14 @@ type BackendClient interface { ListCapabilities(ctx context.Context, target *BackendTarget) (*CapabilityList, error) } +// ConnectionFlusher can flush idle connections for a specific backend. +// Implemented by BackendClient when the underlying transport supports connection pooling. +// Flushing evicts stale keep-alive connections, forcing a fresh dial on the next request. +// This is used after health check failures to recover from stale connections to replaced pods. +type ConnectionFlusher interface { + FlushIdleConnections(backendID string) +} + // CapabilityList contains the capabilities from a backend's MCP server. // This is returned by BackendClient.ListCapabilities(). type CapabilityList struct {