From 08ddc1c42baf3e0efb43a18ae8d5e6cb34162e86 Mon Sep 17 00:00:00 2001 From: Takanori Hirano Date: Sun, 24 Aug 2025 16:45:55 +0000 Subject: [PATCH 1/2] feat: refactor backend architecture with interface pattern and trusted proxy support - Extract backend interface for better separation of concerns - Split monolithic backend.go into separate proxy and transparent backends - Add trusted proxy configuration support with IP/CIDR validation - Implement transparent backend for HTTP/HTTPS proxy targets - Add comprehensive test coverage for new backend implementations - Maintain backward compatibility with existing proxy functionality --- main.go | 11 ++ pkg/backend/interface.go | 12 ++ pkg/backend/main_test.go | 13 ++ pkg/backend/{backend.go => proxy.go} | 2 +- .../{backend_test.go => proxy_test.go} | 25 +++- pkg/backend/transparent.go | 100 +++++++++++++ pkg/backend/transparent_test.go | 138 ++++++++++++++++++ pkg/mcp-proxy/main.go | 15 +- 8 files changed, 311 insertions(+), 5 deletions(-) create mode 100644 pkg/backend/interface.go create mode 100644 pkg/backend/main_test.go rename pkg/backend/{backend.go => proxy.go} (98%) rename pkg/backend/{backend_test.go => proxy_test.go} (86%) create mode 100644 pkg/backend/transparent.go create mode 100644 pkg/backend/transparent_test.go diff --git a/main.go b/main.go index 65fb1b1..f622bfa 100644 --- a/main.go +++ b/main.go @@ -53,6 +53,7 @@ func main() { var passwordHash string var proxyBearerToken string var proxyHeaders string + var trustedProxies string rootCmd := &cobra.Command{ Use: "mcp-warp", @@ -107,6 +108,14 @@ func main() { oidcScopesList = []string{"openid", "profile", "email"} } + var trustedProxiesList []string + if trustedProxies != "" { + trustedProxiesList = strings.Split(trustedProxies, ",") + for i := range trustedProxiesList { + trustedProxiesList[i] = strings.TrimSpace(trustedProxiesList[i]) + } + } + // Parse proxy headers into slice var proxyHeadersList []string if proxyHeaders != "" { @@ -142,6 +151,7 @@ func main() { oidcAllowedUsersList, password, passwordHash, + trustedProxiesList, proxyHeadersList, proxyBearerToken, args, @@ -187,6 +197,7 @@ func main() { // Proxy headers configuration rootCmd.Flags().StringVar(&proxyBearerToken, "proxy-bearer-token", getEnvWithDefault("PROXY_BEARER_TOKEN", ""), "Bearer token to add to Authorization header when proxying requests") + rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", getEnvWithDefault("TRUSTED_PROXIES", ""), "Comma-separated list of trusted proxies (IP addresses or CIDR ranges)") rootCmd.Flags().StringVar(&proxyHeaders, "proxy-headers", getEnvWithDefault("PROXY_HEADERS", ""), "Comma-separated list of headers to add when proxying requests (format: Header1:Value1,Header2:Value2)") if err := rootCmd.Execute(); err != nil { diff --git a/pkg/backend/interface.go b/pkg/backend/interface.go new file mode 100644 index 0000000..849d68a --- /dev/null +++ b/pkg/backend/interface.go @@ -0,0 +1,12 @@ +package backend + +import ( + "context" + "net/http" +) + +type Backend interface { + Run(context.Context) (http.Handler, error) + Wait() error + Close() error +} diff --git a/pkg/backend/main_test.go b/pkg/backend/main_test.go new file mode 100644 index 0000000..4dd912c --- /dev/null +++ b/pkg/backend/main_test.go @@ -0,0 +1,13 @@ +package backend + +import ( + "os" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestMain(m *testing.M) { + gin.SetMode(gin.TestMode) + os.Exit(m.Run()) +} diff --git a/pkg/backend/backend.go b/pkg/backend/proxy.go similarity index 98% rename from pkg/backend/backend.go rename to pkg/backend/proxy.go index f024846..8ee7662 100644 --- a/pkg/backend/backend.go +++ b/pkg/backend/proxy.go @@ -25,7 +25,7 @@ type ProxyBackend struct { client *client.Client } -func NewProxyBackend(logger *zap.Logger, cmd []string) *ProxyBackend { +func NewProxyBackend(logger *zap.Logger, cmd []string) Backend { return &ProxyBackend{ logger: logger, cmd: cmd, diff --git a/pkg/backend/backend_test.go b/pkg/backend/proxy_test.go similarity index 86% rename from pkg/backend/backend_test.go rename to pkg/backend/proxy_test.go index a0f401b..7a201f1 100644 --- a/pkg/backend/backend_test.go +++ b/pkg/backend/proxy_test.go @@ -78,9 +78,32 @@ func TestProxyBackendRun(t *testing.T) { defer pb.Close() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() handler, err := pb.Run(ctx) require.NoError(t, err, "Run should not return error") require.NotNil(t, handler, "handler should not be nil") + + checkCh := make(chan struct{}) + go func() { + <-ctx.Done() + close(checkCh) + }() + + timeout := time.After(10 * time.Millisecond) + select { + case <-checkCh: + t.Error("Test completed too early") + case <-timeout: + // Test timed out + } + + cancel() + + timeout = time.After(10 * time.Second) + select { + case <-checkCh: + // Test completed successfully + case <-timeout: + t.Error("Test timed out") + } } diff --git a/pkg/backend/transparent.go b/pkg/backend/transparent.go new file mode 100644 index 0000000..af68a9d --- /dev/null +++ b/pkg/backend/transparent.go @@ -0,0 +1,100 @@ +package backend + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httputil" + "net/netip" + "net/url" + "sync" + + "go.uber.org/zap" +) + +type TransparentBackend struct { + logger *zap.Logger + url *url.URL + trusted []netip.Prefix + ctx context.Context + ctxLock sync.Mutex +} + +func NewTransparentBackend(logger *zap.Logger, u *url.URL, trusted []string) (Backend, error) { + trn := make([]netip.Prefix, 0, len(trusted)) + for _, c := range trusted { + p, err := netip.ParsePrefix(c) + if err != nil { + return nil, err + } + trn = append(trn, p) + } + + return &TransparentBackend{ + logger: logger, + url: u, + trusted: trn, + }, nil +} + +func (p *TransparentBackend) Run(ctx context.Context) (http.Handler, error) { + p.ctxLock.Lock() + defer p.ctxLock.Unlock() + if p.ctx != nil { + return nil, fmt.Errorf("transparent backend is already running") + } + p.ctx = ctx + rp := httputil.ReverseProxy{ + Rewrite: func(pr *httputil.ProxyRequest) { + pr.SetURL(p.url) + if p.isTrusted(pr.In.RemoteAddr) { + pr.Out.Header["X-Forwarded-For"] = pr.In.Header["X-Forwarded-For"] + } + pr.SetXForwarded() + if p.isTrusted(pr.In.RemoteAddr) { + if v := pr.In.Header.Get("X-Forwarded-Host"); v != "" { + pr.Out.Header.Set("X-Forwarded-Host", v) + } + if v := pr.In.Header.Get("X-Forwarded-Proto"); v != "" { + pr.Out.Header.Set("X-Forwarded-Proto", v) + } + if v := pr.In.Header.Get("X-Forwarded-Port"); v != "" { + pr.Out.Header.Set("X-Forwarded-Port", v) + } + } + }, + } + return &rp, nil +} + +func (p *TransparentBackend) isTrusted(hostport string) bool { + if host, _, err := net.SplitHostPort(hostport); err == nil { + hostport = host + } + ip, err := netip.ParseAddr(hostport) + if err != nil { + return false + } + if ip.Is4In6() { + ip = ip.Unmap() + } + for _, p := range p.trusted { + if p.Contains(ip) { + return true + } + } + return false +} + +func (p *TransparentBackend) Wait() error { + if p.ctx == nil { + return nil + } + <-p.ctx.Done() + return nil +} + +func (p *TransparentBackend) Close() error { + return nil +} diff --git a/pkg/backend/transparent_test.go b/pkg/backend/transparent_test.go new file mode 100644 index 0000000..db4f415 --- /dev/null +++ b/pkg/backend/transparent_test.go @@ -0,0 +1,138 @@ +package backend + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestTransparentBackend(t *testing.T) { + r := gin.New() + r.GET("/", func(c *gin.Context) { + c.JSON(http.StatusOK, c.Request.Header) + }) + ts := httptest.NewServer(r) + u, _ := url.Parse(ts.URL) + + be, err := NewTransparentBackend(zap.NewNop(), u, []string{}) + require.NoError(t, err) + handler, err := be.Run(context.Background()) + require.NoError(t, err) + require.NotNil(t, handler) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + var header http.Header + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &header)) + require.Equal(t, "192.0.2.1", header.Get(("X-Forwarded-For"))) + require.Equal(t, "example.com", header.Get(("X-Forwarded-Host"))) + require.Equal(t, "http", header.Get(("X-Forwarded-Proto"))) +} + +func TestTransparentBackendWithProxy(t *testing.T) { + r := gin.New() + r.GET("/", func(c *gin.Context) { + c.JSON(http.StatusOK, c.Request.Header) + }) + ts := httptest.NewServer(r) + u, _ := url.Parse(ts.URL) + + be, err := NewTransparentBackend(zap.NewNop(), u, []string{"0.0.0.0/0"}) + require.NoError(t, err) + handler, err := be.Run(context.Background()) + require.NoError(t, err) + require.NotNil(t, handler) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Forwarded-For", "192.0.3.1") + req.Header.Set("X-Forwarded-Host", "example.org") + req.Header.Set("X-Forwarded-Proto", "https") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + var header http.Header + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &header)) + require.Equal(t, "192.0.3.1, 192.0.2.1", header.Get(("X-Forwarded-For"))) + require.Equal(t, "example.org", header.Get(("X-Forwarded-Host"))) + require.Equal(t, "https", header.Get(("X-Forwarded-Proto"))) +} + +func TestTransparentBackendWithInvalidProxy(t *testing.T) { + r := gin.New() + r.GET("/", func(c *gin.Context) { + c.JSON(http.StatusOK, c.Request.Header) + }) + ts := httptest.NewServer(r) + u, _ := url.Parse(ts.URL) + + be, err := NewTransparentBackend(zap.NewNop(), u, []string{"1.1.1.1/32"}) + require.NoError(t, err) + handler, err := be.Run(context.Background()) + require.NoError(t, err) + require.NotNil(t, handler) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Forwarded-For", "192.0.3.1") + req.Header.Set("X-Forwarded-Host", "example.org") + req.Header.Set("X-Forwarded-Proto", "https") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code) + var header http.Header + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &header)) + require.Equal(t, "192.0.2.1", header.Get(("X-Forwarded-For"))) + require.Equal(t, "example.com", header.Get(("X-Forwarded-Host"))) + require.Equal(t, "http", header.Get(("X-Forwarded-Proto"))) +} + +func TestTransparentBackendRun(t *testing.T) { + r := gin.New() + r.GET("/", func(c *gin.Context) { + c.JSON(http.StatusOK, c.Request.Header) + }) + ts := httptest.NewServer(r) + u, _ := url.Parse(ts.URL) + + be, err := NewTransparentBackend(zap.NewNop(), u, []string{}) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + _, err = be.Run(ctx) + require.NoError(t, err) + + checkCh := make(chan struct{}) + go func() { + <-ctx.Done() + close(checkCh) + }() + + timeout := time.After(10 * time.Millisecond) + select { + case <-checkCh: + t.Error("Test completed too early") + case <-timeout: + // Test timed out + } + + cancel() + + timeout = time.After(10 * time.Second) + select { + case <-checkCh: + // Test completed successfully + case <-timeout: + t.Error("Test timed out") + } +} diff --git a/pkg/mcp-proxy/main.go b/pkg/mcp-proxy/main.go index 086fd40..73507d3 100644 --- a/pkg/mcp-proxy/main.go +++ b/pkg/mcp-proxy/main.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net/http" - "net/http/httputil" "net/url" "os" "os/signal" @@ -59,6 +58,7 @@ func Run( oidcAllowedUsers []string, password string, passwordHash string, + trustedProxy []string, proxyHeaders []string, proxyBearerToken string, proxyTarget []string, @@ -98,10 +98,18 @@ func Run( if len(proxyTarget) == 0 { return fmt.Errorf("proxy target must be specified") } - var be *backend.ProxyBackend + var be backend.Backend var beHandler http.Handler if proxyURL, err := url.Parse(proxyTarget[0]); err == nil && (proxyURL.Scheme == "http" || proxyURL.Scheme == "https") { - beHandler = httputil.NewSingleHostReverseProxy(proxyURL) + var err error + be, err = backend.NewTransparentBackend(logger, proxyURL, trustedProxy) + if err != nil { + return fmt.Errorf("failed to create transparent backend: %w", err) + } + beHandler, err = be.Run(ctx) + if err != nil { + return fmt.Errorf("failed to create transparent backend: %w", err) + } } else { be = backend.NewProxyBackend(logger, proxyTarget) beHandler, err = be.Run(ctx) @@ -205,6 +213,7 @@ func Run( } router := gin.New() + // router.SetTrustedProxies(trustedProxy) router.Use(ginzap.Ginzap(logger, time.RFC3339, true)) router.Use(ginzap.RecoveryWithZap(logger, true)) From fee67c582651030c9c6b38445f607ec0609f8588 Mon Sep 17 00:00:00 2001 From: Takanori Hirano Date: Sun, 24 Aug 2025 17:13:19 +0000 Subject: [PATCH 2/2] fix: enable trusted proxy configuration Uncommented the SetTrustedProxies call to properly configure trusted proxies --- pkg/mcp-proxy/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/mcp-proxy/main.go b/pkg/mcp-proxy/main.go index 73507d3..13b04b9 100644 --- a/pkg/mcp-proxy/main.go +++ b/pkg/mcp-proxy/main.go @@ -213,7 +213,7 @@ func Run( } router := gin.New() - // router.SetTrustedProxies(trustedProxy) + router.SetTrustedProxies(trustedProxy) router.Use(ginzap.Ginzap(logger, time.RFC3339, true)) router.Use(ginzap.RecoveryWithZap(logger, true))