Skip to content

Commit

Permalink
Merge pull request #1542 from smallstep/herman/webhook-request-id
Browse files Browse the repository at this point in the history
Propagate request ID when webhook requests are made
  • Loading branch information
hslatman committed Feb 27, 2024
2 parents 6ce502c + 041b486 commit c798735
Show file tree
Hide file tree
Showing 20 changed files with 187 additions and 108 deletions.
2 changes: 1 addition & 1 deletion acme/api/revoke_test.go
Expand Up @@ -281,7 +281,7 @@ type mockCA struct {
MockAreSANsallowed func(ctx context.Context, sans []string) error
}

func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) {
func (m *mockCA) SignWithContext(context.Context, *x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) {
return nil, nil
}

Expand Down
2 changes: 1 addition & 1 deletion acme/common.go
Expand Up @@ -21,7 +21,7 @@ var clock Clock

// CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
AreSANsAllowed(ctx context.Context, sans []string) error
IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error
Expand Down
2 changes: 1 addition & 1 deletion acme/order.go
Expand Up @@ -263,7 +263,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
signOps = append(signOps, extraOptions...)

// Sign a new certificate.
certChain, err := auth.Sign(csr, provisioner.SignOptions{
certChain, err := auth.SignWithContext(ctx, csr, provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(o.NotBefore),
NotAfter: provisioner.NewTimeDuration(o.NotAfter),
}, signOps...)
Expand Down
26 changes: 13 additions & 13 deletions acme/order_test.go
Expand Up @@ -271,16 +271,16 @@ func TestOrder_UpdateStatus(t *testing.T) {
}

type mockSignAuth struct {
sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
signWithContext func(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
areSANsAllowed func(ctx context.Context, sans []string) error
loadProvisionerByName func(string) (provisioner.Interface, error)
ret1, ret2 interface{}
err error
}

func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.sign != nil {
return m.sign(csr, signOpts, extraOpts...)
func (m *mockSignAuth) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.signWithContext != nil {
return m.signWithContext(ctx, csr, signOpts, extraOpts...)
} else if m.err != nil {
return nil, m.err
}
Expand Down Expand Up @@ -578,7 +578,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return nil, errors.New("force")
},
Expand Down Expand Up @@ -628,7 +628,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil
},
Expand Down Expand Up @@ -685,7 +685,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil
},
Expand Down Expand Up @@ -770,7 +770,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil
},
Expand Down Expand Up @@ -863,7 +863,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil
},
Expand Down Expand Up @@ -973,7 +973,7 @@ func TestOrder_Finalize(t *testing.T) {
// using the mocking functions as a wrapper for actual test helpers generated per test case or per
// function that's tested.
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{leaf, inter, root}, nil
},
Expand Down Expand Up @@ -1044,7 +1044,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil
},
Expand Down Expand Up @@ -1108,7 +1108,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil
},
Expand Down Expand Up @@ -1175,7 +1175,7 @@ func TestOrder_Finalize(t *testing.T) {
},
},
ca: &mockSignAuth{
sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
assert.Equals(t, _csr, csr)
return []*x509.Certificate{foo, bar, baz}, nil
},
Expand Down
2 changes: 1 addition & 1 deletion api/api.go
Expand Up @@ -42,7 +42,7 @@ type Authority interface {
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
GetTLSOptions() *config.TLSOptions
Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
Expand Down
8 changes: 4 additions & 4 deletions api/api_test.go
Expand Up @@ -189,7 +189,7 @@ type mockAuthority struct {
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
getTLSOptions func() *authority.TLSOptions
root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
signWithContext func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
Expand Down Expand Up @@ -251,9 +251,9 @@ func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) {
return m.ret1.(*x509.Certificate), m.err
}

func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.sign != nil {
return m.sign(cr, opts, signOpts...)
func (m *mockAuthority) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.signWithContext != nil {
return m.signWithContext(ctx, cr, opts, signOpts...)
}
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
Expand Down
2 changes: 1 addition & 1 deletion api/sign.go
Expand Up @@ -78,7 +78,7 @@ func Sign(w http.ResponseWriter, r *http.Request) {
return
}

certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
certChain, err := a.SignWithContext(ctx, body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
return
Expand Down
2 changes: 1 addition & 1 deletion api/ssh.go
Expand Up @@ -330,7 +330,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
NotAfter: time.Unix(int64(cert.ValidBefore), 0),
})

certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...)
certChain, err := a.SignWithContext(ctx, cr, provisioner.SignOptions{}, signOpts...)
if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
return
Expand Down
2 changes: 1 addition & 1 deletion api/ssh_test.go
Expand Up @@ -325,7 +325,7 @@ func Test_SSHSign(t *testing.T) {
signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
return tt.addUserCert, tt.addUserErr
},
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
signWithContext: func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return tt.tlsSignCerts, tt.tlsSignErr
},
})
Expand Down
3 changes: 2 additions & 1 deletion authority/authority_test.go
@@ -1,6 +1,7 @@
package authority

