Skip to content

Commit

Permalink
Replace smallstep/assert with stretchr/testify for ACME provisioner
Browse files Browse the repository at this point in the history
  • Loading branch information
hslatman committed Feb 6, 2024
1 parent 37a9f36 commit e153be3
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 79 deletions.
6 changes: 5 additions & 1 deletion authority/provisioner/acme.go
Expand Up @@ -247,7 +247,11 @@ func (p *ACME) initializeWireOptions() error {
// at this point the Wire options have been validated, and (mostly)
// initialized. Remote keys will be loaded upon the first verification,
// currently.
// TODO(hs): can/should we "prime" the underlying remote keyset?
// TODO(hs): can/should we "prime" the underlying remote keyset, to verify
// auto discovery works as expected? Because of the current way provisioners
// are initialized, doing that as part of the initialization isn't the best
// time to do it, because it could result in operations not resulting in the
// expected result in all cases.

return nil
}
Expand Down
159 changes: 84 additions & 75 deletions authority/provisioner/acme_test.go
Expand Up @@ -11,10 +11,10 @@ import (
"testing"
"time"

"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner/wire"
sassert "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestACMEChallenge_Validate(t *testing.T) {
Expand All @@ -27,14 +27,20 @@ func TestACMEChallenge_Validate(t *testing.T) {
{"dns-01", DNS_01, false},
{"tls-alpn-01", TLS_ALPN_01, false},
{"device-attest-01", DEVICE_ATTEST_01, false},
{"wire-oidc-01", DEVICE_ATTEST_01, false},
{"wire-dpop-01", DEVICE_ATTEST_01, false},
{"uppercase", "HTTP-01", false},
{"fail", "http-02", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.c.Validate(); (err != nil) != tt.wantErr {
t.Errorf("ACMEChallenge.Validate() error = %v, wantErr %v", err, tt.wantErr)
err := tt.c.Validate()
if tt.wantErr {
assert.Error(t, err)
return
}

assert.NoError(t, err)
})
}
}
Expand All @@ -53,26 +59,24 @@ func TestACMEAttestationFormat_Validate(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.f.Validate(); (err != nil) != tt.wantErr {
t.Errorf("ACMEAttestationFormat.Validate() error = %v, wantErr %v", err, tt.wantErr)
err := tt.f.Validate()
if tt.wantErr {
assert.Error(t, err)
return
}

assert.NoError(t, err)
})
}
}

