Skip to content

Commit

Permalink
Add support for user provider X-Request-Id header value
Browse files Browse the repository at this point in the history
  • Loading branch information
hslatman committed Feb 28, 2024
1 parent cf8a501 commit 5c2572c
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 8 deletions.
16 changes: 13 additions & 3 deletions ca/client.go
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca/client"
"github.com/smallstep/certificates/ca/identity"
"github.com/smallstep/certificates/errs"
"go.step.sm/cli-utils/step"
Expand Down Expand Up @@ -105,10 +106,19 @@ func (c *uaClient) PostWithContext(ctx context.Context, u, contentType string, b
const requestIDHeader = "X-Request-Id"

// enforceRequestID checks if the X-Request-Id HTTP header is filled. If it's
// empty, it'll generate a new request ID and set the header.
// empty, the context is searched for a request ID. If that's also empty, a new
// request ID is generated.
func enforceRequestID(r *http.Request) {
if r.Header.Get(requestIDHeader) == "" {
r.Header.Set(requestIDHeader, xid.New().String())
requestID := r.Header.Get(requestIDHeader)
if requestID == "" {
if reqID, ok := client.GetRequestID(r.Context()); ok && reqID != "" {
// TODO(hs): ensure the request ID from the context is fresh, and thus hasn't been
// used before by the client (unless it's a retry for the same request)?
requestID = reqID
} else {
requestID = xid.New().String()
}
r.Header.Set(requestIDHeader, requestID)
}
}

Expand Down
17 changes: 17 additions & 0 deletions ca/client/requestid.go
@@ -0,0 +1,17 @@
package client

import "context"

type requestIDKey struct{}

// WithRequestID returns a new context with the given requestID added to the
// context.
func WithRequestID(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDKey{}, requestID)
}

// GetRequestID returns the request id from the context if it exists.
func GetRequestID(ctx context.Context) (string, bool) {
v, ok := ctx.Value(requestIDKey{}).(string)
return v, ok
}
26 changes: 21 additions & 5 deletions test/e2e/requestid_test.go
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/ca"
"github.com/smallstep/certificates/ca/client"
"github.com/smallstep/certificates/errs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -57,8 +58,8 @@ func TestXxx(t *testing.T) {
c, err := ca.New(cfg)
require.NoError(t, err)

// instantiate a client for the CA
client, err := ca.NewClient(
// instantiate a client for the CA running at the random address
caClient, err := ca.NewClient(
fmt.Sprintf("https://%s", randomAddress),
ca.WithRootFile(rootFilepath),
)
Expand All @@ -75,12 +76,12 @@ func TestXxx(t *testing.T) {

// require OK health response as the baseline
ctx := context.Background()
healthResponse, err := client.HealthWithContext(ctx)
healthResponse, err := caClient.HealthWithContext(ctx)
assert.NoError(t, err)
require.Equal(t, "ok", healthResponse.Status)
assert.Equal(t, "ok", healthResponse.Status)

// expect an error when retrieving an invalid root
rootResponse, err := client.RootWithContext(ctx, "invalid")
rootResponse, err := caClient.RootWithContext(ctx, "invalid")
if assert.Error(t, err) {
apiErr := &errs.Error{}
if assert.ErrorAs(t, err, &apiErr) {
Expand All @@ -94,6 +95,21 @@ func TestXxx(t *testing.T) {
}
assert.Nil(t, rootResponse)

// expect an error when retrieving an invalid root and provided request ID
rootResponse, err = caClient.RootWithContext(client.WithRequestID(ctx, "reqID"), "invalid")
if assert.Error(t, err) {
apiErr := &errs.Error{}
if assert.ErrorAs(t, err, &apiErr) {
assert.Equal(t, 404, apiErr.StatusCode())
assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", apiErr.Err.Error())
assert.Equal(t, "reqID", apiErr.RequestID)

// TODO: include the below error in the JSON? It's currently only output to the CA logs
//assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg)
}
}
assert.Nil(t, rootResponse)

// done testing; stop and wait for the server to quit
err = c.Stop()
require.NoError(t, err)
Expand Down

0 comments on commit 5c2572c

Please sign in to comment.