import (
"context"
"crypto"
"crypto/rand"
"crypto/sha256"
Expand Down Expand Up @@ -414,7 +415,7 @@ func TestNewEmbedded_Sign(t *testing.T) {
csr, err := x509.ParseCertificateRequest(cr)
assert.FatalError(t, err)

cert, err := a.Sign(csr, provisioner.SignOptions{})
cert, err := a.SignWithContext(context.Background(), csr, provisioner.SignOptions{})
assert.FatalError(t, err)
assert.Equals(t, []string{"foo.bar.zar"}, cert[0].DNSNames)
assert.Equals(t, crt, cert[1])
Expand Down
2 changes: 1 addition & 1 deletion authority/authorize_test.go
Expand Up @@ -1375,7 +1375,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
}

generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) {
chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
chain, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...)
if err != nil {
t.Fatal(err)
}
Expand Down
29 changes: 18 additions & 11 deletions authority/provisioner/webhook.go
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"github.com/pkg/errors"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/templates"
"github.com/smallstep/certificates/webhook"
"go.step.sm/linkedca"
Expand All @@ -36,7 +37,7 @@ type WebhookController struct {

// Enrich fetches data from remote servers and adds returned data to the
// templateData
func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
func (wc *WebhookController) Enrich(ctx context.Context, req *webhook.RequestBody) error {
if wc == nil {
return nil
}
Expand All @@ -55,7 +56,11 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
if !wc.isCertTypeOK(wh) {
continue
}
resp, err := wh.Do(wc.client, req, wc.TemplateData)

whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout

resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
if err != nil {
return err
}
Expand All @@ -68,7 +73,7 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
}

// Authorize checks that all remote servers allow the request
func (wc *WebhookController) Authorize(req *webhook.RequestBody) error {
func (wc *WebhookController) Authorize(ctx context.Context, req *webhook.RequestBody) error {
if wc == nil {
return nil
}
Expand All @@ -87,7 +92,11 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error {
if !wc.isCertTypeOK(wh) {
continue
}
resp, err := wh.Do(wc.client, req, wc.TemplateData)

whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout

resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
if err != nil {
return err
}
Expand Down Expand Up @@ -123,13 +132,6 @@ type Webhook struct {
} `json:"-"`
}

func (w *Webhook) Do(client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

return w.DoWithContext(ctx, client, reqBody, data)
}

func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) {
tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL)
if err != nil {
Expand Down Expand Up @@ -169,6 +171,11 @@ retry:
return nil, err
}

requestID, ok := logging.GetRequestID(ctx)
if ok {
req.Header.Set("X-Request-ID", requestID)
}

secret, err := base64.StdEncoding.DecodeString(w.Secret)
if err != nil {
return nil, err
Expand Down

0 comments on commit c798735

Please sign in to comment.