Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 159 additions & 6 deletions pkg/vmcp/discovery/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ package discovery

import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"sort"
"sync"
"time"

"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/logger"
Expand All @@ -18,6 +23,15 @@ import (

//go:generate mockgen -destination=mocks/mock_manager.go -package=mocks -source=manager.go Manager

const (
// cacheTTL is the time-to-live for cached capability entries.
cacheTTL = 5 * time.Minute
// maxCacheSize is the maximum number of entries allowed in the cache.
maxCacheSize = 1000
// cleanupInterval is how often expired cache entries are removed.
cleanupInterval = 1 * time.Minute
)

var (
// ErrAggregatorNil is returned when aggregator is nil.
ErrAggregatorNil = errors.New("aggregator cannot be nil")
Expand All @@ -31,26 +45,47 @@ var (
type Manager interface {
// Discover performs capability aggregation for the given backends with user context.
Discover(ctx context.Context, backends []vmcp.Backend) (*aggregator.AggregatedCapabilities, error)
// Stop gracefully stops the manager and cleans up resources.
Stop()
}

// cacheEntry represents a cached capability discovery result.
type cacheEntry struct {
capabilities *aggregator.AggregatedCapabilities
expiresAt time.Time
}

// DefaultManager is the default implementation of Manager.
type DefaultManager struct {
aggregator aggregator.Aggregator
cache map[string]*cacheEntry
cacheMu sync.RWMutex
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
}

// NewManager creates a new discovery manager with the given aggregator.
func NewManager(agg aggregator.Aggregator) (Manager, error) {
if agg == nil {
return nil, ErrAggregatorNil
}
return &DefaultManager{

m := &DefaultManager{
aggregator: agg,
}, nil
cache: make(map[string]*cacheEntry),
stopCh: make(chan struct{}),
}

// Start background cleanup goroutine
m.wg.Add(1)
go m.cleanupExpiredEntries()

return m, nil
}

// Discover performs capability aggregation by delegating to the aggregator.
// Currently a simple passthrough; future enhancement will add caching layer here
// to share discovered capabilities across sessions for the same user+backend set.
// Discover performs capability aggregation with per-user caching.
// Results are cached by (user, backend-set) combination for improved performance.
//
// The context must contain an authenticated user identity (set by auth middleware).
// Returns ErrNoIdentity if user identity is not found in context.
Expand All @@ -62,12 +97,130 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend)
return nil, fmt.Errorf("%w: ensure auth middleware runs before discovery middleware", ErrNoIdentity)
}

logger.Debugf("Performing capability discovery for user: %s", identity.Subject)
// Generate cache key from user identity and backend set
cacheKey := m.generateCacheKey(identity.Subject, backends)

// Check cache first
if caps := m.getCachedCapabilities(cacheKey); caps != nil {
logger.Debugf("Cache hit for user %s (key: %s)", identity.Subject, cacheKey)
return caps, nil
}

logger.Debugf("Cache miss - performing capability discovery for user: %s", identity.Subject)

// Cache miss - perform aggregation
caps, err := m.aggregator.AggregateCapabilities(ctx, backends)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrDiscoveryFailed, err)
}

// Cache the result (skips caching if at capacity and key doesn't exist)
m.cacheCapabilities(cacheKey, caps)

return caps, nil
}

// Stop gracefully stops the manager and cleans up resources.
// This method is safe to call multiple times.
func (m *DefaultManager) Stop() {
m.stopOnce.Do(func() {
close(m.stopCh)
})
m.wg.Wait()
}

// generateCacheKey creates a cache key from user ID and backend set.
// The key format is: userID:hash(sorted-backend-ids)
func (*DefaultManager) generateCacheKey(userID string, backends []vmcp.Backend) string {
// Extract and sort backend IDs for stable hashing
backendIDs := make([]string, len(backends))
for i, b := range backends {
backendIDs[i] = b.ID
}
sort.Strings(backendIDs)

// Hash the sorted backend IDs
h := sha256.New()
for _, id := range backendIDs {
h.Write([]byte(id))
h.Write([]byte{0}) // Separator to avoid collisions
}
backendHash := hex.EncodeToString(h.Sum(nil))

return fmt.Sprintf("%s:%s", userID, backendHash)
}

// getCachedCapabilities retrieves capabilities from cache if valid and not expired.
func (m *DefaultManager) getCachedCapabilities(key string) *aggregator.AggregatedCapabilities {
m.cacheMu.RLock()
defer m.cacheMu.RUnlock()

entry, ok := m.cache[key]
if !ok {
return nil
}

// Check if entry has expired
if time.Now().After(entry.expiresAt) {
return nil
}

return entry.capabilities
}

// cacheCapabilities stores capabilities in cache if under size limit.
func (m *DefaultManager) cacheCapabilities(key string, caps *aggregator.AggregatedCapabilities) {
m.cacheMu.Lock()
defer m.cacheMu.Unlock()

// Simple eviction: reject caching when at capacity
if len(m.cache) >= maxCacheSize {
_, exists := m.cache[key]
if !exists {
logger.Debugf("Cache at capacity (%d entries), not caching new entry", maxCacheSize)
return
}
}

m.cache[key] = &cacheEntry{
capabilities: caps,
expiresAt: time.Now().Add(cacheTTL),
}
}

// cleanupExpiredEntries periodically removes expired cache entries.
func (m *DefaultManager) cleanupExpiredEntries() {
defer m.wg.Done()

ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()

for {
select {
case <-ticker.C:
m.removeExpiredEntries()
case <-m.stopCh:
return
}
}
}

// removeExpiredEntries removes all expired entries from the cache.
func (m *DefaultManager) removeExpiredEntries() {
m.cacheMu.Lock()
defer m.cacheMu.Unlock()

now := time.Now()
removed := 0

for key, entry := range m.cache {
if now.After(entry.expiresAt) {
delete(m.cache, key)
removed++
}
}

if removed > 0 {
logger.Debugf("Removed %d expired cache entries (%d remaining)", removed, len(m.cache))
}
}
Loading
Loading