Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 103 additions & 11 deletions pkg/vmcp/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"log/slog"
"net"
"net/http"
"sync"
"time"

"github.com/mark3labs/mcp-go/client"
Expand Down Expand Up @@ -60,6 +61,14 @@ type httpBackendClient struct {
// registry manages authentication strategies for outgoing requests to backend MCP servers.
// Must not be nil - use UnauthenticatedStrategy for no authentication.
registry vmcpauth.OutgoingAuthRegistry

// transportMu protects transportCache.
transportMu sync.RWMutex

// transportCache holds one *http.Transport per backend ID so connections are
// reused across calls to the same backend. Call FlushIdleConnections to evict
// stale keep-alive connections (e.g., after a health check failure or backend replacement).
transportCache map[string]*http.Transport
}

// NewHTTPBackendClient creates a new HTTP-based backend client.
Expand All @@ -76,12 +85,75 @@ func NewHTTPBackendClient(registry vmcpauth.OutgoingAuthRegistry) (vmcp.BackendC
}

c := &httpBackendClient{
registry: registry,
registry: registry,
transportCache: make(map[string]*http.Transport),
}
c.clientFactory = c.defaultClientFactory
return c, nil
}

// newBackendTransport creates a *http.Transport with the same defaults as http.DefaultTransport.
// If http.DefaultTransport is a *http.Transport, it is cloned directly (preserving any
// environment-specific settings like TLS config or proxy overrides). Otherwise a transport
// with the standard Go defaults is constructed, preserving proxy, dial timeout, HTTP/2, and
// idle-connection settings that a zero-value &http.Transport{} would drop.
func newBackendTransport() *http.Transport {
if dt, ok := http.DefaultTransport.(*http.Transport); ok {
return dt.Clone()
}
// http.DefaultTransport has been replaced (e.g. in tests or by a third-party library).
// Construct a transport with the same defaults as the Go standard library uses for
// http.DefaultTransport so we don't silently drop proxy, timeout, or HTTP/2 settings.
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}

// getOrCreateTransport returns the cached *http.Transport for a backend, creating one if needed.
// Each backend gets its own transport so connection pools are isolated per backend.
func (h *httpBackendClient) getOrCreateTransport(backendID string) *http.Transport {
h.transportMu.RLock()
if t, ok := h.transportCache[backendID]; ok {
h.transportMu.RUnlock()
return t
}
h.transportMu.RUnlock()

h.transportMu.Lock()
defer h.transportMu.Unlock()
if t, ok := h.transportCache[backendID]; ok {
return t
}
t := newBackendTransport()
h.transportCache[backendID] = t
return t
}

// FlushIdleConnections closes all idle keep-alive connections for the given backend
// and removes its transport from the cache so the next request gets a fresh connection.
// This implements vmcp.ConnectionFlusher.
func (h *httpBackendClient) FlushIdleConnections(backendID string) {
h.transportMu.Lock()
t, ok := h.transportCache[backendID]
if ok {
delete(h.transportCache, backendID)
}
h.transportMu.Unlock()
if ok {
t.CloseIdleConnections()
slog.Debug("flushed idle connections for backend", "backend", backendID)
}
}

// roundTripperFunc is a function adapter for http.RoundTripper.
type roundTripperFunc func(*http.Request) (*http.Response, error)

Expand Down Expand Up @@ -173,7 +245,12 @@ func (h *httpBackendClient) resolveAuthStrategy(target *vmcp.BackendTarget) (vmc
func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vmcp.BackendTarget) (*client.Client, error) {
// Build transport chain (outermost to innermost, request execution order):
// size limit (response body) → trace propagation → identity propagation → authentication → HTTP
var baseTransport = http.DefaultTransport
//
// Use the per-backend cached transport so connections are reused across calls
// to the same backend. Each backend has an isolated pool, preventing stale
// keep-alive connections to one backend from affecting others.
// Call FlushIdleConnections after a failure to evict stale connections.
var baseTransport http.RoundTripper = h.getOrCreateTransport(target.WorkloadID)

// Resolve authentication strategy ONCE at client creation time
authStrategy, err := h.resolveAuthStrategy(target)
Expand Down Expand Up @@ -280,8 +357,9 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm
// This enables type-safe error checking with errors.Is() instead of string matching.
//
// Error detection strategy (in order of preference):
// 1. Check for standard Go error types (context errors, net.Error, url.Error)
// 2. Fall back to string pattern matching for library-specific errors (MCP SDK, HTTP libs)
// Check for standard Go error types (context errors, io.EOF, net.Error),
// then mcp-go transport sentinels, then fall back to string pattern matching
// for library-specific errors (MCP SDK, HTTP libs).
//
// Error chain preservation:
// The returned error wraps the sentinel error (ErrTimeout, ErrBackendUnavailable, etc.) with %w
Expand All @@ -296,7 +374,7 @@ func wrapBackendError(err error, backendID string, operation string) error {
return nil
}

// 1. Type-based detection: Check for context deadline/cancellation
// Type-based detection: context deadline/cancellation
if errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("%w: failed to %s for backend %s (timeout): %v",
vmcp.ErrTimeout, operation, backendID, err)
Expand All @@ -306,23 +384,37 @@ func wrapBackendError(err error, backendID string, operation string) error {
vmcp.ErrCancelled, operation, backendID, err)
}

