Skip to content
Merged

Amp #418

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 163 additions & 28 deletions internal/api/modules/amp/amp.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,20 @@ type Option func(*AmpModule)
type AmpModule struct {
secretSource SecretSource
proxy *httputil.ReverseProxy
proxyMu sync.RWMutex // protects proxy for hot-reload
accessManager *sdkaccess.Manager
authMiddleware_ gin.HandlerFunc
modelMapper *DefaultModelMapper
enabled bool
registerOnce sync.Once

// restrictToLocalhost controls localhost-only access for management routes (hot-reloadable)
restrictToLocalhost bool
restrictMu sync.RWMutex

// configMu protects lastConfig for partial reload comparison
configMu sync.RWMutex
lastConfig *config.AmpCode
}

// New creates a new Amp routing module with the given options.
Expand Down Expand Up @@ -107,6 +116,13 @@ func (m *AmpModule) Register(ctx modules.Context) error {
// Initialize model mapper from config (for routing unavailable models to alternatives)
m.modelMapper = NewModelMapper(settings.ModelMappings)

// Store initial config for partial reload comparison
settingsCopy := settings
m.lastConfig = &settingsCopy

// Initialize localhost restriction setting (hot-reloadable)
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)

// Always register provider aliases - these work without an upstream
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)

Expand All @@ -131,13 +147,12 @@ func (m *AmpModule) Register(ctx modules.Context) error {
return
}

m.proxy = proxy
m.setProxy(proxy)
m.enabled = true

// Register management proxy routes (requires upstream)
// Restrict to localhost by default for security (prevents drive-by browser attacks)
handler := proxyHandler(proxy)
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, handler, settings.RestrictManagementToLocalhost)
// Uses dynamic middleware that checks m.IsRestrictedToLocalhost() for hot-reload support
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler)

log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
log.Debug("amp provider alias routes registered")
Expand All @@ -162,45 +177,165 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
}
}

// OnConfigUpdated handles configuration updates.
// Currently requires restart for URL changes (could be enhanced for dynamic updates).
// OnConfigUpdated handles configuration updates with partial reload support.
// Only updates components that have actually changed to avoid unnecessary work.
// Supports hot-reload for: model-mappings, upstream-api-key, upstream-url, restrict-management-to-localhost.
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
settings := cfg.AmpCode
newSettings := cfg.AmpCode

// Get previous config for comparison
m.configMu.RLock()
oldSettings := m.lastConfig
m.configMu.RUnlock()

// Track what changed for logging
var changes []string

// Check model mappings change
modelMappingsChanged := m.hasModelMappingsChanged(oldSettings, &newSettings)
if modelMappingsChanged {
if m.modelMapper != nil {
m.modelMapper.UpdateMappings(newSettings.ModelMappings)
changes = append(changes, "model-mappings")
if m.enabled {
log.Infof("amp config partial reload: model mappings updated (%d entries)", len(newSettings.ModelMappings))
}
} else if m.enabled {
log.Warnf("amp model mapper not initialized, skipping model mapping update")
}
}

// Update model mappings (hot-reload supported)
if m.modelMapper != nil {
m.modelMapper.UpdateMappings(settings.ModelMappings)
if m.enabled {
log.Infof("amp config updated: reloading %d model mapping(s)", len(settings.ModelMappings))
if m.enabled {
// Check upstream URL change - now supports hot-reload
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
oldUpstreamURL := ""
if oldSettings != nil {
oldUpstreamURL = strings.TrimSpace(oldSettings.UpstreamURL)
}

if newUpstreamURL == "" && oldUpstreamURL != "" {
log.Warn("amp upstream URL removed from config, proxy has been disabled")
m.setProxy(nil)
changes = append(changes, "upstream-url(disabled)")
} else if newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" {
// Recreate proxy with new URL
proxy, err := createReverseProxy(newUpstreamURL, m.secretSource)
if err != nil {
log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err)
} else {
m.setProxy(proxy)
changes = append(changes, "upstream-url")
log.Infof("amp config partial reload: upstream URL updated (%s -> %s)", oldUpstreamURL, newUpstreamURL)
}
}

// Check API key change
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
if apiKeyChanged {
if m.secretSource != nil {
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache()
changes = append(changes, "upstream-api-key")
log.Debug("amp config partial reload: secret cache invalidated")
}
}
}

