Skip to content

Commit

Permalink
Fix cosign generate Azure key pair bug (#1525)
Browse files Browse the repository at this point in the history
* handle the hash func later on, handle 404 response

Signed-off-by: Meredith Lancaster <malancas@github.com>

* use shorthand

Signed-off-by: Meredith Lancaster <malancas@github.com>

* add unit test for create key

Signed-off-by: Meredith Lancaster <malancas@github.com>

* try counting GetKey call

Signed-off-by: Meredith Lancaster <malancas@github.com>

* uncomment test

Signed-off-by: Meredith Lancaster <malancas@github.com>

* add more comments

Signed-off-by: Meredith Lancaster <malancas@github.com>

* remove redundant method definition

Signed-off-by: Meredith Lancaster <malancas@github.com>

* clean up test and error messages

Signed-off-by: Meredith Lancaster <malancas@github.com>

* use the signer's default context

Signed-off-by: Meredith Lancaster <malancas@github.com>

* comment

Signed-off-by: Meredith Lancaster <malancas@github.com>

* create key reference

Signed-off-by: Meredith Lancaster <malancas@github.com>

* use new key ref

Signed-off-by: Meredith Lancaster <malancas@github.com>

* err message

Signed-off-by: Meredith Lancaster <malancas@github.com>

---------

Signed-off-by: Meredith Lancaster <malancas@github.com>
  • Loading branch information
malancas committed Dec 4, 2023
1 parent 0d8f0bc commit f6b3cde
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 23 deletions.
22 changes: 22 additions & 0 deletions pkg/signature/kms/azure/client.go
Expand Up @@ -25,6 +25,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"regexp"
"strings"
Expand Down Expand Up @@ -319,11 +320,32 @@ func (a *azureVaultClient) public(ctx context.Context) (crypto.PublicKey, error)
}

