Skip to content

Commit

Permalink
fix(xforwarder): Set remote proxy IP to HTTP request context when PRO…
Browse files Browse the repository at this point in the history
…XY Protocol is enabled
  • Loading branch information
Andi Sardina Ramos authored and andsarr committed Apr 20, 2024
1 parent e3729ec commit 177bbbc
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 13 deletions.
40 changes: 28 additions & 12 deletions pkg/middlewares/forwardedheaders/forwarded_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,27 @@ var xHeaders = []string{
xRealIP,
}

type contextKey struct {
name string
}

var ProxyAddrKey = &contextKey{"proxy-ip"}

// XForwarded is an HTTP handler wrapper that sets the X-Forwarded headers,
// and other relevant headers for a reverse-proxy.
// Unless insecure is set,
// it first removes all the existing values for those headers if the remote address is not one of the trusted ones.
type XForwarded struct {
insecure bool
trustedIps []string
ipChecker *ip.Checker
next http.Handler
hostname string
insecure bool
proxyProtocolEnabled bool
trustedIps []string
ipChecker *ip.Checker
next http.Handler
hostname string
}

// NewXForwarded creates a new XForwarded.
func NewXForwarded(insecure bool, trustedIps []string, next http.Handler) (*XForwarded, error) {
func NewXForwarded(insecure bool, proxyProtocolEnabled bool, trustedIps []string, next http.Handler) (*XForwarded, error) {
var ipChecker *ip.Checker
if len(trustedIps) > 0 {
var err error
Expand All @@ -66,11 +73,12 @@ func NewXForwarded(insecure bool, trustedIps []string, next http.Handler) (*XFor
}

return &XForwarded{
insecure: insecure,
trustedIps: trustedIps,
ipChecker: ipChecker,
next: next,
hostname: hostname,
insecure: insecure,
proxyProtocolEnabled: proxyProtocolEnabled,
trustedIps: trustedIps,
ipChecker: ipChecker,
next: next,
hostname: hostname,
}, nil
}

Expand Down Expand Up @@ -181,7 +189,15 @@ func (x *XForwarded) rewrite(outreq *http.Request) {

// ServeHTTP implements http.Handler.
func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !x.insecure && !x.isTrustedIP(r.RemoteAddr) {
remoteAddr := r.RemoteAddr

if x.proxyProtocolEnabled {
if proxyAddr, ok := r.Context().Value(ProxyAddrKey).(string); ok {
remoteAddr = proxyAddr
}
}

if !x.insecure && !x.isTrustedIP(remoteAddr) {
for _, h := range xHeaders {
unsafeHeader(r.Header).Del(h)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/middlewares/forwardedheaders/forwarded_header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ func TestServeHTTP(t *testing.T) {
}
}

m, err := NewXForwarded(test.insecure, test.trustedIps,
m, err := NewXForwarded(test.insecure, false, test.trustedIps,
http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
require.NoError(t, err)

Expand Down
25 changes: 25 additions & 0 deletions pkg/server/server_entrypoint_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,11 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati
return nil, err
}

proxyProtocolEnabled := configuration.ProxyProtocol != nil
var handler http.Handler
handler, err = forwardedheaders.NewXForwarded(
configuration.ForwardedHeaders.Insecure,
proxyProtocolEnabled,
configuration.ForwardedHeaders.TrustedIPs,
next)
if err != nil {
Expand Down Expand Up @@ -616,6 +618,29 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati

prevConnContext := serverHTTP.ConnContext
serverHTTP.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
// This adds the remote address of the server making the connection if the PROXY Protocol is enabled
getProxyIP := func(c net.Conn) string {
for {
switch conn := c.(type) {
case *tcprouter.Conn:
c = conn.WriteCloser
case *trackedConnection:
c = conn.WriteCloser
case *writeCloserWrapper:
c = conn.writeCloser
case *net.TCPConn:
return c.RemoteAddr().String()
default:
return conn.RemoteAddr().String()
}
}
}

if proxyProtocolEnabled {
proxyIP := getProxyIP(c)
ctx = context.WithValue(ctx, forwardedheaders.ProxyAddrKey, proxyIP)
}

// This adds an empty struct in order to store a RoundTripper in the ConnContext in case of Kerberos or NTLM.
ctx = service.AddTransportOnContext(ctx)
if prevConnContext != nil {
Expand Down

0 comments on commit 177bbbc

Please sign in to comment.