From f3229d3e3ca2bf226fae706aac4443ad0926497c Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 19 Sep 2023 13:48:43 +0200 Subject: [PATCH 01/31] Propagate (original) request ID to webhook requests Technically the webhook request is a new request, so maybe the `X-Request-ID` should not be set to the value of the original request? But then the original request ID should be propageted in the webhook request body, or using a different header. The way the request ID is used in this functionality is actually more like a tracing ID, so that may be an option too. --- authority/provisioner/webhook.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index 407b84d83..1097c0039 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -15,6 +15,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/webhook" "go.step.sm/linkedca" @@ -169,6 +170,11 @@ retry: return nil, err } + requestID, ok := logging.GetRequestID(ctx) + if ok { + req.Header.Set("X-Request-ID", requestID) + } + secret, err := base64.StdEncoding.DecodeString(w.Secret) if err != nil { return nil, err From b2301ea12731a35f3795505a8b5b2f3ec736e83f Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 19 Sep 2023 15:39:54 +0200 Subject: [PATCH 02/31] Remove the webhook `Do` method --- authority/provisioner/webhook.go | 20 +++++++++++--------- authority/provisioner/webhook_test.go | 19 ++++++++++++++++--- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index 1097c0039..14d357f10 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -56,7 +56,11 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { if !wc.isCertTypeOK(wh) { continue } - resp, err := wh.Do(wc.client, req, wc.TemplateData) + // TODO(hs): propagate context from above + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData) if err != nil { return err } @@ -88,7 +92,12 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error { if !wc.isCertTypeOK(wh) { continue } - resp, err := wh.Do(wc.client, req, wc.TemplateData) + + // TODO(hs): propagate context from above + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData) if err != nil { return err } @@ -124,13 +133,6 @@ type Webhook struct { } `json:"-"` } -func (w *Webhook) Do(client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - return w.DoWithContext(ctx, client, reqBody, data) -} - func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL) if err != nil { diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index 656d75d86..a61da39cd 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -1,6 +1,7 @@ package provisioner import ( + "context" "crypto/hmac" "crypto/sha256" "crypto/tls" @@ -13,6 +14,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/pkg/errors" "github.com/smallstep/assert" @@ -522,7 +524,11 @@ func TestWebhook_Do(t *testing.T) { reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) assert.FatalError(t, err) - got, err := tc.webhook.Do(http.DefaultClient, reqBody, tc.dataArg) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, reqBody, tc.dataArg) if tc.expectErr != nil { assert.Equals(t, tc.expectErr.Error(), err.Error()) return @@ -553,11 +559,18 @@ func TestWebhook_Do(t *testing.T) { } reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) assert.FatalError(t, err) - _, err = wh.Do(client, reqBody, nil) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + _, err = wh.DoWithContext(ctx, client, reqBody, nil) assert.FatalError(t, err) + ctx, cancel = context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + wh.DisableTLSClientAuth = true - _, err = wh.Do(client, reqBody, nil) + _, err = wh.DoWithContext(ctx, client, reqBody, nil) assert.Error(t, err) }) } From 4e06bdbc514826ee65983e7f7d5f201f18023130 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 19 Sep 2023 16:17:36 +0200 Subject: [PATCH 03/31] Add `SignWithContext` method to authority and mocks --- acme/api/revoke_test.go | 4 ++++ acme/common.go | 1 + acme/order_test.go | 10 ++++++++++ api/api.go | 1 + api/api_test.go | 8 ++++++++ authority/provisioner/webhook.go | 19 +++++++++---------- authority/provisioner/webhook_test.go | 4 ++-- authority/ssh.go | 14 +++++++------- authority/tls.go | 23 ++++++++++++++++------- authority/webhook.go | 10 +++++++--- authority/webhook_test.go | 6 ++++-- 11 files changed, 69 insertions(+), 31 deletions(-) diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go index a225aa194..e8edcc418 100644 --- a/acme/api/revoke_test.go +++ b/acme/api/revoke_test.go @@ -285,6 +285,10 @@ func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...prov return nil, nil } +func (m *mockCA) SignWithContext(context.Context, *x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) { + return nil, nil +} + func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error { if m.MockAreSANsallowed != nil { return m.MockAreSANsallowed(ctx, sans) diff --git a/acme/common.go b/acme/common.go index 7d58305fa..afab13b20 100644 --- a/acme/common.go +++ b/acme/common.go @@ -22,6 +22,7 @@ var clock Clock // CertificateAuthority is the interface implemented by a CA authority. type CertificateAuthority interface { Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) AreSANsAllowed(ctx context.Context, sans []string) error IsRevoked(sn string) (bool, error) Revoke(context.Context, *authority.RevokeOptions) error diff --git a/acme/order_test.go b/acme/order_test.go index 2851bb190..3fa99b9b8 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -272,6 +272,7 @@ func TestOrder_UpdateStatus(t *testing.T) { type mockSignAuth struct { sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + signWithContext func(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) areSANsAllowed func(ctx context.Context, sans []string) error loadProvisionerByName func(string) (provisioner.Interface, error) ret1, ret2 interface{} @@ -287,6 +288,15 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } +func (m *mockSignAuth) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + if m.signWithContext != nil { + return m.signWithContext(ctx, csr, signOpts, extraOpts...) + } else if m.err != nil { + return nil, m.err + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error { if m.areSANsAllowed != nil { return m.areSANsAllowed(ctx, sans) diff --git a/api/api.go b/api/api.go index c9820351d..2d6c0bf76 100644 --- a/api/api.go +++ b/api/api.go @@ -42,6 +42,7 @@ type Authority interface { GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Renew(peer *x509.Certificate) ([]*x509.Certificate, error) RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) diff --git a/api/api_test.go b/api/api_test.go index d96015f94..90acf759f 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -193,6 +193,7 @@ type mockAuthority struct { getTLSOptions func() *authority.TLSOptions root func(shasum string) (*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + signWithContext func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) @@ -261,6 +262,13 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignO return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } +func (m *mockAuthority) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + if m.signWithContext != nil { + return m.signWithContext(ctx, cr, opts, signOpts...) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { if m.renew != nil { return m.renew(cert) diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index 14d357f10..1cc2047c3 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -37,7 +37,7 @@ type WebhookController struct { // Enrich fetches data from remote servers and adds returned data to the // templateData -func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { +func (wc *WebhookController) Enrich(ctx context.Context, req *webhook.RequestBody) error { if wc == nil { return nil } @@ -56,11 +56,11 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { if !wc.isCertTypeOK(wh) { continue } - // TODO(hs): propagate context from above - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData) + whCtx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() //nolint:gocritic // every request canceled with its own timeout + + resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData) if err != nil { return err } @@ -73,7 +73,7 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { } // Authorize checks that all remote servers allow the request -func (wc *WebhookController) Authorize(req *webhook.RequestBody) error { +func (wc *WebhookController) Authorize(ctx context.Context, req *webhook.RequestBody) error { if wc == nil { return nil } @@ -93,11 +93,10 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error { continue } - // TODO(hs): propagate context from above - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() + whCtx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() //nolint:gocritic // every request canceled with its own timeout - resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData) + resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData) if err != nil { return err } diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index a61da39cd..cc79a09b7 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -242,7 +242,7 @@ func TestWebhookController_Enrich(t *testing.T) { wh.URL = ts.URL } - err := test.ctl.Enrich(test.req) + err := test.ctl.Enrich(context.Background(), test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } @@ -352,7 +352,7 @@ func TestWebhookController_Authorize(t *testing.T) { wh.URL = ts.URL } - err := test.ctl.Authorize(test.req) + err := test.ctl.Authorize(context.Background(), test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } diff --git a/authority/ssh.go b/authority/ssh.go index f9371d60e..688bfd762 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -146,7 +146,7 @@ func (a *Authority) GetSSHBastion(ctx context.Context, user, hostname string) (* } // SignSSH creates a signed SSH certificate with the given public key and options. -func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { +func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { var ( certOptions []sshutil.Option mods []provisioner.SSHCertModifier @@ -205,7 +205,7 @@ func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provision } // Call enriching webhooks - if err := callEnrichingWebhooksSSH(webhookCtl, cr); err != nil { + if err := callEnrichingWebhooksSSH(ctx, webhookCtl, cr); err != nil { return nil, errs.ApplyOptions( errs.ForbiddenErr(err, err.Error()), errs.WithKeyVal("signOptions", signOpts), @@ -277,7 +277,7 @@ func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provision } // Send certificate to webhooks for authorization - if err := callAuthorizingWebhooksSSH(webhookCtl, certificate, certTpl); err != nil { + if err := callAuthorizingWebhooksSSH(ctx, webhookCtl, certificate, certTpl); err != nil { return nil, errs.ApplyOptions( errs.ForbiddenErr(err, "authority.SignSSH: error signing certificate"), ) @@ -653,7 +653,7 @@ func (a *Authority) getAddUserCommand(principal string) string { return strings.ReplaceAll(cmd, "", principal) } -func callEnrichingWebhooksSSH(webhookCtl webhookController, cr sshutil.CertificateRequest) error { +func callEnrichingWebhooksSSH(ctx context.Context, webhookCtl webhookController, cr sshutil.CertificateRequest) error { if webhookCtl == nil { return nil } @@ -663,10 +663,10 @@ func callEnrichingWebhooksSSH(webhookCtl webhookController, cr sshutil.Certifica if err != nil { return err } - return webhookCtl.Enrich(whEnrichReq) + return webhookCtl.Enrich(ctx, whEnrichReq) } -func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) error { +func callAuthorizingWebhooksSSH(ctx context.Context, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) error { if webhookCtl == nil { return nil } @@ -676,5 +676,5 @@ func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Cert if err != nil { return err } - return webhookCtl.Authorize(whAuthBody) + return webhookCtl.Authorize(ctx, whAuthBody) } diff --git a/authority/tls.go b/authority/tls.go index 6e9679209..900b1ff85 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -91,8 +91,17 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc { } } -// Sign creates a signed certificate from a certificate signing request. +// Sign creates a signed certificate from a certificate signing request. It +// creates a new context.Context, and calls into SignWithContext. +// +// Deprecated: Use authority.SignWithContext with an actual context.Context. func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + return a.SignWithContext(context.Background(), csr, signOpts, extraOpts...) +} + +// SignWithContext creates a signed certificate from a certificate signing request, +// taking the provided context.Context. +func (a *Authority) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { var ( certOptions []x509util.Option certValidators []provisioner.CertificateValidator @@ -163,7 +172,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign } } - if err := callEnrichingWebhooksX509(webhookCtl, attData, csr); err != nil { + if err := callEnrichingWebhooksX509(ctx, webhookCtl, attData, csr); err != nil { return nil, errs.ApplyOptions( errs.ForbiddenErr(err, err.Error()), errs.WithKeyVal("csr", csr), @@ -256,7 +265,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign } // Send certificate to webhooks for authorization - if err := callAuthorizingWebhooksX509(webhookCtl, cert, leaf, attData); err != nil { + if err := callAuthorizingWebhooksX509(ctx, webhookCtl, cert, leaf, attData); err != nil { return nil, errs.ApplyOptions( errs.ForbiddenErr(err, "error creating certificate"), opts..., @@ -952,7 +961,7 @@ func templatingError(err error) error { return errors.Wrap(cause, "error applying certificate template") } -func callEnrichingWebhooksX509(webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) error { +func callEnrichingWebhooksX509(ctx context.Context, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) error { if webhookCtl == nil { return nil } @@ -969,10 +978,10 @@ func callEnrichingWebhooksX509(webhookCtl webhookController, attData *provisione if err != nil { return err } - return webhookCtl.Enrich(whEnrichReq) + return webhookCtl.Enrich(ctx, whEnrichReq) } -func callAuthorizingWebhooksX509(webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) error { +func callAuthorizingWebhooksX509(ctx context.Context, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) error { if webhookCtl == nil { return nil } @@ -989,5 +998,5 @@ func callAuthorizingWebhooksX509(webhookCtl webhookController, cert *x509util.Ce if err != nil { return err } - return webhookCtl.Authorize(whAuthBody) + return webhookCtl.Authorize(ctx, whAuthBody) } diff --git a/authority/webhook.go b/authority/webhook.go index d887e0775..29e3e6c31 100644 --- a/authority/webhook.go +++ b/authority/webhook.go @@ -1,8 +1,12 @@ package authority -import "github.com/smallstep/certificates/webhook" +import ( + "context" + + "github.com/smallstep/certificates/webhook" +) type webhookController interface { - Enrich(*webhook.RequestBody) error - Authorize(*webhook.RequestBody) error + Enrich(context.Context, *webhook.RequestBody) error + Authorize(context.Context, *webhook.RequestBody) error } diff --git a/authority/webhook_test.go b/authority/webhook_test.go index 0e713af7d..75b59f63f 100644 --- a/authority/webhook_test.go +++ b/authority/webhook_test.go @@ -1,6 +1,8 @@ package authority import ( + "context" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/webhook" ) @@ -14,7 +16,7 @@ type mockWebhookController struct { var _ webhookController = &mockWebhookController{} -func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error { +func (wc *mockWebhookController) Enrich(context.Context, *webhook.RequestBody) error { for key, data := range wc.respData { wc.templateData.SetWebhook(key, data) } @@ -22,6 +24,6 @@ func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error { return wc.enrichErr } -func (wc *mockWebhookController) Authorize(*webhook.RequestBody) error { +func (wc *mockWebhookController) Authorize(context.Context, *webhook.RequestBody) error { return wc.authorizeErr } From 9e3807eaa3096d633e6f28437387fef4256d36d7 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 19 Sep 2023 16:34:29 +0200 Subject: [PATCH 04/31] Use `SignWithContext` in the critical paths --- acme/order.go | 2 +- api/sign.go | 2 +- api/ssh.go | 2 +- scep/authority.go | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/acme/order.go b/acme/order.go index 8dfcf97a6..5a86c2c8a 100644 --- a/acme/order.go +++ b/acme/order.go @@ -263,7 +263,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques signOps = append(signOps, extraOptions...) // Sign a new certificate. - certChain, err := auth.Sign(csr, provisioner.SignOptions{ + certChain, err := auth.SignWithContext(ctx, csr, provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(o.NotBefore), NotAfter: provisioner.NewTimeDuration(o.NotAfter), }, signOps...) diff --git a/api/sign.go b/api/sign.go index c0c83ce21..26b3c396f 100644 --- a/api/sign.go +++ b/api/sign.go @@ -78,7 +78,7 @@ func Sign(w http.ResponseWriter, r *http.Request) { return } - certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) + certChain, err := a.SignWithContext(ctx, body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) return diff --git a/api/ssh.go b/api/ssh.go index fbaa8c5a0..a07dab294 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -330,7 +330,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) { NotAfter: time.Unix(int64(cert.ValidBefore), 0), }) - certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...) + certChain, err := a.SignWithContext(ctx, cr, provisioner.SignOptions{}, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) return diff --git a/scep/authority.go b/scep/authority.go index 23c288133..a7333aa75 100644 --- a/scep/authority.go +++ b/scep/authority.go @@ -65,6 +65,7 @@ type AuthorityOptions struct { // SignAuthority is the interface for a signing authority type SignAuthority interface { Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) LoadProvisionerByName(string) (provisioner.Interface, error) } @@ -296,7 +297,7 @@ func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, m } signOps = append(signOps, templateOptions) - certChain, err := a.signAuth.Sign(csr, opts, signOps...) + certChain, err := a.signAuth.SignWithContext(ctx, csr, opts, signOps...) if err != nil { return nil, fmt.Errorf("error generating certificate for order: %w", err) } From 4ef093dc4b478c89b17ce761e0260a99265fa39c Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 19 Sep 2023 16:55:59 +0200 Subject: [PATCH 05/31] Fix broken tests relying on `Sign` in mocks --- acme/order_test.go | 18 +++++++++--------- api/ssh_test.go | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/acme/order_test.go b/acme/order_test.go index 3fa99b9b8..17060f11e 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -588,7 +588,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return nil, errors.New("force") }, @@ -638,7 +638,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, @@ -695,7 +695,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, @@ -780,7 +780,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, @@ -873,7 +873,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, @@ -983,7 +983,7 @@ func TestOrder_Finalize(t *testing.T) { // using the mocking functions as a wrapper for actual test helpers generated per test case or per // function that's tested. ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, @@ -1054,7 +1054,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, @@ -1118,7 +1118,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, @@ -1185,7 +1185,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, diff --git a/api/ssh_test.go b/api/ssh_test.go index 57dd6775e..2b90dc12e 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -325,7 +325,7 @@ func Test_SSHSign(t *testing.T) { signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { return tt.addUserCert, tt.addUserErr }, - sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { return tt.tlsSignCerts, tt.tlsSignErr }, }) From c2dfe595f17252e4eba6cef2f9347ba9ae2006b7 Mon Sep 17 00:00:00 2001 From: Anton Patsev Date: Sat, 24 Feb 2024 11:50:30 +0600 Subject: [PATCH 06/31] =?UTF-8?q?=D0=A1orrection=20of=20spelling=20errors?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/docker/renewer/entrypoint.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/docker/renewer/entrypoint.sh b/examples/docker/renewer/entrypoint.sh index dc84dcbfa..545f7fdab 100755 --- a/examples/docker/renewer/entrypoint.sh +++ b/examples/docker/renewer/entrypoint.sh @@ -7,12 +7,12 @@ sleep 5 rm -f /var/local/step/root_ca.crt rm -f /var/local/step/site.crt /var/local/step/site.key -# Donwload the root certificate +# Download the root certificate step ca root /var/local/step/root_ca.crt # Get token STEP_TOKEN=$(step ca token $COMMON_NAME) -# Donwload the root certificate +# Download the root certificate step ca certificate --token $STEP_TOKEN $COMMON_NAME /var/local/step/site.crt /var/local/step/site.key -exec "$@" \ No newline at end of file +exec "$@" From fa941dc96724f5dd88ae23f875e89f5ac16f0ec8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 15:49:17 +0000 Subject: [PATCH 07/31] Bump github.com/googleapis/gax-go/v2 from 2.12.0 to 2.12.2 Bumps [github.com/googleapis/gax-go/v2](https://github.com/googleapis/gax-go) from 2.12.0 to 2.12.2. - [Release notes](https://github.com/googleapis/gax-go/releases) - [Commits](https://github.com/googleapis/gax-go/compare/v2.12.0...v2.12.2) --- updated-dependencies: - dependency-name: github.com/googleapis/gax-go/v2 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 11d9d775b..1f63781f3 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/google/go-cmp v0.6.0 github.com/google/go-tpm v0.9.0 github.com/google/uuid v1.6.0 - github.com/googleapis/gax-go/v2 v2.12.0 + github.com/googleapis/gax-go/v2 v2.12.2 github.com/hashicorp/vault/api v1.12.0 github.com/hashicorp/vault/api/auth/approle v0.6.0 github.com/hashicorp/vault/api/auth/kubernetes v0.6.0 @@ -162,7 +162,7 @@ require ( golang.org/x/time v0.5.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/genproto v0.0.0-20240125205218-1f4bbc51befe // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240125205218-1f4bbc51befe // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240205150955-31a09d347014 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240205150955-31a09d347014 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index a470c1eef..4348efebb 100644 --- a/go.sum +++ b/go.sum @@ -227,8 +227,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= -github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas= -github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU= +github.com/googleapis/gax-go/v2 v2.12.2 h1:mhN09QQW1jEWeMF74zGR81R30z4VJzjZsfkUhuHF+DA= +github.com/googleapis/gax-go/v2 v2.12.2/go.mod h1:61M8vcyyXR2kqKFxKrfA22jaA8JGF7Dc8App1U3H6jc= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -652,8 +652,8 @@ google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98 google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20240125205218-1f4bbc51befe h1:USL2DhxfgRchafRvt/wYyyQNzwgL7ZiURcozOE/Pkvo= google.golang.org/genproto v0.0.0-20240125205218-1f4bbc51befe/go.mod h1:cc8bqMqtv9gMOr0zHg2Vzff5ULhhL2IXP4sbcn32Dro= -google.golang.org/genproto/googleapis/api v0.0.0-20240125205218-1f4bbc51befe h1:0poefMBYvYbs7g5UkjS6HcxBPaTRAmznle9jnxYoAI8= -google.golang.org/genproto/googleapis/api v0.0.0-20240125205218-1f4bbc51befe/go.mod h1:4jWUdICTdgc3Ibxmr8nAJiiLHwQBY0UI0XZcEMaFKaA= +google.golang.org/genproto/googleapis/api v0.0.0-20240205150955-31a09d347014 h1:x9PwdEgd11LgK+orcck69WVRo7DezSO4VUMPI4xpc8A= +google.golang.org/genproto/googleapis/api v0.0.0-20240205150955-31a09d347014/go.mod h1:rbHMSEDyoYX62nRVLOCc4Qt1HbsdytAYoVwgjiOhF3I= google.golang.org/genproto/googleapis/rpc v0.0.0-20240205150955-31a09d347014 h1:FSL3lRCkhaPFxqi0s9o+V4UI2WTzAVOvkgbd4kVV4Wg= google.golang.org/genproto/googleapis/rpc v0.0.0-20240205150955-31a09d347014/go.mod h1:SaPjaZGWb0lPqs6Ittu0spdfrOArqji4ZdeP5IC/9N4= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= From 0b196b0b81b8a5ea3887c3ca8832f50a0aa79a42 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 15:49:39 +0000 Subject: [PATCH 08/31] Bump github.com/fxamacker/cbor/v2 from 2.5.0 to 2.6.0 Bumps [github.com/fxamacker/cbor/v2](https://github.com/fxamacker/cbor) from 2.5.0 to 2.6.0. - [Release notes](https://github.com/fxamacker/cbor/releases) - [Commits](https://github.com/fxamacker/cbor/compare/v2.5.0...v2.6.0) --- updated-dependencies: - dependency-name: github.com/fxamacker/cbor/v2 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 11d9d775b..0c8c1588d 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/Masterminds/sprig/v3 v3.2.3 github.com/dgraph-io/badger v1.6.2 github.com/dgraph-io/badger/v2 v2.2007.4 - github.com/fxamacker/cbor/v2 v2.5.0 + github.com/fxamacker/cbor/v2 v2.6.0 github.com/go-chi/chi/v5 v5.0.11 github.com/go-jose/go-jose/v3 v3.0.1 github.com/golang/mock v1.6.0 diff --git a/go.sum b/go.sum index a470c1eef..07d37f43c 100644 --- a/go.sum +++ b/go.sum @@ -134,8 +134,8 @@ github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADivE= -github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= +github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA= +github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/go-chi/chi/v5 v5.0.11 h1:BnpYbFZ3T3S1WMpD79r7R5ThWX40TaFB7L31Y8xqSwA= github.com/go-chi/chi/v5 v5.0.11/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA= From e4bbe8970edaba4856cb271d74716631b01c1277 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 27 Feb 2024 09:59:58 +0000 Subject: [PATCH 09/31] Bump google.golang.org/grpc from 1.61.0 to 1.62.0 Bumps [google.golang.org/grpc](https://github.com/grpc/grpc-go) from 1.61.0 to 1.62.0. - [Release notes](https://github.com/grpc/grpc-go/releases) - [Commits](https://github.com/grpc/grpc-go/compare/v1.61.0...v1.62.0) --- updated-dependencies: - dependency-name: google.golang.org/grpc dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- go.mod | 4 ++-- go.sum | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index 1f63781f3..5f59fc1b0 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,7 @@ require ( golang.org/x/exp v0.0.0-20230310171629-522b1b587ee0 golang.org/x/net v0.21.0 google.golang.org/api v0.165.0 - google.golang.org/grpc v1.61.0 + google.golang.org/grpc v1.62.0 google.golang.org/protobuf v1.32.0 ) @@ -93,7 +93,7 @@ require ( github.com/go-piv/piv-go v1.11.0 // indirect github.com/go-sql-driver/mysql v1.7.1 // indirect github.com/golang-jwt/jwt/v5 v5.2.0 // indirect - github.com/golang/glog v1.1.2 // indirect + github.com/golang/glog v1.2.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.4 // indirect diff --git a/go.sum b/go.sum index 4348efebb..ee123dc96 100644 --- a/go.sum +++ b/go.sum @@ -94,7 +94,7 @@ github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/cncf/xds/go v0.0.0-20231109132714-523115ebc101 h1:7To3pQ+pZo0i3dsWEbinPNFs5gPSBOsJtx3wTT94VBY= +github.com/cncf/xds/go v0.0.0-20231128003011-0fa0005c9caa h1:jQCWAUqqlij9Pgj2i/PB79y4KOPYVyFYdROxgaCwdTQ= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= @@ -128,7 +128,7 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/envoyproxy/protoc-gen-validate v1.0.2 h1:QkIBuU5k+x7/QXPvPPnWXWlCdaBFApVqftFV6k087DA= +github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -168,8 +168,8 @@ github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/glog v1.1.2 h1:DVjP2PbBOzHyzA+dn3WhHIq4NdVu3Q+pvivFICf/7fo= -github.com/golang/glog v1.1.2/go.mod h1:zR+okUeTbrL6EL3xHUDxZuEtGv04p5shwip1+mL/rLQ= +github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68= +github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -661,8 +661,8 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.61.0 h1:TOvOcuXn30kRao+gfcvsebNEa5iZIiLkisYEkf7R7o0= -google.golang.org/grpc v1.61.0/go.mod h1:VUbo7IFqmF1QtCAstipjG0GIoq49KvMe9+h1jFLBNJs= +google.golang.org/grpc v1.62.0 h1:HQKZ/fa1bXkX1oFOvSjmZEUL8wLSaZTjCcLAlmZRtdk= +google.golang.org/grpc v1.62.0/go.mod h1:IWTG0VlJLCh1SkC58F7np9ka9mx/WNkjl4PGJaiq+QE= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= From 5ee2e0274c57c12d974fac2dd0c633ad71e8444b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 27 Feb 2024 10:01:00 +0000 Subject: [PATCH 10/31] Bump github.com/go-jose/go-jose/v3 from 3.0.1 to 3.0.2 Bumps [github.com/go-jose/go-jose/v3](https://github.com/go-jose/go-jose) from 3.0.1 to 3.0.2. - [Release notes](https://github.com/go-jose/go-jose/releases) - [Changelog](https://github.com/go-jose/go-jose/blob/main/CHANGELOG.md) - [Commits](https://github.com/go-jose/go-jose/compare/v3.0.1...v3.0.2) --- updated-dependencies: - dependency-name: github.com/go-jose/go-jose/v3 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- go.mod | 2 +- go.sum | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 31f88c21f..51095221e 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/dgraph-io/badger/v2 v2.2007.4 github.com/fxamacker/cbor/v2 v2.6.0 github.com/go-chi/chi/v5 v5.0.11 - github.com/go-jose/go-jose/v3 v3.0.1 + github.com/go-jose/go-jose/v3 v3.0.2 github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.6.0 github.com/google/go-tpm v0.9.0 diff --git a/go.sum b/go.sum index 4fa349750..c054c15af 100644 --- a/go.sum +++ b/go.sum @@ -138,8 +138,9 @@ github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1t github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/go-chi/chi/v5 v5.0.11 h1:BnpYbFZ3T3S1WMpD79r7R5ThWX40TaFB7L31Y8xqSwA= github.com/go-chi/chi/v5 v5.0.11/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= -github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA= github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= +github.com/go-jose/go-jose/v3 v3.0.2 h1:2Edjn8Nrb44UvTdp84KU0bBPs1cO7noRCybtS3eJEUQ= +github.com/go-jose/go-jose/v3 v3.0.2/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= github.com/go-kit/kit v0.4.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.13.0 h1:OoneCcHKHQ03LfBpoQCUfCluwd2Vt3ohz+kvbJneZAU= github.com/go-kit/kit v0.13.0/go.mod h1:phqEHMMUbyrCFCTgH48JueqrM3md2HcAZ8N3XE4FKDg= @@ -207,6 +208,7 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-sev-guest v0.9.3 h1:GOJ+EipURdeWFl/YYdgcCxyPeMgQUWlI056iFkBD8UU= @@ -604,6 +606,7 @@ golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= From 96895087098cdcc95299e2a6126c23f1c5260282 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 27 Feb 2024 13:39:21 +0100 Subject: [PATCH 11/31] Add tests for webhook request IDs --- authority/provisioner/webhook_test.go | 126 ++++++++++++++++++-------- 1 file changed, 86 insertions(+), 40 deletions(-) diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index 0ce3f36d3..ced713d1a 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -17,8 +17,11 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/assert" + sassert "github.com/smallstep/assert" + "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/webhook" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" "go.step.sm/linkedca" @@ -94,19 +97,24 @@ func TestWebhookController_isCertTypeOK(t *testing.T) { } for name, test := range tests { t.Run(name, func(t *testing.T) { - assert.Equals(t, test.want, test.wc.isCertTypeOK(test.wh)) + sassert.Equals(t, test.want, test.wc.isCertTypeOK(test.wh)) }) } } +// withRequestID is a helper that calls into [logging.WithRequestID] and returns +// a new context with the requestID added to the provided context. +func withRequestID(ctx context.Context, requestID string) context.Context { + return logging.WithRequestID(ctx, requestID) +} + func TestWebhookController_Enrich(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) type test struct { ctl *WebhookController + ctx context.Context req *webhook.RequestBody responses []*webhook.ResponseBody expectErr bool @@ -131,6 +139,7 @@ func TestWebhookController_Enrich(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, @@ -145,6 +154,7 @@ func TestWebhookController_Enrich(t *testing.T) { }, TemplateData: x509util.TemplateData{}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, @@ -168,6 +178,7 @@ func TestWebhookController_Enrich(t *testing.T) { TemplateData: x509util.TemplateData{}, certType: linkedca.Webhook_X509, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, @@ -187,14 +198,15 @@ func TestWebhookController_Enrich(t *testing.T) { TemplateData: x509util.TemplateData{}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}}, assertRequest: func(t *testing.T, req *webhook.RequestBody) { key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) - assert.FatalError(t, err) - assert.Equals(t, &webhook.X5CCertificate{ + sassert.FatalError(t, err) + sassert.Equals(t, &webhook.X5CCertificate{ Raw: cert.Raw, PublicKey: key, PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), @@ -209,6 +221,7 @@ func TestWebhookController_Enrich(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -223,6 +236,7 @@ func TestWebhookController_Enrich(t *testing.T) { PublicKey: []byte("bad"), })}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -234,19 +248,21 @@ func TestWebhookController_Enrich(t *testing.T) { for i, wh := range test.ctl.webhooks { var j = i ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "reqID", r.Header.Get("X-Request-ID")) + err := json.NewEncoder(w).Encode(test.responses[j]) - assert.FatalError(t, err) + require.NoError(t, err) })) // nolint: gocritic // defer in loop isn't a memory leak defer ts.Close() wh.URL = ts.URL } - err := test.ctl.Enrich(context.Background(), test.req) + err := test.ctl.Enrich(test.ctx, test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } - assert.Equals(t, test.expectTemplateData, test.ctl.TemplateData) + sassert.Equals(t, test.expectTemplateData, test.ctl.TemplateData) if test.assertRequest != nil { test.assertRequest(t, test.req) } @@ -256,12 +272,11 @@ func TestWebhookController_Enrich(t *testing.T) { func TestWebhookController_Authorize(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) type test struct { ctl *WebhookController + ctx context.Context req *webhook.RequestBody responses []*webhook.ResponseBody expectErr bool @@ -282,6 +297,7 @@ func TestWebhookController_Authorize(t *testing.T) { client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, @@ -292,6 +308,7 @@ func TestWebhookController_Authorize(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}}, certType: linkedca.Webhook_SSH, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: false, @@ -302,13 +319,14 @@ func TestWebhookController_Authorize(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, assertRequest: func(t *testing.T, req *webhook.RequestBody) { key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) - assert.FatalError(t, err) - assert.Equals(t, &webhook.X5CCertificate{ + require.NoError(t, err) + sassert.Equals(t, &webhook.X5CCertificate{ Raw: cert.Raw, PublicKey: key, PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), @@ -322,6 +340,7 @@ func TestWebhookController_Authorize(t *testing.T) { client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -334,6 +353,7 @@ func TestWebhookController_Authorize(t *testing.T) { PublicKey: []byte("bad"), })}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -344,15 +364,17 @@ func TestWebhookController_Authorize(t *testing.T) { for i, wh := range test.ctl.webhooks { var j = i ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "reqID", r.Header.Get("X-Request-ID")) + err := json.NewEncoder(w).Encode(test.responses[j]) - assert.FatalError(t, err) + require.NoError(t, err) })) // nolint: gocritic // defer in loop isn't a memory leak defer ts.Close() wh.URL = ts.URL } - err := test.ctl.Authorize(context.Background(), test.req) + err := test.ctl.Authorize(test.ctx, test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } @@ -368,6 +390,7 @@ func TestWebhook_Do(t *testing.T) { type test struct { webhook Webhook dataArg any + requestID string webhookResponse webhook.ResponseBody expectPath string errStatusCode int @@ -377,6 +400,16 @@ func TestWebhook_Do(t *testing.T) { } tests := map[string]test{ "ok": { + webhook: Webhook{ + ID: "abc123", + Secret: "c2VjcmV0Cg==", + }, + requestID: "reqID", + webhookResponse: webhook.ResponseBody{ + Data: map[string]interface{}{"role": "dba"}, + }, + }, + "ok/no-request-id": { webhook: Webhook{ ID: "abc123", Secret: "c2VjcmV0Cg==", @@ -391,6 +424,7 @@ func TestWebhook_Do(t *testing.T) { Secret: "c2VjcmV0Cg==", BearerToken: "mytoken", }, + requestID: "reqID", webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, @@ -407,6 +441,7 @@ func TestWebhook_Do(t *testing.T) { Password: "mypass", }, }, + requestID: "reqID", webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, @@ -418,7 +453,8 @@ func TestWebhook_Do(t *testing.T) { URL: "/users/{{ .username }}?region={{ .region }}", Secret: "c2VjcmV0Cg==", }, - dataArg: map[string]interface{}{"username": "areed", "region": "central"}, + requestID: "reqID", + dataArg: map[string]interface{}{"username": "areed", "region": "central"}, webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, @@ -453,6 +489,7 @@ func TestWebhook_Do(t *testing.T) { ID: "abc123", Secret: "c2VjcmV0Cg==", }, + requestID: "reqID", webhookResponse: webhook.ResponseBody{ Allow: true, }, @@ -465,6 +502,7 @@ func TestWebhook_Do(t *testing.T) { webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, + requestID: "reqID", errStatusCode: 404, serverErrMsg: "item not found", expectErr: errors.New("Webhook server responded with 404"), @@ -473,38 +511,42 @@ func TestWebhook_Do(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if tc.requestID != "" { + assert.Equal(t, tc.requestID, r.Header.Get("X-Request-ID")) + } + id := r.Header.Get("X-Smallstep-Webhook-ID") - assert.Equals(t, tc.webhook.ID, id) + sassert.Equals(t, tc.webhook.ID, id) sig, err := hex.DecodeString(r.Header.Get("X-Smallstep-Signature")) - assert.FatalError(t, err) + assert.NoError(t, err) body, err := io.ReadAll(r.Body) - assert.FatalError(t, err) + assert.NoError(t, err) secret, err := base64.StdEncoding.DecodeString(tc.webhook.Secret) - assert.FatalError(t, err) + assert.NoError(t, err) h := hmac.New(sha256.New, secret) h.Write(body) mac := h.Sum(nil) - assert.True(t, hmac.Equal(sig, mac)) + sassert.True(t, hmac.Equal(sig, mac)) switch { case tc.webhook.BearerToken != "": ah := fmt.Sprintf("Bearer %s", tc.webhook.BearerToken) - assert.Equals(t, ah, r.Header.Get("Authorization")) + sassert.Equals(t, ah, r.Header.Get("Authorization")) case tc.webhook.BasicAuth.Username != "" || tc.webhook.BasicAuth.Password != "": whReq, err := http.NewRequest("", "", http.NoBody) - assert.FatalError(t, err) + assert.NoError(t, err) whReq.SetBasicAuth(tc.webhook.BasicAuth.Username, tc.webhook.BasicAuth.Password) ah := whReq.Header.Get("Authorization") - assert.Equals(t, ah, whReq.Header.Get("Authorization")) + sassert.Equals(t, ah, whReq.Header.Get("Authorization")) default: - assert.Equals(t, "", r.Header.Get("Authorization")) + sassert.Equals(t, "", r.Header.Get("Authorization")) } if tc.expectPath != "" { - assert.Equals(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) + sassert.Equals(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) } if tc.errStatusCode != 0 { @@ -514,30 +556,34 @@ func TestWebhook_Do(t *testing.T) { reqBody := new(webhook.RequestBody) err = json.Unmarshal(body, reqBody) - assert.FatalError(t, err) - // assert.Equals(t, tc.expectToken, reqBody.Token) + require.NoError(t, err) + // sassert.Equals(t, tc.expectToken, reqBody.Token) err = json.NewEncoder(w).Encode(tc.webhookResponse) - assert.FatalError(t, err) + require.NoError(t, err) })) defer ts.Close() tc.webhook.URL = ts.URL + tc.webhook.URL reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) - assert.FatalError(t, err) + require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx := context.Background() + if tc.requestID != "" { + ctx = withRequestID(context.Background(), tc.requestID) + } + ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, reqBody, tc.dataArg) if tc.expectErr != nil { - assert.Equals(t, tc.expectErr.Error(), err.Error()) + sassert.Equals(t, tc.expectErr.Error(), err.Error()) return } - assert.FatalError(t, err) + assert.NoError(t, err) - assert.Equals(t, got, &tc.webhookResponse) + sassert.Equals(t, got, &tc.webhookResponse) }) } @@ -550,7 +596,7 @@ func TestWebhook_Do(t *testing.T) { URL: ts.URL, } cert, err := tls.LoadX509KeyPair("testdata/certs/foo.crt", "testdata/secrets/foo.key") - assert.FatalError(t, err) + require.NoError(t, err) transport := http.DefaultTransport.(*http.Transport).Clone() transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, @@ -560,19 +606,19 @@ func TestWebhook_Do(t *testing.T) { Transport: transport, } reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) - assert.FatalError(t, err) + require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() _, err = wh.DoWithContext(ctx, client, reqBody, nil) - assert.FatalError(t, err) + require.NoError(t, err) ctx, cancel = context.WithTimeout(context.Background(), time.Second*10) defer cancel() wh.DisableTLSClientAuth = true _, err = wh.DoWithContext(ctx, client, reqBody, nil) - assert.Error(t, err) + require.Error(t, err) }) } From c16a0b70ee31ef59fdb652dadb8705f6f0a49012 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 27 Feb 2024 13:44:44 +0100 Subject: [PATCH 12/31] Remove `smallstep/assert` and `pkg/errors` from webhook tests --- authority/provisioner/webhook_test.go | 33 ++++++++++++--------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index ced713d1a..60dcdbc71 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -9,6 +9,7 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -16,8 +17,6 @@ import ( "testing" "time" - "github.com/pkg/errors" - sassert "github.com/smallstep/assert" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/webhook" "github.com/stretchr/testify/assert" @@ -97,7 +96,7 @@ func TestWebhookController_isCertTypeOK(t *testing.T) { } for name, test := range tests { t.Run(name, func(t *testing.T) { - sassert.Equals(t, test.want, test.wc.isCertTypeOK(test.wh)) + assert.Equal(t, test.want, test.wc.isCertTypeOK(test.wh)) }) } } @@ -205,8 +204,8 @@ func TestWebhookController_Enrich(t *testing.T) { expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}}, assertRequest: func(t *testing.T, req *webhook.RequestBody) { key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) - sassert.FatalError(t, err) - sassert.Equals(t, &webhook.X5CCertificate{ + require.NoError(t, err) + assert.Equal(t, &webhook.X5CCertificate{ Raw: cert.Raw, PublicKey: key, PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), @@ -262,7 +261,7 @@ func TestWebhookController_Enrich(t *testing.T) { if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } - sassert.Equals(t, test.expectTemplateData, test.ctl.TemplateData) + assert.Equal(t, test.expectTemplateData, test.ctl.TemplateData) if test.assertRequest != nil { test.assertRequest(t, test.req) } @@ -326,7 +325,7 @@ func TestWebhookController_Authorize(t *testing.T) { assertRequest: func(t *testing.T, req *webhook.RequestBody) { key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) require.NoError(t, err) - sassert.Equals(t, &webhook.X5CCertificate{ + assert.Equal(t, &webhook.X5CCertificate{ Raw: cert.Raw, PublicKey: key, PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), @@ -515,8 +514,7 @@ func TestWebhook_Do(t *testing.T) { assert.Equal(t, tc.requestID, r.Header.Get("X-Request-ID")) } - id := r.Header.Get("X-Smallstep-Webhook-ID") - sassert.Equals(t, tc.webhook.ID, id) + assert.Equal(t, tc.webhook.ID, r.Header.Get("X-Smallstep-Webhook-ID")) sig, err := hex.DecodeString(r.Header.Get("X-Smallstep-Signature")) assert.NoError(t, err) @@ -529,24 +527,24 @@ func TestWebhook_Do(t *testing.T) { h := hmac.New(sha256.New, secret) h.Write(body) mac := h.Sum(nil) - sassert.True(t, hmac.Equal(sig, mac)) + assert.True(t, hmac.Equal(sig, mac)) switch { case tc.webhook.BearerToken != "": ah := fmt.Sprintf("Bearer %s", tc.webhook.BearerToken) - sassert.Equals(t, ah, r.Header.Get("Authorization")) + assert.Equal(t, ah, r.Header.Get("Authorization")) case tc.webhook.BasicAuth.Username != "" || tc.webhook.BasicAuth.Password != "": whReq, err := http.NewRequest("", "", http.NoBody) - assert.NoError(t, err) + require.NoError(t, err) whReq.SetBasicAuth(tc.webhook.BasicAuth.Username, tc.webhook.BasicAuth.Password) ah := whReq.Header.Get("Authorization") - sassert.Equals(t, ah, whReq.Header.Get("Authorization")) + assert.Equal(t, ah, whReq.Header.Get("Authorization")) default: - sassert.Equals(t, "", r.Header.Get("Authorization")) + assert.Equal(t, "", r.Header.Get("Authorization")) } if tc.expectPath != "" { - sassert.Equals(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) + assert.Equal(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) } if tc.errStatusCode != 0 { @@ -557,7 +555,6 @@ func TestWebhook_Do(t *testing.T) { reqBody := new(webhook.RequestBody) err = json.Unmarshal(body, reqBody) require.NoError(t, err) - // sassert.Equals(t, tc.expectToken, reqBody.Token) err = json.NewEncoder(w).Encode(tc.webhookResponse) require.NoError(t, err) @@ -578,12 +575,12 @@ func TestWebhook_Do(t *testing.T) { got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, reqBody, tc.dataArg) if tc.expectErr != nil { - sassert.Equals(t, tc.expectErr.Error(), err.Error()) + assert.Equal(t, tc.expectErr.Error(), err.Error()) return } assert.NoError(t, err) - sassert.Equals(t, got, &tc.webhookResponse) + assert.Equal(t, &tc.webhookResponse, got) }) } From 041b486c556017aac05a3dc12c1b5681190ac55d Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 27 Feb 2024 14:00:09 +0100 Subject: [PATCH 13/31] Remove usages of `Sign` without context --- acme/api/revoke_test.go | 4 ---- acme/common.go | 1 - acme/order_test.go | 10 ---------- api/api.go | 1 - api/api_test.go | 8 -------- authority/authority_test.go | 3 ++- authority/authorize_test.go | 2 +- authority/provisioners_test.go | 2 +- authority/tls_test.go | 8 ++++---- scep/authority.go | 1 - 10 files changed, 8 insertions(+), 32 deletions(-) diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go index 5d274faf1..85b9a0326 100644 --- a/acme/api/revoke_test.go +++ b/acme/api/revoke_test.go @@ -281,10 +281,6 @@ type mockCA struct { MockAreSANsallowed func(ctx context.Context, sans []string) error } -func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) { - return nil, nil -} - func (m *mockCA) SignWithContext(context.Context, *x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) { return nil, nil } diff --git a/acme/common.go b/acme/common.go index 46e86ae69..e86b23e9b 100644 --- a/acme/common.go +++ b/acme/common.go @@ -21,7 +21,6 @@ var clock Clock // CertificateAuthority is the interface implemented by a CA authority. type CertificateAuthority interface { - Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) AreSANsAllowed(ctx context.Context, sans []string) error IsRevoked(sn string) (bool, error) diff --git a/acme/order_test.go b/acme/order_test.go index 17060f11e..07372af07 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -271,7 +271,6 @@ func TestOrder_UpdateStatus(t *testing.T) { } type mockSignAuth struct { - sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) signWithContext func(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) areSANsAllowed func(ctx context.Context, sans []string) error loadProvisionerByName func(string) (provisioner.Interface, error) @@ -279,15 +278,6 @@ type mockSignAuth struct { err error } -func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { - if m.sign != nil { - return m.sign(csr, signOpts, extraOpts...) - } else if m.err != nil { - return nil, m.err - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - func (m *mockSignAuth) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { if m.signWithContext != nil { return m.signWithContext(ctx, csr, signOpts, extraOpts...) diff --git a/api/api.go b/api/api.go index 1d367f7d6..a12e7e19a 100644 --- a/api/api.go +++ b/api/api.go @@ -42,7 +42,6 @@ type Authority interface { AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) - Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Renew(peer *x509.Certificate) ([]*x509.Certificate, error) RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) diff --git a/api/api_test.go b/api/api_test.go index 4266dff34..cf9885933 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -189,7 +189,6 @@ type mockAuthority struct { authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error) getTLSOptions func() *authority.TLSOptions root func(shasum string) (*x509.Certificate, error) - sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) signWithContext func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) @@ -252,13 +251,6 @@ func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) { return m.ret1.(*x509.Certificate), m.err } -func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { - if m.sign != nil { - return m.sign(cr, opts, signOpts...) - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - func (m *mockAuthority) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { if m.signWithContext != nil { return m.signWithContext(ctx, cr, opts, signOpts...) diff --git a/authority/authority_test.go b/authority/authority_test.go index 45c7cd861..3787dab7c 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -1,6 +1,7 @@ package authority import ( + "context" "crypto" "crypto/rand" "crypto/sha256" @@ -414,7 +415,7 @@ func TestNewEmbedded_Sign(t *testing.T) { csr, err := x509.ParseCertificateRequest(cr) assert.FatalError(t, err) - cert, err := a.Sign(csr, provisioner.SignOptions{}) + cert, err := a.SignWithContext(context.Background(), csr, provisioner.SignOptions{}) assert.FatalError(t, err) assert.Equals(t, []string{"foo.bar.zar"}, cert[0].DNSNames) assert.Equals(t, crt, cert[1]) diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 3d748f69a..8f3c1ae28 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -1375,7 +1375,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { } generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) { - chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...) + chain, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...) if err != nil { t.Fatal(err) } diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index f6af6f548..f62f81273 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -149,7 +149,7 @@ func TestAuthority_LoadProvisionerByCertificate(t *testing.T) { opts, err := a.Authorize(ctx, token) require.NoError(t, err) opts = append(opts, extraOpts...) - certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...) + certs, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...) require.NoError(t, err) return certs[0] } diff --git a/authority/tls_test.go b/authority/tls_test.go index 1fb8411a5..b481ca68c 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -239,7 +239,7 @@ func (e *testEnforcer) Enforce(cert *x509.Certificate) error { return nil } -func TestAuthority_Sign(t *testing.T) { +func TestAuthority_SignWithContext(t *testing.T) { pub, priv, err := keyutil.GenerateDefaultKeyPair() require.NoError(t, err) @@ -848,7 +848,7 @@ ZYtQ9Ot36qc= t.Run(name, func(t *testing.T) { tc := genTestCase(t) - certChain, err := tc.auth.Sign(tc.csr, tc.signOpts, tc.extraOpts...) + certChain, err := tc.auth.SignWithContext(context.Background(), tc.csr, tc.signOpts, tc.extraOpts...) if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) @@ -1797,9 +1797,9 @@ func TestAuthority_constraints(t *testing.T) { t.Fatal(err) } - _, err = auth.Sign(csr, provisioner.SignOptions{}, templateOption) + _, err = auth.SignWithContext(context.Background(), csr, provisioner.SignOptions{}, templateOption) if (err != nil) != tt.wantErr { - t.Errorf("Authority.Sign() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("Authority.SignWithContext() error = %v, wantErr %v", err, tt.wantErr) } _, err = auth.Renew(cert) diff --git a/scep/authority.go b/scep/authority.go index e2aa759eb..8ed065fbb 100644 --- a/scep/authority.go +++ b/scep/authority.go @@ -60,7 +60,6 @@ func MustFromContext(ctx context.Context) *Authority { // SignAuthority is the interface for a signing authority type SignAuthority interface { - Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) LoadProvisionerByName(string) (provisioner.Interface, error) } From 4213a190d5204176132e2f27e7df235639d4adbf Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 27 Feb 2024 16:17:09 +0100 Subject: [PATCH 14/31] Use `X-Request-Id` as canonical request identifier (if available) If `X-Request-Id` is available in an HTTP request made against the CA server, it'll be used as the identifier for the request. This slightly changes the existing behavior, which relied on the custom `X-Smallstep-Id` header, but usage of that header is currently not very widespread, and `X-Request-Id` is more generally known for the use case `X-Smallstep-Id` is used for. `X-Smallstep-Id` is currently still considered, but it'll only be used if `X-Request-Id` is not set. --- logging/context.go | 23 +++++++--- logging/context_test.go | 94 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 5 deletions(-) create mode 100644 logging/context_test.go diff --git a/logging/context.go b/logging/context.go index b24b3638e..ab8464d0c 100644 --- a/logging/context.go +++ b/logging/context.go @@ -21,14 +21,27 @@ func NewRequestID() string { return xid.New().String() } -// RequestID returns a new middleware that gets the given header and sets it -// in the context so it can be written in the logger. If the header does not -// exists or it's the empty string, it uses github.com/rs/xid to create a new -// one. +// defaultRequestIDHeader is the header name used for propagating +// request IDs. If available in an HTTP request, it'll be used instead +// of the X-Smallstep-Id header. +const defaultRequestIDHeader = "X-Request-Id" + +// RequestID returns a new middleware that obtains the current request ID +// and sets it in the context. It first tries to read the request ID from +// the "X-Request-Id" header. If that's not set, it tries to read it from +// the provided header name. If the header does not exist or its value is +// the empty string, it uses github.com/rs/xid to create a new one. func RequestID(headerName string) func(next http.Handler) http.Handler { + if headerName == "" { + headerName = defaultTraceIDHeader + } return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, req *http.Request) { - requestID := req.Header.Get(headerName) + requestID := req.Header.Get(defaultRequestIDHeader) + if requestID == "" { + requestID = req.Header.Get(headerName) + } + if requestID == "" { requestID = NewRequestID() req.Header.Set(headerName, requestID) diff --git a/logging/context_test.go b/logging/context_test.go new file mode 100644 index 000000000..c519539da --- /dev/null +++ b/logging/context_test.go @@ -0,0 +1,94 @@ +package logging + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newRequest(t *testing.T) *http.Request { + r, err := http.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + require.NoError(t, err) + return r +} + +func TestRequestID(t *testing.T) { + requestWithID := newRequest(t) + requestWithID.Header.Set("X-Request-Id", "reqID") + requestWithoutID := newRequest(t) + requestWithEmptyHeader := newRequest(t) + requestWithEmptyHeader.Header.Set("X-Request-Id", "") + requestWithSmallstepID := newRequest(t) + requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID") + + tests := []struct { + name string + headerName string + handler http.HandlerFunc + req *http.Request + }{ + { + name: "default-request-id", + headerName: defaultTraceIDHeader, + handler: func(_ http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Smallstep-Id")) + assert.Equal(t, "reqID", r.Header.Get("X-Request-Id")) + reqID, ok := GetRequestID(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, "reqID", reqID) + } + }, + req: requestWithID, + }, + { + name: "no-request-id", + headerName: "X-Request-Id", + handler: func(_ http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Smallstep-Id")) + value := r.Header.Get("X-Request-Id") + assert.NotEmpty(t, value) + reqID, ok := GetRequestID(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, value, reqID) + } + }, + req: requestWithoutID, + }, + { + name: "empty-header-name", + headerName: "", + handler: func(_ http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Request-Id")) + value := r.Header.Get("X-Smallstep-Id") + assert.NotEmpty(t, value) + reqID, ok := GetRequestID(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, value, reqID) + } + }, + req: requestWithEmptyHeader, + }, + { + name: "fallback-header-name", + headerName: defaultTraceIDHeader, + handler: func(_ http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Request-Id")) + assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id")) + reqID, ok := GetRequestID(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, "smallstepID", reqID) + } + }, + req: requestWithSmallstepID, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := RequestID(tt.headerName) + h(tt.handler).ServeHTTP(httptest.NewRecorder(), tt.req) + }) + } +} From c1c2e73475f4333267aa855d758800dae0255278 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 27 Feb 2024 17:04:21 +0100 Subject: [PATCH 15/31] Add `X-Request-Id` to all requests made by our CA clients --- ca/acmeClient.go | 3 +++ ca/client.go | 20 ++++++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/ca/acmeClient.go b/ca/acmeClient.go index bb3b1d84a..3ef2f1910 100644 --- a/ca/acmeClient.go +++ b/ca/acmeClient.go @@ -48,6 +48,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC return nil, errors.Wrapf(err, "creating GET request %s failed", endpoint) } req.Header.Set("User-Agent", UserAgent) + enforceRequestID(req) resp, err := ac.client.Do(req) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", endpoint) @@ -109,6 +110,7 @@ func (c *ACMEClient) GetNonce() (string, error) { return "", errors.Wrapf(err, "creating GET request %s failed", c.dir.NewNonce) } req.Header.Set("User-Agent", UserAgent) + enforceRequestID(req) resp, err := c.client.Do(req) if err != nil { return "", errors.Wrapf(err, "client GET %s failed", c.dir.NewNonce) @@ -188,6 +190,7 @@ func (c *ACMEClient) post(payload []byte, url string, headerOps ...withHeaderOpt } req.Header.Set("Content-Type", "application/jose+json") req.Header.Set("User-Agent", UserAgent) + enforceRequestID(req) resp, err := c.client.Do(req) if err != nil { return nil, errors.Wrapf(err, "client POST %s failed", c.dir.NewOrder) diff --git a/ca/client.go b/ca/client.go index ac13e1fea..5e2d98c8d 100644 --- a/ca/client.go +++ b/ca/client.go @@ -24,6 +24,7 @@ import ( "strings" "github.com/pkg/errors" + "github.com/rs/xid" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" @@ -83,8 +84,7 @@ func (c *uaClient) GetWithContext(ctx context.Context, u string) (*http.Response if err != nil { return nil, errors.Wrapf(err, "create GET %s request failed", u) } - req.Header.Set("User-Agent", UserAgent) - return c.Client.Do(req) + return c.Do(req) } func (c *uaClient) Post(u, contentType string, body io.Reader) (*http.Response, error) { @@ -97,12 +97,24 @@ func (c *uaClient) PostWithContext(ctx context.Context, u, contentType string, b return nil, errors.Wrapf(err, "create POST %s request failed", u) } req.Header.Set("Content-Type", contentType) - req.Header.Set("User-Agent", UserAgent) - return c.Client.Do(req) + return c.Do(req) +} + +// requestIDHeader is the header name used for propagating request IDs from +// the CA client to the CA and back again. +const requestIDHeader = "X-Request-Id" + +// enforceRequestID checks if the X-Request-Id HTTP header is filled. If it's +// empty, it'll generate a new request ID and set the header. +func enforceRequestID(r *http.Request) { + if r.Header.Get(requestIDHeader) == "" { + r.Header.Set(requestIDHeader, xid.New().String()) + } } func (c *uaClient) Do(req *http.Request) (*http.Response, error) { req.Header.Set("User-Agent", UserAgent) + enforceRequestID(req) return c.Client.Do(req) } From a58f5956e31255b8784be5da1484f3afc634a32c Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 27 Feb 2024 20:48:56 +0100 Subject: [PATCH 16/31] Add reflection of request ID in `X-Request-Id` response header --- logging/context.go | 14 +++++++++----- logging/context_test.go | 12 ++++++++---- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/logging/context.go b/logging/context.go index ab8464d0c..9d7a70713 100644 --- a/logging/context.go +++ b/logging/context.go @@ -21,10 +21,10 @@ func NewRequestID() string { return xid.New().String() } -// defaultRequestIDHeader is the header name used for propagating -// request IDs. If available in an HTTP request, it'll be used instead -// of the X-Smallstep-Id header. -const defaultRequestIDHeader = "X-Request-Id" +// requestIDHeader is the header name used for propagating request IDs. If +// available in an HTTP request, it'll be used instead of the X-Smallstep-Id +// header. It'll always be used in response and set to the request ID. +const requestIDHeader = "X-Request-Id" // RequestID returns a new middleware that obtains the current request ID // and sets it in the context. It first tries to read the request ID from @@ -37,7 +37,7 @@ func RequestID(headerName string) func(next http.Handler) http.Handler { } return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, req *http.Request) { - requestID := req.Header.Get(defaultRequestIDHeader) + requestID := req.Header.Get(requestIDHeader) if requestID == "" { requestID = req.Header.Get(headerName) } @@ -47,6 +47,10 @@ func RequestID(headerName string) func(next http.Handler) http.Handler { req.Header.Set(headerName, requestID) } + // immediately set the request ID to be reflected in the response + w.Header().Set(requestIDHeader, requestID) + + // continue down the handler chain ctx := WithRequestID(req.Context(), requestID) next.ServeHTTP(w, req.WithContext(ctx)) } diff --git a/logging/context_test.go b/logging/context_test.go index c519539da..da993f7bd 100644 --- a/logging/context_test.go +++ b/logging/context_test.go @@ -33,20 +33,21 @@ func TestRequestID(t *testing.T) { { name: "default-request-id", headerName: defaultTraceIDHeader, - handler: func(_ http.ResponseWriter, r *http.Request) { + handler: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Smallstep-Id")) assert.Equal(t, "reqID", r.Header.Get("X-Request-Id")) reqID, ok := GetRequestID(r.Context()) if assert.True(t, ok) { assert.Equal(t, "reqID", reqID) } + assert.Equal(t, "reqID", w.Header().Get("X-Request-Id")) }, req: requestWithID, }, { name: "no-request-id", headerName: "X-Request-Id", - handler: func(_ http.ResponseWriter, r *http.Request) { + handler: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Smallstep-Id")) value := r.Header.Get("X-Request-Id") assert.NotEmpty(t, value) @@ -54,13 +55,14 @@ func TestRequestID(t *testing.T) { if assert.True(t, ok) { assert.Equal(t, value, reqID) } + assert.Equal(t, value, w.Header().Get("X-Request-Id")) }, req: requestWithoutID, }, { name: "empty-header-name", headerName: "", - handler: func(_ http.ResponseWriter, r *http.Request) { + handler: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Request-Id")) value := r.Header.Get("X-Smallstep-Id") assert.NotEmpty(t, value) @@ -68,19 +70,21 @@ func TestRequestID(t *testing.T) { if assert.True(t, ok) { assert.Equal(t, value, reqID) } + assert.Equal(t, value, w.Header().Get("X-Request-Id")) }, req: requestWithEmptyHeader, }, { name: "fallback-header-name", headerName: defaultTraceIDHeader, - handler: func(_ http.ResponseWriter, r *http.Request) { + handler: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Request-Id")) assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id")) reqID, ok := GetRequestID(r.Context()) if assert.True(t, ok) { assert.Equal(t, "smallstepID", reqID) } + assert.Equal(t, "smallstepID", w.Header().Get("X-Request-Id")) }, req: requestWithSmallstepID, }, From fb4cd6fe81205f72726030ce52d50ee429243bc3 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Tue, 27 Feb 2024 22:43:45 +0200 Subject: [PATCH 17/31] fix: Webhook-related instruments * fix: also instrument webhooks that do not reach the wire * fix: register the webhook instrumentation --- authority/ssh.go | 6 ++---- authority/tls.go | 6 ++---- internal/metrix/meter.go | 4 ++++ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/authority/ssh.go b/authority/ssh.go index 26e8eebc0..55f4f4a21 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -675,14 +675,13 @@ func (a *Authority) callEnrichingWebhooksSSH(ctx context.Context, prov provision if webhookCtl == nil { return } + defer func() { a.meter.SSHWebhookEnriched(prov, err) }() var whEnrichReq *webhook.RequestBody if whEnrichReq, err = webhook.NewRequestBody( webhook.WithSSHCertificateRequest(cr), ); err == nil { err = webhookCtl.Enrich(ctx, whEnrichReq) - - a.meter.SSHWebhookEnriched(prov, err) } return @@ -692,14 +691,13 @@ func (a *Authority) callAuthorizingWebhooksSSH(ctx context.Context, prov provisi if webhookCtl == nil { return } + defer func() { a.meter.SSHWebhookAuthorized(prov, err) }() var whAuthBody *webhook.RequestBody if whAuthBody, err = webhook.NewRequestBody( webhook.WithSSHCertificate(cert, certTpl), ); err == nil { err = webhookCtl.Authorize(ctx, whAuthBody) - - a.meter.SSHWebhookAuthorized(prov, err) } return diff --git a/authority/tls.go b/authority/tls.go index 082513c89..1f3f51308 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -999,6 +999,7 @@ func (a *Authority) callEnrichingWebhooksX509(ctx context.Context, prov provisio if webhookCtl == nil { return } + defer func() { a.meter.X509WebhookEnriched(prov, err) }() var attested *webhook.AttestationData if attData != nil { @@ -1013,8 +1014,6 @@ func (a *Authority) callEnrichingWebhooksX509(ctx context.Context, prov provisio webhook.WithAttestationData(attested), ); err == nil { err = webhookCtl.Enrich(ctx, whEnrichReq) - - a.meter.X509WebhookEnriched(prov, err) } return @@ -1024,6 +1023,7 @@ func (a *Authority) callAuthorizingWebhooksX509(ctx context.Context, prov provis if webhookCtl == nil { return } + defer func() { a.meter.X509WebhookAuthorized(prov, err) }() var attested *webhook.AttestationData if attData != nil { @@ -1038,8 +1038,6 @@ func (a *Authority) callAuthorizingWebhooksX509(ctx context.Context, prov provis webhook.WithAttestationData(attested), ); err == nil { err = webhookCtl.Authorize(ctx, whAuthBody) - - a.meter.X509WebhookAuthorized(prov, err) } return diff --git a/internal/metrix/meter.go b/internal/metrix/meter.go index a867b197b..334cf883f 100644 --- a/internal/metrix/meter.go +++ b/internal/metrix/meter.go @@ -42,9 +42,13 @@ func New() (m *Meter) { m.ssh.rekeyed, m.ssh.renewed, m.ssh.signed, + m.ssh.webhookAuthorized, + m.ssh.webhookEnriched, m.x509.rekeyed, m.x509.renewed, m.x509.signed, + m.x509.webhookAuthorized, + m.x509.webhookEnriched, m.kms.signed, m.kms.errors, ) From cf8a50157f7a662501a4f6b4728ad771af6052bb Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 28 Feb 2024 01:05:38 +0100 Subject: [PATCH 18/31] Add a basic e2e test for `X-Request-Id` reflection --- api/api_test.go | 8 +-- ca/ca_test.go | 25 +++++++-- ca/client.go | 51 ++++++++++--------- errs/error.go | 9 ++-- test/e2e/requestid_test.go | 102 +++++++++++++++++++++++++++++++++++++ 5 files changed, 155 insertions(+), 40 deletions(-) create mode 100644 test/e2e/requestid_test.go diff --git a/api/api_test.go b/api/api_test.go index cf9885933..8090c6d43 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -884,16 +884,12 @@ func Test_Sign(t *testing.T) { CsrPEM: CertificateRequest{csr}, OTT: "foobarzar", }) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) invalid, err := json.Marshal(SignRequest{ CsrPEM: CertificateRequest{csr}, OTT: "", }) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) diff --git a/ca/ca_test.go b/ca/ca_test.go index 7ad25cc6d..a8c173c4d 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -289,6 +289,9 @@ ZEp7knvU2psWRw== if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var sign api.SignResponse assert.FatalError(t, readJSON(body, &sign)) @@ -325,7 +328,7 @@ ZEp7knvU2psWRw== assert.FatalError(t, err) assert.Equals(t, intermediate, realIntermediate) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } @@ -369,6 +372,9 @@ func TestCAProvisioners(t *testing.T) { if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var resp api.ProvisionersResponse @@ -379,7 +385,7 @@ func TestCAProvisioners(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, a, b) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } @@ -436,12 +442,15 @@ func TestCAProvisionerEncryptedKey(t *testing.T) { if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var ek api.ProvisionerKeyResponse assert.FatalError(t, readJSON(body, &ek)) assert.Equals(t, ek.Key, tc.expectedKey) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } @@ -498,12 +507,15 @@ func TestCARoot(t *testing.T) { if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var root api.RootResponse assert.FatalError(t, readJSON(body, &root)) assert.Equals(t, root.RootPEM.Certificate, rootCrt) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } @@ -641,6 +653,9 @@ func TestCARenew(t *testing.T) { if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var sign api.SignResponse assert.FatalError(t, readJSON(body, &sign)) @@ -673,7 +688,7 @@ func TestCARenew(t *testing.T) { assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } diff --git a/ca/client.go b/ca/client.go index 5e2d98c8d..8930d8ee1 100644 --- a/ca/client.go +++ b/ca/client.go @@ -622,7 +622,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var version api.VersionResponse if err := readJSON(resp.Body, &version); err != nil { @@ -652,7 +652,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var health api.HealthResponse if err := readJSON(resp.Body, &health); err != nil { @@ -687,7 +687,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var root api.RootResponse if err := readJSON(resp.Body, &root); err != nil { @@ -726,7 +726,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -765,7 +765,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -802,7 +802,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -842,7 +842,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -883,7 +883,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var revoke api.RevokeResponse if err := readJSON(resp.Body, &revoke); err != nil { @@ -926,7 +926,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var provisioners api.ProvisionersResponse if err := readJSON(resp.Body, &provisioners); err != nil { @@ -958,7 +958,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var key api.ProvisionerKeyResponse if err := readJSON(resp.Body, &key); err != nil { @@ -988,7 +988,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var roots api.RootsResponse if err := readJSON(resp.Body, &roots); err != nil { @@ -1018,7 +1018,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var federation api.FederationResponse if err := readJSON(resp.Body, &federation); err != nil { @@ -1052,7 +1052,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SSHSignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -1086,7 +1086,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var renew api.SSHRenewResponse if err := readJSON(resp.Body, &renew); err != nil { @@ -1120,7 +1120,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var rekey api.SSHRekeyResponse if err := readJSON(resp.Body, &rekey); err != nil { @@ -1154,7 +1154,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var revoke api.SSHRevokeResponse if err := readJSON(resp.Body, &revoke); err != nil { @@ -1184,7 +1184,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var keys api.SSHRootsResponse if err := readJSON(resp.Body, &keys); err != nil { @@ -1214,7 +1214,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var keys api.SSHRootsResponse if err := readJSON(resp.Body, &keys); err != nil { @@ -1248,7 +1248,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var cfg api.SSHConfigResponse if err := readJSON(resp.Body, &cfg); err != nil { @@ -1287,7 +1287,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var check api.SSHCheckPrincipalResponse if err := readJSON(resp.Body, &check); err != nil { @@ -1316,7 +1316,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var hosts api.SSHGetHostsResponse if err := readJSON(resp.Body, &hosts); err != nil { @@ -1348,7 +1348,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var bastion api.SSHBastionResponse if err := readJSON(resp.Body, &bastion); err != nil { @@ -1516,12 +1516,13 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error { return protojson.Unmarshal(data, m) } -func readError(r io.ReadCloser) error { - defer r.Close() +func readError(r *http.Response) error { + defer r.Body.Close() apiErr := new(errs.Error) - if err := json.NewDecoder(r).Decode(apiErr); err != nil { + if err := json.NewDecoder(r.Body).Decode(apiErr); err != nil { return err } + apiErr.RequestID = r.Header.Get("X-Request-Id") return apiErr } diff --git a/errs/error.go b/errs/error.go index ba0669256..c9ad92a69 100644 --- a/errs/error.go +++ b/errs/error.go @@ -49,10 +49,11 @@ func WithKeyVal(key string, val interface{}) Option { // Error represents the CA API errors. type Error struct { - Status int - Err error - Msg string - Details map[string]interface{} + Status int + Err error + Msg string + Details map[string]interface{} + RequestID string `json:"-"` } // ErrorResponse represents an error in JSON format. diff --git a/test/e2e/requestid_test.go b/test/e2e/requestid_test.go new file mode 100644 index 000000000..7eccb4f40 --- /dev/null +++ b/test/e2e/requestid_test.go @@ -0,0 +1,102 @@ +package e2e + +import ( + "context" + "encoding/json" + "fmt" + "net" + "path/filepath" + "sync" + "testing" + + "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/ca" + "github.com/smallstep/certificates/errs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/minica" + "go.step.sm/crypto/pemutil" +) + +func TestXxx(t *testing.T) { + dir := t.TempDir() + m, err := minica.New(minica.WithName("Step E2E")) + require.NoError(t, err) + + rootFilepath := filepath.Join(dir, "root.crt") + _, err = pemutil.Serialize(m.Root, pemutil.WithFilename(rootFilepath)) + require.NoError(t, err) + + intermediateCertFilepath := filepath.Join(dir, "intermediate.crt") + _, err = pemutil.Serialize(m.Intermediate, pemutil.WithFilename(intermediateCertFilepath)) + require.NoError(t, err) + + intermediateKeyFilepath := filepath.Join(dir, "intermediate.key") + _, err = pemutil.Serialize(m.Signer, pemutil.WithFilename(intermediateKeyFilepath)) + require.NoError(t, err) + + // get a random address to listen on and connect to; currently no nicer way to get one before starting the server + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + randomAddress := l.Addr().String() + err = l.Close() + require.NoError(t, err) + + cfg := &config.Config{ + Root: []string{rootFilepath}, + IntermediateCert: intermediateCertFilepath, + IntermediateKey: intermediateKeyFilepath, + Address: randomAddress, // reuse the address that was just "reserved" + DNSNames: []string{"127.0.0.1", "stepca.localhost"}, + AuthorityConfig: &config.AuthConfig{ + AuthorityID: "stepca-test", + DeploymentType: "standalone-test", + }, + Logger: json.RawMessage(`{"format": "text"}`), + } + c, err := ca.New(cfg) + require.NoError(t, err) + + // instantiate a client for the CA + client, err := ca.NewClient( + fmt.Sprintf("https://%s", randomAddress), + ca.WithRootFile(rootFilepath), + ) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + err = c.Run() + require.Error(t, err) // expect error when server is stopped + }() + + // require OK health response as the baseline + ctx := context.Background() + healthResponse, err := client.HealthWithContext(ctx) + assert.NoError(t, err) + require.Equal(t, "ok", healthResponse.Status) + + // expect an error when retrieving an invalid root + rootResponse, err := client.RootWithContext(ctx, "invalid") + if assert.Error(t, err) { + apiErr := &errs.Error{} + if assert.ErrorAs(t, err, &apiErr) { + assert.Equal(t, 404, apiErr.StatusCode()) + assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", apiErr.Err.Error()) + assert.NotEmpty(t, apiErr.RequestID) + + // TODO: include the below error in the JSON? It's currently only output to the CA logs + //assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg) + } + } + assert.Nil(t, rootResponse) + + // done testing; stop and wait for the server to quit + err = c.Stop() + require.NoError(t, err) + + wg.Wait() +} From 5c2572c44397bbaf77a4e744e22a43aeb3dc30cf Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 28 Feb 2024 01:55:35 +0100 Subject: [PATCH 19/31] Add support for user provider `X-Request-Id` header value --- ca/client.go | 16 +++++++++++++--- ca/client/requestid.go | 17 +++++++++++++++++ test/e2e/requestid_test.go | 26 +++++++++++++++++++++----- 3 files changed, 51 insertions(+), 8 deletions(-) create mode 100644 ca/client/requestid.go diff --git a/ca/client.go b/ca/client.go index 8930d8ee1..d7ec28752 100644 --- a/ca/client.go +++ b/ca/client.go @@ -28,6 +28,7 @@ import ( "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/ca/client" "github.com/smallstep/certificates/ca/identity" "github.com/smallstep/certificates/errs" "go.step.sm/cli-utils/step" @@ -105,10 +106,19 @@ func (c *uaClient) PostWithContext(ctx context.Context, u, contentType string, b const requestIDHeader = "X-Request-Id" // enforceRequestID checks if the X-Request-Id HTTP header is filled. If it's -// empty, it'll generate a new request ID and set the header. +// empty, the context is searched for a request ID. If that's also empty, a new +// request ID is generated. func enforceRequestID(r *http.Request) { - if r.Header.Get(requestIDHeader) == "" { - r.Header.Set(requestIDHeader, xid.New().String()) + requestID := r.Header.Get(requestIDHeader) + if requestID == "" { + if reqID, ok := client.GetRequestID(r.Context()); ok && reqID != "" { + // TODO(hs): ensure the request ID from the context is fresh, and thus hasn't been + // used before by the client (unless it's a retry for the same request)? + requestID = reqID + } else { + requestID = xid.New().String() + } + r.Header.Set(requestIDHeader, requestID) } } diff --git a/ca/client/requestid.go b/ca/client/requestid.go new file mode 100644 index 000000000..de92f8c0f --- /dev/null +++ b/ca/client/requestid.go @@ -0,0 +1,17 @@ +package client + +import "context" + +type requestIDKey struct{} + +// WithRequestID returns a new context with the given requestID added to the +// context. +func WithRequestID(ctx context.Context, requestID string) context.Context { + return context.WithValue(ctx, requestIDKey{}, requestID) +} + +// GetRequestID returns the request id from the context if it exists. +func GetRequestID(ctx context.Context) (string, bool) { + v, ok := ctx.Value(requestIDKey{}).(string) + return v, ok +} diff --git a/test/e2e/requestid_test.go b/test/e2e/requestid_test.go index 7eccb4f40..a1afd4234 100644 --- a/test/e2e/requestid_test.go +++ b/test/e2e/requestid_test.go @@ -11,6 +11,7 @@ import ( "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/ca" + "github.com/smallstep/certificates/ca/client" "github.com/smallstep/certificates/errs" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -57,8 +58,8 @@ func TestXxx(t *testing.T) { c, err := ca.New(cfg) require.NoError(t, err) - // instantiate a client for the CA - client, err := ca.NewClient( + // instantiate a client for the CA running at the random address + caClient, err := ca.NewClient( fmt.Sprintf("https://%s", randomAddress), ca.WithRootFile(rootFilepath), ) @@ -75,12 +76,12 @@ func TestXxx(t *testing.T) { // require OK health response as the baseline ctx := context.Background() - healthResponse, err := client.HealthWithContext(ctx) + healthResponse, err := caClient.HealthWithContext(ctx) assert.NoError(t, err) - require.Equal(t, "ok", healthResponse.Status) + assert.Equal(t, "ok", healthResponse.Status) // expect an error when retrieving an invalid root - rootResponse, err := client.RootWithContext(ctx, "invalid") + rootResponse, err := caClient.RootWithContext(ctx, "invalid") if assert.Error(t, err) { apiErr := &errs.Error{} if assert.ErrorAs(t, err, &apiErr) { @@ -94,6 +95,21 @@ func TestXxx(t *testing.T) { } assert.Nil(t, rootResponse) + // expect an error when retrieving an invalid root and provided request ID + rootResponse, err = caClient.RootWithContext(client.WithRequestID(ctx, "reqID"), "invalid") + if assert.Error(t, err) { + apiErr := &errs.Error{} + if assert.ErrorAs(t, err, &apiErr) { + assert.Equal(t, 404, apiErr.StatusCode()) + assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", apiErr.Err.Error()) + assert.Equal(t, "reqID", apiErr.RequestID) + + // TODO: include the below error in the JSON? It's currently only output to the CA logs + //assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg) + } + } + assert.Nil(t, rootResponse) + // done testing; stop and wait for the server to quit err = c.Stop() require.NoError(t, err) From 2255857b3a59a6e9bb8a665e9543770889451e41 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 28 Feb 2024 10:50:49 +0100 Subject: [PATCH 20/31] Fix `client` shadowing and e2e request ID test case --- ca/client.go | 20 ++++++++++---------- test/e2e/requestid_test.go | 10 ++++++---- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/ca/client.go b/ca/client.go index d7ec28752..0c0f9907e 100644 --- a/ca/client.go +++ b/ca/client.go @@ -397,8 +397,8 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) { if err != nil { return nil, err } - client := &Client{endpoint: u} - root, err := client.Root(sum) + caClient := &Client{endpoint: u} + root, err := caClient.Root(sum) if err != nil { return nil, err } @@ -759,14 +759,14 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) { func (c *Client) RenewWithContext(ctx context.Context, tr http.RoundTripper) (*api.SignResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) - client := &http.Client{Transport: tr} + caClient := &http.Client{Transport: tr} retry: req, err := http.NewRequestWithContext(ctx, "POST", u.String(), http.NoBody) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") - resp, err := client.Do(req) + resp, err := caClient.Do(req) if err != nil { return nil, clientError(err) } @@ -836,14 +836,14 @@ func (c *Client) RekeyWithContext(ctx context.Context, req *api.RekeyRequest, tr return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"}) - client := &http.Client{Transport: tr} + caClient := &http.Client{Transport: tr} retry: httpReq, err := http.NewRequestWithContext(ctx, "POST", u.String(), bytes.NewReader(body)) if err != nil { return nil, err } httpReq.Header.Set("Content-Type", "application/json") - resp, err := client.Do(httpReq) + resp, err := caClient.Do(httpReq) if err != nil { return nil, clientError(err) } @@ -875,16 +875,16 @@ func (c *Client) RevokeWithContext(ctx context.Context, req *api.RevokeRequest, if err != nil { return nil, errors.Wrap(err, "error marshaling request") } - var client *uaClient + var uaClient *uaClient retry: if tr != nil { - client = newClient(tr) + uaClient = newClient(tr) } else { - client = c.client + uaClient = c.client } u := c.endpoint.ResolveReference(&url.URL{Path: "/revoke"}) - resp, err := client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) + resp, err := uaClient.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } diff --git a/test/e2e/requestid_test.go b/test/e2e/requestid_test.go index a1afd4234..2653039c1 100644 --- a/test/e2e/requestid_test.go +++ b/test/e2e/requestid_test.go @@ -19,7 +19,7 @@ import ( "go.step.sm/crypto/pemutil" ) -func TestXxx(t *testing.T) { +func Test_reflectRequestID(t *testing.T) { dir := t.TempDir() m, err := minica.New(minica.WithName("Step E2E")) require.NoError(t, err) @@ -37,9 +37,11 @@ func TestXxx(t *testing.T) { require.NoError(t, err) // get a random address to listen on and connect to; currently no nicer way to get one before starting the server - l, err := net.Listen("tcp", "127.0.0.1:0") + l, err := net.Listen("tcp4", ":0") require.NoError(t, err) randomAddress := l.Addr().String() + _, port, err := net.SplitHostPort(l.Addr().String()) + require.NoError(t, err) err = l.Close() require.NoError(t, err) @@ -48,7 +50,7 @@ func TestXxx(t *testing.T) { IntermediateCert: intermediateCertFilepath, IntermediateKey: intermediateKeyFilepath, Address: randomAddress, // reuse the address that was just "reserved" - DNSNames: []string{"127.0.0.1", "stepca.localhost"}, + DNSNames: []string{"127.0.0.1", "[::1]", "localhost"}, AuthorityConfig: &config.AuthConfig{ AuthorityID: "stepca-test", DeploymentType: "standalone-test", @@ -60,7 +62,7 @@ func TestXxx(t *testing.T) { // instantiate a client for the CA running at the random address caClient, err := ca.NewClient( - fmt.Sprintf("https://%s", randomAddress), + fmt.Sprintf("https://localhost:%s", port), ca.WithRootFile(rootFilepath), ) require.NoError(t, err) From b83b8aa079b6c6711e27a2bfb2fd3404efefb402 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 28 Feb 2024 11:09:40 +0100 Subject: [PATCH 21/31] Make random TCP address reservation more contained --- test/e2e/requestid_test.go | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/test/e2e/requestid_test.go b/test/e2e/requestid_test.go index 2653039c1..e87b46e5e 100644 --- a/test/e2e/requestid_test.go +++ b/test/e2e/requestid_test.go @@ -19,6 +19,25 @@ import ( "go.step.sm/crypto/pemutil" ) +// reserveAddress "reserves" a TCP address by opening a listener on a random +// port and immediately closing it. The address can then be assumed to be +// available for running a server on. +func reserveAddress(t *testing.T) string { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { + require.NoError(t, err, "failed to listen on a port") + } + } + + address := l.Addr().String() + err = l.Close() + require.NoError(t, err) + + return address +} + func Test_reflectRequestID(t *testing.T) { dir := t.TempDir() m, err := minica.New(minica.WithName("Step E2E")) @@ -37,19 +56,14 @@ func Test_reflectRequestID(t *testing.T) { require.NoError(t, err) // get a random address to listen on and connect to; currently no nicer way to get one before starting the server - l, err := net.Listen("tcp4", ":0") - require.NoError(t, err) - randomAddress := l.Addr().String() - _, port, err := net.SplitHostPort(l.Addr().String()) - require.NoError(t, err) - err = l.Close() - require.NoError(t, err) + // TODO(hs): find/implement a nicer way to expose the CA URL, similar to how e.g. httptest.Server exposes it? + address := reserveAddress(t) cfg := &config.Config{ Root: []string{rootFilepath}, IntermediateCert: intermediateCertFilepath, IntermediateKey: intermediateKeyFilepath, - Address: randomAddress, // reuse the address that was just "reserved" + Address: address, // reuse the address that was just "reserved" DNSNames: []string{"127.0.0.1", "[::1]", "localhost"}, AuthorityConfig: &config.AuthConfig{ AuthorityID: "stepca-test", @@ -62,7 +76,7 @@ func Test_reflectRequestID(t *testing.T) { // instantiate a client for the CA running at the random address caClient, err := ca.NewClient( - fmt.Sprintf("https://localhost:%s", port), + fmt.Sprintf("https://%s", address), ca.WithRootFile(rootFilepath), ) require.NoError(t, err) @@ -91,7 +105,7 @@ func Test_reflectRequestID(t *testing.T) { assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", apiErr.Err.Error()) assert.NotEmpty(t, apiErr.RequestID) - // TODO: include the below error in the JSON? It's currently only output to the CA logs + // TODO: include the below error in the JSON? It's currently only output to the CA logs. Also see https://github.com/smallstep/certificates/pull/759 //assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg) } } @@ -105,9 +119,6 @@ func Test_reflectRequestID(t *testing.T) { assert.Equal(t, 404, apiErr.StatusCode()) assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", apiErr.Err.Error()) assert.Equal(t, "reqID", apiErr.RequestID) - - // TODO: include the below error in the JSON? It's currently only output to the CA logs - //assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg) } } assert.Nil(t, rootResponse) From 535e2a96d5eb099f7af88ce622d13b89c0a25878 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 28 Feb 2024 11:23:51 +0100 Subject: [PATCH 22/31] Fix the e2e request ID test (again) --- test/e2e/requestid_test.go | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/test/e2e/requestid_test.go b/test/e2e/requestid_test.go index e87b46e5e..62b2feb10 100644 --- a/test/e2e/requestid_test.go +++ b/test/e2e/requestid_test.go @@ -19,23 +19,22 @@ import ( "go.step.sm/crypto/pemutil" ) -// reserveAddress "reserves" a TCP address by opening a listener on a random -// port and immediately closing it. The address can then be assumed to be +// reservePort "reserves" a TCP port by opening a listener on a random +// port and immediately closing it. The port can then be assumed to be // available for running a server on. -func reserveAddress(t *testing.T) string { +func reservePort(t *testing.T) (host, port string) { t.Helper() - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { - require.NoError(t, err, "failed to listen on a port") - } - } + l, err := net.Listen("tcp", ":0") + require.NoError(t, err) address := l.Addr().String() err = l.Close() require.NoError(t, err) - return address + host, port, err = net.SplitHostPort(address) + require.NoError(t, err) + + return } func Test_reflectRequestID(t *testing.T) { @@ -57,13 +56,13 @@ func Test_reflectRequestID(t *testing.T) { // get a random address to listen on and connect to; currently no nicer way to get one before starting the server // TODO(hs): find/implement a nicer way to expose the CA URL, similar to how e.g. httptest.Server exposes it? - address := reserveAddress(t) + host, port := reservePort(t) cfg := &config.Config{ Root: []string{rootFilepath}, IntermediateCert: intermediateCertFilepath, IntermediateKey: intermediateKeyFilepath, - Address: address, // reuse the address that was just "reserved" + Address: net.JoinHostPort(host, port), // reuse the address that was just "reserved" DNSNames: []string{"127.0.0.1", "[::1]", "localhost"}, AuthorityConfig: &config.AuthConfig{ AuthorityID: "stepca-test", @@ -76,7 +75,7 @@ func Test_reflectRequestID(t *testing.T) { // instantiate a client for the CA running at the random address caClient, err := ca.NewClient( - fmt.Sprintf("https://%s", address), + fmt.Sprintf("https://localhost:%s", port), ca.WithRootFile(rootFilepath), ) require.NoError(t, err) @@ -93,8 +92,10 @@ func Test_reflectRequestID(t *testing.T) { // require OK health response as the baseline ctx := context.Background() healthResponse, err := caClient.HealthWithContext(ctx) - assert.NoError(t, err) - assert.Equal(t, "ok", healthResponse.Status) + require.NoError(t, err) + if assert.NotNil(t, healthResponse) { + require.Equal(t, "ok", healthResponse.Status) + } // expect an error when retrieving an invalid root rootResponse, err := caClient.RootWithContext(ctx, "invalid") From 7e5f10927feb34d446e2b28ca394d65f9bbb72d8 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 28 Feb 2024 13:18:10 +0100 Subject: [PATCH 23/31] Decouple request ID middleware from logging middleware --- authority/provisioner/webhook.go | 7 +- authority/provisioner/webhook_test.go | 8 +- ca/ca.go | 7 ++ errs/errors_test.go | 27 +++--- internal/requestid/requestid.go | 82 +++++++++++++++++++ .../requestid/requestid_test.go | 53 ++++++------ logging/context.go | 72 +--------------- logging/handler.go | 20 ++--- monitoring/monitoring.go | 3 +- 9 files changed, 155 insertions(+), 124 deletions(-) create mode 100644 internal/requestid/requestid.go rename logging/context_test.go => internal/requestid/requestid_test.go (65%) diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index c33dfa230..1e08b8b78 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -15,7 +15,7 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/webhook" "go.step.sm/linkedca" @@ -171,9 +171,8 @@ retry: return nil, err } - requestID, ok := logging.GetRequestID(ctx) - if ok { - req.Header.Set("X-Request-ID", requestID) + if requestID, ok := requestid.FromContext(ctx); ok { + req.Header.Set("X-Request-Id", requestID) } secret, err := base64.StdEncoding.DecodeString(w.Secret) diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index 60dcdbc71..4c80796f1 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -17,7 +17,7 @@ import ( "testing" "time" - "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/webhook" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -101,10 +101,10 @@ func TestWebhookController_isCertTypeOK(t *testing.T) { } } -// withRequestID is a helper that calls into [logging.WithRequestID] and returns -// a new context with the requestID added to the provided context. +// withRequestID is a helper that calls into [requestid.NewContext] and returns +// a new context with the requestID added. func withRequestID(ctx context.Context, requestID string) context.Context { - return logging.WithRequestID(ctx, requestID) + return requestid.NewContext(ctx, requestID) } func TestWebhookController_Enrich(t *testing.T) { diff --git a/ca/ca.go b/ca/ca.go index 4146466db..ab4a1a9b9 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -29,6 +29,7 @@ import ( "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/internal/metrix" + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/monitoring" "github.com/smallstep/certificates/scep" @@ -329,15 +330,21 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } // Add logger if configured + var legacyTraceHeader string if len(cfg.Logger) > 0 { logger, err := logging.New("ca", cfg.Logger) if err != nil { return nil, err } + legacyTraceHeader = logger.GetTraceHeader() handler = logger.Middleware(handler) insecureHandler = logger.Middleware(insecureHandler) } + // always use request ID middleware; traceHeader is provided for backwards compatibility (for now) + handler = requestid.New(legacyTraceHeader).Middleware(handler) + insecureHandler = requestid.New(legacyTraceHeader).Middleware(insecureHandler) + // Create context with all the necessary values. baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker) diff --git a/errs/errors_test.go b/errs/errors_test.go index 7b83c8d9c..11590d7d6 100644 --- a/errs/errors_test.go +++ b/errs/errors_test.go @@ -2,8 +2,9 @@ package errs import ( "fmt" - "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestError_MarshalJSON(t *testing.T) { @@ -27,13 +28,14 @@ func TestError_MarshalJSON(t *testing.T) { Err: tt.fields.Err, } got, err := e.MarshalJSON() - if (err != nil) != tt.wantErr { - t.Errorf("Error.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Empty(t, got) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Error.MarshalJSON() = %s, want %s", got, tt.want) - } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) }) } } @@ -54,13 +56,14 @@ func TestError_UnmarshalJSON(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := new(Error) - if err := e.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { - t.Errorf("Error.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) - } - //nolint:govet // best option - if !reflect.DeepEqual(tt.expected, e) { - t.Errorf("Error.UnmarshalJSON() wants = %+v, got %+v", tt.expected, e) + err := e.UnmarshalJSON(tt.args.data) + if tt.wantErr { + assert.Error(t, err) + return } + + assert.NoError(t, err) + assert.Equal(t, tt.expected, e) }) } } diff --git a/internal/requestid/requestid.go b/internal/requestid/requestid.go new file mode 100644 index 000000000..97f58f8ca --- /dev/null +++ b/internal/requestid/requestid.go @@ -0,0 +1,82 @@ +package requestid + +import ( + "context" + "net/http" + + "github.com/rs/xid" +) + +const ( + // requestIDHeader is the header name used for propagating request IDs. If + // available in an HTTP request, it'll be used instead of the X-Smallstep-Id + // header. It'll always be used in response and set to the request ID. + requestIDHeader = "X-Request-Id" + + // defaultTraceHeader is the default Smallstep tracing header that's currently + // in use. It is used as a fallback to retrieve a request ID from, if the + // "X-Request-Id" request header is not set. + defaultTraceHeader = "X-Smallstep-Id" +) + +type Handler struct { + legacyTraceHeader string +} + +// New creates a new request ID [handler]. It takes a trace header, +// which is used keep the legacy behavior intact, which relies on the +// X-Smallstep-Id header instead of X-Request-Id. +func New(legacyTraceHeader string) *Handler { + if legacyTraceHeader == "" { + legacyTraceHeader = defaultTraceHeader + } + + return &Handler{legacyTraceHeader: legacyTraceHeader} +} + +// Middleware wraps an [http.Handler] with request ID extraction +// from the X-Reqeust-Id header by default, or from the X-Smallstep-Id +// header if not set. If both are not set, a new request ID is generated. +// In all cases, the request ID is added to the request context, and +// set to be reflected in the response. +func (h *Handler) Middleware(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, req *http.Request) { + requestID := req.Header.Get(requestIDHeader) + if requestID == "" { + requestID = req.Header.Get(h.legacyTraceHeader) + } + + if requestID == "" { + requestID = newRequestID() + req.Header.Set(h.legacyTraceHeader, requestID) // legacy behavior + } + + // immediately set the request ID to be reflected in the response + w.Header().Set(requestIDHeader, requestID) + + // continue down the handler chain + ctx := NewContext(req.Context(), requestID) + next.ServeHTTP(w, req.WithContext(ctx)) + } + return http.HandlerFunc(fn) +} + +// newRequestID creates a new request ID using github.com/rs/xid. +func newRequestID() string { + return xid.New().String() +} + +type requestIDKey struct{} + +// NewContext returns a new context with the given request ID added to the +// context. +func NewContext(ctx context.Context, requestID string) context.Context { + return context.WithValue(ctx, requestIDKey{}, requestID) +} + +// FromContext returns the request ID from the context if it exists and +// is not the empty value. +func FromContext(ctx context.Context) (string, bool) { + v, ok := ctx.Value(requestIDKey{}).(string) + return v, ok && v != "" +} diff --git a/logging/context_test.go b/internal/requestid/requestid_test.go similarity index 65% rename from logging/context_test.go rename to internal/requestid/requestid_test.go index da993f7bd..4d0e872dd 100644 --- a/logging/context_test.go +++ b/internal/requestid/requestid_test.go @@ -1,4 +1,4 @@ -package logging +package requestid import ( "net/http" @@ -10,12 +10,13 @@ import ( ) func newRequest(t *testing.T) *http.Request { + t.Helper() r, err := http.NewRequest(http.MethodGet, "https://example.com", http.NoBody) require.NoError(t, err) return r } -func TestRequestID(t *testing.T) { +func Test_Middleware(t *testing.T) { requestWithID := newRequest(t) requestWithID.Header.Set("X-Request-Id", "reqID") requestWithoutID := newRequest(t) @@ -23,20 +24,19 @@ func TestRequestID(t *testing.T) { requestWithEmptyHeader.Header.Set("X-Request-Id", "") requestWithSmallstepID := newRequest(t) requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID") - tests := []struct { - name string - headerName string - handler http.HandlerFunc - req *http.Request + name string + traceHeader string + next http.HandlerFunc + req *http.Request }{ { - name: "default-request-id", - headerName: defaultTraceIDHeader, - handler: func(w http.ResponseWriter, r *http.Request) { + name: "default-request-id", + traceHeader: defaultTraceHeader, + next: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Smallstep-Id")) assert.Equal(t, "reqID", r.Header.Get("X-Request-Id")) - reqID, ok := GetRequestID(r.Context()) + reqID, ok := FromContext(r.Context()) if assert.True(t, ok) { assert.Equal(t, "reqID", reqID) } @@ -45,13 +45,13 @@ func TestRequestID(t *testing.T) { req: requestWithID, }, { - name: "no-request-id", - headerName: "X-Request-Id", - handler: func(w http.ResponseWriter, r *http.Request) { + name: "no-request-id", + traceHeader: "X-Request-Id", + next: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Smallstep-Id")) value := r.Header.Get("X-Request-Id") assert.NotEmpty(t, value) - reqID, ok := GetRequestID(r.Context()) + reqID, ok := FromContext(r.Context()) if assert.True(t, ok) { assert.Equal(t, value, reqID) } @@ -60,13 +60,13 @@ func TestRequestID(t *testing.T) { req: requestWithoutID, }, { - name: "empty-header-name", - headerName: "", - handler: func(w http.ResponseWriter, r *http.Request) { + name: "empty-header", + traceHeader: "", + next: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Request-Id")) value := r.Header.Get("X-Smallstep-Id") assert.NotEmpty(t, value) - reqID, ok := GetRequestID(r.Context()) + reqID, ok := FromContext(r.Context()) if assert.True(t, ok) { assert.Equal(t, value, reqID) } @@ -75,12 +75,12 @@ func TestRequestID(t *testing.T) { req: requestWithEmptyHeader, }, { - name: "fallback-header-name", - headerName: defaultTraceIDHeader, - handler: func(w http.ResponseWriter, r *http.Request) { + name: "fallback-header-name", + traceHeader: defaultTraceHeader, + next: func(w http.ResponseWriter, r *http.Request) { assert.Empty(t, r.Header.Get("X-Request-Id")) assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id")) - reqID, ok := GetRequestID(r.Context()) + reqID, ok := FromContext(r.Context()) if assert.True(t, ok) { assert.Equal(t, "smallstepID", reqID) } @@ -91,8 +91,11 @@ func TestRequestID(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := RequestID(tt.headerName) - h(tt.handler).ServeHTTP(httptest.NewRecorder(), tt.req) + handler := New(tt.traceHeader).Middleware(tt.next) + + w := httptest.NewRecorder() + handler.ServeHTTP(w, tt.req) + assert.NotEmpty(t, w.Header().Get("X-Request-Id")) }) } } diff --git a/logging/context.go b/logging/context.go index 9d7a70713..212e2560a 100644 --- a/logging/context.go +++ b/logging/context.go @@ -2,82 +2,18 @@ package logging import ( "context" - "net/http" - - "github.com/rs/xid" -) - -type key int - -const ( - // RequestIDKey is the context key that should store the request identifier. - RequestIDKey key = iota - // UserIDKey is the context key that should store the user identifier. - UserIDKey ) -// NewRequestID creates a new request id using github.com/rs/xid. -func NewRequestID() string { - return xid.New().String() -} - -// requestIDHeader is the header name used for propagating request IDs. If -// available in an HTTP request, it'll be used instead of the X-Smallstep-Id -// header. It'll always be used in response and set to the request ID. -const requestIDHeader = "X-Request-Id" - -// RequestID returns a new middleware that obtains the current request ID -// and sets it in the context. It first tries to read the request ID from -// the "X-Request-Id" header. If that's not set, it tries to read it from -// the provided header name. If the header does not exist or its value is -// the empty string, it uses github.com/rs/xid to create a new one. -func RequestID(headerName string) func(next http.Handler) http.Handler { - if headerName == "" { - headerName = defaultTraceIDHeader - } - return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, req *http.Request) { - requestID := req.Header.Get(requestIDHeader) - if requestID == "" { - requestID = req.Header.Get(headerName) - } - - if requestID == "" { - requestID = NewRequestID() - req.Header.Set(headerName, requestID) - } - - // immediately set the request ID to be reflected in the response - w.Header().Set(requestIDHeader, requestID) - - // continue down the handler chain - ctx := WithRequestID(req.Context(), requestID) - next.ServeHTTP(w, req.WithContext(ctx)) - } - return http.HandlerFunc(fn) - } -} - -// WithRequestID returns a new context with the given requestID added to the -// context. -func WithRequestID(ctx context.Context, requestID string) context.Context { - return context.WithValue(ctx, RequestIDKey, requestID) -} - -// GetRequestID returns the request id from the context if it exists. -func GetRequestID(ctx context.Context) (string, bool) { - v, ok := ctx.Value(RequestIDKey).(string) - return v, ok -} +type userIDKey struct{} // WithUserID decodes the token, extracts the user from the payload and stores // it in the context. func WithUserID(ctx context.Context, userID string) context.Context { - return context.WithValue(ctx, UserIDKey, userID) + return context.WithValue(ctx, userIDKey{}, userID) } // GetUserID returns the request id from the context if it exists. func GetUserID(ctx context.Context) (string, bool) { - v, ok := ctx.Value(UserIDKey).(string) - return v, ok + v, ok := ctx.Value(userIDKey{}).(string) + return v, ok && v != "" } diff --git a/logging/handler.go b/logging/handler.go index a8b77d603..77287690b 100644 --- a/logging/handler.go +++ b/logging/handler.go @@ -9,6 +9,7 @@ import ( "time" "github.com/sirupsen/logrus" + "github.com/smallstep/certificates/internal/requestid" ) // LoggerHandler creates a logger handler @@ -29,16 +30,15 @@ type options struct { // NewLoggerHandler returns the given http.Handler with the logger integrated. func NewLoggerHandler(name string, logger *Logger, next http.Handler) http.Handler { - h := RequestID(logger.GetTraceHeader()) onlyTraceHealthEndpoint, _ := strconv.ParseBool(os.Getenv("STEP_LOGGER_ONLY_TRACE_HEALTH_ENDPOINT")) - return h(&LoggerHandler{ + return &LoggerHandler{ name: name, logger: logger.GetImpl(), options: options{ onlyTraceHealthEndpoint: onlyTraceHealthEndpoint, }, next: next, - }) + } } // ServeHTTP implements the http.Handler and call to the handler to log with a @@ -54,14 +54,14 @@ func (l *LoggerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // writeEntry writes to the Logger writer the request information in the logger. func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Time, d time.Duration) { - var reqID, user string + var requestID, userID string ctx := r.Context() - if v, ok := ctx.Value(RequestIDKey).(string); ok && v != "" { - reqID = v + if v, ok := requestid.FromContext(ctx); ok { + requestID = v } - if v, ok := ctx.Value(UserIDKey).(string); ok && v != "" { - user = v + if v, ok := GetUserID(ctx); ok && v != "" { + userID = v } // Remote hostname @@ -85,10 +85,10 @@ func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Tim status := w.StatusCode() fields := logrus.Fields{ - "request-id": reqID, + "request-id": requestID, "remote-address": addr, "name": l.name, - "user-id": user, + "user-id": userID, "time": t.Format(time.RFC3339), "duration-ns": d.Nanoseconds(), "duration": d.String(), diff --git a/monitoring/monitoring.go b/monitoring/monitoring.go index a0d0886b3..7c88ab3b9 100644 --- a/monitoring/monitoring.go +++ b/monitoring/monitoring.go @@ -9,6 +9,7 @@ import ( "github.com/newrelic/go-agent/v3/newrelic" "github.com/pkg/errors" + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/logging" ) @@ -82,7 +83,7 @@ func newRelicMiddleware(app *newrelic.Application) Middleware { txn.AddAttribute("httpResponseCode", strconv.Itoa(status)) // Add custom attributes - if v, ok := logging.GetRequestID(r.Context()); ok { + if v, ok := requestid.FromContext(r.Context()); ok { txn.AddAttribute("request.id", v) } From 06696e64926f5247b9e1dea2a8839b398f731fbd Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 28 Feb 2024 13:37:51 +0100 Subject: [PATCH 24/31] Move user ID handling to `userid` package --- internal/userid/userid.go | 20 ++++++++++++++++++++ logging/context.go | 19 ------------------- logging/handler.go | 3 ++- 3 files changed, 22 insertions(+), 20 deletions(-) create mode 100644 internal/userid/userid.go delete mode 100644 logging/context.go diff --git a/internal/userid/userid.go b/internal/userid/userid.go new file mode 100644 index 000000000..bab4908f3 --- /dev/null +++ b/internal/userid/userid.go @@ -0,0 +1,20 @@ +package userid + +import "context" + +type userIDKey struct{} + +// NewContext returns a new context with the given user ID added to the +// context. +// TODO(hs): this doesn't seem to be used / set currently; implement +// when/where it makes sense. +func NewContext(ctx context.Context, userID string) context.Context { + return context.WithValue(ctx, userIDKey{}, userID) +} + +// FromContext returns the user ID from the context if it exists +// and is not empty. +func FromContext(ctx context.Context) (string, bool) { + v, ok := ctx.Value(userIDKey{}).(string) + return v, ok && v != "" +} diff --git a/logging/context.go b/logging/context.go deleted file mode 100644 index 212e2560a..000000000 --- a/logging/context.go +++ /dev/null @@ -1,19 +0,0 @@ -package logging - -import ( - "context" -) - -type userIDKey struct{} - -// WithUserID decodes the token, extracts the user from the payload and stores -// it in the context. -func WithUserID(ctx context.Context, userID string) context.Context { - return context.WithValue(ctx, userIDKey{}, userID) -} - -// GetUserID returns the request id from the context if it exists. -func GetUserID(ctx context.Context) (string, bool) { - v, ok := ctx.Value(userIDKey{}).(string) - return v, ok && v != "" -} diff --git a/logging/handler.go b/logging/handler.go index 77287690b..a29383b28 100644 --- a/logging/handler.go +++ b/logging/handler.go @@ -10,6 +10,7 @@ import ( "github.com/sirupsen/logrus" "github.com/smallstep/certificates/internal/requestid" + "github.com/smallstep/certificates/internal/userid" ) // LoggerHandler creates a logger handler @@ -60,7 +61,7 @@ func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Tim if v, ok := requestid.FromContext(ctx); ok { requestID = v } - if v, ok := GetUserID(ctx); ok && v != "" { + if v, ok := userid.FromContext(ctx); ok { userID = v } From 532b9df0a3cbf312ef0e54aa8b350c00309e6bab Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 28 Feb 2024 13:57:37 +0100 Subject: [PATCH 25/31] Improve CA client request ID handling --- ca/client.go | 15 +++-- ca/client/requestid.go | 11 ++-- ca/client_test.go | 120 +++++++++++++++++++++++++------------ test/e2e/requestid_test.go | 2 +- 4 files changed, 95 insertions(+), 53 deletions(-) diff --git a/ca/client.go b/ca/client.go index 0c0f9907e..9e245cd70 100644 --- a/ca/client.go +++ b/ca/client.go @@ -109,9 +109,8 @@ const requestIDHeader = "X-Request-Id" // empty, the context is searched for a request ID. If that's also empty, a new // request ID is generated. func enforceRequestID(r *http.Request) { - requestID := r.Header.Get(requestIDHeader) - if requestID == "" { - if reqID, ok := client.GetRequestID(r.Context()); ok && reqID != "" { + if requestID := r.Header.Get(requestIDHeader); requestID == "" { + if reqID, ok := client.RequestIDFromContext(r.Context()); ok { // TODO(hs): ensure the request ID from the context is fresh, and thus hasn't been // used before by the client (unless it's a retry for the same request)? requestID = reqID @@ -759,14 +758,14 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) { func (c *Client) RenewWithContext(ctx context.Context, tr http.RoundTripper) (*api.SignResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) - caClient := &http.Client{Transport: tr} + httpClient := &http.Client{Transport: tr} retry: req, err := http.NewRequestWithContext(ctx, "POST", u.String(), http.NoBody) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") - resp, err := caClient.Do(req) + resp, err := httpClient.Do(req) if err != nil { return nil, clientError(err) } @@ -836,14 +835,14 @@ func (c *Client) RekeyWithContext(ctx context.Context, req *api.RekeyRequest, tr return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"}) - caClient := &http.Client{Transport: tr} + httpClient := &http.Client{Transport: tr} retry: httpReq, err := http.NewRequestWithContext(ctx, "POST", u.String(), bytes.NewReader(body)) if err != nil { return nil, err } httpReq.Header.Set("Content-Type", "application/json") - resp, err := caClient.Do(httpReq) + resp, err := httpClient.Do(httpReq) if err != nil { return nil, clientError(err) } @@ -1530,7 +1529,7 @@ func readError(r *http.Response) error { defer r.Body.Close() apiErr := new(errs.Error) if err := json.NewDecoder(r.Body).Decode(apiErr); err != nil { - return err + return fmt.Errorf("failed decoding CA error response: %w", err) } apiErr.RequestID = r.Header.Get("X-Request-Id") return apiErr diff --git a/ca/client/requestid.go b/ca/client/requestid.go index de92f8c0f..2bebb7e53 100644 --- a/ca/client/requestid.go +++ b/ca/client/requestid.go @@ -4,14 +4,15 @@ import "context" type requestIDKey struct{} -// WithRequestID returns a new context with the given requestID added to the +// NewRequestIDContext returns a new context with the given request ID added to the // context. -func WithRequestID(ctx context.Context, requestID string) context.Context { +func NewRequestIDContext(ctx context.Context, requestID string) context.Context { return context.WithValue(ctx, requestIDKey{}, requestID) } -// GetRequestID returns the request id from the context if it exists. -func GetRequestID(ctx context.Context) (string, bool) { +// RequestIDFromContext returns the request ID from the context if it exists. +// and is not empty. +func RequestIDFromContext(ctx context.Context) (string, bool) { v, ok := ctx.Value(requestIDKey{}).(string) - return v, ok + return v, ok && v != "" } diff --git a/ca/client_test.go b/ca/client_test.go index 6292e3eae..6fe8a135d 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -17,16 +17,17 @@ import ( "testing" "time" - "go.step.sm/crypto/x509util" - "golang.org/x/crypto/ssh" - - "github.com/smallstep/assert" + sassert "github.com/smallstep/assert" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/ca/client" "github.com/smallstep/certificates/errs" + "github.com/stretchr/testify/assert" + "go.step.sm/crypto/x509util" + "golang.org/x/crypto/ssh" ) const ( @@ -196,7 +197,7 @@ func TestClient_Version(t *testing.T) { if got != nil { t.Errorf("Client.Version() = %v, want nil", got) } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) + sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Version() = %v, want %v", got, tt.response) @@ -247,7 +248,7 @@ func TestClient_Health(t *testing.T) { if got != nil { t.Errorf("Client.Health() = %v, want nil", got) } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) + sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Health() = %v, want %v", got, tt.response) @@ -304,7 +305,7 @@ func TestClient_Root(t *testing.T) { if got != nil { t.Errorf("Client.Root() = %v, want nil", got) } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) + sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Root() = %v, want %v", got, tt.response) @@ -359,7 +360,7 @@ func TestClient_Sign(t *testing.T) { body := new(api.SignRequest) if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) - assert.Fatal(t, ok, "response expected to be error type") + sassert.Fatal(t, ok, "response expected to be error type") render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { @@ -386,7 +387,7 @@ func TestClient_Sign(t *testing.T) { if got != nil { t.Errorf("Client.Sign() = %v, want nil", got) } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) + sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Sign() = %v, want %v", got, tt.response) @@ -431,7 +432,7 @@ func TestClient_Revoke(t *testing.T) { body := new(api.RevokeRequest) if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) - assert.Fatal(t, ok, "response expected to be error type") + sassert.Fatal(t, ok, "response expected to be error type") render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { @@ -458,7 +459,7 @@ func TestClient_Revoke(t *testing.T) { if got != nil { t.Errorf("Client.Revoke() = %v, want nil", got) } - assert.HasPrefix(t, err.Error(), tt.expectedErr.Error()) + sassert.HasPrefix(t, err.Error(), tt.expectedErr.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Revoke() = %v, want %v", got, tt.response) @@ -520,10 +521,10 @@ func TestClient_Renew(t *testing.T) { } var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, err.Error(), tt.err.Error()) + sassert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Renew() = %v, want %v", got, tt.response) @@ -589,10 +590,10 @@ func TestClient_RenewWithToken(t *testing.T) { } var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, err.Error(), tt.err.Error()) + sassert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response) @@ -659,10 +660,10 @@ func TestClient_Rekey(t *testing.T) { } var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, err.Error(), tt.err.Error()) + sassert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Renew() = %v, want %v", got, tt.response) @@ -722,7 +723,7 @@ func TestClient_Provisioners(t *testing.T) { if got != nil { t.Errorf("Client.Provisioners() = %v, want nil", got) } - assert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error()) + sassert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response) @@ -781,10 +782,10 @@ func TestClient_ProvisionerKey(t *testing.T) { } var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, tt.err.Error(), err.Error()) + sassert.HasPrefix(t, tt.err.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response) @@ -841,10 +842,10 @@ func TestClient_Roots(t *testing.T) { t.Errorf("Client.Roots() = %v, want nil", got) } var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, err.Error(), tt.err.Error()) + sassert.HasPrefix(t, err.Error(), tt.err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Roots() = %v, want %v", got, tt.response) @@ -900,10 +901,10 @@ func TestClient_Federation(t *testing.T) { t.Errorf("Client.Federation() = %v, want nil", got) } var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, tt.err.Error(), err.Error()) + sassert.HasPrefix(t, tt.err.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.Federation() = %v, want %v", got, tt.response) @@ -963,10 +964,10 @@ func TestClient_SSHRoots(t *testing.T) { t.Errorf("Client.SSHKeys() = %v, want nil", got) } var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, tt.err.Error(), err.Error()) + sassert.HasPrefix(t, tt.err.Error(), err.Error()) default: if !reflect.DeepEqual(got, tt.response) { t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response) @@ -1069,11 +1070,11 @@ func TestClient_RootFingerprintWithServer(t *testing.T) { defer srv.Close() client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) - assert.FatalError(t, err) + sassert.FatalError(t, err) fp, err := client.RootFingerprint() - assert.FatalError(t, err) - assert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) + sassert.FatalError(t, err) + sassert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) } func TestClient_SSHBastion(t *testing.T) { @@ -1126,10 +1127,10 @@ func TestClient_SSHBastion(t *testing.T) { } if tt.responseCode != 200 { var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { + sassert.Equals(t, sc.StatusCode(), tt.responseCode) } - assert.HasPrefix(t, err.Error(), tt.err.Error()) + sassert.HasPrefix(t, err.Error(), tt.err.Error()) } default: if !reflect.DeepEqual(got, tt.response) { @@ -1164,3 +1165,44 @@ func TestClient_GetCaURL(t *testing.T) { }) } } + +func Test_enforceRequestID(t *testing.T) { + set := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + set.Header.Set("X-Request-Id", "already-set") + inContext := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + inContext = inContext.WithContext(client.NewRequestIDContext(inContext.Context(), "from-context")) + new := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + + tests := []struct { + name string + r *http.Request + want string + }{ + { + name: "set", + r: set, + want: "already-set", + }, + { + name: "context", + r: inContext, + want: "from-context", + }, + { + name: "new", + r: new, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + enforceRequestID(tt.r) + + v := tt.r.Header.Get("X-Request-Id") + if assert.NotEmpty(t, v) { + if tt.want != "" { + assert.Equal(t, tt.want, v) + } + } + }) + } +} diff --git a/test/e2e/requestid_test.go b/test/e2e/requestid_test.go index 62b2feb10..d2f968c37 100644 --- a/test/e2e/requestid_test.go +++ b/test/e2e/requestid_test.go @@ -113,7 +113,7 @@ func Test_reflectRequestID(t *testing.T) { assert.Nil(t, rootResponse) // expect an error when retrieving an invalid root and provided request ID - rootResponse, err = caClient.RootWithContext(client.WithRequestID(ctx, "reqID"), "invalid") + rootResponse, err = caClient.RootWithContext(client.NewRequestIDContext(ctx, "reqID"), "invalid") if assert.Error(t, err) { apiErr := &errs.Error{} if assert.ErrorAs(t, err, &apiErr) { From b9d6bfc1eb5a5476530c5ce26b91aade6ed36bad Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 28 Feb 2024 14:39:38 +0100 Subject: [PATCH 26/31] Cleanup CA client tests by removing `smallstep/assert` --- ca/client_test.go | 582 +++++++++++++++-------------------------- ca/tls_options_test.go | 22 +- ca/tls_test.go | 30 +-- 3 files changed, 242 insertions(+), 392 deletions(-) diff --git a/ca/client_test.go b/ca/client_test.go index 6fe8a135d..5fd111794 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -9,15 +9,14 @@ import ( "encoding/json" "encoding/pem" "errors" - "fmt" "net/http" "net/http/httptest" "net/url" "reflect" + "strings" "testing" "time" - sassert "github.com/smallstep/assert" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" @@ -26,6 +25,7 @@ import ( "github.com/smallstep/certificates/ca/client" "github.com/smallstep/certificates/errs" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.step.sm/crypto/x509util" "golang.org/x/crypto/ssh" ) @@ -107,52 +107,49 @@ DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w== -----END CERTIFICATE REQUEST-----` ) -func mustKey() *ecdsa.PrivateKey { +func mustKey(t *testing.T) *ecdsa.PrivateKey { + t.Helper() priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - panic(err) - } + require.NoError(t, err) return priv } -func parseCertificate(data string) *x509.Certificate { +func parseCertificate(t *testing.T, data string) *x509.Certificate { + t.Helper() block, _ := pem.Decode([]byte(data)) if block == nil { - panic("failed to parse certificate PEM") + require.Fail(t, "failed to parse certificate PEM") + return nil } cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - panic("failed to parse certificate: " + err.Error()) - } + require.NoError(t, err, "failed to parse certificate") return cert } -func parseCertificateRequest(string) *x509.CertificateRequest { +func parseCertificateRequest(t *testing.T, csrPEM string) *x509.CertificateRequest { + t.Helper() block, _ := pem.Decode([]byte(csrPEM)) if block == nil { - panic("failed to parse certificate request PEM") + require.Fail(t, "failed to parse certificate request PEM") + return nil } csr, err := x509.ParseCertificateRequest(block.Bytes) - if err != nil { - panic("failed to parse certificate request: " + err.Error()) - } + require.NoError(t, err, "failed to parse certificate request") return csr } func equalJSON(t *testing.T, a, b interface{}) bool { + t.Helper() if reflect.DeepEqual(a, b) { return true } + ab, err := json.Marshal(a) - if err != nil { - t.Error(err) - return false - } + require.NoError(t, err) + bb, err := json.Marshal(b) - if err != nil { - t.Error(err) - return false - } + require.NoError(t, err) + return bytes.Equal(ab, bb) } @@ -177,32 +174,23 @@ func TestClient_Version(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Version() - if (err != nil) != tt.wantErr { - t.Errorf("Client.Version() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.EqualError(t, err, tt.expectedErr.Error()) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Version() = %v, want nil", got) - } - sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Version() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -227,40 +215,30 @@ func TestClient_Health(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Health() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Health() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.EqualError(t, err, tt.expectedErr.Error()) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Health() = %v, want nil", got) - } - sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Health() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_Root(t *testing.T) { ok := &api.RootResponse{ - RootPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + RootPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, } tests := []struct { @@ -281,10 +259,7 @@ func TestClient_Root(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { expected := "/root/" + tt.shasum @@ -295,37 +270,31 @@ func TestClient_Root(t *testing.T) { }) got, err := c.Root(tt.shasum) - if (err != nil) != tt.wantErr { - t.Errorf("Client.Root() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.EqualError(t, err, tt.expectedErr.Error()) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Root() = %v, want nil", got) - } - sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Root() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_Sign(t *testing.T) { ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } request := &api.SignRequest{ - CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)}, + CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(t, csrPEM)}, OTT: "the-ott", NotBefore: api.NewTimeDuration(time.Now()), NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)), @@ -351,16 +320,13 @@ func TestClient_Sign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { body := new(api.SignRequest) if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) - sassert.Fatal(t, ok, "response expected to be error type") + require.True(t, ok, "response expected to be error type") render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { @@ -376,23 +342,16 @@ func TestClient_Sign(t *testing.T) { }) got, err := c.Sign(tt.request) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Sign() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.EqualError(t, err, tt.expectedErr.Error()) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Sign() = %v, want nil", got) - } - sassert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Sign() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -423,16 +382,13 @@ func TestClient_Revoke(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { body := new(api.RevokeRequest) if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) - sassert.Fatal(t, ok, "response expected to be error type") + require.True(t, ok, "response expected to be error type") render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { @@ -448,34 +404,27 @@ func TestClient_Revoke(t *testing.T) { }) got, err := c.Revoke(tt.request, nil) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Revoke() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.True(t, strings.HasPrefix(err.Error(), tt.expectedErr.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Revoke() = %v, want nil", got) - } - sassert.HasPrefix(t, err.Error(), tt.expectedErr.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Revoke() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_Renew(t *testing.T) { ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } @@ -498,49 +447,38 @@ func TestClient_Renew(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Renew(nil) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Renew() = %v, want nil", got) - } - - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) - } - sassert.HasPrefix(t, err.Error(), tt.err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Renew() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_RenewWithToken(t *testing.T) { ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } @@ -563,10 +501,7 @@ func TestClient_RenewWithToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if req.Header.Get("Authorization") != "Bearer token" { @@ -577,44 +512,36 @@ func TestClient_RenewWithToken(t *testing.T) { }) got, err := c.RenewWithToken("token") - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.RenewWithToken() = %v, want nil", got) - } - - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) - } - sassert.HasPrefix(t, err.Error(), tt.err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_Rekey(t *testing.T) { ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } request := &api.RekeyRequest{ - CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)}, + CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(t, csrPEM)}, } tests := []struct { @@ -637,38 +564,27 @@ func TestClient_Rekey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Rekey(tt.request, nil) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Renew() = %v, want nil", got) - } - - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) - } - sassert.HasPrefix(t, err.Error(), tt.err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Renew() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -700,10 +616,7 @@ func TestClient_Provisioners(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if req.RequestURI != tt.expectedURI { @@ -713,22 +626,16 @@ func TestClient_Provisioners(t *testing.T) { }) got, err := c.Provisioners(tt.args...) - if (err != nil) != tt.wantErr { - t.Errorf("Client.Provisioners() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.True(t, strings.HasPrefix(err.Error(), errs.InternalServerErrorDefaultMsg)) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Provisioners() = %v, want nil", got) - } - sassert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -756,10 +663,7 @@ func TestClient_ProvisionerKey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { expected := "/provisioners/" + tt.kid + "/encrypted-key" @@ -770,27 +674,20 @@ func TestClient_ProvisionerKey(t *testing.T) { }) got, err := c.ProvisionerKey(tt.kid) - if (err != nil) != tt.wantErr { - t.Errorf("Client.ProvisionerKey() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.ProvisionerKey() = %v, want nil", got) - } - - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) - } - sassert.HasPrefix(t, tt.err.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -798,7 +695,7 @@ func TestClient_ProvisionerKey(t *testing.T) { func TestClient_Roots(t *testing.T) { ok := &api.RootsResponse{ Certificates: []api.Certificate{ - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } @@ -820,37 +717,27 @@ func TestClient_Roots(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Roots() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Roots() = %v, want nil", got) - } - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) - } - sassert.HasPrefix(t, err.Error(), tt.err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Roots() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -858,7 +745,7 @@ func TestClient_Roots(t *testing.T) { func TestClient_Federation(t *testing.T) { ok := &api.FederationResponse{ Certificates: []api.Certificate{ - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } @@ -879,46 +766,34 @@ func TestClient_Federation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Federation() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Federation() = %v, want nil", got) - } - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) - } - sassert.HasPrefix(t, tt.err.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Federation() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_SSHRoots(t *testing.T) { - key, err := ssh.NewPublicKey(mustKey().Public()) - if err != nil { - t.Fatal(err) - } + key, err := ssh.NewPublicKey(mustKey(t).Public()) + require.NoError(t, err) ok := &api.SSHRootsResponse{ HostKeys: []api.SSHPublicKey{{PublicKey: key}}, @@ -942,37 +817,27 @@ func TestClient_SSHRoots(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.SSHRoots() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.SSHKeys() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.SSHKeys() = %v, want nil", got) - } - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) - } - sassert.HasPrefix(t, tt.err.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -1004,13 +869,14 @@ func Test_parseEndpoint(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := parseEndpoint(tt.args.endpoint) - if (err != nil) != tt.wantErr { - t.Errorf("parseEndpoint() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("parseEndpoint() = %v, want %v", got, tt.want) - } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) }) } } @@ -1043,24 +909,21 @@ func TestClient_RootFingerprint(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tr := tt.server.Client().Transport c, err := NewClient(tt.server.URL, WithTransport(tr)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.RootFingerprint() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.RootFingerprint() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Empty(t, got) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Client.RootFingerprint() = %v, want %v", got, tt.want) - } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) }) } } @@ -1069,12 +932,12 @@ func TestClient_RootFingerprintWithServer(t *testing.T) { srv := startCABootstrapServer() defer srv.Close() - client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) - sassert.FatalError(t, err) + caClient, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) + require.NoError(t, err) - fp, err := client.RootFingerprint() - sassert.FatalError(t, err) - sassert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) + fp, err := caClient.RootFingerprint() + assert.NoError(t, err) + assert.Equal(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) } func TestClient_SSHBastion(t *testing.T) { @@ -1104,39 +967,29 @@ func TestClient_SSHBastion(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.SSHBastion(tt.request) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.SSHBastion() error = %v, wantErr %v", err, tt.wantErr) - return - } - - switch { - case err != nil: - if got != nil { - t.Errorf("Client.SSHBastion() = %v, want nil", got) - } - if tt.responseCode != 200 { - var sc render.StatusCodedError - if sassert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - sassert.Equals(t, sc.StatusCode(), tt.responseCode) + if tt.wantErr { + if assert.Error(t, err) { + if tt.responseCode != 200 { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) } - sassert.HasPrefix(t, err.Error(), tt.err.Error()) - } - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.SSHBastion() = %v, want %v", got, tt.response) } + assert.Nil(t, got) + return } + + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -1155,13 +1008,10 @@ func TestClient_GetCaURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } - if got := c.GetCaURL(); got != tt.want { - t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want) - } + require.NoError(t, err) + + got := c.GetCaURL() + assert.Equal(t, tt.want, got) }) } } @@ -1171,7 +1021,7 @@ func Test_enforceRequestID(t *testing.T) { set.Header.Set("X-Request-Id", "already-set") inContext := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) inContext = inContext.WithContext(client.NewRequestIDContext(inContext.Context(), "from-context")) - new := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + newRequestID := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) tests := []struct { name string @@ -1190,7 +1040,7 @@ func Test_enforceRequestID(t *testing.T) { }, { name: "new", - r: new, + r: newRequestID, }, } for _, tt := range tests { diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index 7dea3dc8f..c29947ad2 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -130,7 +130,7 @@ func TestVerifyClientCertIfGiven(t *testing.T) { //nolint:gosec // test tls config func TestAddRootCA(t *testing.T) { - cert := parseCertificate(rootPEM) + cert := parseCertificate(t, rootPEM) pool := x509.NewCertPool() pool.AddCert(cert) @@ -163,7 +163,7 @@ func TestAddRootCA(t *testing.T) { //nolint:gosec // test tls config func TestAddClientCA(t *testing.T) { - cert := parseCertificate(rootPEM) + cert := parseCertificate(t, rootPEM) pool := x509.NewCertPool() pool.AddCert(cert) @@ -214,7 +214,7 @@ func TestAddRootsToRootCAs(t *testing.T) { t.Fatal(err) } - cert := parseCertificate(string(root)) + cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() pool.AddCert(cert) @@ -269,7 +269,7 @@ func TestAddRootsToClientCAs(t *testing.T) { t.Fatal(err) } - cert := parseCertificate(string(root)) + cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() pool.AddCert(cert) @@ -329,8 +329,8 @@ func TestAddFederationToRootCAs(t *testing.T) { t.Fatal(err) } - crt1 := parseCertificate(string(root)) - crt2 := parseCertificate(string(federated)) + crt1 := parseCertificate(t, string(root)) + crt2 := parseCertificate(t, string(federated)) pool := x509.NewCertPool() pool.AddCert(crt1) pool.AddCert(crt2) @@ -394,8 +394,8 @@ func TestAddFederationToClientCAs(t *testing.T) { t.Fatal(err) } - crt1 := parseCertificate(string(root)) - crt2 := parseCertificate(string(federated)) + crt1 := parseCertificate(t, string(root)) + crt2 := parseCertificate(t, string(federated)) pool := x509.NewCertPool() pool.AddCert(crt1) pool.AddCert(crt2) @@ -454,7 +454,7 @@ func TestAddRootsToCAs(t *testing.T) { t.Fatal(err) } - cert := parseCertificate(string(root)) + cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() pool.AddCert(cert) @@ -514,8 +514,8 @@ func TestAddFederationToCAs(t *testing.T) { t.Fatal(err) } - crt1 := parseCertificate(string(root)) - crt2 := parseCertificate(string(federated)) + crt1 := parseCertificate(t, string(root)) + crt2 := parseCertificate(t, string(federated)) pool := x509.NewCertPool() pool.AddCert(crt1) pool.AddCert(crt2) diff --git a/ca/tls_test.go b/ca/tls_test.go index dbcc6023d..a19685ce3 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -401,13 +401,13 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { } func TestCertificate(t *testing.T) { - cert := parseCertificate(certPEM) + cert := parseCertificate(t, certPEM) ok := &api.SignResponse{ ServerPEM: api.Certificate{Certificate: cert}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ {Certificate: cert}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } tests := []struct { @@ -434,12 +434,12 @@ func TestCertificate(t *testing.T) { } func TestIntermediateCertificate(t *testing.T) { - intermediate := parseCertificate(rootPEM) + intermediate := parseCertificate(t, rootPEM) ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, CaPEM: api.Certificate{Certificate: intermediate}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, + {Certificate: parseCertificate(t, certPEM)}, {Certificate: intermediate}, }, } @@ -467,24 +467,24 @@ func TestIntermediateCertificate(t *testing.T) { } func TestRootCertificateCertificate(t *testing.T) { - root := parseCertificate(rootPEM) + root := parseCertificate(t, rootPEM) ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, TLS: &tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{ {root, root}, }}, } noTLS := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } tests := []struct { From cd3e91b198bcdef0081ea4ae0869b32daec1cbbc Mon Sep 17 00:00:00 2001 From: Carl Tashian Date: Wed, 28 Feb 2024 14:36:25 -0800 Subject: [PATCH 27/31] Updated README --- README.md | 72 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 42 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 4505a7efc..6303ff0fc 100644 --- a/README.md +++ b/README.md @@ -1,49 +1,62 @@ -# Step Certificates +# step-ca -`step-ca` is an online certificate authority for secure, automated certificate management. It's the server counterpart to the [`step` CLI tool](https://github.com/smallstep/cli). +[![GitHub release](https://img.shields.io/github/release/smallstep/certificates.svg)](https://github.com/smallstep/certificates/releases/latest) +[![Go Report Card](https://goreportcard.com/badge/github.com/smallstep/certificates)](https://goreportcard.com/report/github.com/smallstep/certificates) +[![Build Status](https://github.com/smallstep/certificates/actions/workflows/test.yml/badge.svg)](https://github.com/smallstep/certificates) +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) +[![CLA assistant](https://cla-assistant.io/readme/badge/smallstep/certificates)](https://cla-assistant.io/smallstep/certificates) -You can use it to: -- Issue X.509 certificates for your internal infrastructure: - - HTTPS certificates that [work in browsers](https://smallstep.com/blog/step-v0-8-6-valid-HTTPS-certificates-for-dev-pre-prod.html) ([RFC5280](https://tools.ietf.org/html/rfc5280) and [CA/Browser Forum](https://cabforum.org/baseline-requirements-documents/) compliance) - - TLS certificates for VMs, containers, APIs, mobile clients, database connections, printers, wifi networks, toaster ovens... - - Client certificates to [enable mutual TLS (mTLS)](https://smallstep.com/hello-mtls) in your infra. mTLS is an optional feature in TLS where both client and server authenticate each other. Why add the complexity of a VPN when you can safely use mTLS over the public internet? +`step-ca` is an online certificate authority for secure, automated certificate management for DevOps. +It's the server counterpart to the [`step` CLI tool](https://github.com/smallstep/cli) for working with certificates and keys. +Both projects are maintained by [Smallstep Labs](https://smallstep.com). + +You can use `step-ca` to: +- Issue HTTPS server and client certificates that [work in browsers](https://smallstep.com/blog/step-v0-8-6-valid-HTTPS-certificates-for-dev-pre-prod.html) ([RFC5280](https://tools.ietf.org/html/rfc5280) and [CA/Browser Forum](https://cabforum.org/baseline-requirements-documents/) compliance) +- Issue TLS certificates for DevOps: VMs, containers, APIs, database connections, Kubernetes pods... - Issue SSH certificates: - - For people, in exchange for single sign-on ID tokens + - For people, in exchange for single sign-on identity tokens - For hosts, in exchange for cloud instance identity documents - Easily automate certificate management: - - It's an ACME v2 server - - It has a JSON API + - It's an [ACME server](https://smallstep.com/docs/step-ca/acme-basics/) that supports all [popular ACME challenge types](https://smallstep.com/docs/step-ca/acme-basics/#acme-challenge-types) - It comes with a [Go wrapper](./examples#user-content-basic-client-usage) - ... and there's a [command-line client](https://github.com/smallstep/cli) you can use in scripts! -Whatever your use case, `step-ca` is easy to use and hard to misuse, thanks to [safe, sane defaults](https://smallstep.com/docs/step-ca/certificate-authority-server-production#sane-cryptographic-defaults). - --- -**Don't want to run your own CA?** -To get up and running quickly, or as an alternative to running your own `step-ca` server, consider creating a [free hosted smallstep Certificate Manager authority](https://info.smallstep.com/certificate-manager-early-access-mvp/). +### Comparison with Smallstep's commercial product + +`step-ca` is optimized for a two-tier PKI serving common DevOps use cases. + +As you design your PKI, if you need any of the following, [consider our commerical CA](http://smallstep.com): +- Multiple certificate authorities +- Active revocation (CRL, OSCP) +- Turnkey high-volume, high availability CA +- An API for seamless IaC management of your PKI +- Integrated support for SCEP & NDES, for migrating from legacy Active Directory Certificate Services deployments +- Device identity — cross-platform device inventory and attestation using Secure Enclave & TPM 2.0 +- Highly automated PKI — managed certificate renewal, monitoring, TPM-based attested enrollment +- Seamless client deployments of EAP-TLS Wi-Fi, VPN, SSH, and browser certificates +- Jamf, Intune, or other MDM for root distribution and client enrollment +- Web Admin UI — history, issuance, and metrics +- ACME External Account Binding (EAB) +- Deep integration with an identity provider +- Fine-grained, role-based access control +- FIPS-compliant software +- HSM-bound private keys + +See our [full feature comparison](https://smallstep.com/step-ca-vs-smallstep-certificate-manager/) for more. + +You can [start a free trial](https://smallstep.com/signup) or [set up a call with us](https://go.smallstep.com/request-demo) to learn more. --- **Questions? Find us in [Discussions](https://github.com/smallstep/certificates/discussions) or [Join our Discord](https://u.step.sm/discord).** [Website](https://smallstep.com/certificates) | -[Documentation](https://smallstep.com/docs) | +[Documentation](https://smallstep.com/docs/step-ca) | [Installation](https://smallstep.com/docs/step-ca/installation) | -[Getting Started](https://smallstep.com/docs/step-ca/getting-started) | [Contributor's Guide](./CONTRIBUTING.md) -[![GitHub release](https://img.shields.io/github/release/smallstep/certificates.svg)](https://github.com/smallstep/certificates/releases/latest) -[![Go Report Card](https://goreportcard.com/badge/github.com/smallstep/certificates)](https://goreportcard.com/report/github.com/smallstep/certificates) -[![Build Status](https://github.com/smallstep/certificates/actions/workflows/test.yml/badge.svg)](https://github.com/smallstep/certificates) -[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) -[![CLA assistant](https://cla-assistant.io/readme/badge/smallstep/certificates)](https://cla-assistant.io/smallstep/certificates) - -[![GitHub stars](https://img.shields.io/github/stars/smallstep/certificates.svg?style=social)](https://github.com/smallstep/certificates/stargazers) -[![Twitter followers](https://img.shields.io/twitter/follow/smallsteplabs.svg?label=Follow&style=social)](https://twitter.com/intent/follow?screen_name=smallsteplabs) - -![star us](https://github.com/smallstep/certificates/raw/master/docs/images/star.gif) - ## Features ### 🦾 A fast, stable, flexible private CA @@ -52,7 +65,6 @@ Setting up a *public key infrastructure* (PKI) is out of reach for many small te - Choose key types (RSA, ECDSA, EdDSA) and lifetimes to suit your needs - [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with automated enrollment, renewal, and passive revocation -- Capable of high availability (HA) deployment using [root federation](https://smallstep.com/blog/step-v0.8.3-federation-root-rotation.html) and/or multiple intermediaries - Can operate as [an online intermediate CA for an existing root CA](https://smallstep.com/docs/tutorials/intermediate-ca-new-ca) - [Badger, BoltDB, Postgres, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases) @@ -127,5 +139,5 @@ and visiting http://localhost:8080. ## Feedback? -* Tell us what you like and don't like about managing your PKI - we're eager to help solve problems in this space. -* Tell us about a feature you'd like to see! [Add a feature request Issue](https://github.com/smallstep/certificates/issues/new?assignees=&labels=enhancement%2C+needs+triage&template=enhancement.md&title=), [ask on Discussions](https://github.com/smallstep/certificates/discussions), or hit us up on [Twitter](https://twitter.com/smallsteplabs). +* Tell us what you like and don't like about managing your PKI - we're eager to help solve problems in this space. [Join our Discord](https://u.step.sm/discord) or [GitHub Discussions](https://github.com/smallstep/certificates/discussions) +* Tell us about a feature you'd like to see! [Request a Feature](https://github.com/smallstep/certificates/issues/new?assignees=&labels=enhancement%2C+needs+triage&template=enhancement.md&title=) From 0898c6db972b2c823eb873b719d31eca96cff613 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Thu, 29 Feb 2024 20:26:27 +0100 Subject: [PATCH 28/31] Use UUIDv4 as automatically generated client request identifier --- ca/client.go | 15 +++++++++++++-- ca/client_test.go | 10 ++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/ca/client.go b/ca/client.go index 9e245cd70..b18efbaf4 100644 --- a/ca/client.go +++ b/ca/client.go @@ -24,7 +24,6 @@ import ( "strings" "github.com/pkg/errors" - "github.com/rs/xid" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" @@ -35,6 +34,7 @@ import ( "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" + "go.step.sm/crypto/randutil" "go.step.sm/crypto/x509util" "golang.org/x/net/http2" "google.golang.org/protobuf/encoding/protojson" @@ -105,6 +105,17 @@ func (c *uaClient) PostWithContext(ctx context.Context, u, contentType string, b // the CA client to the CA and back again. const requestIDHeader = "X-Request-Id" +// newRequestID generates a new random UUIDv4 request ID. If it fails, +// the request ID will be the empty string. +func newRequestID() string { + requestID, err := randutil.UUIDv4() + if err != nil { + return "" + } + + return requestID +} + // enforceRequestID checks if the X-Request-Id HTTP header is filled. If it's // empty, the context is searched for a request ID. If that's also empty, a new // request ID is generated. @@ -115,7 +126,7 @@ func enforceRequestID(r *http.Request) { // used before by the client (unless it's a retry for the same request)? requestID = reqID } else { - requestID = xid.New().String() + requestID = newRequestID() } r.Header.Set(requestIDHeader, requestID) } diff --git a/ca/client_test.go b/ca/client_test.go index 5fd111794..44d24c6ee 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -17,6 +17,7 @@ import ( "testing" "time" + "github.com/google/uuid" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" @@ -1056,3 +1057,12 @@ func Test_enforceRequestID(t *testing.T) { }) } } + +func Test_newRequestID(t *testing.T) { + requestID := newRequestID() + u, err := uuid.Parse(requestID) + assert.NoError(t, err) + assert.Equal(t, uuid.Version(0x4), u.Version()) + assert.Equal(t, uuid.RFC4122, u.Variant()) + assert.Equal(t, requestID, u.String()) +} From 7fd524f70b82021df5e7d58f2a6d5fc483ecafe1 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Fri, 1 Mar 2024 01:04:50 +0100 Subject: [PATCH 29/31] Default to generating request IDs using UUIDv4 format in CA --- internal/requestid/requestid.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/internal/requestid/requestid.go b/internal/requestid/requestid.go index 97f58f8ca..7008d4696 100644 --- a/internal/requestid/requestid.go +++ b/internal/requestid/requestid.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/rs/xid" + "go.step.sm/crypto/randutil" ) const ( @@ -61,9 +62,16 @@ func (h *Handler) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(fn) } -// newRequestID creates a new request ID using github.com/rs/xid. +// newRequestID generates a new random UUIDv4 request ID. If UUIDv4 +// generation fails, it'll fallback to generating a random ID using +// github.com/rs/xid. func newRequestID() string { - return xid.New().String() + requestID, err := randutil.UUIDv4() + if err != nil { + requestID = xid.New().String() + } + + return requestID } type requestIDKey struct{} From d392c169fce826a32ab562102f3e1ccba1fb8abc Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Mon, 4 Mar 2024 12:00:08 +0100 Subject: [PATCH 30/31] Improve functional coverage of request ID integration test --- authority/provisioner/webhook_test.go | 33 +-- ca/client/requestid.go | 6 +- ca/provisioner_test.go | 8 +- ca/tls_options_test.go | 98 +++------ ca/tls_test.go | 103 ++++----- internal/requestid/requestid.go | 7 +- internal/requestid/requestid_test.go | 4 + internal/userid/userid.go | 6 +- logging/handler.go | 1 + monitoring/monitoring.go | 1 + test/e2e/requestid_test.go | 132 ------------ test/integration/requestid_test.go | 289 ++++++++++++++++++++++++++ 12 files changed, 402 insertions(+), 286 deletions(-) delete mode 100644 test/e2e/requestid_test.go create mode 100644 test/integration/requestid_test.go diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index 4c80796f1..905834183 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -17,13 +17,15 @@ import ( "testing" "time" - "github.com/smallstep/certificates/internal/requestid" - "github.com/smallstep/certificates/webhook" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" "go.step.sm/linkedca" + + "github.com/smallstep/certificates/internal/requestid" + "github.com/smallstep/certificates/webhook" ) func TestWebhookController_isCertTypeOK(t *testing.T) { @@ -103,7 +105,8 @@ func TestWebhookController_isCertTypeOK(t *testing.T) { // withRequestID is a helper that calls into [requestid.NewContext] and returns // a new context with the requestID added. -func withRequestID(ctx context.Context, requestID string) context.Context { +func withRequestID(t *testing.T, ctx context.Context, requestID string) context.Context { + t.Helper() return requestid.NewContext(ctx, requestID) } @@ -138,7 +141,7 @@ func TestWebhookController_Enrich(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, @@ -153,7 +156,7 @@ func TestWebhookController_Enrich(t *testing.T) { }, TemplateData: x509util.TemplateData{}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, @@ -177,7 +180,7 @@ func TestWebhookController_Enrich(t *testing.T) { TemplateData: x509util.TemplateData{}, certType: linkedca.Webhook_X509, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, @@ -197,7 +200,7 @@ func TestWebhookController_Enrich(t *testing.T) { TemplateData: x509util.TemplateData{}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, @@ -220,7 +223,7 @@ func TestWebhookController_Enrich(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -235,7 +238,7 @@ func TestWebhookController_Enrich(t *testing.T) { PublicKey: []byte("bad"), })}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -296,7 +299,7 @@ func TestWebhookController_Authorize(t *testing.T) { client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, @@ -307,7 +310,7 @@ func TestWebhookController_Authorize(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}}, certType: linkedca.Webhook_SSH, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: false, @@ -318,7 +321,7 @@ func TestWebhookController_Authorize(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, @@ -339,7 +342,7 @@ func TestWebhookController_Authorize(t *testing.T) { client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -352,7 +355,7 @@ func TestWebhookController_Authorize(t *testing.T) { PublicKey: []byte("bad"), })}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -568,7 +571,7 @@ func TestWebhook_Do(t *testing.T) { ctx := context.Background() if tc.requestID != "" { - ctx = withRequestID(context.Background(), tc.requestID) + ctx = withRequestID(t, ctx, tc.requestID) } ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() diff --git a/ca/client/requestid.go b/ca/client/requestid.go index 2bebb7e53..1fb785ebf 100644 --- a/ca/client/requestid.go +++ b/ca/client/requestid.go @@ -2,17 +2,17 @@ package client import "context" -type requestIDKey struct{} +type contextKey struct{} // NewRequestIDContext returns a new context with the given request ID added to the // context. func NewRequestIDContext(ctx context.Context, requestID string) context.Context { - return context.WithValue(ctx, requestIDKey{}, requestID) + return context.WithValue(ctx, contextKey{}, requestID) } // RequestIDFromContext returns the request ID from the context if it exists. // and is not empty. func RequestIDFromContext(ctx context.Context) (string, bool) { - v, ok := ctx.Value(requestIDKey{}).(string) + v, ok := ctx.Value(contextKey{}).(string) return v, ok && v != "" } diff --git a/ca/provisioner_test.go b/ca/provisioner_test.go index 39193f3f9..5a754f084 100644 --- a/ca/provisioner_test.go +++ b/ca/provisioner_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" @@ -41,14 +43,12 @@ func getTestProvisioner(t *testing.T, caURL string) *Provisioner { } func TestNewProvisioner(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() want := getTestProvisioner(t, ca.URL) caBundle, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) type args struct { name string diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index c29947ad2..4ac6ff85e 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -10,6 +10,8 @@ import ( "sort" "testing" + "github.com/stretchr/testify/require" + "github.com/smallstep/certificates/api" ) @@ -196,23 +198,17 @@ func TestAddClientCA(t *testing.T) { //nolint:gosec // test tls config func TestAddRootsToRootCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() @@ -251,23 +247,17 @@ func TestAddRootsToRootCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddRootsToClientCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() @@ -306,28 +296,20 @@ func TestAddRootsToClientCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddFederationToRootCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) crt1 := parseCertificate(t, string(root)) crt2 := parseCertificate(t, string(federated)) @@ -371,28 +353,20 @@ func TestAddFederationToRootCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddFederationToClientCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) crt1 := parseCertificate(t, string(root)) crt2 := parseCertificate(t, string(federated)) @@ -436,23 +410,17 @@ func TestAddFederationToClientCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddRootsToCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() @@ -491,28 +459,20 @@ func TestAddRootsToCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddFederationToCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) crt1 := parseCertificate(t, string(root)) crt2 := parseCertificate(t, string(federated)) diff --git a/ca/tls_test.go b/ca/tls_test.go index a19685ce3..d1ce11ea7 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -17,27 +17,28 @@ import ( "testing" "time" - "github.com/smallstep/certificates/api" - "github.com/smallstep/certificates/authority" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" + + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority" ) -func generateOTT(subject string) string { +func generateOTT(t *testing.T, subject string) string { + t.Helper() now := time.Now() jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password"))) - if err != nil { - panic(err) - } + require.NoError(t, err) + opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts) - if err != nil { - panic(err) - } + require.NoError(t, err) + id, err := randutil.ASCII(64) - if err != nil { - panic(err) - } + require.NoError(t, err) + cl := struct { jose.Claims SANS []string `json:"sans"` @@ -53,9 +54,8 @@ func generateOTT(subject string) string { SANS: []string{subject}, } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() - if err != nil { - panic(err) - } + require.NoError(t, err) + return raw } @@ -72,32 +72,28 @@ func startTestServer(baseContext context.Context, tlsConfig *tls.Config, handler return srv } -func startCATestServer() *httptest.Server { +func startCATestServer(t *testing.T) *httptest.Server { config, err := authority.LoadConfiguration("testdata/ca.json") - if err != nil { - panic(err) - } + require.NoError(t, err) ca, err := New(config) - if err != nil { - panic(err) - } + require.NoError(t, err) // Use a httptest.Server instead baseContext := buildContext(ca.auth, nil, nil, nil) srv := startTestServer(baseContext, ca.srv.TLSConfig, ca.srv.Handler) return srv } -func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) { - srv := startCATestServer() +func sign(t *testing.T, domain string) (*Client, *api.SignResponse, crypto.PrivateKey) { + t.Helper() + srv := startCATestServer(t) defer srv.Close() - return signDuration(srv, domain, 0) + return signDuration(t, srv, domain, 0) } -func signDuration(srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) { - req, pk, err := CreateSignRequest(generateOTT(domain)) - if err != nil { - panic(err) - } +func signDuration(t *testing.T, srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) { + t.Helper() + req, pk, err := CreateSignRequest(generateOTT(t, domain)) + require.NoError(t, err) if duration > 0 { req.NotBefore = api.NewTimeDuration(time.Now()) @@ -105,13 +101,11 @@ func signDuration(srv *httptest.Server, domain string, duration time.Duration) ( } client, err := NewClient(srv.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - panic(err) - } + require.NoError(t, err) + sr, err := client.Sign(req) - if err != nil { - panic(err) - } + require.NoError(t, err) + return client, sr, pk } @@ -145,7 +139,7 @@ func serverHandler(t *testing.T, clientDomain string) http.Handler { func TestClient_GetServerTLSConfig_http(t *testing.T) { clientDomain := "test.domain" - client, sr, pk := sign("127.0.0.1") + client, sr, pk := sign(t, "127.0.0.1") // Create mTLS server ctx, cancel := context.WithCancel(context.Background()) @@ -212,7 +206,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client, sr, pk := sign(clientDomain) + client, sr, pk := sign(t, clientDomain) cli := tt.getClient(t, client, sr, pk) if cli == nil { return @@ -246,19 +240,18 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { defer reset() // Start CA - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() clientDomain := "test.domain" - client, sr, pk := signDuration(ca, "127.0.0.1", 5*time.Second) + client, sr, pk := signDuration(t, ca, "127.0.0.1", 5*time.Second) // Start mTLS server ctx, cancel := context.WithCancel(context.Background()) defer cancel() tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk) - if err != nil { - t.Fatalf("Client.GetServerTLSConfig() error = %v", err) - } + require.NoError(t, err) + srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvMTLS.Close() @@ -266,30 +259,26 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { ctx, cancel = context.WithCancel(context.Background()) defer cancel() tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven()) - if err != nil { - t.Fatalf("Client.GetServerTLSConfig() error = %v", err) - } + require.NoError(t, err) + srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvTLS.Close() // Transport - client, sr, pk = signDuration(ca, clientDomain, 5*time.Second) + client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second) tr1, err := client.Transport(context.Background(), sr, pk) - if err != nil { - t.Fatalf("Client.Transport() error = %v", err) - } + require.NoError(t, err) + // Transport with tlsConfig - client, sr, pk = signDuration(ca, clientDomain, 5*time.Second) + client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second) tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk) - if err != nil { - t.Fatalf("Client.GetClientTLSConfig() error = %v", err) - } + require.NoError(t, err) + tr2 := getDefaultTransport(tlsConfig) // No client cert root, err := RootCertificate(sr) - if err != nil { - t.Fatalf("RootCertificate() error = %v", err) - } + require.NoError(t, err) + tlsConfig = getDefaultTLSConfig(sr) tlsConfig.RootCAs = x509.NewCertPool() tlsConfig.RootCAs.AddCert(root) diff --git a/internal/requestid/requestid.go b/internal/requestid/requestid.go index 7008d4696..ace08f167 100644 --- a/internal/requestid/requestid.go +++ b/internal/requestid/requestid.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/rs/xid" + "go.step.sm/crypto/randutil" ) @@ -74,17 +75,17 @@ func newRequestID() string { return requestID } -type requestIDKey struct{} +type contextKey struct{} // NewContext returns a new context with the given request ID added to the // context. func NewContext(ctx context.Context, requestID string) context.Context { - return context.WithValue(ctx, requestIDKey{}, requestID) + return context.WithValue(ctx, contextKey{}, requestID) } // FromContext returns the request ID from the context if it exists and // is not the empty value. func FromContext(ctx context.Context) (string, bool) { - v, ok := ctx.Value(requestIDKey{}).(string) + v, ok := ctx.Value(contextKey{}).(string) return v, ok && v != "" } diff --git a/internal/requestid/requestid_test.go b/internal/requestid/requestid_test.go index 4d0e872dd..84a9021f4 100644 --- a/internal/requestid/requestid_test.go +++ b/internal/requestid/requestid_test.go @@ -19,11 +19,15 @@ func newRequest(t *testing.T) *http.Request { func Test_Middleware(t *testing.T) { requestWithID := newRequest(t) requestWithID.Header.Set("X-Request-Id", "reqID") + requestWithoutID := newRequest(t) + requestWithEmptyHeader := newRequest(t) requestWithEmptyHeader.Header.Set("X-Request-Id", "") + requestWithSmallstepID := newRequest(t) requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID") + tests := []struct { name string traceHeader string diff --git a/internal/userid/userid.go b/internal/userid/userid.go index bab4908f3..48087da89 100644 --- a/internal/userid/userid.go +++ b/internal/userid/userid.go @@ -2,19 +2,19 @@ package userid import "context" -type userIDKey struct{} +type contextKey struct{} // NewContext returns a new context with the given user ID added to the // context. // TODO(hs): this doesn't seem to be used / set currently; implement // when/where it makes sense. func NewContext(ctx context.Context, userID string) context.Context { - return context.WithValue(ctx, userIDKey{}, userID) + return context.WithValue(ctx, contextKey{}, userID) } // FromContext returns the user ID from the context if it exists // and is not empty. func FromContext(ctx context.Context) (string, bool) { - v, ok := ctx.Value(userIDKey{}).(string) + v, ok := ctx.Value(contextKey{}).(string) return v, ok && v != "" } diff --git a/logging/handler.go b/logging/handler.go index a29383b28..06fc56d3f 100644 --- a/logging/handler.go +++ b/logging/handler.go @@ -9,6 +9,7 @@ import ( "time" "github.com/sirupsen/logrus" + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/internal/userid" ) diff --git a/monitoring/monitoring.go b/monitoring/monitoring.go index 7c88ab3b9..2ca2ef546 100644 --- a/monitoring/monitoring.go +++ b/monitoring/monitoring.go @@ -9,6 +9,7 @@ import ( "github.com/newrelic/go-agent/v3/newrelic" "github.com/pkg/errors" + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/logging" ) diff --git a/test/e2e/requestid_test.go b/test/e2e/requestid_test.go deleted file mode 100644 index d2f968c37..000000000 --- a/test/e2e/requestid_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package e2e - -import ( - "context" - "encoding/json" - "fmt" - "net" - "path/filepath" - "sync" - "testing" - - "github.com/smallstep/certificates/authority/config" - "github.com/smallstep/certificates/ca" - "github.com/smallstep/certificates/ca/client" - "github.com/smallstep/certificates/errs" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.step.sm/crypto/minica" - "go.step.sm/crypto/pemutil" -) - -// reservePort "reserves" a TCP port by opening a listener on a random -// port and immediately closing it. The port can then be assumed to be -// available for running a server on. -func reservePort(t *testing.T) (host, port string) { - t.Helper() - l, err := net.Listen("tcp", ":0") - require.NoError(t, err) - - address := l.Addr().String() - err = l.Close() - require.NoError(t, err) - - host, port, err = net.SplitHostPort(address) - require.NoError(t, err) - - return -} - -func Test_reflectRequestID(t *testing.T) { - dir := t.TempDir() - m, err := minica.New(minica.WithName("Step E2E")) - require.NoError(t, err) - - rootFilepath := filepath.Join(dir, "root.crt") - _, err = pemutil.Serialize(m.Root, pemutil.WithFilename(rootFilepath)) - require.NoError(t, err) - - intermediateCertFilepath := filepath.Join(dir, "intermediate.crt") - _, err = pemutil.Serialize(m.Intermediate, pemutil.WithFilename(intermediateCertFilepath)) - require.NoError(t, err) - - intermediateKeyFilepath := filepath.Join(dir, "intermediate.key") - _, err = pemutil.Serialize(m.Signer, pemutil.WithFilename(intermediateKeyFilepath)) - require.NoError(t, err) - - // get a random address to listen on and connect to; currently no nicer way to get one before starting the server - // TODO(hs): find/implement a nicer way to expose the CA URL, similar to how e.g. httptest.Server exposes it? - host, port := reservePort(t) - - cfg := &config.Config{ - Root: []string{rootFilepath}, - IntermediateCert: intermediateCertFilepath, - IntermediateKey: intermediateKeyFilepath, - Address: net.JoinHostPort(host, port), // reuse the address that was just "reserved" - DNSNames: []string{"127.0.0.1", "[::1]", "localhost"}, - AuthorityConfig: &config.AuthConfig{ - AuthorityID: "stepca-test", - DeploymentType: "standalone-test", - }, - Logger: json.RawMessage(`{"format": "text"}`), - } - c, err := ca.New(cfg) - require.NoError(t, err) - - // instantiate a client for the CA running at the random address - caClient, err := ca.NewClient( - fmt.Sprintf("https://localhost:%s", port), - ca.WithRootFile(rootFilepath), - ) - require.NoError(t, err) - - var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() - err = c.Run() - require.Error(t, err) // expect error when server is stopped - }() - - // require OK health response as the baseline - ctx := context.Background() - healthResponse, err := caClient.HealthWithContext(ctx) - require.NoError(t, err) - if assert.NotNil(t, healthResponse) { - require.Equal(t, "ok", healthResponse.Status) - } - - // expect an error when retrieving an invalid root - rootResponse, err := caClient.RootWithContext(ctx, "invalid") - if assert.Error(t, err) { - apiErr := &errs.Error{} - if assert.ErrorAs(t, err, &apiErr) { - assert.Equal(t, 404, apiErr.StatusCode()) - assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", apiErr.Err.Error()) - assert.NotEmpty(t, apiErr.RequestID) - - // TODO: include the below error in the JSON? It's currently only output to the CA logs. Also see https://github.com/smallstep/certificates/pull/759 - //assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg) - } - } - assert.Nil(t, rootResponse) - - // expect an error when retrieving an invalid root and provided request ID - rootResponse, err = caClient.RootWithContext(client.NewRequestIDContext(ctx, "reqID"), "invalid") - if assert.Error(t, err) { - apiErr := &errs.Error{} - if assert.ErrorAs(t, err, &apiErr) { - assert.Equal(t, 404, apiErr.StatusCode()) - assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", apiErr.Err.Error()) - assert.Equal(t, "reqID", apiErr.RequestID) - } - } - assert.Nil(t, rootResponse) - - // done testing; stop and wait for the server to quit - err = c.Stop() - require.NoError(t, err) - - wg.Wait() -} diff --git a/test/integration/requestid_test.go b/test/integration/requestid_test.go new file mode 100644 index 000000000..f15db12f0 --- /dev/null +++ b/test/integration/requestid_test.go @@ -0,0 +1,289 @@ +package integration + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/minica" + "go.step.sm/crypto/pemutil" + "go.step.sm/crypto/randutil" + "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/ca" + "github.com/smallstep/certificates/ca/client" + "github.com/smallstep/certificates/errs" +) + +// reservePort "reserves" a TCP port by opening a listener on a random +// port and immediately closing it. The port can then be assumed to be +// available for running a server on. +func reservePort(t *testing.T) (host, port string) { + t.Helper() + l, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + address := l.Addr().String() + err = l.Close() + require.NoError(t, err) + + host, port, err = net.SplitHostPort(address) + require.NoError(t, err) + + return +} + +func Test_reflectRequestID(t *testing.T) { + dir := t.TempDir() + m, err := minica.New(minica.WithName("Step E2E")) + require.NoError(t, err) + + rootFilepath := filepath.Join(dir, "root.crt") + _, err = pemutil.Serialize(m.Root, pemutil.WithFilename(rootFilepath)) + require.NoError(t, err) + + intermediateCertFilepath := filepath.Join(dir, "intermediate.crt") + _, err = pemutil.Serialize(m.Intermediate, pemutil.WithFilename(intermediateCertFilepath)) + require.NoError(t, err) + + intermediateKeyFilepath := filepath.Join(dir, "intermediate.key") + _, err = pemutil.Serialize(m.Signer, pemutil.WithFilename(intermediateKeyFilepath)) + require.NoError(t, err) + + // get a random address to listen on and connect to; currently no nicer way to get one before starting the server + // TODO(hs): find/implement a nicer way to expose the CA URL, similar to how e.g. httptest.Server exposes it? + host, port := reservePort(t) + + authorizingSrv := newAuthorizingServer(t, m) + defer authorizingSrv.Close() + authorizingSrv.StartTLS() + + password := []byte("1234") + jwk, jwe, err := jose.GenerateDefaultKeyPair(password) + require.NoError(t, err) + encryptedKey, err := jwe.CompactSerialize() + require.NoError(t, err) + prov := &provisioner.JWK{ + ID: "jwk", + Name: "jwk", + Type: "JWK", + Key: jwk, + EncryptedKey: encryptedKey, + Claims: &config.GlobalProvisionerClaims, + Options: &provisioner.Options{ + Webhooks: []*provisioner.Webhook{ + { + ID: "webhook", + Name: "webhook-test", + URL: fmt.Sprintf("%s/authorize", authorizingSrv.URL), + Kind: "AUTHORIZING", + CertType: "X509", + }, + }, + }, + } + err = prov.Init(provisioner.Config{}) + require.NoError(t, err) + + cfg := &config.Config{ + Root: []string{rootFilepath}, + IntermediateCert: intermediateCertFilepath, + IntermediateKey: intermediateKeyFilepath, + Address: net.JoinHostPort(host, port), // reuse the address that was just "reserved" + DNSNames: []string{"127.0.0.1", "[::1]", "localhost"}, + AuthorityConfig: &config.AuthConfig{ + AuthorityID: "stepca-test", + DeploymentType: "standalone-test", + Provisioners: provisioner.List{prov}, + }, + Logger: json.RawMessage(`{"format": "text"}`), + } + c, err := ca.New(cfg) + require.NoError(t, err) + + // instantiate a client for the CA running at the random address + caClient, err := ca.NewClient( + fmt.Sprintf("https://localhost:%s", port), + ca.WithRootFile(rootFilepath), + ) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + err = c.Run() + require.ErrorIs(t, err, http.ErrServerClosed) + }() + + // require OK health response as the baseline + ctx := context.Background() + healthResponse, err := caClient.HealthWithContext(ctx) + require.NoError(t, err) + if assert.NotNil(t, healthResponse) { + require.Equal(t, "ok", healthResponse.Status) + } + + // expect an error when retrieving an invalid root + rootResponse, err := caClient.RootWithContext(ctx, "invalid") + var firstErr *errs.Error + if assert.ErrorAs(t, err, &firstErr) { + assert.Equal(t, 404, firstErr.StatusCode()) + assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", firstErr.Err.Error()) + assert.NotEmpty(t, firstErr.RequestID) + + // TODO: include the below error in the JSON? It's currently only output to the CA logs. Also see https://github.com/smallstep/certificates/pull/759 + //assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg) + } + assert.Nil(t, rootResponse) + + // expect an error when retrieving an invalid root and provided request ID + rootResponse, err = caClient.RootWithContext(client.NewRequestIDContext(ctx, "reqID"), "invalid") + var secondErr *errs.Error + if assert.ErrorAs(t, err, &secondErr) { + assert.Equal(t, 404, secondErr.StatusCode()) + assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", secondErr.Err.Error()) + assert.Equal(t, "reqID", secondErr.RequestID) + } + assert.Nil(t, rootResponse) + + // prepare a Sign request + subject := "test" + decryptedJWK := decryptPrivateKey(t, jwe, password) + ott := generateOTT(t, decryptedJWK, subject) + + signer, err := keyutil.GenerateDefaultSigner() + require.NoError(t, err) + + csr, err := x509util.CreateCertificateRequest(subject, []string{subject}, signer) + require.NoError(t, err) + + // perform the Sign request using the OTT and CSR + signResponse, err := caClient.SignWithContext(client.NewRequestIDContext(ctx, "signRequestID"), &api.SignRequest{ + CsrPEM: api.CertificateRequest{CertificateRequest: csr}, + OTT: ott, + NotAfter: api.NewTimeDuration(time.Now().Add(1 * time.Hour)), + NotBefore: api.NewTimeDuration(time.Now().Add(-1 * time.Hour)), + }) + assert.NoError(t, err) + + // assert a certificate was returned for the subject "test" + if assert.NotNil(t, signResponse) { + assert.Len(t, signResponse.CertChainPEM, 2) + cert, err := x509.ParseCertificate(signResponse.CertChainPEM[0].Raw) + assert.NoError(t, err) + if assert.NotNil(t, cert) { + assert.Equal(t, "test", cert.Subject.CommonName) + assert.Contains(t, cert.DNSNames, "test") + } + } + + // done testing; stop and wait for the server to quit + err = c.Stop() + require.NoError(t, err) + + wg.Wait() +} + +func decryptPrivateKey(t *testing.T, jwe *jose.JSONWebEncryption, pass []byte) *jose.JSONWebKey { + t.Helper() + d, err := jwe.Decrypt(pass) + require.NoError(t, err) + + jwk := &jose.JSONWebKey{} + err = json.Unmarshal(d, jwk) + require.NoError(t, err) + + return jwk +} + +func generateOTT(t *testing.T, jwk *jose.JSONWebKey, subject string) string { + t.Helper() + now := time.Now() + + keyID, err := jose.Thumbprint(jwk) + require.NoError(t, err) + + opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", keyID) + signer, err := jose.NewSigner(jose.SigningKey{Key: jwk.Key}, opts) + require.NoError(t, err) + + id, err := randutil.ASCII(64) + require.NoError(t, err) + + cl := struct { + jose.Claims + SANS []string `json:"sans"` + }{ + Claims: jose.Claims{ + ID: id, + Subject: subject, + Issuer: "jwk", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(time.Minute)), + Audience: []string{"https://127.0.0.1/1.0/sign"}, + }, + SANS: []string{subject}, + } + raw, err := jose.Signed(signer).Claims(cl).CompactSerialize() + require.NoError(t, err) + + return raw +} + +func newAuthorizingServer(t *testing.T, ca *minica.CA) *httptest.Server { + t.Helper() + + key, err := keyutil.GenerateDefaultSigner() + require.NoError(t, err) + + csr, err := x509util.CreateCertificateRequest("127.0.0.1", []string{"127.0.0.1"}, key) + require.NoError(t, err) + + crt, err := ca.SignCSR(csr) + require.NoError(t, err) + + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if assert.Equal(t, "signRequestID", r.Header.Get("X-Request-Id")) { + json.NewEncoder(w).Encode(struct{ Allow bool }{Allow: true}) + w.WriteHeader(http.StatusOK) + return + } + + w.WriteHeader(http.StatusBadRequest) + })) + trustedRoots := x509.NewCertPool() + trustedRoots.AddCert(ca.Root) + + srv.TLS = &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{crt.Raw, ca.Intermediate.Raw}, + PrivateKey: key, + Leaf: crt, + }, + }, + ClientCAs: trustedRoots, + ClientAuth: tls.RequireAndVerifyClientCert, + ServerName: "localhost", + } + + return srv +} From 2a47644d31458041b5e3f02ffc595d1ca10b6f6d Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Mon, 4 Mar 2024 12:01:25 +0100 Subject: [PATCH 31/31] Fix linting issue --- test/integration/requestid_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/integration/requestid_test.go b/test/integration/requestid_test.go index f15db12f0..54fd2eb07 100644 --- a/test/integration/requestid_test.go +++ b/test/integration/requestid_test.go @@ -248,7 +248,7 @@ func generateOTT(t *testing.T, jwk *jose.JSONWebKey, subject string) string { return raw } -func newAuthorizingServer(t *testing.T, ca *minica.CA) *httptest.Server { +func newAuthorizingServer(t *testing.T, mca *minica.CA) *httptest.Server { t.Helper() key, err := keyutil.GenerateDefaultSigner() @@ -257,7 +257,7 @@ func newAuthorizingServer(t *testing.T, ca *minica.CA) *httptest.Server { csr, err := x509util.CreateCertificateRequest("127.0.0.1", []string{"127.0.0.1"}, key) require.NoError(t, err) - crt, err := ca.SignCSR(csr) + crt, err := mca.SignCSR(csr) require.NoError(t, err) srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -270,12 +270,12 @@ func newAuthorizingServer(t *testing.T, ca *minica.CA) *httptest.Server { w.WriteHeader(http.StatusBadRequest) })) trustedRoots := x509.NewCertPool() - trustedRoots.AddCert(ca.Root) + trustedRoots.AddCert(mca.Root) srv.TLS = &tls.Config{ Certificates: []tls.Certificate{ { - Certificate: [][]byte{crt.Raw, ca.Intermediate.Raw}, + Certificate: [][]byte{crt.Raw, mca.Intermediate.Raw}, PrivateKey: key, Leaf: crt, },