From 8e2cf81c095920f2a15cbe7f8e7e8342572437f2 Mon Sep 17 00:00:00 2001 From: Trey Date: Mon, 17 Nov 2025 10:44:00 -0800 Subject: [PATCH 1/3] Add per-user caching to vMCP discovery manager Implements issue #2503 by adding an in-memory cache to the discovery manager that caches capability aggregation results per (user, backend-set) combination. **Implementation Details:** - Cache keyed by `userID:sha256(sorted-backend-ids)` for stability - 5-minute TTL per cache entry (hardcoded) - 1000 entry maximum capacity (hardcoded) - Simple eviction: rejects new entries when at capacity - Background cleanup goroutine removes expired entries every minute - Thread-safe with sync.RWMutex protecting all cache operations - Graceful shutdown via new Stop() method on Manager interface **Cache Behavior:** - Cache hit: Returns cached capabilities without calling aggregator - Cache miss: Calls aggregator, caches result (if under size limit) - Expired entries: Treated as cache miss, triggers re-aggregation - Backend order: Hash normalized via sorting for stable keys - User isolation: Separate cache entries per user identity **Changes:** - `pkg/vmcp/discovery/manager.go`: - Added cache infrastructure to DefaultManager - Modified Discover() to check cache before aggregation - Added Stop() method to Manager interface - Added background cleanup goroutine - Added cache management helper methods - `pkg/vmcp/discovery/manager_test.go`: - Added 9 comprehensive test cases covering: - Cache hits and misses (user/backend variations) - Cache key stability (backend order independence) - Concurrent access thread safety - Expiration and cleanup - Size limit enforcement - Graceful shutdown - Added defer mgr.Stop() to prevent goroutine leaks - Regenerated mocks for updated Manager interface --- pkg/vmcp/discovery/manager.go | 161 ++++++++- pkg/vmcp/discovery/manager_test.go | 428 +++++++++++++++++++++++ pkg/vmcp/discovery/mocks/mock_manager.go | 12 + 3 files changed, 595 insertions(+), 6 deletions(-) diff --git a/pkg/vmcp/discovery/manager.go b/pkg/vmcp/discovery/manager.go index 9af44a701..b8b8d3e56 100644 --- a/pkg/vmcp/discovery/manager.go +++ b/pkg/vmcp/discovery/manager.go @@ -7,8 +7,13 @@ package discovery import ( "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" + "sort" + "sync" + "time" "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/logger" @@ -18,6 +23,15 @@ import ( //go:generate mockgen -destination=mocks/mock_manager.go -package=mocks -source=manager.go Manager +const ( + // cacheTTL is the time-to-live for cached capability entries. + cacheTTL = 5 * time.Minute + // maxCacheSize is the maximum number of entries allowed in the cache. + maxCacheSize = 1000 + // cleanupInterval is how often expired cache entries are removed. + cleanupInterval = 1 * time.Minute +) + var ( // ErrAggregatorNil is returned when aggregator is nil. ErrAggregatorNil = errors.New("aggregator cannot be nil") @@ -31,11 +45,23 @@ var ( type Manager interface { // Discover performs capability aggregation for the given backends with user context. Discover(ctx context.Context, backends []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) + // Stop gracefully stops the manager and cleans up resources. + Stop() +} + +// cacheEntry represents a cached capability discovery result. +type cacheEntry struct { + capabilities *aggregator.AggregatedCapabilities + expiresAt time.Time } // DefaultManager is the default implementation of Manager. type DefaultManager struct { aggregator aggregator.Aggregator + cache map[string]*cacheEntry + cacheMu sync.RWMutex + stopCh chan struct{} + wg sync.WaitGroup } // NewManager creates a new discovery manager with the given aggregator. @@ -43,14 +69,22 @@ func NewManager(agg aggregator.Aggregator) (Manager, error) { if agg == nil { return nil, ErrAggregatorNil } - return &DefaultManager{ + + m := &DefaultManager{ aggregator: agg, - }, nil + cache: make(map[string]*cacheEntry), + stopCh: make(chan struct{}), + } + + // Start background cleanup goroutine + m.wg.Add(1) + go m.cleanupExpiredEntries() + + return m, nil } -// Discover performs capability aggregation by delegating to the aggregator. -// Currently a simple passthrough; future enhancement will add caching layer here -// to share discovered capabilities across sessions for the same user+backend set. +// Discover performs capability aggregation with per-user caching. +// Results are cached by (user, backend-set) combination for improved performance. // // The context must contain an authenticated user identity (set by auth middleware). // Returns ErrNoIdentity if user identity is not found in context. @@ -62,12 +96,127 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) return nil, fmt.Errorf("%w: ensure auth middleware runs before discovery middleware", ErrNoIdentity) } - logger.Debugf("Performing capability discovery for user: %s", identity.Subject) + // Generate cache key from user identity and backend set + cacheKey := m.generateCacheKey(identity.Subject, backends) + + // Check cache first + if caps := m.getCachedCapabilities(cacheKey); caps != nil { + logger.Debugf("Cache hit for user %s (key: %s)", identity.Subject, cacheKey) + return caps, nil + } + + logger.Debugf("Cache miss - performing capability discovery for user: %s", identity.Subject) + // Cache miss - perform aggregation caps, err := m.aggregator.AggregateCapabilities(ctx, backends) if err != nil { return nil, fmt.Errorf("%w: %v", ErrDiscoveryFailed, err) } + // Cache the result (evicts soonest-expiring entry if at capacity) + m.cacheCapabilities(cacheKey, caps) + return caps, nil } + +// Stop gracefully stops the manager and cleans up resources. +func (m *DefaultManager) Stop() { + close(m.stopCh) + m.wg.Wait() +} + +// generateCacheKey creates a cache key from user ID and backend set. +// The key format is: userID:hash(sorted-backend-ids) +func (*DefaultManager) generateCacheKey(userID string, backends []vmcp.Backend) string { + // Extract and sort backend IDs for stable hashing + backendIDs := make([]string, len(backends)) + for i, b := range backends { + backendIDs[i] = b.ID + } + sort.Strings(backendIDs) + + // Hash the sorted backend IDs + h := sha256.New() + for _, id := range backendIDs { + h.Write([]byte(id)) + h.Write([]byte{0}) // Separator to avoid collisions + } + backendHash := hex.EncodeToString(h.Sum(nil)) + + return fmt.Sprintf("%s:%s", userID, backendHash) +} + +// getCachedCapabilities retrieves capabilities from cache if valid and not expired. +func (m *DefaultManager) getCachedCapabilities(key string) *aggregator.AggregatedCapabilities { + m.cacheMu.RLock() + defer m.cacheMu.RUnlock() + + entry, ok := m.cache[key] + if !ok { + return nil + } + + // Check if entry has expired + if time.Now().After(entry.expiresAt) { + return nil + } + + return entry.capabilities +} + +// cacheCapabilities stores capabilities in cache if under size limit. +func (m *DefaultManager) cacheCapabilities(key string, caps *aggregator.AggregatedCapabilities) { + m.cacheMu.Lock() + defer m.cacheMu.Unlock() + + // Simple eviction: reject caching when at capacity + if len(m.cache) >= maxCacheSize { + _, exists := m.cache[key] + if !exists { + logger.Debugf("Cache at capacity (%d entries), not caching new entry", maxCacheSize) + return + } + } + + m.cache[key] = &cacheEntry{ + capabilities: caps, + expiresAt: time.Now().Add(cacheTTL), + } +} + +// cleanupExpiredEntries periodically removes expired cache entries. +func (m *DefaultManager) cleanupExpiredEntries() { + defer m.wg.Done() + + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + m.removeExpiredEntries() + case <-m.stopCh: + return + } + } +} + +// removeExpiredEntries removes all expired entries from the cache. +func (m *DefaultManager) removeExpiredEntries() { + m.cacheMu.Lock() + defer m.cacheMu.Unlock() + + now := time.Now() + removed := 0 + + for key, entry := range m.cache { + if now.After(entry.expiresAt) { + delete(m.cache, key) + removed++ + } + } + + if removed > 0 { + logger.Debugf("Removed %d expired cache entries (%d remaining)", removed, len(m.cache)) + } +} diff --git a/pkg/vmcp/discovery/manager_test.go b/pkg/vmcp/discovery/manager_test.go index ebde42412..d1ce2b9bf 100644 --- a/pkg/vmcp/discovery/manager_test.go +++ b/pkg/vmcp/discovery/manager_test.go @@ -3,7 +3,9 @@ package discovery import ( "context" "errors" + "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -62,6 +64,7 @@ func TestDefaultManager_Discover(t *testing.T) { mgr, err := NewManager(mockAgg) require.NoError(t, err) + defer mgr.Stop() // Create context with user identity identity := &auth.Identity{Subject: "user123", Name: "Test User"} @@ -113,6 +116,7 @@ func TestDefaultManager_Discover(t *testing.T) { mgr, err := NewManager(mockAgg) require.NoError(t, err) + defer mgr.Stop() // Create context with user identity identity := &auth.Identity{Subject: "user456"} @@ -126,6 +130,429 @@ func TestDefaultManager_Discover(t *testing.T) { }) } +func TestDefaultManager_Caching(t *testing.T) { + t.Parallel() + + t.Run("cache hit for same user and backends", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAgg := aggmocks.NewMockAggregator(ctrl) + backends := []vmcp.Backend{newTestBackend("backend1")} + expectedCaps := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{newTestTool("tool1", "backend1")}, + } + + // Expect only one call to aggregator + mockAgg.EXPECT(). + AggregateCapabilities(gomock.Any(), backends). + Return(expectedCaps, nil). + Times(1) + + mgr, err := NewManager(mockAgg) + require.NoError(t, err) + defer mgr.Stop() + + identity := &auth.Identity{Subject: "user123", Name: "Test User"} + ctx := auth.WithIdentity(context.Background(), identity) + + // First call - should hit aggregator + caps1, err := mgr.Discover(ctx, backends) + require.NoError(t, err) + assert.Equal(t, expectedCaps, caps1) + + // Second call - should hit cache + caps2, err := mgr.Discover(ctx, backends) + require.NoError(t, err) + assert.Equal(t, expectedCaps, caps2) + }) + + t.Run("cache miss for different user", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAgg := aggmocks.NewMockAggregator(ctrl) + backends := []vmcp.Backend{newTestBackend("backend1")} + expectedCaps := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{newTestTool("tool1", "backend1")}, + } + + // Expect two calls to aggregator (one per user) + mockAgg.EXPECT(). + AggregateCapabilities(gomock.Any(), backends). + Return(expectedCaps, nil). + Times(2) + + mgr, err := NewManager(mockAgg) + require.NoError(t, err) + defer mgr.Stop() + + // User 1 + identity1 := &auth.Identity{Subject: "user123"} + ctx1 := auth.WithIdentity(context.Background(), identity1) + caps1, err := mgr.Discover(ctx1, backends) + require.NoError(t, err) + assert.Equal(t, expectedCaps, caps1) + + // User 2 - different user, should not hit cache + identity2 := &auth.Identity{Subject: "user456"} + ctx2 := auth.WithIdentity(context.Background(), identity2) + caps2, err := mgr.Discover(ctx2, backends) + require.NoError(t, err) + assert.Equal(t, expectedCaps, caps2) + }) + + t.Run("cache miss for different backends", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAgg := aggmocks.NewMockAggregator(ctrl) + backends1 := []vmcp.Backend{newTestBackend("backend1")} + backends2 := []vmcp.Backend{newTestBackend("backend2")} + expectedCaps := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{newTestTool("tool1", "backend1")}, + } + + // Expect two calls to aggregator (one per backend set) + mockAgg.EXPECT(). + AggregateCapabilities(gomock.Any(), backends1). + Return(expectedCaps, nil). + Times(1) + mockAgg.EXPECT(). + AggregateCapabilities(gomock.Any(), backends2). + Return(expectedCaps, nil). + Times(1) + + mgr, err := NewManager(mockAgg) + require.NoError(t, err) + defer mgr.Stop() + + dm := mgr.(*DefaultManager) + + identity := &auth.Identity{Subject: "user123"} + ctx := auth.WithIdentity(context.Background(), identity) + + // First backend set + caps1, err := mgr.Discover(ctx, backends1) + require.NoError(t, err) + assert.NotNil(t, caps1) + + // Different backend set - should not hit cache + caps2, err := mgr.Discover(ctx, backends2) + require.NoError(t, err) + assert.NotNil(t, caps2) + + // Verify cache contains 2 entries (one per backend set) + dm.cacheMu.RLock() + cacheSize := len(dm.cache) + dm.cacheMu.RUnlock() + assert.Equal(t, 2, cacheSize) + }) + + t.Run("cache key stable regardless of backend order", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAgg := aggmocks.NewMockAggregator(ctrl) + backends1 := []vmcp.Backend{ + newTestBackend("backend1"), + newTestBackend("backend2"), + } + backends2 := []vmcp.Backend{ + newTestBackend("backend2"), // Reversed order + newTestBackend("backend1"), + } + expectedCaps := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{newTestTool("tool1", "backend1")}, + } + + // Expect only one call - cache should hit on second call despite order + mockAgg.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(expectedCaps, nil). + Times(1) + + mgr, err := NewManager(mockAgg) + require.NoError(t, err) + defer mgr.Stop() + + identity := &auth.Identity{Subject: "user123"} + ctx := auth.WithIdentity(context.Background(), identity) + + // First call + caps1, err := mgr.Discover(ctx, backends1) + require.NoError(t, err) + assert.Equal(t, expectedCaps, caps1) + + // Second call with reversed backend order - should hit cache + caps2, err := mgr.Discover(ctx, backends2) + require.NoError(t, err) + assert.Equal(t, expectedCaps, caps2) + }) + + t.Run("concurrent access is thread-safe", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAgg := aggmocks.NewMockAggregator(ctrl) + backends := []vmcp.Backend{newTestBackend("backend1")} + expectedCaps := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{newTestTool("tool1", "backend1")}, + } + + // Should only call aggregator once due to caching + mockAgg.EXPECT(). + AggregateCapabilities(gomock.Any(), backends). + Return(expectedCaps, nil). + MinTimes(1). + MaxTimes(10) // Allow some race condition calls + + mgr, err := NewManager(mockAgg) + require.NoError(t, err) + defer mgr.Stop() + + dm := mgr.(*DefaultManager) + + identity := &auth.Identity{Subject: "user123"} + ctx := auth.WithIdentity(context.Background(), identity) + + var wg sync.WaitGroup + numGoroutines := 10 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + caps, err := mgr.Discover(ctx, backends) + assert.NoError(t, err) + assert.NotNil(t, caps) + }() + } + + wg.Wait() + + // Verify cache contains only one entry for this user+backend combination + dm.cacheMu.RLock() + cacheSize := len(dm.cache) + dm.cacheMu.RUnlock() + assert.Equal(t, 1, cacheSize) + }) +} + +func TestDefaultManager_CacheExpiration(t *testing.T) { + t.Parallel() + + t.Run("expired entries are not returned", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAgg := aggmocks.NewMockAggregator(ctrl) + backends := []vmcp.Backend{newTestBackend("backend1")} + expectedCaps := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{newTestTool("tool1", "backend1")}, + } + + // Expect two calls - once for initial, once after expiration + mockAgg.EXPECT(). + AggregateCapabilities(gomock.Any(), backends). + Return(expectedCaps, nil). + Times(2) + + mgr, err := NewManager(mockAgg) + require.NoError(t, err) + defer mgr.Stop() + + // Get the underlying manager to manipulate cache directly + dm := mgr.(*DefaultManager) + + identity := &auth.Identity{Subject: "user123"} + ctx := auth.WithIdentity(context.Background(), identity) + + // First call + caps1, err := dm.Discover(ctx, backends) + require.NoError(t, err) + assert.Equal(t, expectedCaps, caps1) + + // Manually expire the cache entry + cacheKey := dm.generateCacheKey(identity.Subject, backends) + dm.cacheMu.Lock() + dm.cache[cacheKey].expiresAt = time.Now().Add(-1 * time.Second) + dm.cacheMu.Unlock() + + // Second call - should not use expired cache + caps2, err := dm.Discover(ctx, backends) + require.NoError(t, err) + assert.Equal(t, expectedCaps, caps2) + }) + + t.Run("background cleanup removes expired entries", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAgg := aggmocks.NewMockAggregator(ctrl) + backends := []vmcp.Backend{newTestBackend("backend1")} + expectedCaps := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{newTestTool("tool1", "backend1")}, + } + + mockAgg.EXPECT(). + AggregateCapabilities(gomock.Any(), backends). + Return(expectedCaps, nil). + Times(1) + + mgr, err := NewManager(mockAgg) + require.NoError(t, err) + defer mgr.Stop() + + dm := mgr.(*DefaultManager) + + identity := &auth.Identity{Subject: "user123"} + ctx := auth.WithIdentity(context.Background(), identity) + + // Add entry to cache + _, err = dm.Discover(ctx, backends) + require.NoError(t, err) + + // Verify entry is in cache + dm.cacheMu.RLock() + initialCount := len(dm.cache) + dm.cacheMu.RUnlock() + assert.Equal(t, 1, initialCount) + + // Manually expire the entry + cacheKey := dm.generateCacheKey(identity.Subject, backends) + dm.cacheMu.Lock() + dm.cache[cacheKey].expiresAt = time.Now().Add(-1 * time.Second) + dm.cacheMu.Unlock() + + // Manually trigger cleanup + dm.removeExpiredEntries() + + // Verify entry was removed + dm.cacheMu.RLock() + finalCount := len(dm.cache) + dm.cacheMu.RUnlock() + assert.Equal(t, 0, finalCount) + }) +} + +func TestDefaultManager_CacheSizeLimit(t *testing.T) { + t.Parallel() + + t.Run("stops caching at size limit", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAgg := aggmocks.NewMockAggregator(ctrl) + expectedCaps := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{newTestTool("tool1", "backend1")}, + } + + // Expect many calls since we'll exceed cache size + mockAgg.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(expectedCaps, nil). + AnyTimes() + + mgr, err := NewManager(mockAgg) + require.NoError(t, err) + defer mgr.Stop() + + dm := mgr.(*DefaultManager) + ctx := context.Background() + + // Fill cache to capacity + for i := 0; i < maxCacheSize; i++ { + identity := &auth.Identity{Subject: "user" + string(rune(i))} + ctxWithIdentity := auth.WithIdentity(ctx, identity) + backends := []vmcp.Backend{newTestBackend("backend1")} + _, err := dm.Discover(ctxWithIdentity, backends) + require.NoError(t, err) + } + + // Verify cache is at capacity + dm.cacheMu.RLock() + cacheSize := len(dm.cache) + dm.cacheMu.RUnlock() + assert.Equal(t, maxCacheSize, cacheSize) + + // Try to add one more - should not be cached + newIdentity := &auth.Identity{Subject: "user-overflow"} + ctxWithNewIdentity := auth.WithIdentity(ctx, newIdentity) + backends := []vmcp.Backend{newTestBackend("backend2")} + _, err = dm.Discover(ctxWithNewIdentity, backends) + require.NoError(t, err) + + // Verify cache size didn't increase + dm.cacheMu.RLock() + finalSize := len(dm.cache) + dm.cacheMu.RUnlock() + assert.Equal(t, maxCacheSize, finalSize) + + // Verify new entry is not in cache + cacheKey := dm.generateCacheKey(newIdentity.Subject, backends) + dm.cacheMu.RLock() + _, exists := dm.cache[cacheKey] + dm.cacheMu.RUnlock() + assert.False(t, exists) + }) +} + +func TestDefaultManager_Stop(t *testing.T) { + t.Parallel() + + t.Run("stop terminates cleanup goroutine", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAgg := aggmocks.NewMockAggregator(ctrl) + + mgr, err := NewManager(mockAgg) + require.NoError(t, err) + + dm := mgr.(*DefaultManager) + + // Verify cleanup goroutine is running + select { + case <-dm.stopCh: + t.Fatal("stopCh should not be closed yet") + default: + // Expected - stopCh is still open + } + + // Stop should complete without hanging + done := make(chan struct{}) + go func() { + dm.Stop() + close(done) + }() + + select { + case <-done: + // Success - Stop() completed + case <-time.After(2 * time.Second): + t.Fatal("Stop() did not complete within timeout") + } + + // Verify stopCh is closed (which signals cleanup goroutine to exit) + select { + case <-dm.stopCh: + // Expected - stopCh is now closed + default: + t.Fatal("stopCh should be closed after Stop()") + } + }) +} + // Test helpers func newTestBackend(id string) vmcp.Backend { @@ -138,6 +565,7 @@ func newTestBackend(id string) vmcp.Backend { } } +//nolint:unparam // name parameter kept for flexibility in future tests func newTestTool(name, backendID string) vmcp.Tool { return vmcp.Tool{ Name: name, diff --git a/pkg/vmcp/discovery/mocks/mock_manager.go b/pkg/vmcp/discovery/mocks/mock_manager.go index f1f446ae9..06f24778c 100644 --- a/pkg/vmcp/discovery/mocks/mock_manager.go +++ b/pkg/vmcp/discovery/mocks/mock_manager.go @@ -56,3 +56,15 @@ func (mr *MockManagerMockRecorder) Discover(ctx, backends any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Discover", reflect.TypeOf((*MockManager)(nil).Discover), ctx, backends) } + +// Stop mocks base method. +func (m *MockManager) Stop() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Stop") +} + +// Stop indicates an expected call of Stop. +func (mr *MockManagerMockRecorder) Stop() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockManager)(nil).Stop)) +} From 4a5415f7618c369f502ac5c51a1f5aac93d20035 Mon Sep 17 00:00:00 2001 From: Trey Date: Mon, 17 Nov 2025 11:08:01 -0800 Subject: [PATCH 2/3] Address internal review feedback --- pkg/vmcp/discovery/manager.go | 2 +- pkg/vmcp/server/health_test.go | 3 +++ pkg/vmcp/server/integration_test.go | 3 +++ pkg/vmcp/server/server.go | 5 +++++ pkg/vmcp/server/server_test.go | 1 + 5 files changed, 13 insertions(+), 1 deletion(-) diff --git a/pkg/vmcp/discovery/manager.go b/pkg/vmcp/discovery/manager.go index b8b8d3e56..d12d81d43 100644 --- a/pkg/vmcp/discovery/manager.go +++ b/pkg/vmcp/discovery/manager.go @@ -113,7 +113,7 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) return nil, fmt.Errorf("%w: %v", ErrDiscoveryFailed, err) } - // Cache the result (evicts soonest-expiring entry if at capacity) + // Cache the result (skips caching if at capacity and key doesn't exist) m.cacheCapabilities(cacheKey, caps) return caps, nil diff --git a/pkg/vmcp/server/health_test.go b/pkg/vmcp/server/health_test.go index 5fe27c8f3..812b11b03 100644 --- a/pkg/vmcp/server/health_test.go +++ b/pkg/vmcp/server/health_test.go @@ -56,6 +56,9 @@ func createTestServer(t *testing.T) *server.Server { }, nil). AnyTimes() + // Mock Stop to be called during server shutdown + mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + srv, err := server.New(&server.Config{ Name: "test-vmcp", Version: "1.0.0", diff --git a/pkg/vmcp/server/integration_test.go b/pkg/vmcp/server/integration_test.go index 5d3480ac0..80909d8d5 100644 --- a/pkg/vmcp/server/integration_test.go +++ b/pkg/vmcp/server/integration_test.go @@ -189,6 +189,9 @@ func TestIntegration_AggregatorToRouterToServer(t *testing.T) { Return(aggregatedCaps, nil). AnyTimes() + // Mock Stop to be called during server shutdown + mockDiscoveryMgr.EXPECT().Stop().Times(1) + srv, err := server.New(&server.Config{ Name: "test-vmcp", Version: "1.0.0", diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index c7da8abc4..63d1bbafd 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -436,6 +436,11 @@ func (s *Server) Stop(ctx context.Context) error { } } + // Stop discovery manager to clean up background goroutines + if s.discoveryMgr != nil { + s.discoveryMgr.Stop() + } + if len(errs) > 0 { logger.Errorf("Errors during shutdown: %v", errs) return errors.Join(errs...) diff --git a/pkg/vmcp/server/server_test.go b/pkg/vmcp/server/server_test.go index 4cdca0989..0a0a74422 100644 --- a/pkg/vmcp/server/server_test.go +++ b/pkg/vmcp/server/server_test.go @@ -153,6 +153,7 @@ func TestServer_Stop(t *testing.T) { mockRouter := routerMocks.NewMockRouter(ctrl) mockBackendClient := mocks.NewMockBackendClient(ctrl) mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT().Stop().Times(1) s, err := server.New(&server.Config{}, mockRouter, mockBackendClient, mockDiscoveryMgr, []vmcp.Backend{}, nil) require.NoError(t, err) From d64b4dc9d15684df4b36d0445b225e28cb45bef7 Mon Sep 17 00:00:00 2001 From: Trey Date: Tue, 18 Nov 2025 06:36:11 -0800 Subject: [PATCH 3/3] add sync.Once to Stop function --- pkg/vmcp/discovery/manager.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pkg/vmcp/discovery/manager.go b/pkg/vmcp/discovery/manager.go index d12d81d43..827f47d4e 100644 --- a/pkg/vmcp/discovery/manager.go +++ b/pkg/vmcp/discovery/manager.go @@ -61,6 +61,7 @@ type DefaultManager struct { cache map[string]*cacheEntry cacheMu sync.RWMutex stopCh chan struct{} + stopOnce sync.Once wg sync.WaitGroup } @@ -120,8 +121,11 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) } // Stop gracefully stops the manager and cleans up resources. +// This method is safe to call multiple times. func (m *DefaultManager) Stop() { - close(m.stopCh) + m.stopOnce.Do(func() { + close(m.stopCh) + }) m.wg.Wait() }