From 177bbbc5d3b9c40313730c34c611b14fec952016 Mon Sep 17 00:00:00 2001 From: Andi Sardina Ramos Date: Sun, 3 Mar 2024 15:05:36 +0200 Subject: [PATCH] fix(xforwarder): Set remote proxy IP to HTTP request context when PROXY Protocol is enabled --- .../forwardedheaders/forwarded_header.go | 40 +++++++++++++------ .../forwardedheaders/forwarded_header_test.go | 2 +- pkg/server/server_entrypoint_tcp.go | 25 ++++++++++++ 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/pkg/middlewares/forwardedheaders/forwarded_header.go b/pkg/middlewares/forwardedheaders/forwarded_header.go index 19881ad748..a3b3cfdbcd 100644 --- a/pkg/middlewares/forwardedheaders/forwarded_header.go +++ b/pkg/middlewares/forwardedheaders/forwarded_header.go @@ -37,20 +37,27 @@ var xHeaders = []string{ xRealIP, } +type contextKey struct { + name string +} + +var ProxyAddrKey = &contextKey{"proxy-ip"} + // XForwarded is an HTTP handler wrapper that sets the X-Forwarded headers, // and other relevant headers for a reverse-proxy. // Unless insecure is set, // it first removes all the existing values for those headers if the remote address is not one of the trusted ones. type XForwarded struct { - insecure bool - trustedIps []string - ipChecker *ip.Checker - next http.Handler - hostname string + insecure bool + proxyProtocolEnabled bool + trustedIps []string + ipChecker *ip.Checker + next http.Handler + hostname string } // NewXForwarded creates a new XForwarded. -func NewXForwarded(insecure bool, trustedIps []string, next http.Handler) (*XForwarded, error) { +func NewXForwarded(insecure bool, proxyProtocolEnabled bool, trustedIps []string, next http.Handler) (*XForwarded, error) { var ipChecker *ip.Checker if len(trustedIps) > 0 { var err error @@ -66,11 +73,12 @@ func NewXForwarded(insecure bool, trustedIps []string, next http.Handler) (*XFor } return &XForwarded{ - insecure: insecure, - trustedIps: trustedIps, - ipChecker: ipChecker, - next: next, - hostname: hostname, + insecure: insecure, + proxyProtocolEnabled: proxyProtocolEnabled, + trustedIps: trustedIps, + ipChecker: ipChecker, + next: next, + hostname: hostname, }, nil } @@ -181,7 +189,15 @@ func (x *XForwarded) rewrite(outreq *http.Request) { // ServeHTTP implements http.Handler. func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if !x.insecure && !x.isTrustedIP(r.RemoteAddr) { + remoteAddr := r.RemoteAddr + + if x.proxyProtocolEnabled { + if proxyAddr, ok := r.Context().Value(ProxyAddrKey).(string); ok { + remoteAddr = proxyAddr + } + } + + if !x.insecure && !x.isTrustedIP(remoteAddr) { for _, h := range xHeaders { unsafeHeader(r.Header).Del(h) } diff --git a/pkg/middlewares/forwardedheaders/forwarded_header_test.go b/pkg/middlewares/forwardedheaders/forwarded_header_test.go index 8e1d109253..bc50c9d055 100644 --- a/pkg/middlewares/forwardedheaders/forwarded_header_test.go +++ b/pkg/middlewares/forwardedheaders/forwarded_header_test.go @@ -299,7 +299,7 @@ func TestServeHTTP(t *testing.T) { } } - m, err := NewXForwarded(test.insecure, test.trustedIps, + m, err := NewXForwarded(test.insecure, false, test.trustedIps, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) require.NoError(t, err) diff --git a/pkg/server/server_entrypoint_tcp.go b/pkg/server/server_entrypoint_tcp.go index 1fb371d081..ea4f839733 100644 --- a/pkg/server/server_entrypoint_tcp.go +++ b/pkg/server/server_entrypoint_tcp.go @@ -558,9 +558,11 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati return nil, err } + proxyProtocolEnabled := configuration.ProxyProtocol != nil var handler http.Handler handler, err = forwardedheaders.NewXForwarded( configuration.ForwardedHeaders.Insecure, + proxyProtocolEnabled, configuration.ForwardedHeaders.TrustedIPs, next) if err != nil { @@ -616,6 +618,29 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati prevConnContext := serverHTTP.ConnContext serverHTTP.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + // This adds the remote address of the server making the connection if the PROXY Protocol is enabled + getProxyIP := func(c net.Conn) string { + for { + switch conn := c.(type) { + case *tcprouter.Conn: + c = conn.WriteCloser + case *trackedConnection: + c = conn.WriteCloser + case *writeCloserWrapper: + c = conn.writeCloser + case *net.TCPConn: + return c.RemoteAddr().String() + default: + return conn.RemoteAddr().String() + } + } + } + + if proxyProtocolEnabled { + proxyIP := getProxyIP(c) + ctx = context.WithValue(ctx, forwardedheaders.ProxyAddrKey, proxyIP) + } + // This adds an empty struct in order to store a RoundTripper in the ConnContext in case of Kerberos or NTLM. ctx = service.AddTransportOnContext(ctx) if prevConnContext != nil {