func TestACME_Getters(t *testing.T) {
p, err := generateACME()
assert.FatalError(t, err)
id := "acme/" + p.Name
if got := p.GetID(); got != id {
t.Errorf("ACME.GetID() = %v, want %v", got, id)
}
if got := p.GetName(); got != p.Name {
t.Errorf("ACME.GetName() = %v, want %v", got, p.Name)
}
if got := p.GetType(); got != TypeACME {
t.Errorf("ACME.GetType() = %v, want %v", got, TypeACME)
}
require.NoError(t, err)
id := "acme/test@acme-provisioner.com"
assert.Equal(t, id, p.GetID())
assert.Equal(t, "test@acme-provisioner.com", p.GetName())
assert.Equal(t, TypeACME, p.GetType())
kid, key, ok := p.GetEncryptedKey()
if kid != "" || key != "" || ok == true {
t.Errorf("ACME.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
Expand All @@ -82,13 +86,9 @@ func TestACME_Getters(t *testing.T) {

func TestACME_Init(t *testing.T) {
appleCA, err := os.ReadFile("testdata/certs/apple-att-ca.crt")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
yubicoCA, err := os.ReadFile("testdata/certs/yubico-piv-ca.crt")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
fakeWireDPoPKey := []byte(`-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
-----END PUBLIC KEY-----`)
Expand Down Expand Up @@ -224,11 +224,11 @@ MCowBQYDK2VwAyEA5c+4NKZSNQcR1T8qN6SjwgdPZQ0Ge12Ylx/YeGAJ35k=
t.Log(string(tc.p.AttestationRoots))
err := tc.p.Init(config)
if tc.err != nil {
sassert.EqualError(t, err, tc.err.Error())
assert.EqualError(t, err, tc.err.Error())
return
}

sassert.NoError(t, err)
assert.NoError(t, err)
})
}
}
Expand All @@ -244,12 +244,12 @@ func TestACME_AuthorizeRenew(t *testing.T) {
tests := map[string]func(*testing.T) test{
"fail/renew-disabled": func(t *testing.T) test {
p, err := generateACME()
assert.FatalError(t, err)
require.NoError(t, err)
// disable renewal
disable := true
p.Claims = &Claims{DisableRenewal: &disable}
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
require.NoError(t, err)
return test{
p: p,
cert: &x509.Certificate{
Expand All @@ -262,7 +262,7 @@ func TestACME_AuthorizeRenew(t *testing.T) {
},
"ok": func(t *testing.T) test {
p, err := generateACME()
assert.FatalError(t, err)
require.NoError(t, err)
return test{
p: p,
cert: &x509.Certificate{
Expand All @@ -275,16 +275,19 @@ func TestACME_AuthorizeRenew(t *testing.T) {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
err := tc.p.AuthorizeRenew(context.Background(), tc.cert)
if tc.err != nil {
if assert.Implements(t, (*render.StatusCodedError)(nil), err) {
var sc render.StatusCodedError
if errors.As(err, &sc) {
assert.Equal(t, tc.code, sc.StatusCode())
}
}
} else {
assert.Nil(t, tc.err)
assert.EqualError(t, err, tc.err.Error())
return
}

assert.NoError(t, err)
})
}
}
Expand All @@ -299,7 +302,7 @@ func TestACME_AuthorizeSign(t *testing.T) {
tests := map[string]func(*testing.T) test{
"ok": func(t *testing.T) test {
p, err := generateACME()
assert.FatalError(t, err)
require.NoError(t, err)
return test{
p: p,
token: "foo",
Expand All @@ -309,39 +312,43 @@ func TestACME_AuthorizeSign(t *testing.T) {
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
if assert.NotNil(t, tc.err) {
sc, ok := err.(render.StatusCodedError)
assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tc.code)
assert.HasPrefix(t, err.Error(), tc.err.Error())
opts, err := tc.p.AuthorizeSign(context.Background(), tc.token)
if tc.err != nil {
if assert.Implements(t, (*render.StatusCodedError)(nil), err) {
var sc render.StatusCodedError
if errors.As(err, &sc) {
assert.Equal(t, tc.code, sc.StatusCode())
}
}
} else {
if assert.Nil(t, tc.err) && assert.NotNil(t, opts) {
assert.Equals(t, 8, len(opts)) // number of SignOptions returned
for _, o := range opts {
switch v := o.(type) {
case *ACME:
case *provisionerExtensionOption:
assert.Equals(t, v.Type, TypeACME)
assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs)
case *forceCNOption:
assert.Equals(t, v.ForceCN, tc.p.ForceCN)
case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator:
case *validityValidator:
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
case *WebhookController:
assert.Len(t, 0, v.webhooks)
default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
assert.EqualError(t, err, tc.err.Error())
return
}

assert.NoError(t, err)
if assert.NotNil(t, opts) {
assert.Len(t, opts, 8) // number of SignOptions returned
for _, o := range opts {
switch v := o.(type) {
case *ACME:
case *provisionerExtensionOption:
assert.Equal(t, v.Type, TypeACME)
assert.Equal(t, v.Name, tc.p.GetName())
assert.Equal(t, v.CredentialID, "")
assert.Len(t, v.KeyValuePairs, 0)
case *forceCNOption:
assert.Equal(t, v.ForceCN, tc.p.ForceCN)
case profileDefaultDuration:
assert.Equal(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator:
case *validityValidator:
assert.Equal(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equal(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator:
assert.Equal(t, nil, v.policyEngine)
case *WebhookController:
assert.Len(t, v.webhooks, 0)
default:
require.NoError(t, fmt.Errorf("unexpected sign option of type %T", v))
}
}
}
Expand Down Expand Up @@ -372,20 +379,23 @@ func TestACME_IsChallengeEnabled(t *testing.T) {
{"ok dns-01 enabled", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, DNS_01}, true},
{"ok tls-alpn-01 enabled", fields{[]ACMEChallenge{"http-01", "dns-01", "tls-alpn-01"}}, args{ctx, TLS_ALPN_01}, true},
{"ok device-attest-01 enabled", fields{[]ACMEChallenge{"device-attest-01", "dns-01"}}, args{ctx, DEVICE_ATTEST_01}, true},
{"ok wire-oidc-01 enabled", fields{[]ACMEChallenge{"wire-oidc-01"}}, args{ctx, WIREOIDC_01}, true},
{"ok wire-dpop-01 enabled", fields{[]ACMEChallenge{"wire-dpop-01"}}, args{ctx, WIREDPOP_01}, true},
{"fail http-01", fields{[]ACMEChallenge{"dns-01"}}, args{ctx, "http-01"}, false},
{"fail dns-01", fields{[]ACMEChallenge{"http-01", "tls-alpn-01"}}, args{ctx, "dns-01"}, false},
{"fail tls-alpn-01", fields{[]ACMEChallenge{"http-01", "dns-01", "device-attest-01"}}, args{ctx, "tls-alpn-01"}, false},
{"fail device-attest-01", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, "device-attest-01"}, false},
{"fail wire-oidc-01", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, "wire-oidc-01"}, false},
{"fail wire-dpop-01", fields{[]ACMEChallenge{"http-01", "dns-01"}}, args{ctx, "wire-dpop-01"}, false},
{"fail unknown", fields{[]ACMEChallenge{"http-01", "dns-01", "tls-alpn-01", "device-attest-01"}}, args{ctx, "unknown"}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &ACME{
Challenges: tt.fields.Challenges,
}
if got := p.IsChallengeEnabled(tt.args.ctx, tt.args.challenge); got != tt.want {
t.Errorf("ACME.AuthorizeChallenge() = %v, want %v", got, tt.want)
}
got := p.IsChallengeEnabled(tt.args.ctx, tt.args.challenge)
assert.Equal(t, tt.want, got)
})
}
}
Expand Down Expand Up @@ -419,9 +429,8 @@ func TestACME_IsAttestationFormatEnabled(t *testing.T) {
p := &ACME{
AttestationFormats: tt.fields.AttestationFormats,
}
if got := p.IsAttestationFormatEnabled(tt.args.ctx, tt.args.format); got != tt.want {
t.Errorf("ACME.IsAttestationFormatEnabled() = %v, want %v", got, tt.want)
}
got := p.IsAttestationFormatEnabled(tt.args.ctx, tt.args.format)
assert.Equal(t, tt.want, got)
})
}
}
6 changes: 3 additions & 3 deletions authority/provisioner/wire/oidc_options.go
Expand Up @@ -53,11 +53,11 @@ func (o *OIDCOptions) GetVerifier(ctx context.Context) (*oidc.IDTokenVerifier, e
switch {
case o.Provider.DiscoveryBaseURL != "":
// creates a new OIDC provider using automatic discovery and the default HTTP client
if provider, err := oidc.NewProvider(ctx, o.Provider.DiscoveryBaseURL); err != nil {
provider, err := oidc.NewProvider(ctx, o.Provider.DiscoveryBaseURL)
if err != nil {
return nil, fmt.Errorf("failed creating new OIDC provider using discovery: %w", err)
} else {
o.provider = provider
}
o.provider = provider
default:
o.provider = o.oidcProviderConfig.NewProvider(ctx)
}
Expand Down

0 comments on commit e153be3

Please sign in to comment.