Skip to content

Commit

Permalink
Decouple request ID middleware from logging middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
hslatman committed Feb 28, 2024
1 parent 535e2a9 commit 7e5f109
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 124 deletions.
7 changes: 3 additions & 4 deletions authority/provisioner/webhook.go
Expand Up @@ -15,7 +15,7 @@ import (
"time"

"github.com/pkg/errors"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/templates"
"github.com/smallstep/certificates/webhook"
"go.step.sm/linkedca"
Expand Down Expand Up @@ -171,9 +171,8 @@ retry:
return nil, err
}

requestID, ok := logging.GetRequestID(ctx)
if ok {
req.Header.Set("X-Request-ID", requestID)
if requestID, ok := requestid.FromContext(ctx); ok {
req.Header.Set("X-Request-Id", requestID)
}

secret, err := base64.StdEncoding.DecodeString(w.Secret)
Expand Down
8 changes: 4 additions & 4 deletions authority/provisioner/webhook_test.go
Expand Up @@ -17,7 +17,7 @@ import (
"testing"
"time"

"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/webhook"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -101,10 +101,10 @@ func TestWebhookController_isCertTypeOK(t *testing.T) {
}
}

// withRequestID is a helper that calls into [logging.WithRequestID] and returns
// a new context with the requestID added to the provided context.
// withRequestID is a helper that calls into [requestid.NewContext] and returns
// a new context with the requestID added.
func withRequestID(ctx context.Context, requestID string) context.Context {
return logging.WithRequestID(ctx, requestID)
return requestid.NewContext(ctx, requestID)
}

func TestWebhookController_Enrich(t *testing.T) {
Expand Down
7 changes: 7 additions & 0 deletions ca/ca.go
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/internal/metrix"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/monitoring"
"github.com/smallstep/certificates/scep"
Expand Down Expand Up @@ -329,15 +330,21 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
}

// Add logger if configured
var legacyTraceHeader string
if len(cfg.Logger) > 0 {
logger, err := logging.New("ca", cfg.Logger)
if err != nil {
return nil, err
}
legacyTraceHeader = logger.GetTraceHeader()
handler = logger.Middleware(handler)
insecureHandler = logger.Middleware(insecureHandler)
}

// always use request ID middleware; traceHeader is provided for backwards compatibility (for now)
handler = requestid.New(legacyTraceHeader).Middleware(handler)
insecureHandler = requestid.New(legacyTraceHeader).Middleware(insecureHandler)

// Create context with all the necessary values.
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker)

Expand Down
27 changes: 15 additions & 12 deletions errs/errors_test.go
Expand Up @@ -2,8 +2,9 @@ package errs

import (
"fmt"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
)

func TestError_MarshalJSON(t *testing.T) {
Expand All @@ -27,13 +28,14 @@ func TestError_MarshalJSON(t *testing.T) {
Err: tt.fields.Err,
}
got, err := e.MarshalJSON()
if (err != nil) != tt.wantErr {
t.Errorf("Error.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
if tt.wantErr {
assert.Error(t, err)
assert.Empty(t, got)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Error.MarshalJSON() = %s, want %s", got, tt.want)
}

assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
Expand All @@ -54,13 +56,14 @@ func TestError_UnmarshalJSON(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := new(Error)
if err := e.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("Error.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
}
//nolint:govet // best option
if !reflect.DeepEqual(tt.expected, e) {
t.Errorf("Error.UnmarshalJSON() wants = %+v, got %+v", tt.expected, e)
err := e.UnmarshalJSON(tt.args.data)
if tt.wantErr {
assert.Error(t, err)
return
}

assert.NoError(t, err)
assert.Equal(t, tt.expected, e)
})
}
}
82 changes: 82 additions & 0 deletions internal/requestid/requestid.go
@@ -0,0 +1,82 @@
package requestid

import (
"context"
"net/http"

"github.com/rs/xid"
)

const (
// requestIDHeader is the header name used for propagating request IDs. If
// available in an HTTP request, it'll be used instead of the X-Smallstep-Id
// header. It'll always be used in response and set to the request ID.
requestIDHeader = "X-Request-Id"

// defaultTraceHeader is the default Smallstep tracing header that's currently
// in use. It is used as a fallback to retrieve a request ID from, if the
// "X-Request-Id" request header is not set.
defaultTraceHeader = "X-Smallstep-Id"
)

type Handler struct {
legacyTraceHeader string
}

// New creates a new request ID [handler]. It takes a trace header,
// which is used keep the legacy behavior intact, which relies on the
// X-Smallstep-Id header instead of X-Request-Id.
func New(legacyTraceHeader string) *Handler {
if legacyTraceHeader == "" {
legacyTraceHeader = defaultTraceHeader
}

return &Handler{legacyTraceHeader: legacyTraceHeader}
}

// Middleware wraps an [http.Handler] with request ID extraction
// from the X-Reqeust-Id header by default, or from the X-Smallstep-Id
// header if not set. If both are not set, a new request ID is generated.
// In all cases, the request ID is added to the request context, and
// set to be reflected in the response.
func (h *Handler) Middleware(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, req *http.Request) {
requestID := req.Header.Get(requestIDHeader)
if requestID == "" {
requestID = req.Header.Get(h.legacyTraceHeader)
}

if requestID == "" {
requestID = newRequestID()
req.Header.Set(h.legacyTraceHeader, requestID) // legacy behavior
}

// immediately set the request ID to be reflected in the response
w.Header().Set(requestIDHeader, requestID)

// continue down the handler chain
ctx := NewContext(req.Context(), requestID)
next.ServeHTTP(w, req.WithContext(ctx))
}
return http.HandlerFunc(fn)
}