func (a *azureVaultClient) createKey(ctx context.Context) (crypto.PublicKey, error) {
// check if the key already exists by attempting to fetch it
_, err := a.getKey(ctx)
// if the error is nil, this means the key already exists
// and we can return the public key
if err == nil {
return a.public(ctx)
}

// If the returned error is not nil, set the error to the
// custom azcore.ResponseError error implementation
// this custom error allows us to check the status code
// returned by the GetKey operation. If the operation
// returned a 404, we know that the key does not exist
// and we can create it.
var respErr *azcore.ResponseError
if ok := errors.As(err, &respErr); !ok {
return nil, fmt.Errorf("unexpected error returned by get key operation: %w", err)
}

// if a non-404 status code is returned, return the error
// since this is an unexpected error response
if respErr.StatusCode != http.StatusNotFound {
return nil, fmt.Errorf("unexpected status code returned by get key operation: %w", err)
}

// if a 404 was returned, then we can create the key
_, err = a.client.CreateKey(
ctx,
a.keyName,
Expand Down
133 changes: 125 additions & 8 deletions pkg/signature/kms/azure/client_test.go
Expand Up @@ -17,14 +17,20 @@ package azure

import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"errors"
"fmt"
"net/http"
"os"
"testing"

"github.com/jellydator/ttlcache/v3"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys"
)
Expand All @@ -33,21 +39,26 @@ type testKVClient struct {
key azkeys.JSONWebKey
}

func (c *testKVClient) CreateKey(_ context.Context, _ string, _ azkeys.CreateKeyParameters, _ *azkeys.CreateKeyOptions) (result azkeys.CreateKeyResponse, err error) {
func (c *testKVClient) CreateKey(_ context.Context, _ string, _ azkeys.CreateKeyParameters, _ *azkeys.CreateKeyOptions) (azkeys.CreateKeyResponse, error) {
key, err := generatePublicKey("EC")
if err != nil {
return result, err
return azkeys.CreateKeyResponse{}, err
}
c.key = key

result.Key = &key
return result, nil
return azkeys.CreateKeyResponse{
KeyBundle: azkeys.KeyBundle{
Key: &key,
},
}, nil
}

func (c *testKVClient) GetKey(_ context.Context, _, _ string, _ *azkeys.GetKeyOptions) (result azkeys.GetKeyResponse, err error) {
result.Key = &c.key

return result, nil
func (c *testKVClient) GetKey(_ context.Context, _, _ string, _ *azkeys.GetKeyOptions) (azkeys.GetKeyResponse, error) {
return azkeys.GetKeyResponse{
KeyBundle: azkeys.KeyBundle{
Key: &c.key,
},
}, nil
}

func (c *testKVClient) Sign(_ context.Context, _, _ string, _ azkeys.SignParameters, _ *azkeys.SignOptions) (result azkeys.SignResponse, err error) {
Expand All @@ -58,6 +69,53 @@ func (c *testKVClient) Verify(_ context.Context, _, _ string, _ azkeys.VerifyPar
return result, nil
}

type keyNotFoundClient struct {
testKVClient
key azkeys.JSONWebKey
getKeyReturnsErr bool
getKeyCallThreshold int
getKeyCallCount int
}

func (c *keyNotFoundClient) GetKey(_ context.Context, _, _ string, _ *azkeys.GetKeyOptions) (azkeys.GetKeyResponse, error) {
if c.getKeyReturnsErr && c.getKeyCallCount < c.getKeyCallThreshold {
c.getKeyCallCount++
return azkeys.GetKeyResponse{}, &azcore.ResponseError{
StatusCode: http.StatusNotFound,
RawResponse: &http.Response{},
}
}

return azkeys.GetKeyResponse{
KeyBundle: azkeys.KeyBundle{
Key: &c.key,
},
}, nil
}

type nonResponseErrClient struct {
testKVClient
keyCache *ttlcache.Cache[string, crypto.PublicKey]
}

func (c *nonResponseErrClient) GetKey(_ context.Context, _, _ string, _ *azkeys.GetKeyOptions) (result azkeys.GetKeyResponse, err error) {
err = errors.New("unexpected error")
return result, err
}

type non404RespClient struct {
testKVClient
keyCache *ttlcache.Cache[string, crypto.PublicKey]
}

func (c *non404RespClient) GetKey(_ context.Context, _, _ string, _ *azkeys.GetKeyOptions) (result azkeys.GetKeyResponse, err error) {
err = &azcore.ResponseError{
StatusCode: http.StatusServiceUnavailable,
}

return result, err
}

func generatePublicKey(azureKeyType string) (azkeys.JSONWebKey, error) {
keyOps := []*azkeys.KeyOperation{to.Ptr(azkeys.KeyOperationSign), to.Ptr(azkeys.KeyOperationVerify)}
kid := "https://honk-vault.vault.azure.net/keys/honk-key/abc123"
Expand Down Expand Up @@ -152,6 +210,65 @@ func TestAzureVaultClientFetchPublicKey(t *testing.T) {
}
}

func TestAzureVaultClientCreateKey(t *testing.T) {
type test struct {
name string
client kvClient
expectSuccess bool
}

key, err := generatePublicKey("EC")
if err != nil {
t.Fatalf("unexpected error while generating public key for testing: %v", err)
}

tests := []test{
{
name: "Successfully create key if it doesn't exist",
client: &keyNotFoundClient{
key: key,
getKeyReturnsErr: true,
getKeyCallThreshold: 1,
},
expectSuccess: true,
},
{
name: "Return public key if it already exists",
client: &testKVClient{
key: key,
},
expectSuccess: true,
},
{
name: "Fail to create key due to unknown error",
client: &nonResponseErrClient{},
expectSuccess: false,
},
{
name: "Fail to create key due to non-404 status code error",
client: &non404RespClient{},
expectSuccess: false,
},
}

for _, tc := range tests {
client := azureVaultClient{
client: tc.client,
keyCache: ttlcache.New[string, crypto.PublicKey](
ttlcache.WithDisableTouchOnHit[string, crypto.PublicKey](),
),
}

_, err = client.createKey(context.Background())
if err != nil && tc.expectSuccess {
t.Fatalf("Test '%s' failed. Expected nil error, actual value: %v", tc.name, err)
}
if err == nil && !tc.expectSuccess {
t.Fatalf("Test '%s' failed. Expected non-nil error", tc.name)
}
}
}

func TestGetAuthenticationMethod(t *testing.T) {
clearEnv := map[string]string{
"AZURE_TENANT_ID": "",
Expand Down
26 changes: 26 additions & 0 deletions pkg/signature/kms/azure/integration_test.go
Expand Up @@ -112,6 +112,32 @@ func TestLoadSignerVerifier(t *testing.T) {
}
}

func TestCreateKey(t *testing.T) {
azureVaultURL := os.Getenv("VAULT_URL")
if azureVaultURL == "" {
t.Fatalf("VAULT_URL must be set")
}

newKeyRef := fmt.Sprintf("azurekms://%s.vault.azure.net/%s", azureVaultURL, "new-test-key")

sv, err := LoadSignerVerifier(context.Background(), newKeyRef)
if err != nil {
t.Fatalf("LoadSignerVerifier unexpectedly returned non-nil error: %v", err)
}

publicKey, err := sv.client.createKey(context.Background())
if err != nil {
t.Errorf("getKey failed with error: %v", err)
}
if publicKey == nil {
t.Errorf("public key is nil")
}

if _, ok := publicKey.(*ecdsa.PublicKey); !ok {
t.Errorf("expected public key to be of type *ecdsa.PublicKey")
}
}

func TestGetKey(t *testing.T) {
azureKeyRef := os.Getenv("AZURE_KEY_REF")

Expand Down
33 changes: 18 additions & 15 deletions pkg/signature/kms/azure/signer.go
Expand Up @@ -70,11 +70,6 @@ func LoadSignerVerifier(defaultCtx context.Context, referenceStr string) (*Signe
return nil, err
}

a.hashFunc, _, err = a.client.getKeyVaultHashFunc(defaultCtx)
if err != nil {
return nil, err
}

return a, nil
}

Expand All @@ -92,14 +87,13 @@ func LoadSignerVerifier(defaultCtx context.Context, referenceStr string) (*Signe
//
// All other options are ignored if specified.
func (a *SignerVerifier) SignMessage(message io.Reader, opts ...signature.SignOption) ([]byte, error) {
ctx := context.Background()
var digest []byte

for _, opt := range opts {
opt.ApplyDigest(&digest)
}

hashFunc, _, err := a.client.getKeyVaultHashFunc(ctx)
hashFunc, _, err := a.client.getKeyVaultHashFunc(a.defaultCtx)
if err != nil {
return nil, err
}
Expand All @@ -109,7 +103,7 @@ func (a *SignerVerifier) SignMessage(message io.Reader, opts ...signature.SignOp
return nil, err
}

rawSig, err := a.client.sign(ctx, digest)
rawSig, err := a.client.sign(a.defaultCtx, digest)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -142,14 +136,18 @@ func (a *SignerVerifier) SignMessage(message io.Reader, opts ...signature.SignOp
//
// All other options are ignored if specified.
func (a *SignerVerifier) VerifySignature(sig, message io.Reader, opts ...signature.VerifyOption) error {
ctx := context.Background()
hashFunc, _, err := a.client.getKeyVaultHashFunc(a.defaultCtx)
if err != nil {
return err
}

var digest []byte
var signerOpts crypto.SignerOpts = a.hashFunc
var signerOpts crypto.SignerOpts = hashFunc
for _, opt := range opts {
opt.ApplyDigest(&digest)
}

digest, _, err := signature.ComputeDigestForVerifying(message, signerOpts.HashFunc(), azureSupportedHashFuncs, opts...)
digest, _, err = signature.ComputeDigestForVerifying(message, signerOpts.HashFunc(), azureSupportedHashFuncs, opts...)
if err != nil {
return err
}
Expand Down Expand Up @@ -177,13 +175,13 @@ func (a *SignerVerifier) VerifySignature(sig, message io.Reader, opts ...signatu
rawSigBytes := []byte{}
rawSigBytes = append(rawSigBytes, r.Bytes()...)
rawSigBytes = append(rawSigBytes, s.Bytes()...)
return a.client.verify(ctx, rawSigBytes, digest)
return a.client.verify(a.defaultCtx, rawSigBytes, digest)
}

// PublicKey returns the public key that can be used to verify signatures created by
// this signer. All options provided in arguments to this method are ignored.
func (a *SignerVerifier) PublicKey(_ ...signature.PublicKeyOption) (crypto.PublicKey, error) {
return a.client.public(context.Background())
return a.client.public(a.defaultCtx)
}

// CreateKey attempts to create a new key in Vault with the specified algorithm.
Expand Down Expand Up @@ -223,14 +221,19 @@ func (c cryptoSignerWrapper) Sign(_ io.Reader, digest []byte, opts crypto.Signer
// CryptoSigner returns a crypto.Signer object that uses the underlying SignerVerifier, along with a crypto.SignerOpts object
// that allows the KMS to be used in APIs that only accept the standard golang objects
func (a *SignerVerifier) CryptoSigner(ctx context.Context, errFunc func(error)) (crypto.Signer, crypto.SignerOpts, error) {
hashFunc, _, err := a.client.getKeyVaultHashFunc(a.defaultCtx)
if err != nil {
return nil, nil, err
}

csw := &cryptoSignerWrapper{
ctx: ctx,
sv: a,
hashFunc: a.hashFunc,
hashFunc: hashFunc,
errFunc: errFunc,
}

return csw, a.hashFunc, nil
return csw, hashFunc, nil
}

// SupportedAlgorithms returns the list of algorithms supported by the Azure KMS service
Expand Down

0 comments on commit f6b3cde

Please sign in to comment.