Skip to content

Commit

Permalink
Merge pull request #1743 from smallstep/herman/improve-request-id
Browse files Browse the repository at this point in the history
Improve end-to-end request ID propagation
  • Loading branch information
hslatman committed Mar 4, 2024
2 parents 0d5c692 + 2a47644 commit 10aa48c
Show file tree
Hide file tree
Showing 21 changed files with 1,060 additions and 688 deletions.
8 changes: 2 additions & 6 deletions api/api_test.go
Expand Up @@ -884,16 +884,12 @@ func Test_Sign(t *testing.T) {
CsrPEM: CertificateRequest{csr},
OTT: "foobarzar",
})
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
invalid, err := json.Marshal(SignRequest{
CsrPEM: CertificateRequest{csr},
OTT: "",
})
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)

expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
Expand Down
7 changes: 3 additions & 4 deletions authority/provisioner/webhook.go
Expand Up @@ -15,7 +15,7 @@ import (
"time"

"github.com/pkg/errors"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/templates"
"github.com/smallstep/certificates/webhook"
"go.step.sm/linkedca"
Expand Down Expand Up @@ -171,9 +171,8 @@ retry:
return nil, err
}

requestID, ok := logging.GetRequestID(ctx)
if ok {
req.Header.Set("X-Request-ID", requestID)
if requestID, ok := requestid.FromContext(ctx); ok {
req.Header.Set("X-Request-Id", requestID)
}

secret, err := base64.StdEncoding.DecodeString(w.Secret)
Expand Down
39 changes: 21 additions & 18 deletions authority/provisioner/webhook_test.go
Expand Up @@ -17,13 +17,15 @@ import (
"testing"
"time"

"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/webhook"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"go.step.sm/crypto/pemutil"
"go.step.sm/crypto/x509util"
"go.step.sm/linkedca"

"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/webhook"
)

func TestWebhookController_isCertTypeOK(t *testing.T) {
Expand Down Expand Up @@ -101,10 +103,11 @@ func TestWebhookController_isCertTypeOK(t *testing.T) {
}
}

// withRequestID is a helper that calls into [logging.WithRequestID] and returns
// a new context with the requestID added to the provided context.
func withRequestID(ctx context.Context, requestID string) context.Context {
return logging.WithRequestID(ctx, requestID)
// withRequestID is a helper that calls into [requestid.NewContext] and returns
// a new context with the requestID added.
func withRequestID(t *testing.T, ctx context.Context, requestID string) context.Context {
t.Helper()
return requestid.NewContext(ctx, requestID)
}

func TestWebhookController_Enrich(t *testing.T) {
Expand Down Expand Up @@ -138,7 +141,7 @@ func TestWebhookController_Enrich(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}},
TemplateData: x509util.TemplateData{},
},
ctx: withRequestID(context.Background(), "reqID"),
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}},
expectErr: false,
Expand All @@ -153,7 +156,7 @@ func TestWebhookController_Enrich(t *testing.T) {
},
TemplateData: x509util.TemplateData{},
},
ctx: withRequestID(context.Background(), "reqID"),
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{
{Allow: true, Data: map[string]any{"role": "bar"}},
Expand All @@ -177,7 +180,7 @@ func TestWebhookController_Enrich(t *testing.T) {
TemplateData: x509util.TemplateData{},
certType: linkedca.Webhook_X509,
},
ctx: withRequestID(context.Background(), "reqID"),
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{
{Allow: true, Data: map[string]any{"role": "bar"}},
Expand All @@ -197,7 +200,7 @@ func TestWebhookController_Enrich(t *testing.T) {
TemplateData: x509util.TemplateData{},
options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)},
},
ctx: withRequestID(context.Background(), "reqID"),
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}},
expectErr: false,
Expand All @@ -220,7 +223,7 @@ func TestWebhookController_Enrich(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}},
TemplateData: x509util.TemplateData{},
},
ctx: withRequestID(context.Background(), "reqID"),
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true,
Expand All @@ -235,7 +238,7 @@ func TestWebhookController_Enrich(t *testing.T) {
PublicKey: []byte("bad"),
})},
},
ctx: withRequestID(context.Background(), "reqID"),
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true,
Expand Down Expand Up @@ -296,7 +299,7 @@ func TestWebhookController_Authorize(t *testing.T) {
client: http.DefaultClient,
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
},
ctx: withRequestID(context.Background(), "reqID"),
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true}},
expectErr: false,
Expand All @@ -307,7 +310,7 @@ func TestWebhookController_Authorize(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}},
certType: linkedca.Webhook_SSH,
},
ctx: withRequestID(context.Background(), "reqID"),
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: false,
Expand All @@ -318,7 +321,7 @@ func TestWebhookController_Authorize(t *testing.T) {
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)},
},
ctx: withRequestID(context.Background(), "reqID"),
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: true}},
expectErr: false,
Expand All @@ -339,7 +342,7 @@ func TestWebhookController_Authorize(t *testing.T) {
client: http.DefaultClient,
webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}},
},
ctx: withRequestID(context.Background(), "reqID"),
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true,
Expand All @@ -352,7 +355,7 @@ func TestWebhookController_Authorize(t *testing.T) {
PublicKey: []byte("bad"),
})},
},
ctx: withRequestID(context.Background(), "reqID"),
ctx: withRequestID(t, context.Background(), "reqID"),
req: &webhook.RequestBody{},
responses: []*webhook.ResponseBody{{Allow: false}},
expectErr: true,
Expand Down Expand Up @@ -568,7 +571,7 @@ func TestWebhook_Do(t *testing.T) {

ctx := context.Background()
if tc.requestID != "" {
ctx = withRequestID(context.Background(), tc.requestID)
ctx = withRequestID(t, ctx, tc.requestID)
}
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
Expand Down
3 changes: 3 additions & 0 deletions ca/acmeClient.go
Expand Up @@ -48,6 +48,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC
return nil, errors.Wrapf(err, "creating GET request %s failed", endpoint)
}
req.Header.Set("User-Agent", UserAgent)
enforceRequestID(req)
resp, err := ac.client.Do(req)
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", endpoint)
Expand Down Expand Up @@ -109,6 +110,7 @@ func (c *ACMEClient) GetNonce() (string, error) {
return "", errors.Wrapf(err, "creating GET request %s failed", c.dir.NewNonce)
}
req.Header.Set("User-Agent", UserAgent)
enforceRequestID(req)
resp, err := c.client.Do(req)
if err != nil {
return "", errors.Wrapf(err, "client GET %s failed", c.dir.NewNonce)
Expand Down Expand Up @@ -188,6 +190,7 @@ func (c *ACMEClient) post(payload []byte, url string, headerOps ...withHeaderOpt
}
req.Header.Set("Content-Type", "application/jose+json")
req.Header.Set("User-Agent", UserAgent)
enforceRequestID(req)
resp, err := c.client.Do(req)
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", c.dir.NewOrder)
Expand Down
7 changes: 7 additions & 0 deletions ca/ca.go
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/internal/metrix"
"github.com/smallstep/certificates/internal/requestid"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/monitoring"
"github.com/smallstep/certificates/scep"
Expand Down Expand Up @@ -329,15 +330,21 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
}