// Check restrict-management-to-localhost change - now supports hot-reload
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
changes = append(changes, "restrict-management-to-localhost")
if newSettings.RestrictManagementToLocalhost {
log.Infof("amp config partial reload: management routes now restricted to localhost")
} else {
log.Warnf("amp config partial reload: management routes now accessible from any IP - this is insecure!")
}
}
} else if m.enabled {
log.Warnf("amp model mapper not initialized, skipping model mapping update")
}

if !m.enabled {
return nil
// Store current config for next comparison
m.configMu.Lock()
settingsCopy := newSettings // copy struct
m.lastConfig = &settingsCopy
m.configMu.Unlock()

// Log summary if any changes detected
if len(changes) > 0 {
log.Debugf("amp config partial reload completed: %v", changes)
} else {
log.Debug("amp config checked: no changes detected")
}

upstreamURL := strings.TrimSpace(settings.UpstreamURL)
if upstreamURL == "" {
log.Warn("amp upstream URL removed from config, restart required to disable")
return nil
return nil
}

// hasModelMappingsChanged compares old and new model mappings.
func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.AmpCode) bool {
if old == nil {
return len(new.ModelMappings) > 0
}

// If API key changed, invalidate the cache
if m.secretSource != nil {
if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache()
log.Debug("amp secret cache invalidated due to config update")
if len(old.ModelMappings) != len(new.ModelMappings) {
return true
}

// Build map for efficient comparison
oldMap := make(map[string]string, len(old.ModelMappings))
for _, mapping := range old.ModelMappings {
oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To)
}

for _, mapping := range new.ModelMappings {
from := strings.TrimSpace(mapping.From)
to := strings.TrimSpace(mapping.To)
if oldTo, exists := oldMap[from]; !exists || oldTo != to {
return true
}
}

log.Debug("amp config updated (restart required for URL changes)")
return nil
return false
}

// hasAPIKeyChanged compares old and new API keys.
func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) bool {
oldKey := ""
if old != nil {
oldKey = strings.TrimSpace(old.UpstreamAPIKey)
}
newKey := strings.TrimSpace(new.UpstreamAPIKey)
return oldKey != newKey
}

// GetModelMapper returns the model mapper instance (for testing/debugging).
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
return m.modelMapper
}

// getProxy returns the current proxy instance (thread-safe for hot-reload).
func (m *AmpModule) getProxy() *httputil.ReverseProxy {
m.proxyMu.RLock()
defer m.proxyMu.RUnlock()
return m.proxy
}

// setProxy updates the proxy instance (thread-safe for hot-reload).
func (m *AmpModule) setProxy(proxy *httputil.ReverseProxy) {
m.proxyMu.Lock()
defer m.proxyMu.Unlock()
m.proxy = proxy
}

// IsRestrictedToLocalhost returns whether management routes are restricted to localhost.
func (m *AmpModule) IsRestrictedToLocalhost() bool {
m.restrictMu.RLock()
defer m.restrictMu.RUnlock()
return m.restrictToLocalhost
}

