Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 52 additions & 41 deletions apikey.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mfa

import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
Expand Down Expand Up @@ -46,14 +47,26 @@ type ApiKey struct {
Store *Storage `dynamodbav:"-" json:"-"`
}

// CompletionStats stores the number of records processed in a batch operation
type CompletionStats struct {
Complete int `json:"complete"`
Incomplete int `json:"incomplete"`
}

// BatchStats stores the number of TOTP and Webauth records processed in a batch operation
type BatchStats struct {
TOTP CompletionStats `json:"totp"`
Webauthn CompletionStats `json:"webauthn"`
}

// Load refreshes an ApiKey from the database record
func (k *ApiKey) Load() error {
return k.Store.Load(envConfig.ApiKeyTable, ApiKeyTablePK, k.Key, k)
}

// Save an ApiKey to the database
func (k *ApiKey) Save() error {
return k.Store.Store(envConfig.ApiKeyTable, k)
func (k *ApiKey) Save(ctx context.Context) error {
return k.Store.StoreCtx(ctx, envConfig.ApiKeyTable, k)
}

// Hash generates a bcrypt hash from the Secret field and stores it in HashedSecret
Expand Down Expand Up @@ -221,61 +234,60 @@ func (k *ApiKey) Activate() error {

// ReEncryptTOTPs loads each TOTP record that was encrypted using the old key, re-encrypts it using the new
// key, and writes the updated data back to the database.
func (k *ApiKey) ReEncryptTOTPs(storage *Storage, oldKey ApiKey) (complete, incomplete int, err error) {
func (k *ApiKey) ReEncryptTOTPs(ctx context.Context, storage *Storage, oldKey ApiKey) (CompletionStats, error) {
var records []TOTP
err = storage.ScanApiKey(envConfig.TotpTable, oldKey.Key, &records)
err := storage.ScanApiKey(envConfig.TotpTable, oldKey.Key, &records)
if err != nil {
err = fmt.Errorf("failed to query %s table for key %s: %w", envConfig.TotpTable, oldKey.Key, err)
return
return CompletionStats{}, err
}

incomplete = len(records)
stats := CompletionStats{Incomplete: len(records)}
for _, r := range records {
err = k.ReEncryptLegacy(oldKey, &r.EncryptedTotpKey)
if err != nil {
err = fmt.Errorf("failed to re-encrypt TOTP %v: %w", r.UUID, err)
return
return stats, fmt.Errorf("failed to re-encrypt TOTP %v: %w", r.UUID, err)
}

r.ApiKey = k.Key

err = storage.Store(envConfig.TotpTable, &r)
err = storage.StoreCtx(ctx, envConfig.TotpTable, &r)
if err != nil {
err = fmt.Errorf("failed to store TOTP %v: %w", r.UUID, err)
return
return stats, fmt.Errorf("failed to store TOTP %v: %w", r.UUID, err)
}
complete++
incomplete--

stats.Complete++
stats.Incomplete--
}
return
return stats, nil
}

// ReEncryptWebAuthnUsers loads each WebAuthn record that was encrypted using the old key, re-encrypts it using the new
// key, and writes the updated data back to the database.
func (k *ApiKey) ReEncryptWebAuthnUsers(storage *Storage, oldKey ApiKey) (complete, incomplete int, err error) {
func (k *ApiKey) ReEncryptWebAuthnUsers(ctx context.Context, storage *Storage, oldKey ApiKey) (CompletionStats, error) {
var users []WebauthnUser
err = storage.ScanApiKey(envConfig.WebauthnTable, oldKey.Key, &users)
err := storage.ScanApiKey(envConfig.WebauthnTable, oldKey.Key, &users)
if err != nil {
err = fmt.Errorf("failed to query %s table for key %s: %w", envConfig.WebauthnTable, oldKey.Key, err)
return
return CompletionStats{}, err
}

incomplete = len(users)
stats := CompletionStats{Incomplete: len(users)}
for _, user := range users {
user.ApiKey = oldKey
err = k.ReEncryptWebAuthnUser(storage, user)
err = k.ReEncryptWebAuthnUser(ctx, storage, user)
if err != nil {
err = fmt.Errorf("failed to re-encrypt Webauthn %v: %w", user.ID, err)
return
return stats, fmt.Errorf("reencryption failed %v: %w", user.ID, err)
}
complete++
incomplete--

stats.Complete++
stats.Incomplete--
}
return
return stats, nil
}

// ReEncryptWebAuthnUser re-encrypts a WebAuthnUser using the new key, and writes the updated data back to the database.
func (k *ApiKey) ReEncryptWebAuthnUser(storage *Storage, user WebauthnUser) error {
func (k *ApiKey) ReEncryptWebAuthnUser(ctx context.Context, storage *Storage, user WebauthnUser) error {
oldKey := user.ApiKey
err := k.ReEncrypt(oldKey, &user.EncryptedSessionData)
if err != nil {
Expand All @@ -297,7 +309,7 @@ func (k *ApiKey) ReEncryptWebAuthnUser(storage *Storage, user WebauthnUser) erro
user.ApiKey = *k
user.ApiKeyValue = k.Key

err = storage.Store(envConfig.WebauthnTable, &user)
err = storage.StoreCtx(ctx, envConfig.WebauthnTable, &user)
if err != nil {
return err
}
Expand Down Expand Up @@ -348,6 +360,8 @@ func (k *ApiKey) ReEncryptLegacy(oldKey ApiKey, v *string) error {
// ActivateApiKey is the handler for the POST /api-key/activate endpoint. It creates the key secret and updates the
// database record.
func (a *App) ActivateApiKey(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

var requestBody struct {
ApiKeyValue string `json:"apiKeyValue"`
Email string `json:"email"`
Expand Down Expand Up @@ -393,7 +407,7 @@ func (a *App) ActivateApiKey(w http.ResponseWriter, r *http.Request) {
return
}

err = newKey.Save()
err = newKey.Save(ctx)
if err != nil {
log.Printf("failed to save key: %s", err)
jsonResponse(w, internalServerError, http.StatusInternalServerError)
Expand All @@ -412,6 +426,8 @@ func (a *App) ActivateApiKey(w http.ResponseWriter, r *http.Request) {

// CreateApiKey is the handler for the POST /api-key endpoint. It creates a new API Key and saves it to the database.
func (a *App) CreateApiKey(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

var requestBody struct {
Email string `json:"email"`
}
Expand All @@ -436,7 +452,7 @@ func (a *App) CreateApiKey(w http.ResponseWriter, r *http.Request) {
}

key.Store = a.db
err = key.Save()
err = key.Save(ctx)
if err != nil {
log.Printf("failed to save key: %s", err)
jsonResponse(w, internalServerError, http.StatusInternalServerError)
Expand All @@ -456,6 +472,7 @@ func (a *App) CreateApiKey(w http.ResponseWriter, r *http.Request) {
// any number of times to continue the process. A status of 200 does not indicate that all keys were encrypted using the
// new key. Check the response data to determine if the rotation process is complete.
func (a *App) RotateApiKey(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var requestBody map[string]string
err := json.NewDecoder(r.Body).Decode(&requestBody)
if err != nil {
Expand Down Expand Up @@ -489,25 +506,19 @@ func (a *App) RotateApiKey(w http.ResponseWriter, r *http.Request) {
return
}

totpComplete, totpIncomplete, err := newKey.ReEncryptTOTPs(a.GetDB(), oldKey)
webauthnStats, err := newKey.ReEncryptWebAuthnUsers(ctx, a.GetDB(), oldKey)
if err != nil {
log.Printf("failed to re-encrypt TOTP data: %s", err)
jsonResponse(w, internalServerError, http.StatusInternalServerError)
return
log.Printf("failed to re-encrypt one or more WebAuthn record: %s", err)
}

webauthnComplete, webauthnIncomplete, err := newKey.ReEncryptWebAuthnUsers(a.GetDB(), oldKey)
totpStats, err := newKey.ReEncryptTOTPs(ctx, a.GetDB(), oldKey)
if err != nil {
log.Printf("failed to re-encrypt WebAuthn data: %s", err)
jsonResponse(w, internalServerError, http.StatusInternalServerError)
return
log.Printf("failed to re-encrypt one or more TOTP record: %s", err)
}

responseBody := map[string]int{
"totpComplete": totpComplete,
"totpIncomplete": totpIncomplete,
"webauthnComplete": webauthnComplete,
"webauthnIncomplete": webauthnIncomplete,
responseBody := BatchStats{
TOTP: totpStats,
Webauthn: webauthnStats,
}

jsonResponse(w, responseBody, http.StatusOK)
Expand Down
48 changes: 32 additions & 16 deletions apikey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,11 @@ func (ms *MfaSuite) TestAppRotateApiKey() {
key := user.ApiKey
must(db.Store(config.ApiKeyTable, key))

totp := ms.newTOTP(key)
const numberOfTOTPs = 1000
totpList := make([]TOTP, numberOfTOTPs)
for i := range totpList {
totpList[i] = ms.newTOTP(key)
}

newKey := newTestKey()
must(db.Store(config.ApiKeyTable, newKey))
Expand Down Expand Up @@ -425,7 +429,9 @@ func (ms *MfaSuite) TestAppRotateApiKey() {
request.Header.Set(HeaderAPISecret, tt.key.Secret)

ctxWithUser := context.WithValue(request.Context(), UserContextKey, tt.key)
request = request.WithContext(ctxWithUser)
ctxWithDeadline, cancel := context.WithTimeout(ctxWithUser, time.Second)
defer cancel()
request = request.WithContext(ctxWithDeadline)

res := httptest.NewRecorder()
Router(ms.app).ServeHTTP(res, request)
Expand All @@ -436,14 +442,24 @@ func (ms *MfaSuite) TestAppRotateApiKey() {
return
}

var response map[string]int
var response BatchStats
ms.decodeBody(res.Body.Bytes(), &response)
ms.Equal(1, response["totpComplete"])
ms.Equal(1, response["webauthnComplete"])

totpFromDB := TOTP{UUID: totp.UUID, ApiKey: newKey.Key}
must(db.Load(config.TotpTable, "uuid", totp.UUID, &totpFromDB))
ms.Equal(newKey.Key, totpFromDB.ApiKey)
ms.Greater(response.TOTP.Complete, 0, "none of the TOTPs were re-encrypted")
ms.Less(response.TOTP.Complete, numberOfTOTPs, "test didn't cancel before completion")
ms.Equalf(numberOfTOTPs, response.TOTP.Complete+response.TOTP.Incomplete,
"total of TOTP.Complete (%d) and TOTP.Incomplete (%d) should equal the total number of TOTPs (%d)",
response.TOTP.Complete, response.TOTP.Incomplete, numberOfTOTPs)
ms.Equal(1, response.Webauthn.Complete)

foundOne := false
for i := range totpList {
totpFromDB := TOTP{UUID: totpList[i].UUID, ApiKey: newKey.Key}
must(db.Load(config.TotpTable, "uuid", totpList[i].UUID, &totpFromDB))
if newKey.Key == totpFromDB.ApiKey {
foundOne = true
}
}
ms.True(foundOne, "did not find a TOTP with the new key")

dbUser := WebauthnUser{ID: user.ID, Store: db, ApiKey: newKey}
must(dbUser.Load())
Expand Down Expand Up @@ -522,10 +538,10 @@ func (ms *MfaSuite) TestApiKey_ReEncryptTOTPs() {

_ = ms.newTOTP(oldKey)

complete, incomplete, err := newKey.ReEncryptTOTPs(storage, oldKey)
stats, err := newKey.ReEncryptTOTPs(ms.T().Context(), storage, oldKey)
ms.NoError(err)
ms.Equal(1, complete)
ms.Equal(0, incomplete)
ms.Equal(1, stats.Complete)
ms.Equal(0, stats.Incomplete)
}

func (ms *MfaSuite) TestReEncryptWebAuthnUsers() {
Expand All @@ -540,10 +556,10 @@ func (ms *MfaSuite) TestReEncryptWebAuthnUsers() {
newKey := newTestKey()
must(ms.app.GetDB().Store(ms.app.GetConfig().ApiKeyTable, newKey))

complete, incomplete, err := newKey.ReEncryptWebAuthnUsers(storage, users[0].ApiKey)
stats, err := newKey.ReEncryptWebAuthnUsers(ms.T().Context(), storage, users[0].ApiKey)
ms.NoError(err)
ms.Equal(0, incomplete)
ms.Equal(1, complete)
ms.Equal(0, stats.Incomplete)
ms.Equal(1, stats.Complete)

// verify only users[0] is affected because each test user belongs to a different key
for i, user := range users {
Expand Down Expand Up @@ -585,7 +601,7 @@ func (ms *MfaSuite) TestReEncryptWebAuthnUser() {
must(ms.app.GetDB().Store(ms.app.GetConfig().ApiKeyTable, newKey))
ms.NotEqual(newKey.Secret, tt.user.ApiKey.Secret)

err = newKey.ReEncryptWebAuthnUser(storage, tt.user)
err = newKey.ReEncryptWebAuthnUser(ms.T().Context(), storage, tt.user)
ms.NoError(err)

dbUser := WebauthnUser{ID: tt.user.ID, ApiKey: newKey, Store: storage}
Expand Down
38 changes: 22 additions & 16 deletions openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -519,22 +519,28 @@ paths:
schema:
type: object
properties:
totpComplete:
type: integer
description: the number of TOTP codes that were encrypted with the new key
required: true
totpIncomplete:
type: integer
description: the number of TOTP codes that have not yet been encrypted with the new key
required: true
webauthnComplete:
type: integer
description: the number of Webauthn passkeys that were encrypted with the new key
required: true
webauthnIncomplete:
type: integer
description: the number of Webauthn passkeys that have not yet been encrypted with the new key
required: true
totp:
type: object
properties:
complete:
type: integer
description: the number of TOTP codes that were encrypted with the new key
required: true
incomplete:
type: integer
description: the number of TOTP codes that have not yet been encrypted with the new key
required: true
webauthn:
type: object
properties:
complete:
type: integer
description: the number of Webauthn passkeys that were encrypted with the new key
required: true
incomplete:
type: integer
description: the number of Webauthn passkeys that have not yet been encrypted with the new key
required: true
400:
description: Bad Request
content:
Expand Down
6 changes: 5 additions & 1 deletion storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ func NewStorage(config aws.Config) (*Storage, error) {

// Store puts item at key.
func (s *Storage) Store(table string, item interface{}) error {
return s.StoreCtx(context.Background(), table, item)
}

// StoreCtx puts item at key. This is a context-aware equivalent of Store.
func (s *Storage) StoreCtx(ctx context.Context, table string, item interface{}) error {
av, err := attributevalue.MarshalMap(item)
if err != nil {
return err
Expand All @@ -43,7 +48,6 @@ func (s *Storage) Store(table string, item interface{}) error {
TableName: aws.String(table),
}

ctx := context.Background()
_, err = s.client.PutItem(ctx, input)
return err
}
Expand Down