// Add logger if configured
var legacyTraceHeader string
if len(cfg.Logger) > 0 {
logger, err := logging.New("ca", cfg.Logger)
if err != nil {
return nil, err
}
legacyTraceHeader = logger.GetTraceHeader()
handler = logger.Middleware(handler)
insecureHandler = logger.Middleware(insecureHandler)
}

// always use request ID middleware; traceHeader is provided for backwards compatibility (for now)
handler = requestid.New(legacyTraceHeader).Middleware(handler)
insecureHandler = requestid.New(legacyTraceHeader).Middleware(insecureHandler)

// Create context with all the necessary values.
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker)

Expand Down
25 changes: 20 additions & 5 deletions ca/ca_test.go
Expand Up @@ -289,6 +289,9 @@ ZEp7knvU2psWRw==

if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body}
resp := &http.Response{
Body: body,
}
if rr.Code < http.StatusBadRequest {
var sign api.SignResponse
assert.FatalError(t, readJSON(body, &sign))
Expand Down Expand Up @@ -325,7 +328,7 @@ ZEp7knvU2psWRw==
assert.FatalError(t, err)
assert.Equals(t, intermediate, realIntermediate)
} else {
err := readError(body)
err := readError(resp)
if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error"))
}
Expand Down Expand Up @@ -369,6 +372,9 @@ func TestCAProvisioners(t *testing.T) {

if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body}
resp := &http.Response{
Body: body,
}
if rr.Code < http.StatusBadRequest {
var resp api.ProvisionersResponse

Expand All @@ -379,7 +385,7 @@ func TestCAProvisioners(t *testing.T) {
assert.FatalError(t, err)
assert.Equals(t, a, b)
} else {
err := readError(body)
err := readError(resp)
if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error"))
}
Expand Down Expand Up @@ -436,12 +442,15 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {

if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body}
resp := &http.Response{
Body: body,
}
if rr.Code < http.StatusBadRequest {
var ek api.ProvisionerKeyResponse
assert.FatalError(t, readJSON(body, &ek))
assert.Equals(t, ek.Key, tc.expectedKey)
} else {
err := readError(body)
err := readError(resp)
if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error"))
}
Expand Down Expand Up @@ -498,12 +507,15 @@ func TestCARoot(t *testing.T) {

if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body}
resp := &http.Response{
Body: body,
}
if rr.Code < http.StatusBadRequest {
var root api.RootResponse
assert.FatalError(t, readJSON(body, &root))
assert.Equals(t, root.RootPEM.Certificate, rootCrt)
} else {
err := readError(body)
err := readError(resp)
if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error"))
}
Expand Down Expand Up @@ -641,6 +653,9 @@ func TestCARenew(t *testing.T) {

if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body}
resp := &http.Response{
Body: body,
}
if rr.Code < http.StatusBadRequest {
var sign api.SignResponse
assert.FatalError(t, readJSON(body, &sign))
Expand Down Expand Up @@ -673,7 +688,7 @@ func TestCARenew(t *testing.T) {

assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions)
} else {
err := readError(body)
err := readError(resp)
if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error"))
}
Expand Down

0 comments on commit 10aa48c

Please sign in to comment.