diff --git a/apikey.go b/apikey.go index ee3c58d..1b98bd9 100644 --- a/apikey.go +++ b/apikey.go @@ -1,6 +1,7 @@ package mfa import ( + "context" "crypto/aes" "crypto/cipher" "crypto/rand" @@ -46,14 +47,26 @@ type ApiKey struct { Store *Storage `dynamodbav:"-" json:"-"` } +// CompletionStats stores the number of records processed in a batch operation +type CompletionStats struct { + Complete int `json:"complete"` + Incomplete int `json:"incomplete"` +} + +// BatchStats stores the number of TOTP and Webauth records processed in a batch operation +type BatchStats struct { + TOTP CompletionStats `json:"totp"` + Webauthn CompletionStats `json:"webauthn"` +} + // Load refreshes an ApiKey from the database record func (k *ApiKey) Load() error { return k.Store.Load(envConfig.ApiKeyTable, ApiKeyTablePK, k.Key, k) } // Save an ApiKey to the database -func (k *ApiKey) Save() error { - return k.Store.Store(envConfig.ApiKeyTable, k) +func (k *ApiKey) Save(ctx context.Context) error { + return k.Store.StoreCtx(ctx, envConfig.ApiKeyTable, k) } // Hash generates a bcrypt hash from the Secret field and stores it in HashedSecret @@ -221,61 +234,60 @@ func (k *ApiKey) Activate() error { // ReEncryptTOTPs loads each TOTP record that was encrypted using the old key, re-encrypts it using the new // key, and writes the updated data back to the database. -func (k *ApiKey) ReEncryptTOTPs(storage *Storage, oldKey ApiKey) (complete, incomplete int, err error) { +func (k *ApiKey) ReEncryptTOTPs(ctx context.Context, storage *Storage, oldKey ApiKey) (CompletionStats, error) { var records []TOTP - err = storage.ScanApiKey(envConfig.TotpTable, oldKey.Key, &records) + err := storage.ScanApiKey(envConfig.TotpTable, oldKey.Key, &records) if err != nil { err = fmt.Errorf("failed to query %s table for key %s: %w", envConfig.TotpTable, oldKey.Key, err) - return + return CompletionStats{}, err } - incomplete = len(records) + stats := CompletionStats{Incomplete: len(records)} for _, r := range records { err = k.ReEncryptLegacy(oldKey, &r.EncryptedTotpKey) if err != nil { - err = fmt.Errorf("failed to re-encrypt TOTP %v: %w", r.UUID, err) - return + return stats, fmt.Errorf("failed to re-encrypt TOTP %v: %w", r.UUID, err) } r.ApiKey = k.Key - err = storage.Store(envConfig.TotpTable, &r) + err = storage.StoreCtx(ctx, envConfig.TotpTable, &r) if err != nil { - err = fmt.Errorf("failed to store TOTP %v: %w", r.UUID, err) - return + return stats, fmt.Errorf("failed to store TOTP %v: %w", r.UUID, err) } - complete++ - incomplete-- + + stats.Complete++ + stats.Incomplete-- } - return + return stats, nil } // ReEncryptWebAuthnUsers loads each WebAuthn record that was encrypted using the old key, re-encrypts it using the new // key, and writes the updated data back to the database. -func (k *ApiKey) ReEncryptWebAuthnUsers(storage *Storage, oldKey ApiKey) (complete, incomplete int, err error) { +func (k *ApiKey) ReEncryptWebAuthnUsers(ctx context.Context, storage *Storage, oldKey ApiKey) (CompletionStats, error) { var users []WebauthnUser - err = storage.ScanApiKey(envConfig.WebauthnTable, oldKey.Key, &users) + err := storage.ScanApiKey(envConfig.WebauthnTable, oldKey.Key, &users) if err != nil { err = fmt.Errorf("failed to query %s table for key %s: %w", envConfig.WebauthnTable, oldKey.Key, err) - return + return CompletionStats{}, err } - incomplete = len(users) + stats := CompletionStats{Incomplete: len(users)} for _, user := range users { user.ApiKey = oldKey - err = k.ReEncryptWebAuthnUser(storage, user) + err = k.ReEncryptWebAuthnUser(ctx, storage, user) if err != nil { - err = fmt.Errorf("failed to re-encrypt Webauthn %v: %w", user.ID, err) - return + return stats, fmt.Errorf("reencryption failed %v: %w", user.ID, err) } - complete++ - incomplete-- + + stats.Complete++ + stats.Incomplete-- } - return + return stats, nil } // ReEncryptWebAuthnUser re-encrypts a WebAuthnUser using the new key, and writes the updated data back to the database. -func (k *ApiKey) ReEncryptWebAuthnUser(storage *Storage, user WebauthnUser) error { +func (k *ApiKey) ReEncryptWebAuthnUser(ctx context.Context, storage *Storage, user WebauthnUser) error { oldKey := user.ApiKey err := k.ReEncrypt(oldKey, &user.EncryptedSessionData) if err != nil { @@ -297,7 +309,7 @@ func (k *ApiKey) ReEncryptWebAuthnUser(storage *Storage, user WebauthnUser) erro user.ApiKey = *k user.ApiKeyValue = k.Key - err = storage.Store(envConfig.WebauthnTable, &user) + err = storage.StoreCtx(ctx, envConfig.WebauthnTable, &user) if err != nil { return err } @@ -348,6 +360,8 @@ func (k *ApiKey) ReEncryptLegacy(oldKey ApiKey, v *string) error { // ActivateApiKey is the handler for the POST /api-key/activate endpoint. It creates the key secret and updates the // database record. func (a *App) ActivateApiKey(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + var requestBody struct { ApiKeyValue string `json:"apiKeyValue"` Email string `json:"email"` @@ -393,7 +407,7 @@ func (a *App) ActivateApiKey(w http.ResponseWriter, r *http.Request) { return } - err = newKey.Save() + err = newKey.Save(ctx) if err != nil { log.Printf("failed to save key: %s", err) jsonResponse(w, internalServerError, http.StatusInternalServerError) @@ -412,6 +426,8 @@ func (a *App) ActivateApiKey(w http.ResponseWriter, r *http.Request) { // CreateApiKey is the handler for the POST /api-key endpoint. It creates a new API Key and saves it to the database. func (a *App) CreateApiKey(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + var requestBody struct { Email string `json:"email"` } @@ -436,7 +452,7 @@ func (a *App) CreateApiKey(w http.ResponseWriter, r *http.Request) { } key.Store = a.db - err = key.Save() + err = key.Save(ctx) if err != nil { log.Printf("failed to save key: %s", err) jsonResponse(w, internalServerError, http.StatusInternalServerError) @@ -456,6 +472,7 @@ func (a *App) CreateApiKey(w http.ResponseWriter, r *http.Request) { // any number of times to continue the process. A status of 200 does not indicate that all keys were encrypted using the // new key. Check the response data to determine if the rotation process is complete. func (a *App) RotateApiKey(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() var requestBody map[string]string err := json.NewDecoder(r.Body).Decode(&requestBody) if err != nil { @@ -489,25 +506,19 @@ func (a *App) RotateApiKey(w http.ResponseWriter, r *http.Request) { return } - totpComplete, totpIncomplete, err := newKey.ReEncryptTOTPs(a.GetDB(), oldKey) + webauthnStats, err := newKey.ReEncryptWebAuthnUsers(ctx, a.GetDB(), oldKey) if err != nil { - log.Printf("failed to re-encrypt TOTP data: %s", err) - jsonResponse(w, internalServerError, http.StatusInternalServerError) - return + log.Printf("failed to re-encrypt one or more WebAuthn record: %s", err) } - webauthnComplete, webauthnIncomplete, err := newKey.ReEncryptWebAuthnUsers(a.GetDB(), oldKey) + totpStats, err := newKey.ReEncryptTOTPs(ctx, a.GetDB(), oldKey) if err != nil { - log.Printf("failed to re-encrypt WebAuthn data: %s", err) - jsonResponse(w, internalServerError, http.StatusInternalServerError) - return + log.Printf("failed to re-encrypt one or more TOTP record: %s", err) } - responseBody := map[string]int{ - "totpComplete": totpComplete, - "totpIncomplete": totpIncomplete, - "webauthnComplete": webauthnComplete, - "webauthnIncomplete": webauthnIncomplete, + responseBody := BatchStats{ + TOTP: totpStats, + Webauthn: webauthnStats, } jsonResponse(w, responseBody, http.StatusOK) diff --git a/apikey_test.go b/apikey_test.go index 918e2b6..365845b 100644 --- a/apikey_test.go +++ b/apikey_test.go @@ -366,7 +366,11 @@ func (ms *MfaSuite) TestAppRotateApiKey() { key := user.ApiKey must(db.Store(config.ApiKeyTable, key)) - totp := ms.newTOTP(key) + const numberOfTOTPs = 1000 + totpList := make([]TOTP, numberOfTOTPs) + for i := range totpList { + totpList[i] = ms.newTOTP(key) + } newKey := newTestKey() must(db.Store(config.ApiKeyTable, newKey)) @@ -425,7 +429,9 @@ func (ms *MfaSuite) TestAppRotateApiKey() { request.Header.Set(HeaderAPISecret, tt.key.Secret) ctxWithUser := context.WithValue(request.Context(), UserContextKey, tt.key) - request = request.WithContext(ctxWithUser) + ctxWithDeadline, cancel := context.WithTimeout(ctxWithUser, time.Second) + defer cancel() + request = request.WithContext(ctxWithDeadline) res := httptest.NewRecorder() Router(ms.app).ServeHTTP(res, request) @@ -436,14 +442,24 @@ func (ms *MfaSuite) TestAppRotateApiKey() { return } - var response map[string]int + var response BatchStats ms.decodeBody(res.Body.Bytes(), &response) - ms.Equal(1, response["totpComplete"]) - ms.Equal(1, response["webauthnComplete"]) - - totpFromDB := TOTP{UUID: totp.UUID, ApiKey: newKey.Key} - must(db.Load(config.TotpTable, "uuid", totp.UUID, &totpFromDB)) - ms.Equal(newKey.Key, totpFromDB.ApiKey) + ms.Greater(response.TOTP.Complete, 0, "none of the TOTPs were re-encrypted") + ms.Less(response.TOTP.Complete, numberOfTOTPs, "test didn't cancel before completion") + ms.Equalf(numberOfTOTPs, response.TOTP.Complete+response.TOTP.Incomplete, + "total of TOTP.Complete (%d) and TOTP.Incomplete (%d) should equal the total number of TOTPs (%d)", + response.TOTP.Complete, response.TOTP.Incomplete, numberOfTOTPs) + ms.Equal(1, response.Webauthn.Complete) + + foundOne := false + for i := range totpList { + totpFromDB := TOTP{UUID: totpList[i].UUID, ApiKey: newKey.Key} + must(db.Load(config.TotpTable, "uuid", totpList[i].UUID, &totpFromDB)) + if newKey.Key == totpFromDB.ApiKey { + foundOne = true + } + } + ms.True(foundOne, "did not find a TOTP with the new key") dbUser := WebauthnUser{ID: user.ID, Store: db, ApiKey: newKey} must(dbUser.Load()) @@ -522,10 +538,10 @@ func (ms *MfaSuite) TestApiKey_ReEncryptTOTPs() { _ = ms.newTOTP(oldKey) - complete, incomplete, err := newKey.ReEncryptTOTPs(storage, oldKey) + stats, err := newKey.ReEncryptTOTPs(ms.T().Context(), storage, oldKey) ms.NoError(err) - ms.Equal(1, complete) - ms.Equal(0, incomplete) + ms.Equal(1, stats.Complete) + ms.Equal(0, stats.Incomplete) } func (ms *MfaSuite) TestReEncryptWebAuthnUsers() { @@ -540,10 +556,10 @@ func (ms *MfaSuite) TestReEncryptWebAuthnUsers() { newKey := newTestKey() must(ms.app.GetDB().Store(ms.app.GetConfig().ApiKeyTable, newKey)) - complete, incomplete, err := newKey.ReEncryptWebAuthnUsers(storage, users[0].ApiKey) + stats, err := newKey.ReEncryptWebAuthnUsers(ms.T().Context(), storage, users[0].ApiKey) ms.NoError(err) - ms.Equal(0, incomplete) - ms.Equal(1, complete) + ms.Equal(0, stats.Incomplete) + ms.Equal(1, stats.Complete) // verify only users[0] is affected because each test user belongs to a different key for i, user := range users { @@ -585,7 +601,7 @@ func (ms *MfaSuite) TestReEncryptWebAuthnUser() { must(ms.app.GetDB().Store(ms.app.GetConfig().ApiKeyTable, newKey)) ms.NotEqual(newKey.Secret, tt.user.ApiKey.Secret) - err = newKey.ReEncryptWebAuthnUser(storage, tt.user) + err = newKey.ReEncryptWebAuthnUser(ms.T().Context(), storage, tt.user) ms.NoError(err) dbUser := WebauthnUser{ID: tt.user.ID, ApiKey: newKey, Store: storage} diff --git a/openapi.yaml b/openapi.yaml index 89ef141..e6f343b 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -519,22 +519,28 @@ paths: schema: type: object properties: - totpComplete: - type: integer - description: the number of TOTP codes that were encrypted with the new key - required: true - totpIncomplete: - type: integer - description: the number of TOTP codes that have not yet been encrypted with the new key - required: true - webauthnComplete: - type: integer - description: the number of Webauthn passkeys that were encrypted with the new key - required: true - webauthnIncomplete: - type: integer - description: the number of Webauthn passkeys that have not yet been encrypted with the new key - required: true + totp: + type: object + properties: + complete: + type: integer + description: the number of TOTP codes that were encrypted with the new key + required: true + incomplete: + type: integer + description: the number of TOTP codes that have not yet been encrypted with the new key + required: true + webauthn: + type: object + properties: + complete: + type: integer + description: the number of Webauthn passkeys that were encrypted with the new key + required: true + incomplete: + type: integer + description: the number of Webauthn passkeys that have not yet been encrypted with the new key + required: true 400: description: Bad Request content: diff --git a/storage.go b/storage.go index 3a17915..ebd2777 100644 --- a/storage.go +++ b/storage.go @@ -33,6 +33,11 @@ func NewStorage(config aws.Config) (*Storage, error) { // Store puts item at key. func (s *Storage) Store(table string, item interface{}) error { + return s.StoreCtx(context.Background(), table, item) +} + +// StoreCtx puts item at key. This is a context-aware equivalent of Store. +func (s *Storage) StoreCtx(ctx context.Context, table string, item interface{}) error { av, err := attributevalue.MarshalMap(item) if err != nil { return err @@ -43,7 +48,6 @@ func (s *Storage) Store(table string, item interface{}) error { TableName: aws.String(table), } - ctx := context.Background() _, err = s.client.PutItem(ctx, input) return err }