From ca39242fe5a534937afe9ba5d40d97d3727a07ea Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 21 Mar 2024 11:50:22 -0700 Subject: [PATCH] Allow to compare kms errors with errors.Is This commit implements the "Is(target error) bool" interface to apiv1 errors so we can compare them with errors.Is even if the message is not empty. --- kms/apiv1/options.go | 32 ++++++++++++++++++++-- kms/apiv1/options_test.go | 56 ++++++++++++++++++++++++++++++++++++++- kms/mackms/mackms.go | 33 ++++++++++++++++------- kms/mackms/mackms_test.go | 39 ++++++++++++++++++++++++++- 4 files changed, 147 insertions(+), 13 deletions(-) diff --git a/kms/apiv1/options.go b/kms/apiv1/options.go index e1ffb797..eb13f91a 100644 --- a/kms/apiv1/options.go +++ b/kms/apiv1/options.go @@ -72,8 +72,13 @@ func (e NotImplementedError) Error() string { return "not implemented" } +func (e NotImplementedError) Is(target error) bool { + _, ok := target.(NotImplementedError) + return ok +} + // AlreadyExistsError is the type of error returned if a key already exists. This -// is currently only implmented for pkcs11 and tpmkms. +// is currently only implemented for pkcs11, tpmkms, and mackms. type AlreadyExistsError struct { Message string } @@ -82,7 +87,30 @@ func (e AlreadyExistsError) Error() string { if e.Message != "" { return e.Message } - return "key already exists" + return "already exists" +} + +func (e AlreadyExistsError) Is(target error) bool { + _, ok := target.(AlreadyExistsError) + return ok +} + +// NotFoundError is the type of error returned if a key or certificate does not +// exist. This is currently only implemented for mackms. +type NotFoundError struct { + Message string +} + +func (e NotFoundError) Error() string { + if e.Message != "" { + return e.Message + } + return "not found" +} + +func (e NotFoundError) Is(target error) bool { + _, ok := target.(NotFoundError) + return ok } // Type represents the KMS type used. diff --git a/kms/apiv1/options_test.go b/kms/apiv1/options_test.go index d1f1aede..4e159b08 100644 --- a/kms/apiv1/options_test.go +++ b/kms/apiv1/options_test.go @@ -3,8 +3,12 @@ package apiv1 import ( "context" "crypto" + "errors" + "fmt" "os" "testing" + + "github.com/stretchr/testify/assert" ) type fakeKM struct{} @@ -124,7 +128,7 @@ func TestErrAlreadyExists_Error(t *testing.T) { fields fields want string }{ - {"default", fields{}, "key already exists"}, + {"default", fields{}, "already exists"}, {"custom", fields{"custom message: key already exists"}, "custom message: key already exists"}, } for _, tt := range tests { @@ -139,6 +143,30 @@ func TestErrAlreadyExists_Error(t *testing.T) { } } +func TestNotFoundError_Error(t *testing.T) { + type fields struct { + msg string + } + tests := []struct { + name string + fields fields + want string + }{ + {"default", fields{}, "not found"}, + {"custom", fields{"custom message: not found"}, "custom message: not found"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NotFoundError{ + Message: tt.fields.msg, + } + if got := e.Error(); got != tt.want { + t.Errorf("ErrAlreadyExists.Error() = %v, want %v", got, tt.want) + } + }) + } +} + func TestTypeOf(t *testing.T) { type args struct { rawuri string @@ -176,3 +204,29 @@ func TestTypeOf(t *testing.T) { }) } } + +func TestError_Is(t *testing.T) { + tests := []struct { + name string + err error + target error + want bool + }{ + {"ok not implemented", NotImplementedError{}, NotImplementedError{}, true}, + {"ok not implemented with message", NotImplementedError{Message: "something"}, NotImplementedError{}, true}, + {"ok already exists", AlreadyExistsError{}, AlreadyExistsError{}, true}, + {"ok already exists with message", AlreadyExistsError{Message: "something"}, AlreadyExistsError{}, true}, + {"ok not found", NotFoundError{}, NotFoundError{}, true}, + {"ok not found with message", NotFoundError{Message: "something"}, NotFoundError{}, true}, + {"fail not implemented", errors.New("not implemented"), NotImplementedError{}, false}, + {"fail already exists", errors.New("already exists"), AlreadyExistsError{}, false}, + {"fail not found", errors.New("not found"), NotFoundError{}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, errors.Is(tt.err, tt.target)) + assert.Equal(t, tt.want, errors.Is(fmt.Errorf("wrap 1: %w", tt.err), tt.target)) + assert.Equal(t, tt.want, errors.Is(fmt.Errorf("wrap 1: %w", fmt.Errorf("wrap 2: %w", tt.err)), tt.target)) + }) + } +} diff --git a/kms/mackms/mackms.go b/kms/mackms/mackms.go index f45924ca..32e86989 100644 --- a/kms/mackms/mackms.go +++ b/kms/mackms/mackms.go @@ -141,7 +141,7 @@ func (k *MacKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, key, err := getPrivateKey(u) if err != nil { - return nil, fmt.Errorf("mackms GetPublicKey failed: %w", err) + return nil, fmt.Errorf("mackms GetPublicKey failed: %w", apiv1Error(err)) } defer key.Release() @@ -263,7 +263,7 @@ func (k *MacKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons secKeyRef, err := security.SecKeyCreateRandomKey(attrs) if err != nil { - return nil, fmt.Errorf("mackms CreateKey failed: %w", err) + return nil, fmt.Errorf("mackms CreateKey failed: %w", apiv1Error(err)) } defer secKeyRef.Release() @@ -307,7 +307,7 @@ func (k *MacKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, er key, err := getPrivateKey(u) if err != nil { - return nil, fmt.Errorf("mackms CreateSigner failed: %w", err) + return nil, fmt.Errorf("mackms CreateSigner failed: %w", apiv1Error(err)) } defer key.Release() @@ -343,7 +343,7 @@ func (k *MacKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certi cert, err := loadCertificate(u.label, u.serialNumber, nil) if err != nil { - return nil, fmt.Errorf("mackms LoadCertificate failed: %w", err) + return nil, fmt.Errorf("mackms LoadCertificate failed: %w", apiv1Error(err)) } return cert, nil @@ -375,7 +375,7 @@ func (k *MacKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { // Store the certificate and update the label if required if err := storeCertificate(u.label, req.Certificate); err != nil { - return fmt.Errorf("mackms StoreCertificate failed: %w", err) + return fmt.Errorf("mackms StoreCertificate failed: %w", apiv1Error(err)) } return nil @@ -402,7 +402,7 @@ func (k *MacKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([ cert, err := loadCertificate(u.label, u.serialNumber, nil) if err != nil { - return nil, fmt.Errorf("mackms LoadCertificateChain failed1: %w", err) + return nil, fmt.Errorf("mackms LoadCertificateChain failed1: %w", apiv1Error(err)) } chain := []*x509.Certificate{cert} @@ -453,7 +453,7 @@ func (k *MacKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) // Store the certificate and update the label if required if err := storeCertificate(u.label, req.CertificateChain[0]); err != nil { - return fmt.Errorf("mackms StoreCertificateChain failed: %w", err) + return fmt.Errorf("mackms StoreCertificateChain failed: %w", apiv1Error(err)) } // Store the rest of the chain but do not fail if already exists @@ -503,7 +503,7 @@ func (*MacKMS) DeleteKey(req *apiv1.DeleteKeyRequest) error { } // Extract logic to deleteItem to avoid defer on loops if err := deleteItem(dict, u.hash); err != nil { - return fmt.Errorf("mackms DeleteKey failed: %w", err) + return fmt.Errorf("mackms DeleteKey failed: %w", apiv1Error(err)) } } @@ -548,7 +548,7 @@ func (*MacKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { } if err := deleteItem(query, nil); err != nil { - return fmt.Errorf("mackms DeleteCertificate failed: %w", err) + return fmt.Errorf("mackms DeleteCertificate failed: %w", apiv1Error(err)) } return nil @@ -1003,3 +1003,18 @@ func ecdhToECDSAPublicKey(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) { return nil, errors.New("failed to convert *ecdh.PublicKey to *ecdsa.PublicKey") } } + +func apiv1Error(err error) error { + switch { + case errors.Is(err, security.ErrNotFound): + return apiv1.NotFoundError{ + Message: err.Error(), + } + case errors.Is(err, security.ErrAlreadyExists): + return apiv1.AlreadyExistsError{ + Message: err.Error(), + } + default: + return err + } +} diff --git a/kms/mackms/mackms_test.go b/kms/mackms/mackms_test.go index 341f6128..a93b6d59 100644 --- a/kms/mackms/mackms_test.go +++ b/kms/mackms/mackms_test.go @@ -29,6 +29,8 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/hex" + "fmt" + "io" "math/big" "net/url" "testing" @@ -1143,7 +1145,7 @@ func TestMacKMS_DeleteCertificate(t *testing.T) { _, err := kms.LoadCertificate(&apiv1.LoadCertificateRequest{ Name: "mackms:serial=" + hex.EncodeToString(cert.SerialNumber.Bytes()), }) - assert.ErrorIs(t, err, security.ErrNotFound) + assert.ErrorIs(t, err, apiv1.NotFoundError{}) } kms := &MacKMS{} @@ -1196,3 +1198,38 @@ func TestMacKMS_DeleteCertificate(t *testing.T) { }) } } + +func Test_apiv1Error(t *testing.T) { + type args struct { + err error + } + tests := []struct { + name string + args args + assertion assert.ErrorAssertionFunc + }{ + {"ok not found", args{security.ErrNotFound}, func(t assert.TestingT, err error, msg ...interface{}) bool { + return assert.ErrorIs(t, err, apiv1.NotFoundError{}, msg...) + }}, + {"ok not found wrapped", args{fmt.Errorf("something happened: %w", security.ErrNotFound)}, func(t assert.TestingT, err error, msg ...interface{}) bool { + return assert.ErrorIs(t, err, apiv1.NotFoundError{}, msg...) + }}, + {"ok already exists", args{security.ErrAlreadyExists}, func(t assert.TestingT, err error, msg ...interface{}) bool { + return assert.ErrorIs(t, err, apiv1.AlreadyExistsError{}, msg...) + }}, + {"ok already exists wrapped", args{fmt.Errorf("something happened: %w", security.ErrAlreadyExists)}, func(t assert.TestingT, err error, msg ...interface{}) bool { + return assert.ErrorIs(t, err, apiv1.AlreadyExistsError{}, msg...) + }}, + {"ok other", args{io.ErrUnexpectedEOF}, func(t assert.TestingT, err error, msg ...interface{}) bool { + return assert.ErrorIs(t, err, io.ErrUnexpectedEOF, msg...) + }}, + {"ok other wrapped", args{fmt.Errorf("something happened: %w", io.ErrUnexpectedEOF)}, func(t assert.TestingT, err error, msg ...interface{}) bool { + return assert.ErrorIs(t, err, io.ErrUnexpectedEOF, msg...) + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.assertion(t, apiv1Error(tt.args.err)) + }) + } +}