Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Marius Neugebauer
committed
Aug 27, 2020
1 parent
8cf3393
commit b8dbd8d
Showing
5 changed files
with
174 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
// Copyright © 2020 by PACE Telematics GmbH. All rights reserved. | ||
// Created at 2020/08/27 by Marius Neugebauer | ||
|
||
package http | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net" | ||
"net/http" | ||
) | ||
|
||
// RequestInContextMiddleware stores a representation of the request in the | ||
// context of said request. Some information of that request can then be | ||
// accessed through the context using functions of this package. | ||
func RequestInContextMiddleware(next http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
ctxReq := ctxRequest{ | ||
RemoteAddr: r.RemoteAddr, | ||
XForwardedFor: r.Header.Get("X-Forwarded-For"), | ||
} | ||
r = r.WithContext(contextWithRequest(r.Context(), &ctxReq)) | ||
next.ServeHTTP(w, r) | ||
}) | ||
} | ||
|
||
// ContextTransfer copies a request representation from one context to another. | ||
func ContextTransfer(ctx, targetCtx context.Context) context.Context { | ||
if r := requestFromContext(ctx); r != nil { | ||
return contextWithRequest(targetCtx, r) | ||
} | ||
return targetCtx | ||
} | ||
|
||
type ctxRequest struct { | ||
RemoteAddr string // requester IP:port | ||
XForwardedFor string // X-Forwarded-For header | ||
} | ||
|
||
func contextWithRequest(ctx context.Context, ctxReq *ctxRequest) context.Context { | ||
return context.WithValue(ctx, (*ctxRequest)(nil), ctxReq) | ||
} | ||
|
||
func requestFromContext(ctx context.Context) *ctxRequest { | ||
if v := ctx.Value((*ctxRequest)(nil)); v != nil { | ||
return v.(*ctxRequest) | ||
} | ||
return nil | ||
} | ||
|
||
// GetXForwardedForHeaderFromContext returns the X-Forwarded-For header value | ||
// that would express that you forwarded the request that is stored in the | ||
// context. | ||
// | ||
// If the remote address of the request is 12.34.56.78:9999 then the value is | ||
// that remote ip without the port. If the request already includes this header, | ||
// the remote ip is appended to the value of that header. For example if the | ||
// request on top of the remote ip also includes the header "X-Forwarded-For: | ||
// 100.100.100.100" then the resulting value is "100.100.100.100, 12.34.56.78". | ||
// | ||
// Returns ErrNotFound if the context does not have a request. Returns | ||
// ErrInvalidRequest if the request in the context is malformed, for example | ||
// because it does not have a remote address, which should never happen. | ||
func GetXForwardedForHeaderFromContext(ctx context.Context) (string, error) { | ||
ctxReq := requestFromContext(ctx) | ||
if ctxReq == nil { | ||
return "", fmt.Errorf("getting request from context: %w", ErrNotFound) | ||
} | ||
xForwardedFor := ctxReq.XForwardedFor | ||
ip, _, err := net.SplitHostPort(ctxReq.RemoteAddr) | ||
if err != nil { | ||
return "", fmt.Errorf( | ||
"%w (from context): could not get ip from remote address: %s", | ||
ErrInvalidRequest, err) | ||
} | ||
if ip == "" { | ||
return "", fmt.Errorf( | ||
"%w (from context): could not get ip from remote address: %q", | ||
ErrInvalidRequest, ctxReq.RemoteAddr) | ||
} | ||
if xForwardedFor != "" { | ||
xForwardedFor += ", " | ||
} | ||
return xForwardedFor + ip, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
// Copyright © 2020 by PACE Telematics GmbH. All rights reserved. | ||
// Created at 2020/08/27 by Marius Neugebauer | ||
|
||
package http_test | ||
|
||
import ( | ||
"errors" | ||
"net/http" | ||
"testing" | ||
|
||
. "github.com/pace/bricks/http" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestGetXForwardedForHeaderFromContext(t *testing.T) { | ||
cases := map[string]struct { | ||
RemoteAddr string // input IP:port | ||
XForwardedFor string // input X-Forwarded-For header | ||
ExpectErr error | ||
ExpectXForwardedFor string // output X-Forwarded-For header | ||
}{ | ||
"direct request": { | ||
RemoteAddr: "12.34.56.78:9999", | ||
ExpectXForwardedFor: "12.34.56.78", | ||
}, | ||
"behind a proxy": { | ||
RemoteAddr: "12.34.56.78:9999", | ||
XForwardedFor: "100.100.100.100", | ||
ExpectXForwardedFor: "100.100.100.100, 12.34.56.78", | ||
}, | ||
"behind multiple proxies": { | ||
RemoteAddr: "4.4.4.4:1234", | ||
XForwardedFor: "1.1.1.1, 2.2.2.2, 3.3.3.3", | ||
ExpectXForwardedFor: "1.1.1.1, 2.2.2.2, 3.3.3.3, 4.4.4.4", | ||
}, | ||
"ipv6": { | ||
RemoteAddr: "[d953:7242:7970:566c:ee3a:0581:36cd:4fd6]:1234", | ||
XForwardedFor: "7342:57fb:4188:fd49:1eed:644f:22d6:69a2", | ||
ExpectXForwardedFor: "7342:57fb:4188:fd49:1eed:644f:22d6:69a2, d953:7242:7970:566c:ee3a:0581:36cd:4fd6", | ||
}, | ||
"missing remote address": { | ||
RemoteAddr: "", | ||
ExpectErr: ErrInvalidRequest, | ||
}, | ||
"broken remote address": { | ||
RemoteAddr: "1234567890ß", | ||
ExpectErr: ErrInvalidRequest, | ||
}, | ||
} | ||
for name, c := range cases { | ||
t.Run(name, func(t *testing.T) { | ||
r, err := http.NewRequest("GET", "http://example.com/", nil) | ||
require.NoError(t, err) | ||
r.RemoteAddr = c.RemoteAddr | ||
if c.XForwardedFor != "" { | ||
r.Header.Set("X-Forwarded-For", c.XForwardedFor) | ||
} | ||
RequestInContextMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { | ||
ctx := r.Context() | ||
xForwardedFor, err := GetXForwardedForHeaderFromContext(ctx) | ||
if c.ExpectErr != nil { | ||
assert.True(t, errors.Is(err, c.ExpectErr)) | ||
} else { | ||
assert.NoError(t, err) | ||
} | ||
assert.Equal(t, c.ExpectXForwardedFor, xForwardedFor) | ||
})).ServeHTTP(nil, r) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// Copyright © 2020 by PACE Telematics GmbH. All rights reserved. | ||
// Created at 2020/08/27 by Marius Neugebauer | ||
|
||
package http | ||
|
||
import "errors" | ||
|
||
// All exported package errors. | ||
var ( | ||
ErrNotFound = errors.New("not found") | ||
ErrInvalidRequest = errors.New("request is invalid") | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters