diff --git a/internal/api/modules/amp/amp.go b/internal/api/modules/amp/amp.go index 2bd83865..9b256ac5 100644 --- a/internal/api/modules/amp/amp.go +++ b/internal/api/modules/amp/amp.go @@ -27,11 +27,20 @@ type Option func(*AmpModule) type AmpModule struct { secretSource SecretSource proxy *httputil.ReverseProxy + proxyMu sync.RWMutex // protects proxy for hot-reload accessManager *sdkaccess.Manager authMiddleware_ gin.HandlerFunc modelMapper *DefaultModelMapper enabled bool registerOnce sync.Once + + // restrictToLocalhost controls localhost-only access for management routes (hot-reloadable) + restrictToLocalhost bool + restrictMu sync.RWMutex + + // configMu protects lastConfig for partial reload comparison + configMu sync.RWMutex + lastConfig *config.AmpCode } // New creates a new Amp routing module with the given options. @@ -107,6 +116,13 @@ func (m *AmpModule) Register(ctx modules.Context) error { // Initialize model mapper from config (for routing unavailable models to alternatives) m.modelMapper = NewModelMapper(settings.ModelMappings) + // Store initial config for partial reload comparison + settingsCopy := settings + m.lastConfig = &settingsCopy + + // Initialize localhost restriction setting (hot-reloadable) + m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost) + // Always register provider aliases - these work without an upstream m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth) @@ -131,13 +147,12 @@ func (m *AmpModule) Register(ctx modules.Context) error { return } - m.proxy = proxy + m.setProxy(proxy) m.enabled = true // Register management proxy routes (requires upstream) - // Restrict to localhost by default for security (prevents drive-by browser attacks) - handler := proxyHandler(proxy) - m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, handler, settings.RestrictManagementToLocalhost) + // Uses dynamic middleware that checks m.IsRestrictedToLocalhost() for hot-reload support + m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler) log.Infof("amp upstream proxy enabled for: %s", upstreamURL) log.Debug("amp provider alias routes registered") @@ -162,45 +177,165 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc { } } -// OnConfigUpdated handles configuration updates. -// Currently requires restart for URL changes (could be enhanced for dynamic updates). +// OnConfigUpdated handles configuration updates with partial reload support. +// Only updates components that have actually changed to avoid unnecessary work. +// Supports hot-reload for: model-mappings, upstream-api-key, upstream-url, restrict-management-to-localhost. func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error { - settings := cfg.AmpCode + newSettings := cfg.AmpCode + + // Get previous config for comparison + m.configMu.RLock() + oldSettings := m.lastConfig + m.configMu.RUnlock() + + // Track what changed for logging + var changes []string + + // Check model mappings change + modelMappingsChanged := m.hasModelMappingsChanged(oldSettings, &newSettings) + if modelMappingsChanged { + if m.modelMapper != nil { + m.modelMapper.UpdateMappings(newSettings.ModelMappings) + changes = append(changes, "model-mappings") + if m.enabled { + log.Infof("amp config partial reload: model mappings updated (%d entries)", len(newSettings.ModelMappings)) + } + } else if m.enabled { + log.Warnf("amp model mapper not initialized, skipping model mapping update") + } + } - // Update model mappings (hot-reload supported) - if m.modelMapper != nil { - m.modelMapper.UpdateMappings(settings.ModelMappings) - if m.enabled { - log.Infof("amp config updated: reloading %d model mapping(s)", len(settings.ModelMappings)) + if m.enabled { + // Check upstream URL change - now supports hot-reload + newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL) + oldUpstreamURL := "" + if oldSettings != nil { + oldUpstreamURL = strings.TrimSpace(oldSettings.UpstreamURL) + } + + if newUpstreamURL == "" && oldUpstreamURL != "" { + log.Warn("amp upstream URL removed from config, proxy has been disabled") + m.setProxy(nil) + changes = append(changes, "upstream-url(disabled)") + } else if newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" { + // Recreate proxy with new URL + proxy, err := createReverseProxy(newUpstreamURL, m.secretSource) + if err != nil { + log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err) + } else { + m.setProxy(proxy) + changes = append(changes, "upstream-url") + log.Infof("amp config partial reload: upstream URL updated (%s -> %s)", oldUpstreamURL, newUpstreamURL) + } + } + + // Check API key change + apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings) + if apiKeyChanged { + if m.secretSource != nil { + if ms, ok := m.secretSource.(*MultiSourceSecret); ok { + ms.UpdateExplicitKey(newSettings.UpstreamAPIKey) + ms.InvalidateCache() + changes = append(changes, "upstream-api-key") + log.Debug("amp config partial reload: secret cache invalidated") + } + } + } + + // Check restrict-management-to-localhost change - now supports hot-reload + if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost { + m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost) + changes = append(changes, "restrict-management-to-localhost") + if newSettings.RestrictManagementToLocalhost { + log.Infof("amp config partial reload: management routes now restricted to localhost") + } else { + log.Warnf("amp config partial reload: management routes now accessible from any IP - this is insecure!") + } } - } else if m.enabled { - log.Warnf("amp model mapper not initialized, skipping model mapping update") } - if !m.enabled { - return nil + // Store current config for next comparison + m.configMu.Lock() + settingsCopy := newSettings // copy struct + m.lastConfig = &settingsCopy + m.configMu.Unlock() + + // Log summary if any changes detected + if len(changes) > 0 { + log.Debugf("amp config partial reload completed: %v", changes) + } else { + log.Debug("amp config checked: no changes detected") } - upstreamURL := strings.TrimSpace(settings.UpstreamURL) - if upstreamURL == "" { - log.Warn("amp upstream URL removed from config, restart required to disable") - return nil + return nil +} + +// hasModelMappingsChanged compares old and new model mappings. +func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.AmpCode) bool { + if old == nil { + return len(new.ModelMappings) > 0 } - // If API key changed, invalidate the cache - if m.secretSource != nil { - if ms, ok := m.secretSource.(*MultiSourceSecret); ok { - ms.UpdateExplicitKey(settings.UpstreamAPIKey) - ms.InvalidateCache() - log.Debug("amp secret cache invalidated due to config update") + if len(old.ModelMappings) != len(new.ModelMappings) { + return true + } + + // Build map for efficient comparison + oldMap := make(map[string]string, len(old.ModelMappings)) + for _, mapping := range old.ModelMappings { + oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To) + } + + for _, mapping := range new.ModelMappings { + from := strings.TrimSpace(mapping.From) + to := strings.TrimSpace(mapping.To) + if oldTo, exists := oldMap[from]; !exists || oldTo != to { + return true } } - log.Debug("amp config updated (restart required for URL changes)") - return nil + return false +} + +// hasAPIKeyChanged compares old and new API keys. +func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) bool { + oldKey := "" + if old != nil { + oldKey = strings.TrimSpace(old.UpstreamAPIKey) + } + newKey := strings.TrimSpace(new.UpstreamAPIKey) + return oldKey != newKey } // GetModelMapper returns the model mapper instance (for testing/debugging). func (m *AmpModule) GetModelMapper() *DefaultModelMapper { return m.modelMapper } + +// getProxy returns the current proxy instance (thread-safe for hot-reload). +func (m *AmpModule) getProxy() *httputil.ReverseProxy { + m.proxyMu.RLock() + defer m.proxyMu.RUnlock() + return m.proxy +} + +// setProxy updates the proxy instance (thread-safe for hot-reload). +func (m *AmpModule) setProxy(proxy *httputil.ReverseProxy) { + m.proxyMu.Lock() + defer m.proxyMu.Unlock() + m.proxy = proxy +} + +// IsRestrictedToLocalhost returns whether management routes are restricted to localhost. +func (m *AmpModule) IsRestrictedToLocalhost() bool { + m.restrictMu.RLock() + defer m.restrictMu.RUnlock() + return m.restrictToLocalhost +} + +// setRestrictToLocalhost updates the localhost restriction setting. +func (m *AmpModule) setRestrictToLocalhost(restrict bool) { + m.restrictMu.Lock() + defer m.restrictMu.Unlock() + m.restrictToLocalhost = restrict +} diff --git a/internal/api/modules/amp/model_mapping_test.go b/internal/api/modules/amp/model_mapping_test.go index c11d61bd..4f4e5a8e 100644 --- a/internal/api/modules/amp/model_mapping_test.go +++ b/internal/api/modules/amp/model_mapping_test.go @@ -152,9 +152,9 @@ func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) { mapper := NewModelMapper(nil) mapper.UpdateMappings([]config.AmpModelMapping{ - {From: "", To: "model-b"}, // Invalid: empty from - {From: "model-a", To: ""}, // Invalid: empty to - {From: " ", To: "model-b"}, // Invalid: whitespace from + {From: "", To: "model-b"}, // Invalid: empty from + {From: "model-a", To: ""}, // Invalid: empty to + {From: " ", To: "model-b"}, // Invalid: whitespace from {From: "model-c", To: "model-d"}, // Valid }) diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index b7105a14..b986a53a 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -14,15 +14,16 @@ import ( log "github.com/sirupsen/logrus" ) -// localhostOnlyMiddleware restricts access to localhost (127.0.0.1, ::1) only. -// Returns 403 Forbidden for non-localhost clients. -// -// Security: Uses RemoteAddr (actual TCP connection) instead of ClientIP() to prevent -// header spoofing attacks via X-Forwarded-For or similar headers. This means the -// middleware will not work correctly behind reverse proxies - users deploying behind -// nginx/Cloudflare should disable this feature and use firewall rules instead. -func localhostOnlyMiddleware() gin.HandlerFunc { +// localhostOnlyMiddleware returns a middleware that dynamically checks the module's +// localhost restriction setting. This allows hot-reload of the restriction without restarting. +func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc { return func(c *gin.Context) { + // Check current setting (hot-reloadable) + if !m.IsRestrictedToLocalhost() { + c.Next() + return + } + // Use actual TCP connection address (RemoteAddr) to prevent header spoofing // This cannot be forged by X-Forwarded-For or other client-controlled headers remoteAddr := c.Request.RemoteAddr @@ -79,21 +80,32 @@ func noCORSMiddleware() gin.HandlerFunc { // registerManagementRoutes registers Amp management proxy routes // These routes proxy through to the Amp control plane for OAuth, user management, etc. -// If restrictToLocalhost is true, routes will only accept connections from 127.0.0.1/::1. -func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, proxyHandler gin.HandlerFunc, restrictToLocalhost bool) { +// Uses dynamic middleware and proxy getter for hot-reload support. +func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler) { ampAPI := engine.Group("/api") // Always disable CORS for management routes to prevent browser-based attacks ampAPI.Use(noCORSMiddleware()) - // Apply localhost-only restriction if configured - if restrictToLocalhost { - ampAPI.Use(localhostOnlyMiddleware()) + // Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost()) + ampAPI.Use(m.localhostOnlyMiddleware()) + + if m.IsRestrictedToLocalhost() { log.Info("amp management routes restricted to localhost only (CORS disabled)") } else { log.Warn("amp management routes are NOT restricted to localhost - this is insecure!") } + // Dynamic proxy handler that uses m.getProxy() for hot-reload support + proxyHandler := func(c *gin.Context) { + proxy := m.getProxy() + if proxy == nil { + c.JSON(503, gin.H{"error": "amp upstream proxy not available"}) + return + } + proxy.ServeHTTP(c.Writer, c.Request) + } + // Management routes - these are proxied directly to Amp upstream ampAPI.Any("/internal", proxyHandler) ampAPI.Any("/internal/*path", proxyHandler) @@ -114,11 +126,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha ampAPI.Any("/tab/*path", proxyHandler) // Root-level routes that AMP CLI expects without /api prefix - // These need the same security middleware as the /api/* routes - rootMiddleware := []gin.HandlerFunc{noCORSMiddleware()} - if restrictToLocalhost { - rootMiddleware = append(rootMiddleware, localhostOnlyMiddleware()) - } + // These need the same security middleware as the /api/* routes (dynamic for hot-reload) + rootMiddleware := []gin.HandlerFunc{noCORSMiddleware(), m.localhostOnlyMiddleware()} engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...) // Root-level auth routes for CLI login flow @@ -134,7 +143,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler) geminiBridge := createGeminiBridgeHandler(geminiHandlers) geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy { - return m.proxy + return m.getProxy() }) geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge) @@ -177,10 +186,10 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler) // Create fallback handler wrapper that forwards to ampcode.com when provider not found - // Uses lazy evaluation to access proxy (which is created after routes are registered) + // Uses m.getProxy() for hot-reload support (proxy can be updated at runtime) // Also includes model mapping support for routing unavailable models to alternatives fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { - return m.proxy + return m.getProxy() }, m.modelMapper) // Provider-specific routes under /api/provider/:provider diff --git a/internal/api/modules/amp/routes_test.go b/internal/api/modules/amp/routes_test.go index 89e43506..a40852c0 100644 --- a/internal/api/modules/amp/routes_test.go +++ b/internal/api/modules/amp/routes_test.go @@ -13,16 +13,26 @@ func TestRegisterManagementRoutes(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() - // Spy to track if proxy handler was called + // Create module with proxy for testing + m := &AmpModule{ + restrictToLocalhost: false, // disable localhost restriction for tests + } + + // Create a mock proxy that tracks calls proxyCalled := false - proxyHandler := func(c *gin.Context) { + mockProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { proxyCalled = true - c.String(200, "proxied") - } + w.WriteHeader(200) + w.Write([]byte("proxied")) + })) + defer mockProxy.Close() + + // Create real proxy to mock server + proxy, _ := createReverseProxy(mockProxy.URL, NewStaticSecretSource("")) + m.setProxy(proxy) - m := &AmpModule{} base := &handlers.BaseAPIHandler{} - m.registerManagementRoutes(r, base, proxyHandler, false) // false = don't restrict to localhost in tests + m.registerManagementRoutes(r, base) managementPaths := []struct { path string @@ -41,9 +51,9 @@ func TestRegisterManagementRoutes(t *testing.T) { {"/api/otel", http.MethodGet}, {"/api/tab", http.MethodGet}, {"/api/tab/some/path", http.MethodGet}, - {"/auth", http.MethodGet}, // Root-level auth route - {"/auth/cli-login", http.MethodGet}, // CLI login flow - {"/auth/callback", http.MethodGet}, // OAuth callback + {"/auth", http.MethodGet}, // Root-level auth route + {"/auth/cli-login", http.MethodGet}, // CLI login flow + {"/auth/callback", http.MethodGet}, // OAuth callback // Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST {"/api/provider/google/v1beta1/models", http.MethodGet}, {"/api/provider/google/v1beta1/models", http.MethodPost}, @@ -231,8 +241,13 @@ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() - // Apply localhost-only middleware - r.Use(localhostOnlyMiddleware()) + // Create module with localhost restriction enabled + m := &AmpModule{ + restrictToLocalhost: true, + } + + // Apply dynamic localhost-only middleware + r.Use(m.localhostOnlyMiddleware()) r.GET("/test", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) @@ -305,3 +320,53 @@ func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) { }) } } + +func TestLocalhostOnlyMiddleware_HotReload(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + + // Create module with localhost restriction initially enabled + m := &AmpModule{ + restrictToLocalhost: true, + } + + // Apply dynamic localhost-only middleware + r.Use(m.localhostOnlyMiddleware()) + r.GET("/test", func(c *gin.Context) { + c.String(http.StatusOK, "ok") + }) + + // Test 1: Remote IP should be blocked when restriction is enabled + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "192.168.1.100:12345" + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected 403 when restriction enabled, got %d", w.Code) + } + + // Test 2: Hot-reload - disable restriction + m.setRestrictToLocalhost(false) + + req = httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "192.168.1.100:12345" + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200 after disabling restriction, got %d", w.Code) + } + + // Test 3: Hot-reload - re-enable restriction + m.setRestrictToLocalhost(true) + + req = httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "192.168.1.100:12345" + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected 403 after re-enabling restriction, got %d", w.Code) + } +}