Skip to content

Commit

Permalink
Use X-Request-Id as canonical request identifier (if available)
Browse files Browse the repository at this point in the history
If `X-Request-Id` is available in an HTTP request made against the
CA server, it'll be used as the identifier for the request. This
slightly changes the existing behavior, which relied on the custom
`X-Smallstep-Id` header, but usage of that header is currently not
very widespread, and `X-Request-Id` is more generally known for
the use case `X-Smallstep-Id` is used for.

`X-Smallstep-Id` is currently still considered, but it'll only be
used if `X-Request-Id` is not set.
  • Loading branch information
hslatman committed Feb 27, 2024
1 parent 041b486 commit 4213a19
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 5 deletions.
23 changes: 18 additions & 5 deletions logging/context.go
Expand Up @@ -21,14 +21,27 @@ func NewRequestID() string {
return xid.New().String()
}

// RequestID returns a new middleware that gets the given header and sets it
// in the context so it can be written in the logger. If the header does not
// exists or it's the empty string, it uses github.com/rs/xid to create a new
// one.
// defaultRequestIDHeader is the header name used for propagating
// request IDs. If available in an HTTP request, it'll be used instead
// of the X-Smallstep-Id header.
const defaultRequestIDHeader = "X-Request-Id"

// RequestID returns a new middleware that obtains the current request ID
// and sets it in the context. It first tries to read the request ID from
// the "X-Request-Id" header. If that's not set, it tries to read it from
// the provided header name. If the header does not exist or its value is
// the empty string, it uses github.com/rs/xid to create a new one.
func RequestID(headerName string) func(next http.Handler) http.Handler {
if headerName == "" {
headerName = defaultTraceIDHeader
}
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, req *http.Request) {
requestID := req.Header.Get(headerName)
requestID := req.Header.Get(defaultRequestIDHeader)
if requestID == "" {
requestID = req.Header.Get(headerName)
}

if requestID == "" {
requestID = NewRequestID()
req.Header.Set(headerName, requestID)
Expand Down
94 changes: 94 additions & 0 deletions logging/context_test.go
@@ -0,0 +1,94 @@
package logging

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func newRequest(t *testing.T) *http.Request {
r, err := http.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
require.NoError(t, err)
return r
}

func TestRequestID(t *testing.T) {
requestWithID := newRequest(t)
requestWithID.Header.Set("X-Request-Id", "reqID")
requestWithoutID := newRequest(t)
requestWithEmptyHeader := newRequest(t)
requestWithEmptyHeader.Header.Set("X-Request-Id", "")
requestWithSmallstepID := newRequest(t)
requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID")

tests := []struct {
name string
headerName string
handler http.HandlerFunc
req *http.Request
}{
{
name: "default-request-id",
headerName: defaultTraceIDHeader,
handler: func(_ http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
assert.Equal(t, "reqID", r.Header.Get("X-Request-Id"))
reqID, ok := GetRequestID(r.Context())
if assert.True(t, ok) {
assert.Equal(t, "reqID", reqID)
}
},
req: requestWithID,
},
{
name: "no-request-id",
headerName: "X-Request-Id",
handler: func(_ http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
value := r.Header.Get("X-Request-Id")
assert.NotEmpty(t, value)
reqID, ok := GetRequestID(r.Context())
if assert.True(t, ok) {
assert.Equal(t, value, reqID)
}
},
req: requestWithoutID,
},
{
name: "empty-header-name",
headerName: "",
handler: func(_ http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Request-Id"))
value := r.Header.Get("X-Smallstep-Id")
assert.NotEmpty(t, value)
reqID, ok := GetRequestID(r.Context())
if assert.True(t, ok) {
assert.Equal(t, value, reqID)
}
},
req: requestWithEmptyHeader,
},
{
name: "fallback-header-name",
headerName: defaultTraceIDHeader,
handler: func(_ http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Request-Id"))
assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id"))
reqID, ok := GetRequestID(r.Context())
if assert.True(t, ok) {
assert.Equal(t, "smallstepID", reqID)
}
},
req: requestWithSmallstepID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := RequestID(tt.headerName)
h(tt.handler).ServeHTTP(httptest.NewRecorder(), tt.req)
})
}
}

0 comments on commit 4213a19

Please sign in to comment.