Skip to content

Commit

Permalink
Merge pull request #225 from pace/user-agent-in-context
Browse files Browse the repository at this point in the history
store User-Agent header in context
  • Loading branch information
Marius Neugebauer committed Sep 11, 2020
2 parents 54fb72f + 86bc63e commit c2d4de4
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
13 changes: 13 additions & 0 deletions http/context.go
Expand Up @@ -18,6 +18,7 @@ func RequestInContextMiddleware(next http.Handler) http.Handler {
ctxReq := ctxRequest{
RemoteAddr: r.RemoteAddr,
XForwardedFor: r.Header.Get("X-Forwarded-For"),
UserAgent: r.Header.Get("User-Agent"),
}
r = r.WithContext(contextWithRequest(r.Context(), &ctxReq))
next.ServeHTTP(w, r)
Expand All @@ -35,6 +36,7 @@ func ContextTransfer(ctx, targetCtx context.Context) context.Context {
type ctxRequest struct {
RemoteAddr string // requester IP:port
XForwardedFor string // X-Forwarded-For header
UserAgent string // User-Agent header
}

func contextWithRequest(ctx context.Context, ctxReq *ctxRequest) context.Context {
Expand Down Expand Up @@ -83,3 +85,14 @@ func GetXForwardedForHeaderFromContext(ctx context.Context) (string, error) {
}
return xForwardedFor + ip, nil
}

// GetUserAgentFromContext returns the User-Agent header value from the request
// that is stored in the context. Returns ErrNotFound if the context does not
// have a request.
func GetUserAgentFromContext(ctx context.Context) (string, error) {
ctxReq := requestFromContext(ctx)
if ctxReq == nil {
return "", fmt.Errorf("getting request from context: %w", ErrNotFound)
}
return ctxReq.UserAgent, nil
}
45 changes: 45 additions & 0 deletions http/context_test.go
Expand Up @@ -4,6 +4,7 @@
package http_test

import (
"context"
"errors"
"net/http"
"testing"
Expand All @@ -13,6 +14,24 @@ import (
"github.com/stretchr/testify/require"
)

func TestContextTransfer(t *testing.T) {
r, err := http.NewRequest("GET", "http://example.com/", nil)
require.NoError(t, err)
r.Header.Set("User-Agent", "Foobar")
RequestInContextMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
ctx := ContextTransfer(r.Context(), context.Background())
userAgent, err := GetUserAgentFromContext(ctx)
assert.NoError(t, err)
assert.Equal(t, "Foobar", userAgent)
})).ServeHTTP(nil, r)

// without request
ctx := ContextTransfer(context.Background(), context.Background())
userAgent, err := GetUserAgentFromContext(ctx)
assert.True(t, errors.Is(err, ErrNotFound), err)
assert.Empty(t, userAgent)
}

func TestGetXForwardedForHeaderFromContext(t *testing.T) {
cases := map[string]struct {
RemoteAddr string // input IP:port
Expand Down Expand Up @@ -43,6 +62,10 @@ func TestGetXForwardedForHeaderFromContext(t *testing.T) {
RemoteAddr: "",
ExpectErr: ErrInvalidRequest,
},
"missing remote ip": {
RemoteAddr: ":80",
ExpectErr: ErrInvalidRequest,
},
"broken remote address": {
RemoteAddr: "1234567890ß",
ExpectErr: ErrInvalidRequest,
Expand All @@ -68,4 +91,26 @@ func TestGetXForwardedForHeaderFromContext(t *testing.T) {
})).ServeHTTP(nil, r)
})
}

// no request in context
xForwardedFor, err := GetXForwardedForHeaderFromContext(context.Background())
assert.True(t, errors.Is(err, ErrNotFound), err)
assert.Empty(t, xForwardedFor)
}

func TestGetUserAgentFromContext(t *testing.T) {
r, err := http.NewRequest("GET", "http://example.com/", nil)
require.NoError(t, err)
r.Header.Set("User-Agent", "Foobar")
RequestInContextMiddleware(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAgent, err := GetUserAgentFromContext(ctx)
assert.NoError(t, err)
assert.Equal(t, "Foobar", userAgent)
})).ServeHTTP(nil, r)

// no request in context
userAgent, err := GetUserAgentFromContext(context.Background())
assert.True(t, errors.Is(err, ErrNotFound), err)
assert.Empty(t, userAgent)
}

0 comments on commit c2d4de4

Please sign in to comment.