// 2. Type-based detection: Check for io.EOF errors
// These indicate the connection was closed unexpectedly
// Type-based detection: io.EOF errors indicate the connection was closed unexpectedly
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
return fmt.Errorf("%w: failed to %s for backend %s (connection closed): %v",
vmcp.ErrBackendUnavailable, operation, backendID, err)
}

// 3. Type-based detection: Check for net.Error with Timeout() method
// This handles network timeouts from the standard library
// Type-based detection: net.Error with Timeout() handles network timeouts from the standard library
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return fmt.Errorf("%w: failed to %s for backend %s (timeout): %v",
vmcp.ErrTimeout, operation, backendID, err)
}

// 4. String-based detection: Fall back to pattern matching for cases where
// we don't have structured error types (MCP SDK, HTTP libraries with embedded status codes)
// mcp-go transport sentinel errors: check these before string-based fallbacks
// to ensure accurate classification of protocol-level errors.
if errors.Is(err, transport.ErrUnauthorized) {
return fmt.Errorf("%w: failed to %s for backend %s: %v",
vmcp.ErrAuthenticationFailed, operation, backendID, err)
}
// ErrLegacySSEServer is returned for any 4xx (except 401) on initialize POST.
// This includes 403 (auth rejection) and 404/405 (endpoint not found/method not allowed).
// We cannot distinguish auth failures from routing errors without the raw status code,
// so we surface a clear message and classify as backend unavailable to allow recovery.
if errors.Is(err, transport.ErrLegacySSEServer) {
const legacyMsg = "server rejected MCP initialize — possible auth rejection or legacy SSE-only server"
return fmt.Errorf("%w: failed to %s for backend %s (%s): %v",
vmcp.ErrBackendUnavailable, operation, backendID, legacyMsg, err)
}

// String-based detection: fall back to pattern matching for cases where
// we don't have structured error types (MCP SDK, HTTP libraries with embedded status codes).
// Authentication errors (401, 403, auth failures)
if vmcp.IsAuthenticationError(err) {
return fmt.Errorf("%w: failed to %s for backend %s: %v",
Expand Down
69 changes: 69 additions & 0 deletions pkg/vmcp/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ package client
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/propagation"
Expand Down Expand Up @@ -659,6 +661,73 @@ func TestTracePropagatingRoundTripper_ParentChildSpan(t *testing.T) {
"traceparent should contain child span ID, not parent")
}

// TestWrapBackendError verifies that wrapBackendError maps mcp-go transport sentinel
// errors to the correct vmcp sentinel errors for downstream health monitoring.
func TestWrapBackendError(t *testing.T) {
t.Parallel()

tests := []struct {
name string
err error
wantSentinel error
wantMsgContains string
}{
{
name: "nil error returns nil",
err: nil,
wantSentinel: nil,
},
{
// mcp-go returns ErrUnauthorized for 401 on initialize POST.
// Must map to ErrAuthenticationFailed so health monitors classify
// the backend as BackendUnauthenticated, not BackendUnhealthy.
name: "ErrUnauthorized maps to ErrAuthenticationFailed",
err: transport.ErrUnauthorized,
wantSentinel: vmcp.ErrAuthenticationFailed,
},
{
// errors.Is traverses the error chain, so wrapping ErrUnauthorized
// in another error must still produce ErrAuthenticationFailed.
name: "wrapped ErrUnauthorized maps to ErrAuthenticationFailed",
err: fmt.Errorf("transport layer: %w", transport.ErrUnauthorized),
wantSentinel: vmcp.ErrAuthenticationFailed,
},
{
// mcp-go returns ErrLegacySSEServer for non-401 4xx on initialize POST
// (e.g. 403, 404, 405). Classified as backend unavailable so the health
// monitor can recover if the backend is later corrected.
name: "ErrLegacySSEServer maps to ErrBackendUnavailable",
err: transport.ErrLegacySSEServer,
wantSentinel: vmcp.ErrBackendUnavailable,
wantMsgContains: "legacy SSE",
},
{
name: "context.DeadlineExceeded maps to ErrTimeout",
err: context.DeadlineExceeded,
wantSentinel: vmcp.ErrTimeout,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

result := wrapBackendError(tt.err, "test-backend", "initialize")

if tt.err == nil {
assert.NoError(t, result)
return
}

require.Error(t, result)
assert.ErrorIs(t, result, tt.wantSentinel)
if tt.wantMsgContains != "" {
assert.Contains(t, result.Error(), tt.wantMsgContains)
}
})
}
}