// newRequestID creates a new request ID using github.com/rs/xid.
func newRequestID() string {
return xid.New().String()
}

type requestIDKey struct{}

// NewContext returns a new context with the given request ID added to the
// context.
func NewContext(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDKey{}, requestID)
}

// FromContext returns the request ID from the context if it exists and
// is not the empty value.
func FromContext(ctx context.Context) (string, bool) {
v, ok := ctx.Value(requestIDKey{}).(string)
return v, ok && v != ""
}
53 changes: 28 additions & 25 deletions logging/context_test.go → internal/requestid/requestid_test.go
@@ -1,4 +1,4 @@
package logging
package requestid

import (
"net/http"
Expand All @@ -10,33 +10,33 @@ import (
)

func newRequest(t *testing.T) *http.Request {
t.Helper()
r, err := http.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
require.NoError(t, err)
return r
}

func TestRequestID(t *testing.T) {
func Test_Middleware(t *testing.T) {
requestWithID := newRequest(t)
requestWithID.Header.Set("X-Request-Id", "reqID")
requestWithoutID := newRequest(t)
requestWithEmptyHeader := newRequest(t)
requestWithEmptyHeader.Header.Set("X-Request-Id", "")
requestWithSmallstepID := newRequest(t)
requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID")

tests := []struct {
name string
headerName string
handler http.HandlerFunc
req *http.Request
name string
traceHeader string
next http.HandlerFunc
req *http.Request
}{
{
name: "default-request-id",
headerName: defaultTraceIDHeader,
handler: func(w http.ResponseWriter, r *http.Request) {
name: "default-request-id",
traceHeader: defaultTraceHeader,
next: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
assert.Equal(t, "reqID", r.Header.Get("X-Request-Id"))
reqID, ok := GetRequestID(r.Context())
reqID, ok := FromContext(r.Context())
if assert.True(t, ok) {
assert.Equal(t, "reqID", reqID)
}
Expand All @@ -45,13 +45,13 @@ func TestRequestID(t *testing.T) {
req: requestWithID,
},
{
name: "no-request-id",
headerName: "X-Request-Id",
handler: func(w http.ResponseWriter, r *http.Request) {
name: "no-request-id",
traceHeader: "X-Request-Id",
next: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
value := r.Header.Get("X-Request-Id")
assert.NotEmpty(t, value)
reqID, ok := GetRequestID(r.Context())
reqID, ok := FromContext(r.Context())
if assert.True(t, ok) {
assert.Equal(t, value, reqID)
}
Expand All @@ -60,13 +60,13 @@ func TestRequestID(t *testing.T) {
req: requestWithoutID,
},
{
name: "empty-header-name",
headerName: "",
handler: func(w http.ResponseWriter, r *http.Request) {
name: "empty-header",
traceHeader: "",
next: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Request-Id"))
value := r.Header.Get("X-Smallstep-Id")
assert.NotEmpty(t, value)
reqID, ok := GetRequestID(r.Context())
reqID, ok := FromContext(r.Context())
if assert.True(t, ok) {
assert.Equal(t, value, reqID)
}
Expand All @@ -75,12 +75,12 @@ func TestRequestID(t *testing.T) {
req: requestWithEmptyHeader,
},
{
name: "fallback-header-name",
headerName: defaultTraceIDHeader,
handler: func(w http.ResponseWriter, r *http.Request) {
name: "fallback-header-name",
traceHeader: defaultTraceHeader,
next: func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Request-Id"))
assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id"))
reqID, ok := GetRequestID(r.Context())
reqID, ok := FromContext(r.Context())
if assert.True(t, ok) {
assert.Equal(t, "smallstepID", reqID)
}
Expand All @@ -91,8 +91,11 @@ func TestRequestID(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := RequestID(tt.headerName)
h(tt.handler).ServeHTTP(httptest.NewRecorder(), tt.req)
handler := New(tt.traceHeader).Middleware(tt.next)

w := httptest.NewRecorder()
handler.ServeHTTP(w, tt.req)
assert.NotEmpty(t, w.Header().Get("X-Request-Id"))
})
}
}

0 comments on commit 7e5f109

Please sign in to comment.