Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 77 additions & 71 deletions apikey.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mfa

import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
Expand All @@ -25,8 +26,6 @@ const ApiKeyTablePK = "value"
const (
paramNewKeyId = "newKeyId"
paramNewKeySecret = "newKeySecret"
paramOldKeyId = "oldKeyId"
paramOldKeySecret = "oldKeySecret"
)

const (
Expand All @@ -48,14 +47,26 @@ type ApiKey struct {
Store *Storage `dynamodbav:"-" json:"-"`
}

// CompletionStats stores the number of records processed in a batch operation
type CompletionStats struct {
Complete int `json:"complete"`
Incomplete int `json:"incomplete"`
}

// BatchStats stores the number of TOTP and Webauth records processed in a batch operation
type BatchStats struct {
TOTP CompletionStats `json:"totp"`
Webauthn CompletionStats `json:"webauthn"`
}

// Load refreshes an ApiKey from the database record
func (k *ApiKey) Load() error {
return k.Store.Load(envConfig.ApiKeyTable, ApiKeyTablePK, k.Key, k)
}

// Save an ApiKey to the database
func (k *ApiKey) Save() error {
return k.Store.Store(envConfig.ApiKeyTable, k)
func (k *ApiKey) Save(ctx context.Context) error {
return k.Store.StoreCtx(ctx, envConfig.ApiKeyTable, k)
}

// Hash generates a bcrypt hash from the Secret field and stores it in HashedSecret
Expand Down Expand Up @@ -223,61 +234,60 @@ func (k *ApiKey) Activate() error {

// ReEncryptTOTPs loads each TOTP record that was encrypted using the old key, re-encrypts it using the new
// key, and writes the updated data back to the database.
func (k *ApiKey) ReEncryptTOTPs(storage *Storage, oldKey ApiKey) (complete, incomplete int, err error) {
func (k *ApiKey) ReEncryptTOTPs(ctx context.Context, storage *Storage, oldKey ApiKey) (CompletionStats, error) {
var records []TOTP
err = storage.ScanApiKey(envConfig.TotpTable, oldKey.Key, &records)
err := storage.ScanApiKey(envConfig.TotpTable, oldKey.Key, &records)
if err != nil {
err = fmt.Errorf("failed to query %s table for key %s: %w", envConfig.TotpTable, oldKey.Key, err)
return
return CompletionStats{}, err
}

incomplete = len(records)
stats := CompletionStats{Incomplete: len(records)}
for _, r := range records {
err = k.ReEncryptLegacy(oldKey, &r.EncryptedTotpKey)
if err != nil {
err = fmt.Errorf("failed to re-encrypt TOTP %v: %w", r.UUID, err)
return
return stats, fmt.Errorf("failed to re-encrypt TOTP %v: %w", r.UUID, err)
}

r.ApiKey = k.Key

err = storage.Store(envConfig.TotpTable, &r)
err = storage.StoreCtx(ctx, envConfig.TotpTable, &r)
if err != nil {
err = fmt.Errorf("failed to store TOTP %v: %w", r.UUID, err)
return
return stats, fmt.Errorf("failed to store TOTP %v: %w", r.UUID, err)
}
complete++
incomplete--

stats.Complete++
stats.Incomplete--
}
return
return stats, nil
}

// ReEncryptWebAuthnUsers loads each WebAuthn record that was encrypted using the old key, re-encrypts it using the new
// key, and writes the updated data back to the database.
func (k *ApiKey) ReEncryptWebAuthnUsers(storage *Storage, oldKey ApiKey) (complete, incomplete int, err error) {
func (k *ApiKey) ReEncryptWebAuthnUsers(ctx context.Context, storage *Storage, oldKey ApiKey) (CompletionStats, error) {
var users []WebauthnUser
err = storage.ScanApiKey(envConfig.WebauthnTable, oldKey.Key, &users)
err := storage.ScanApiKey(envConfig.WebauthnTable, oldKey.Key, &users)
if err != nil {
err = fmt.Errorf("failed to query %s table for key %s: %w", envConfig.WebauthnTable, oldKey.Key, err)
return
return CompletionStats{}, err
}

incomplete = len(users)
stats := CompletionStats{Incomplete: len(users)}
for _, user := range users {
user.ApiKey = oldKey
err = k.ReEncryptWebAuthnUser(storage, user)
err = k.ReEncryptWebAuthnUser(ctx, storage, user)
if err != nil {
err = fmt.Errorf("failed to re-encrypt Webauthn %v: %w", user.ID, err)
return
return stats, fmt.Errorf("reencryption failed %v: %w", user.ID, err)
}
complete++
incomplete--

stats.Complete++
stats.Incomplete--
}
return
return stats, nil
}

// ReEncryptWebAuthnUser re-encrypts a WebAuthnUser using the new key, and writes the updated data back to the database.
func (k *ApiKey) ReEncryptWebAuthnUser(storage *Storage, user WebauthnUser) error {
func (k *ApiKey) ReEncryptWebAuthnUser(ctx context.Context, storage *Storage, user WebauthnUser) error {
oldKey := user.ApiKey
err := k.ReEncrypt(oldKey, &user.EncryptedSessionData)
if err != nil {
Expand All @@ -299,7 +309,7 @@ func (k *ApiKey) ReEncryptWebAuthnUser(storage *Storage, user WebauthnUser) erro
user.ApiKey = *k
user.ApiKeyValue = k.Key

err = storage.Store(envConfig.WebauthnTable, &user)
err = storage.StoreCtx(ctx, envConfig.WebauthnTable, &user)
if err != nil {
return err
}
Expand Down Expand Up @@ -350,6 +360,8 @@ func (k *ApiKey) ReEncryptLegacy(oldKey ApiKey, v *string) error {
// ActivateApiKey is the handler for the POST /api-key/activate endpoint. It creates the key secret and updates the
// database record.
func (a *App) ActivateApiKey(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

var requestBody struct {
ApiKeyValue string `json:"apiKeyValue"`
Email string `json:"email"`
Expand Down Expand Up @@ -395,18 +407,27 @@ func (a *App) ActivateApiKey(w http.ResponseWriter, r *http.Request) {
return
}

err = newKey.Save()
err = newKey.Save(ctx)
if err != nil {
log.Printf("failed to save key: %s", err)
jsonResponse(w, internalServerError, http.StatusInternalServerError)
return
}

jsonResponse(w, map[string]string{"apiSecret": newKey.Secret}, http.StatusOK)
response := map[string]string{
"email": newKey.Email,
"apiKeyValue": newKey.Key,
"apiSecret": newKey.Secret,
"activatedAt": time.Unix(int64(newKey.ActivatedAt)/1000, 0).UTC().Format(time.RFC3339),
"createdAt": time.Unix(int64(newKey.CreatedAt)/1000, 0).UTC().Format(time.RFC3339),
}
jsonResponse(w, response, http.StatusOK)
}

// CreateApiKey is the handler for the POST /api-key endpoint. It creates a new API Key and saves it to the database.
func (a *App) CreateApiKey(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

var requestBody struct {
Email string `json:"email"`
}
Expand All @@ -431,7 +452,7 @@ func (a *App) CreateApiKey(w http.ResponseWriter, r *http.Request) {
}

key.Store = a.db
err = key.Save()
err = key.Save(ctx)
if err != nil {
log.Printf("failed to save key: %s", err)
jsonResponse(w, internalServerError, http.StatusInternalServerError)
Expand All @@ -451,71 +472,56 @@ func (a *App) CreateApiKey(w http.ResponseWriter, r *http.Request) {
// any number of times to continue the process. A status of 200 does not indicate that all keys were encrypted using the
// new key. Check the response data to determine if the rotation process is complete.
func (a *App) RotateApiKey(w http.ResponseWriter, r *http.Request) {
requestBody, err := parseRotateKeyRequestBody(r.Body)
ctx := r.Context()
var requestBody map[string]string
err := json.NewDecoder(r.Body).Decode(&requestBody)
if err != nil {
if strings.HasSuffix(err.Error(), "is required") {
jsonResponse(w, err, http.StatusBadRequest)
} else {
log.Printf("invalid request in RotateApiKey: %s", err)
jsonResponse(w, invalidRequest, http.StatusBadRequest)
}
log.Printf("invalid request in ActivateApiKey: %s", err)
jsonResponse(w, invalidRequest, http.StatusBadRequest)
return
}

oldKey := ApiKey{Key: requestBody[paramOldKeyId], Store: a.GetDB()}
err = oldKey.loadAndCheck(requestBody[paramOldKeySecret])
if err != nil {
log.Printf("old key is not valid: %s", err)
jsonResponse(w, apiKeyNotFound, http.StatusNotFound)
if requestBody[paramNewKeyId] == "" {
jsonResponse(w, paramNewKeyId+" is required", http.StatusBadRequest)
return
}

newKey := ApiKey{Key: requestBody[paramNewKeyId], Store: a.GetDB()}
err = newKey.loadAndCheck(requestBody[paramNewKeySecret])
if err != nil {
log.Printf("new key is not valid: %s", err)
jsonResponse(w, apiKeyNotFound, http.StatusNotFound)
if requestBody[paramNewKeySecret] == "" {
jsonResponse(w, paramNewKeySecret+" is required", http.StatusBadRequest)
return
}

totpComplete, totpIncomplete, err := newKey.ReEncryptTOTPs(a.GetDB(), oldKey)
oldKey, err := getAPIKey(r)
if err != nil {
log.Printf("failed to re-encrypt TOTP data: %s", err)
log.Printf("Rotate API key error: %v", err)
jsonResponse(w, internalServerError, http.StatusInternalServerError)
return
}

webauthnComplete, webauthnIncomplete, err := newKey.ReEncryptWebAuthnUsers(a.GetDB(), oldKey)
newKey := ApiKey{Key: requestBody[paramNewKeyId], Store: a.GetDB()}
err = newKey.loadAndCheck(requestBody[paramNewKeySecret])
if err != nil {
log.Printf("failed to re-encrypt WebAuthn data: %s", err)
jsonResponse(w, internalServerError, http.StatusInternalServerError)
log.Printf("new key is not valid: %s", err)
jsonResponse(w, apiKeyNotFound, http.StatusNotFound)
return
}

responseBody := map[string]int{
"totpComplete": totpComplete,
"totpIncomplete": totpIncomplete,
"webauthnComplete": webauthnComplete,
"webauthnIncomplete": webauthnIncomplete,
webauthnStats, err := newKey.ReEncryptWebAuthnUsers(ctx, a.GetDB(), oldKey)
if err != nil {
log.Printf("failed to re-encrypt one or more WebAuthn record: %s", err)
}

jsonResponse(w, responseBody, http.StatusOK)
}

func parseRotateKeyRequestBody(body io.Reader) (map[string]string, error) {
var requestBody map[string]string
err := json.NewDecoder(body).Decode(&requestBody)
totpStats, err := newKey.ReEncryptTOTPs(ctx, a.GetDB(), oldKey)
if err != nil {
return nil, fmt.Errorf("invalid request in RotateApiKey: %w", err)
log.Printf("failed to re-encrypt one or more TOTP record: %s", err)
}

fields := []string{paramNewKeyId, paramNewKeySecret, paramOldKeyId, paramOldKeySecret}
for _, field := range fields {
if _, ok := requestBody[field]; !ok {
return nil, fmt.Errorf("%s is required", field)
}
responseBody := BatchStats{
TOTP: totpStats,
Webauthn: webauthnStats,
}
return requestBody, nil

jsonResponse(w, responseBody, http.StatusOK)
}

func (k *ApiKey) loadAndCheck(secret string) error {
Expand Down
Loading