func TestResolveAuthStrategy(t *testing.T) {
t.Parallel()

Expand Down
7 changes: 5 additions & 2 deletions pkg/vmcp/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,12 @@ func IsAuthenticationError(err error) bool {
return true
}

// Check for HTTP 401/403 status codes with context
// Match patterns like "401 Unauthorized", "HTTP 401", "status code 401"
// Check for HTTP 401/403 status codes with context.
// Match patterns like "401 Unauthorized", "HTTP 401", "status code 401".
// Also match mcp-go's ErrUnauthorized = "unauthorized (401)" which uses
// reversed order compared to the "401 unauthorized" pattern above.
if strings.Contains(errLower, "401 unauthorized") ||
strings.Contains(errLower, "unauthorized (401)") ||
strings.Contains(errLower, "403 forbidden") ||
strings.Contains(errLower, "http 401") ||
strings.Contains(errLower, "http 403") ||
Expand Down
9 changes: 9 additions & 0 deletions pkg/vmcp/health/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ func (h *healthChecker) CheckHealth(ctx context.Context, target *vmcp.BackendTar
return vmcp.BackendHealthy, nil
}

// FlushIdleConnections closes idle connections for the given backend.
// Implements vmcp.ConnectionFlusher by delegating to the underlying client if it supports it.
// Called by the Monitor after a health check failure to evict stale keep-alive connections.
func (h *healthChecker) FlushIdleConnections(backendID string) {
if flusher, ok := h.client.(vmcp.ConnectionFlusher); ok {
flusher.FlushIdleConnections(backendID)
}
}

// categorizeError determines the appropriate health status based on the error type.
// This uses sentinel error checking with errors.Is() for type-safe error categorization.
// Falls back to string-based detection for backwards compatibility with non-wrapped errors.
Expand Down
10 changes: 10 additions & 0 deletions pkg/vmcp/health/checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,9 @@ func TestIsAuthenticationError(t *testing.T) {
{name: "request unauthorized", err: errors.New("request unauthorized"), expectErr: true},
{name: "access denied", err: errors.New("access denied"), expectErr: true},

// mcp-go ErrUnauthorized format: "unauthorized (401)" (reversed order vs "401 unauthorized")
{name: "unauthorized (401) - mcp-go ErrUnauthorized format", err: errors.New("unauthorized (401)"), expectErr: true},

// Negative cases - should NOT be detected as auth errors
{name: "connection refused", err: errors.New("connection refused"), expectErr: false},
{name: "timeout", err: errors.New("request timeout"), expectErr: false},
Expand Down Expand Up @@ -515,6 +518,13 @@ func TestHealthChecker_CheckHealth_AuthErrorsCategorizesAsUnauthenticated(t *tes
name: "wrapped sentinel auth error",
err: fmt.Errorf("client credentials grant failed: %w", vmcp.ErrAuthenticationFailed),
},
{
// transport.ErrUnauthorized is wrapped with ErrAuthenticationFailed in wrapBackendError,
// so a 401 from the mcp-go transport layer reaches health monitoring as
// BackendUnauthenticated instead of BackendUnhealthy.
name: "mcp-go ErrUnauthorized wrapped as ErrAuthenticationFailed by wrapBackendError",
err: fmt.Errorf("%w: failed to initialize for backend my-backend: unauthorized (401)", vmcp.ErrAuthenticationFailed),
},
}

for _, tt := range tests {
Expand Down
9 changes: 9 additions & 0 deletions pkg/vmcp/health/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,10 @@ func (m *Monitor) UpdateBackends(newBackends []vmcp.Backend) {
}
// Remove backend from status tracker so it no longer appears in status reports
m.statusTracker.RemoveBackend(id)
// Evict the cached transport for this backend to free connections and memory
if flusher, ok := m.checker.(vmcp.ConnectionFlusher); ok {
flusher.FlushIdleConnections(id)
}
}
}
}
Expand Down Expand Up @@ -424,6 +428,11 @@ func (m *Monitor) performHealthCheck(ctx context.Context, backend *vmcp.Backend)
if err != nil {
slog.Debug("health check failed for backend", "backend", backend.Name, "error", err, "status", status)
m.statusTracker.RecordFailure(backend.ID, backend.Name, status, err)
// Flush idle connections so the next attempt gets a fresh TCP connection.
// This recovers from stale keep-alive connections to pods that have been replaced.
if flusher, ok := m.checker.(vmcp.ConnectionFlusher); ok {
flusher.FlushIdleConnections(backend.ID)
}
} else {
// Pass status to RecordSuccess - it may be healthy or degraded (from slow response)
// RecordSuccess will further check for recovering state (had recent failures)
Expand Down
Loading
Loading