diff --git a/changelog/13716.txt b/changelog/13716.txt new file mode 100644 index 0000000000000..e7d35a017f25f --- /dev/null +++ b/changelog/13716.txt @@ -0,0 +1,3 @@ +```release-note:bug +identity/oidc: Check for a nil signing key on rotation to prevent panics. +``` diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index 6ff810c750786..a935ee6ec0950 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -548,19 +548,11 @@ func (i *IdentityStore) pathOIDCCreateUpdateKey(ctx context.Context, req *logica // generate current and next keys if creating a new key or changing algorithms if key.Algorithm != prevAlgorithm { - signingKey, err := generateKeys(key.Algorithm) + err = key.generateAndSetKey(ctx, i.Logger(), req.Storage) if err != nil { return nil, err } - key.SigningKey = signingKey - key.KeyRing = append(key.KeyRing, &expireableKey{KeyID: signingKey.Public().KeyID}) - - if err := saveOIDCPublicKey(ctx, req.Storage, signingKey.Public()); err != nil { - return nil, err - } - i.Logger().Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID) - err = key.generateAndSetNextKey(ctx, i.Logger(), req.Storage) if err != nil { return nil, err @@ -1013,6 +1005,24 @@ func mergeJSONTemplates(logger hclog.Logger, output map[string]interface{}, temp return nil } +// generateAndSetKey will generate new signing and public key pairs and set +// them as the SigningKey. +func (k *namedKey) generateAndSetKey(ctx context.Context, logger hclog.Logger, s logical.Storage) error { + signingKey, err := generateKeys(k.Algorithm) + if err != nil { + return err + } + + k.SigningKey = signingKey + k.KeyRing = append(k.KeyRing, &expireableKey{KeyID: signingKey.Public().KeyID}) + + if err := saveOIDCPublicKey(ctx, s, signingKey.Public()); err != nil { + return err + } + logger.Debug("generated OIDC public key to sign JWTs", "key_id", signingKey.Public().KeyID) + return nil +} + // generateAndSetNextKey will generate new signing and public key pairs and set // them as the NextSigningKey. func (k *namedKey) generateAndSetNextKey(ctx context.Context, logger hclog.Logger, s logical.Storage) error { @@ -1032,6 +1042,9 @@ func (k *namedKey) generateAndSetNextKey(ctx context.Context, logger hclog.Logge } func (k *namedKey) signPayload(payload []byte) (string, error) { + if k.SigningKey == nil { + return "", fmt.Errorf("signing key is nil; rotate the key and try again") + } signingKey := jose.SigningKey{Key: k.SigningKey, Algorithm: jose.SignatureAlgorithm(k.Algorithm)} signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{}) if err != nil { @@ -1482,21 +1495,27 @@ func (i *IdentityStore) pathOIDCIntrospect(ctx context.Context, req *logical.Req // verification_ttl can be overridden with an overrideVerificationTTL value >= 0 func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.Storage, overrideVerificationTTL time.Duration) error { verificationTTL := k.VerificationTTL - if overrideVerificationTTL >= 0 { verificationTTL = overrideVerificationTTL } now := time.Now() - // set the previous public key's expiry time - for _, key := range k.KeyRing { - if key.KeyID == k.SigningKey.KeyID { - key.ExpireAt = now.Add(verificationTTL) - break + if k.SigningKey != nil { + // set the previous public key's expiry time + for _, key := range k.KeyRing { + if key.KeyID == k.SigningKey.KeyID { + key.ExpireAt = now.Add(verificationTTL) + break + } } + } else { + // this can occur for keys generated before vault 1.9.0 but rotated on + // vault 1.9.0 + logger.Debug("nil signing key detected on rotation") } if k.NextSigningKey == nil { + logger.Debug("nil next signing key detected on rotation") // keys will not have a NextSigningKey if they were generated before // vault 1.9 err := k.generateAndSetNextKey(ctx, logger, s) @@ -1504,6 +1523,7 @@ func (k *namedKey) rotate(ctx context.Context, logger hclog.Logger, s logical.St return err } } + // do the rotation k.SigningKey = k.NextSigningKey k.NextRotation = now.Add(k.RotationPeriod) @@ -1695,21 +1715,21 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor return now, err } - namedKeys, err := s.List(ctx, namedKeyConfigPath) + keyNames, err := s.List(ctx, namedKeyConfigPath) if err != nil { return now, err } usedKeys := make([]string, 0) - for _, k := range namedKeys { - entry, err := s.Get(ctx, namedKeyConfigPath+k) + for _, keyName := range keyNames { + entry, err := s.Get(ctx, namedKeyConfigPath+keyName) if err != nil { return now, err } if entry == nil { - i.Logger().Warn("could not find key to update", "key", k) + i.Logger().Warn("could not find key to update", "key", keyName) continue } @@ -1722,14 +1742,14 @@ func (i *IdentityStore) expireOIDCPublicKeys(ctx context.Context, s logical.Stor keyRing := key.KeyRing var keyringUpdated bool - for i := 0; i < len(keyRing); i++ { - k := keyRing[i] + for j := 0; j < len(keyRing); j++ { + k := keyRing[j] if !k.ExpireAt.IsZero() && k.ExpireAt.Before(now) { - keyRing[i] = keyRing[len(keyRing)-1] + keyRing[j] = keyRing[len(keyRing)-1] keyRing = keyRing[:len(keyRing)-1] keyringUpdated = true - i-- + j-- continue } diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index d52ae7a14c760..8d23d63b3cf84 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -2,8 +2,6 @@ package vault import ( "context" - "crypto/rand" - "crypto/rsa" "encoding/json" "strconv" "strings" @@ -11,7 +9,7 @@ import ( "time" "github.com/go-test/deep" - uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/framework" @@ -893,6 +891,79 @@ func TestOIDC_SignIDToken(t *testing.T) { } } +// TestOIDC_SignIDToken_NilSigningKey tests that an error is returned when +// attempting to sign an ID token with a nil signing key +func TestOIDC_SignIDToken_NilSigningKey(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + ctx := namespace.RootContext(nil) + + // Create and load an entity, an entity is required to generate an ID token + testEntity := &identity.Entity{ + Name: "test-entity-name", + ID: "test-entity-id", + BucketKey: "test-entity-bucket-key", + } + + txn := c.identityStore.db.Txn(true) + defer txn.Abort() + err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true) + if err != nil { + t.Fatal(err) + } + txn.Commit() + + // Create a test key "test-key" with a nil SigningKey + namedKey := &namedKey{ + name: "test-key", + AllowedClientIDs: []string{"*"}, + Algorithm: "RS256", + VerificationTTL: 60 * time.Second, + RotationPeriod: 60 * time.Second, + KeyRing: nil, + SigningKey: nil, + NextSigningKey: nil, + NextRotation: time.Now(), + } + s := c.router.MatchingStorageByAPIPath(ctx, "identity/oidc") + if err := namedKey.generateAndSetNextKey(ctx, hclog.NewNullLogger(), s); err != nil { + t.Fatalf("failed to set next signing key") + } + // Store namedKey + entry, _ := logical.StorageEntryJSON(namedKeyConfigPath+namedKey.name, namedKey) + if err := s.Put(ctx, entry); err != nil { + t.Fatalf("writing to in mem storage failed") + } + + // Create a test role "test-role" -- expect no warning + resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/role/test-role", + Operation: logical.CreateOperation, + Data: map[string]interface{}{ + "key": "test-key", + "ttl": "1m", + }, + Storage: s, + }) + expectSuccess(t, resp, err) + if resp != nil { + t.Fatalf("was expecting a nil response but instead got: %#v", resp) + } + + // Generate a token against the role "test-role" -- should fail + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/token/test-role", + Operation: logical.ReadOperation, + Storage: s, + EntityID: "test-entity-id", + }) + expectError(t, resp, err) + // validate error message + expectedStrings := map[string]interface{}{ + "error signing OIDC token: signing key is nil; rotate the key and try again": true, + } + expectStrings(t, []string{err.Error()}, expectedStrings) +} + // TestOIDC_PeriodicFunc tests timing logic for running key // rotations and expiration actions. func TestOIDC_PeriodicFunc(t *testing.T) { @@ -900,72 +971,111 @@ func TestOIDC_PeriodicFunc(t *testing.T) { c, _, _ := TestCoreUnsealed(t) ctx := namespace.RootContext(nil) - // Prepare a dummy signing key - key, _ := rsa.GenerateKey(rand.Reader, 2048) - id, _ := uuid.GenerateUUID() - jwk := &jose.JSONWebKey{ - Key: key, - KeyID: id, - Algorithm: "RS256", - Use: "sig", - } - cyclePeriod := 2 * time.Second testSets := []struct { - namedKey *namedKey - testCases []struct { - cycle int - numKeys int - numPublicKeys int - } + namedKey *namedKey + expectedKeyCount int + setSigningKey bool + setNextSigningKey bool + cycles int }{ { - // don't set NextSigningKey to ensure its non-existence can be handled - &namedKey{ + namedKey: &namedKey{ name: "test-key", Algorithm: "RS256", VerificationTTL: 1 * cyclePeriod, RotationPeriod: 1 * cyclePeriod, KeyRing: nil, - SigningKey: jwk, + SigningKey: nil, + NextSigningKey: nil, + NextRotation: time.Now(), + }, + expectedKeyCount: 3, + setSigningKey: true, + setNextSigningKey: true, + cycles: 4, + }, + { + // don't set SigningKey to ensure its non-existence can be handled + namedKey: &namedKey{ + name: "test-key-nil-signing-key", + Algorithm: "RS256", + VerificationTTL: 1 * cyclePeriod, + RotationPeriod: 1 * cyclePeriod, + KeyRing: nil, + SigningKey: nil, + NextSigningKey: nil, + NextRotation: time.Now(), + }, + expectedKeyCount: 2, + setSigningKey: false, + setNextSigningKey: true, + cycles: 2, + }, + { + // don't set NextSigningKey to ensure its non-existence can be handled + namedKey: &namedKey{ + name: "test-key-nil-next-signing-key", + Algorithm: "RS256", + VerificationTTL: 1 * cyclePeriod, + RotationPeriod: 1 * cyclePeriod, + KeyRing: nil, + SigningKey: nil, + NextSigningKey: nil, NextRotation: time.Now(), }, - []struct { - cycle int - numKeys int - numPublicKeys int - }{ - {1, 2, 2}, - {2, 3, 3}, - {3, 3, 3}, - {4, 3, 3}, - {5, 3, 3}, - {6, 3, 3}, - {7, 3, 3}, + expectedKeyCount: 2, + setSigningKey: true, + setNextSigningKey: false, + cycles: 2, + }, + { + // don't set keys to ensure non-existence can be handled + namedKey: &namedKey{ + name: "test-key-nil-signing-and-next-signing-key", + Algorithm: "RS256", + VerificationTTL: 1 * cyclePeriod, + RotationPeriod: 1 * cyclePeriod, + KeyRing: nil, + SigningKey: nil, + NextSigningKey: nil, + NextRotation: time.Now(), }, + expectedKeyCount: 2, + setSigningKey: false, + setNextSigningKey: false, + cycles: 2, }, } for _, testSet := range testSets { - // Store namedKey storage := c.router.MatchingStorageByAPIPath(ctx, "identity/oidc") + if testSet.setSigningKey { + if err := testSet.namedKey.generateAndSetKey(ctx, hclog.NewNullLogger(), storage); err != nil { + t.Fatalf("failed to set signing key") + } + } + if testSet.setNextSigningKey { + if err := testSet.namedKey.generateAndSetNextKey(ctx, hclog.NewNullLogger(), storage); err != nil { + t.Fatalf("failed to set next signing key") + } + } + // Store namedKey entry, _ := logical.StorageEntryJSON(namedKeyConfigPath+testSet.namedKey.name, testSet.namedKey) if err := storage.Put(ctx, entry); err != nil { t.Fatalf("writing to in mem storage failed") } - currentCycle := 1 - numCases := len(testSet.testCases) - lastCycle := testSet.testCases[numCases-1].cycle - namedKeySamples := make([]*logical.StorageEntry, numCases) - publicKeysSamples := make([][]string, numCases) + currentCycle := 0 + lastCycle := testSet.cycles - 1 + namedKeySamples := make([]*logical.StorageEntry, testSet.cycles) + publicKeysSamples := make([][]string, testSet.cycles) i := 0 - // var start time.Time for currentCycle <= lastCycle { c.identityStore.oidcPeriodicFunc(ctx) - if currentCycle == testSet.testCases[i].cycle { + if currentCycle == i { namedKeyEntry, _ := storage.Get(ctx, namedKeyConfigPath+testSet.namedKey.name) publicKeysEntry, _ := storage.List(ctx, publicKeysConfigPath) namedKeySamples[i] = namedKeyEntry @@ -985,15 +1095,34 @@ func TestOIDC_PeriodicFunc(t *testing.T) { } // measure collected samples - for i := range testSet.testCases { + for i := 0; i < testSet.cycles; i++ { + cycle := i + 1 namedKeySamples[i].DecodeJSON(&testSet.namedKey) - if len(testSet.namedKey.KeyRing) != testSet.testCases[i].numKeys { - t.Fatalf("At cycle: %d expected namedKey's KeyRing to be of length %d but was: %d", testSet.testCases[i].cycle, testSet.testCases[i].numKeys, len(testSet.namedKey.KeyRing)) + actualKeyRingLen := len(testSet.namedKey.KeyRing) + if actualKeyRingLen < testSet.expectedKeyCount { + t.Errorf( + "For key: %s at cycle: %d expected namedKey's KeyRing to be at least of length %d but was: %d", + testSet.namedKey.name, + cycle, + testSet.expectedKeyCount, + actualKeyRingLen, + ) } - if len(publicKeysSamples[i]) != testSet.testCases[i].numPublicKeys { - t.Fatalf("At cycle: %d expected public keys to be of length %d but was: %d", testSet.testCases[i].cycle, testSet.testCases[i].numPublicKeys, len(publicKeysSamples[i])) + actualPubKeysLen := len(publicKeysSamples[i]) + if actualPubKeysLen < testSet.expectedKeyCount { + t.Errorf( + "For key: %s at cycle: %d expected public keys to be at least of length %d but was: %d", + testSet.namedKey.name, + cycle, + testSet.expectedKeyCount, + actualPubKeysLen, + ) } } + + if err := storage.Delete(ctx, namedKeyConfigPath+testSet.namedKey.name); err != nil { + t.Fatalf("deleting from in mem storage failed") + } } }