Skip to content

Commit

Permalink
save request in context
Browse files Browse the repository at this point in the history
  • Loading branch information
Marius Neugebauer committed Aug 27, 2020
1 parent 8cf3393 commit b8dbd8d
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 1 deletion.
85 changes: 85 additions & 0 deletions http/context.go
@@ -0,0 +1,85 @@
// Copyright © 2020 by PACE Telematics GmbH. All rights reserved.
// Created at 2020/08/27 by Marius Neugebauer

package http

import (
"context"
"fmt"
"net"
"net/http"
)

// RequestInContextMiddleware stores a representation of the request in the
// context of said request. Some information of that request can then be
// accessed through the context using functions of this package.
func RequestInContextMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctxReq := ctxRequest{
RemoteAddr: r.RemoteAddr,
XForwardedFor: r.Header.Get("X-Forwarded-For"),
}
r = r.WithContext(contextWithRequest(r.Context(), &ctxReq))
next.ServeHTTP(w, r)
})
}

// ContextTransfer copies a request representation from one context to another.
func ContextTransfer(ctx, targetCtx context.Context) context.Context {
if r := requestFromContext(ctx); r != nil {
return contextWithRequest(targetCtx, r)
}
return targetCtx
}

type ctxRequest struct {
RemoteAddr string // requester IP:port
XForwardedFor string // X-Forwarded-For header
}

func contextWithRequest(ctx context.Context, ctxReq *ctxRequest) context.Context {
return context.WithValue(ctx, (*ctxRequest)(nil), ctxReq)
}

func requestFromContext(ctx context.Context) *ctxRequest {
if v := ctx.Value((*ctxRequest)(nil)); v != nil {
return v.(*ctxRequest)
}
return nil
}

// GetXForwardedForHeaderFromContext returns the X-Forwarded-For header value
// that would express that you forwarded the request that is stored in the
// context.
//
// If the remote address of the request is 12.34.56.78:9999 then the value is
// that remote ip without the port. If the request already includes this header,
// the remote ip is appended to the value of that header. For example if the
// request on top of the remote ip also includes the header "X-Forwarded-For:
// 100.100.100.100" then the resulting value is "100.100.100.100, 12.34.56.78".
//
// Returns ErrNotFound if the context does not have a request. Returns
// ErrInvalidRequest if the request in the context is malformed, for example
// because it does not have a remote address, which should never happen.
func GetXForwardedForHeaderFromContext(ctx context.Context) (string, error) {
ctxReq := requestFromContext(ctx)
if ctxReq == nil {
return "", fmt.Errorf("getting request from context: %w", ErrNotFound)
}
xForwardedFor := ctxReq.XForwardedFor
ip, _, err := net.SplitHostPort(ctxReq.RemoteAddr)
if err != nil {
return "", fmt.Errorf(
"%w (from context): could not get ip from remote address: %s",
ErrInvalidRequest, err)
}
if ip == "" {
return "", fmt.Errorf(
"%w (from context): could not get ip from remote address: %q",
ErrInvalidRequest, ctxReq.RemoteAddr)
}
if xForwardedFor != "" {
xForwardedFor += ", "
}
return xForwardedFor + ip, nil
}
71 changes: 71 additions & 0 deletions http/context_test.go
@@ -0,0 +1,71 @@
// Copyright © 2020 by PACE Telematics GmbH. All rights reserved.
// Created at 2020/08/27 by Marius Neugebauer

package http_test

import (
"errors"
"net/http"
"testing"

. "github.com/pace/bricks/http"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGetXForwardedForHeaderFromContext(t *testing.T) {
cases := map[string]struct {
RemoteAddr string // input IP:port
XForwardedFor string // input X-Forwarded-For header
ExpectErr error
ExpectXForwardedFor string // output X-Forwarded-For header
}{
"direct request": {
RemoteAddr: "12.34.56.78:9999",
ExpectXForwardedFor: "12.34.56.78",
},
"behind a proxy": {
RemoteAddr: "12.34.56.78:9999",
XForwardedFor: "100.100.100.100",
ExpectXForwardedFor: "100.100.100.100, 12.34.56.78",
},
"behind multiple proxies": {
RemoteAddr: "4.4.4.4:1234",
XForwardedFor: "1.1.1.1, 2.2.2.2, 3.3.3.3",
ExpectXForwardedFor: "1.1.1.1, 2.2.2.2, 3.3.3.3, 4.4.4.4",
},
"ipv6": {
RemoteAddr: "[d953:7242:7970:566c:ee3a:0581:36cd:4fd6]:1234",
XForwardedFor: "7342:57fb:4188:fd49:1eed:644f:22d6:69a2",
ExpectXForwardedFor: "7342:57fb:4188:fd49:1eed:644f:22d6:69a2, d953:7242:7970:566c:ee3a:0581:36cd:4fd6",
},
"missing remote address": {
RemoteAddr: "",
ExpectErr: ErrInvalidRequest,
},
"broken remote address": {
RemoteAddr: "1234567890ß",
ExpectErr: ErrInvalidRequest,
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
r, err := http.NewRequest("GET", "http://example.com/", nil)
require.NoError(t, err)
r.RemoteAddr = c.RemoteAddr
if c.XForwardedFor != "" {
r.Header.Set("X-Forwarded-For", c.XForwardedFor)
}
RequestInContextMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
ctx := r.Context()
xForwardedFor, err := GetXForwardedForHeaderFromContext(ctx)
if c.ExpectErr != nil {
assert.True(t, errors.Is(err, c.ExpectErr))
} else {
assert.NoError(t, err)
}
assert.Equal(t, c.ExpectXForwardedFor, xForwardedFor)
})).ServeHTTP(nil, r)
})
}
}
12 changes: 12 additions & 0 deletions http/errors.go
@@ -0,0 +1,12 @@
// Copyright © 2020 by PACE Telematics GmbH. All rights reserved.
// Created at 2020/08/27 by Marius Neugebauer

package http

import "errors"

// All exported package errors.
var (
ErrNotFound = errors.New("not found")
ErrInvalidRequest = errors.New("request is invalid")
)
3 changes: 3 additions & 0 deletions http/router.go
Expand Up @@ -42,6 +42,9 @@ func Router() *mux.Router {

r.Use(locale.Handler())

// makes some infos about the request accessable from the context
r.Use(RequestInContextMiddleware)

// for prometheus
r.Handle("/metrics", metric.Handler())

Expand Down
4 changes: 3 additions & 1 deletion pkg/context/transfer.go
Expand Up @@ -3,13 +3,14 @@ package context
import (
"context"

"github.com/pace/bricks/http"
"github.com/pace/bricks/http/oauth2"
"github.com/pace/bricks/locale"
"github.com/pace/bricks/maintenance/errors"
"github.com/pace/bricks/maintenance/log"
)

// Transfer takes the logger, log.Sink, authentication and
// Transfer takes the logger, log.Sink, authentication, request and
// error info from the given context and returns a complete
// new context with all these objects.
func Transfer(in context.Context) context.Context {
Expand All @@ -18,5 +19,6 @@ func Transfer(in context.Context) context.Context {
out = log.SinkContextTransfer(in, out)
out = oauth2.ContextTransfer(in, out)
out = errors.ContextTransfer(in, out)
out = http.ContextTransfer(in, out)
return locale.ContextTransfer(in, out)
}

0 comments on commit b8dbd8d

Please sign in to comment.