// setRestrictToLocalhost updates the localhost restriction setting.
func (m *AmpModule) setRestrictToLocalhost(restrict bool) {
m.restrictMu.Lock()
defer m.restrictMu.Unlock()
m.restrictToLocalhost = restrict
}
6 changes: 3 additions & 3 deletions internal/api/modules/amp/model_mapping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) {
mapper := NewModelMapper(nil)

mapper.UpdateMappings([]config.AmpModelMapping{
{From: "", To: "model-b"}, // Invalid: empty from
{From: "model-a", To: ""}, // Invalid: empty to
{From: " ", To: "model-b"}, // Invalid: whitespace from
{From: "", To: "model-b"}, // Invalid: empty from
{From: "model-a", To: ""}, // Invalid: empty to
{From: " ", To: "model-b"}, // Invalid: whitespace from
{From: "model-c", To: "model-d"}, // Valid
})

Expand Down
51 changes: 30 additions & 21 deletions internal/api/modules/amp/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@ import (
log "github.com/sirupsen/logrus"
)

// localhostOnlyMiddleware restricts access to localhost (127.0.0.1, ::1) only.
// Returns 403 Forbidden for non-localhost clients.
//
// Security: Uses RemoteAddr (actual TCP connection) instead of ClientIP() to prevent
// header spoofing attacks via X-Forwarded-For or similar headers. This means the
// middleware will not work correctly behind reverse proxies - users deploying behind
// nginx/Cloudflare should disable this feature and use firewall rules instead.
func localhostOnlyMiddleware() gin.HandlerFunc {
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Check current setting (hot-reloadable)
if !m.IsRestrictedToLocalhost() {
c.Next()
return
}

// Use actual TCP connection address (RemoteAddr) to prevent header spoofing
// This cannot be forged by X-Forwarded-For or other client-controlled headers
remoteAddr := c.Request.RemoteAddr
Expand Down Expand Up @@ -79,21 +80,32 @@ func noCORSMiddleware() gin.HandlerFunc {

// registerManagementRoutes registers Amp management proxy routes
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
// If restrictToLocalhost is true, routes will only accept connections from 127.0.0.1/::1.
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, proxyHandler gin.HandlerFunc, restrictToLocalhost bool) {
// Uses dynamic middleware and proxy getter for hot-reload support.
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler) {
ampAPI := engine.Group("/api")

// Always disable CORS for management routes to prevent browser-based attacks
ampAPI.Use(noCORSMiddleware())

// Apply localhost-only restriction if configured
if restrictToLocalhost {
ampAPI.Use(localhostOnlyMiddleware())
// Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
ampAPI.Use(m.localhostOnlyMiddleware())

if m.IsRestrictedToLocalhost() {
log.Info("amp management routes restricted to localhost only (CORS disabled)")
} else {
log.Warn("amp management routes are NOT restricted to localhost - this is insecure!")
}

// Dynamic proxy handler that uses m.getProxy() for hot-reload support
proxyHandler := func(c *gin.Context) {
proxy := m.getProxy()
if proxy == nil {
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
return
}
proxy.ServeHTTP(c.Writer, c.Request)
}

// Management routes - these are proxied directly to Amp upstream
ampAPI.Any("/internal", proxyHandler)
ampAPI.Any("/internal/*path", proxyHandler)
Expand All @@ -114,11 +126,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
ampAPI.Any("/tab/*path", proxyHandler)

// Root-level routes that AMP CLI expects without /api prefix
// These need the same security middleware as the /api/* routes
rootMiddleware := []gin.HandlerFunc{noCORSMiddleware()}
if restrictToLocalhost {
rootMiddleware = append(rootMiddleware, localhostOnlyMiddleware())
}
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
rootMiddleware := []gin.HandlerFunc{noCORSMiddleware(), m.localhostOnlyMiddleware()}
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)

// Root-level auth routes for CLI login flow
Expand All @@ -134,7 +143,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
geminiBridge := createGeminiBridgeHandler(geminiHandlers)
geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy {
return m.proxy
return m.getProxy()
})
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)

Expand Down Expand Up @@ -177,10 +186,10 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)

// Create fallback handler wrapper that forwards to ampcode.com when provider not found
// Uses lazy evaluation to access proxy (which is created after routes are registered)
// Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
// Also includes model mapping support for routing unavailable models to alternatives
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.proxy
return m.getProxy()
}, m.modelMapper)

// Provider-specific routes under /api/provider/:provider
Expand Down
Loading