Skip to content

Commit

Permalink
Implement EncryptedData struct (#3302)
Browse files Browse the repository at this point in the history
Relates to #3292

Adds a new struct for representing an encrypted piece of data. As per
@jhrozek's design doc, this struct will be serialized and stored in the
database. This will be implemented in a later PR. For now, the Engine
interface has been refactored to accept and return instances of these
structs.
  • Loading branch information
dmjb committed May 13, 2024
1 parent b575d92 commit a231e84
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 94 deletions.
14 changes: 10 additions & 4 deletions internal/controlplane/handlers_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (s *Server) GetAuthorizationURL(ctx context.Context,
if err != nil {
return nil, status.Errorf(codes.Internal, "error encrypting redirect URL: %s", err)
}
redirectUrl = sql.NullString{Valid: true, String: encryptedRedirectUrl}
redirectUrl = sql.NullString{Valid: true, String: encryptedRedirectUrl.EncodedData}
}

// Insert the new session state into the database along with the user's project ID
Expand Down Expand Up @@ -246,7 +246,10 @@ func (s *Server) processOAuthCallback(ctx context.Context, w http.ResponseWriter
logger.BusinessRecord(ctx).ProviderID = p.ID

if stateData.RedirectUrl.Valid {
redirectUrl, err := s.cryptoEngine.DecryptString(stateData.RedirectUrl.String)
// TODO: get rid of this once we store the EncryptedData struct in
// the database.
encryptedData := mcrypto.NewBackwardsCompatibleEncryptedData(stateData.RedirectUrl.String)
redirectUrl, err := s.cryptoEngine.DecryptString(encryptedData)
if err != nil {
return fmt.Errorf("error decrypting redirect URL: %w", err)
}
Expand Down Expand Up @@ -325,7 +328,10 @@ func (s *Server) processAppCallback(ctx context.Context, w http.ResponseWriter,

// If we have a redirect URL, redirect the user, otherwise show a success page
if stateData.RedirectUrl.Valid {
redirectUrl, err := s.cryptoEngine.DecryptString(stateData.RedirectUrl.String)
// TODO: get rid of this once we store the EncryptedData struct in
// the database.
encryptedData := mcrypto.NewBackwardsCompatibleEncryptedData(stateData.RedirectUrl.String)
redirectUrl, err := s.cryptoEngine.DecryptString(encryptedData)
if err != nil {
return fmt.Errorf("error decrypting redirect URL: %w", err)
}
Expand Down Expand Up @@ -473,7 +479,7 @@ func (s *Server) StoreProviderToken(ctx context.Context,
_, err = s.store.UpsertAccessToken(ctx, db.UpsertAccessTokenParams{
ProjectID: projectID,
Provider: provider.Name,
EncryptedToken: encryptedToken,
EncryptedToken: encryptedToken.EncodedData,
OwnerFilter: owner,
})
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/controlplane/handlers_oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ func TestProviderCallback(t *testing.T) {
}
encryptedUrl := sql.NullString{
Valid: true,
String: encryptedUrlString,
String: encryptedUrlString.EncodedData,
}

tx := sql.Tx{}
Expand Down
41 changes: 13 additions & 28 deletions internal/crypto/algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"crypto/cipher"
"crypto/rand"
"errors"
"fmt"
"io"

"golang.org/x/crypto/argon2"
Expand All @@ -29,36 +28,22 @@ import (

// EncryptionAlgorithm represents a crypto algorithm used by the Engine
type EncryptionAlgorithm interface {
Encrypt(data []byte) ([]byte, error)
Decrypt(data []byte) ([]byte, error)
Encrypt(data []byte, salt []byte) ([]byte, error)
Decrypt(data []byte, salt []byte) ([]byte, error)
}

const maxSize = 32 * 1024 * 1024

// In a real application, you should use a unique salt for
// each key and save it with the encrypted data.
var (
salt = []byte("somesalt")
errUnknownAlgorithm = errors.New("unexpected encryption algorithm")
)

// EncryptionAlgorithmType is an enum of supported encryption algorithms
type EncryptionAlgorithmType string

const (
// AESCFB is the AES-CFB algorithm
AESCFB EncryptionAlgorithmType = "aes-cfb"
// Aes256Cfb is the AES-256-CFB algorithm
Aes256Cfb EncryptionAlgorithmType = "aes-256-cfb"
)

// AlgorithmTypeFromString converts a string to an EncryptionAlgorithmType
// or returns errUnknownAlgorithm.
func AlgorithmTypeFromString(input string) (EncryptionAlgorithmType, error) {
// for backwards compatibility - default to AES-CFB if string is empty
if input == "" || input == string(AESCFB) {
return AESCFB, nil
}
return "", fmt.Errorf("%w: %s", errUnknownAlgorithm, input)
}
const maxSize = 32 * 1024 * 1024

// ErrUnknownAlgorithm is used when an incorrect algorithm name is used.
var ErrUnknownAlgorithm = errors.New("unexpected encryption algorithm")

func newAlgorithm(key []byte) EncryptionAlgorithm {
// TODO: Make the type of algorithm selectable
Expand All @@ -70,11 +55,11 @@ type aesCFBSAlgorithm struct {
}

// Encrypt encrypts a row of data.
func (a *aesCFBSAlgorithm) Encrypt(data []byte) ([]byte, error) {
func (a *aesCFBSAlgorithm) Encrypt(data []byte, salt []byte) ([]byte, error) {
if len(data) > maxSize {
return nil, status.Errorf(codes.InvalidArgument, "data is too large (>32MB)")
}
block, err := aes.NewCipher(a.deriveKey())
block, err := aes.NewCipher(a.deriveKey(salt))
if err != nil {
return nil, status.Errorf(codes.Unknown, "failed to create cipher: %s", err)
}
Expand All @@ -93,8 +78,8 @@ func (a *aesCFBSAlgorithm) Encrypt(data []byte) ([]byte, error) {
}

// Decrypt decrypts a row of data.
func (a *aesCFBSAlgorithm) Decrypt(ciphertext []byte) ([]byte, error) {
block, err := aes.NewCipher(a.deriveKey())
func (a *aesCFBSAlgorithm) Decrypt(ciphertext []byte, salt []byte) ([]byte, error) {
block, err := aes.NewCipher(a.deriveKey(salt))
if err != nil {
return nil, status.Errorf(codes.Unknown, "failed to create cipher: %s", err)
}
Expand All @@ -110,6 +95,6 @@ func (a *aesCFBSAlgorithm) Decrypt(ciphertext []byte) ([]byte, error) {
}

// Function to derive a key from a passphrase using Argon2
func (a *aesCFBSAlgorithm) deriveKey() []byte {
func (a *aesCFBSAlgorithm) deriveKey(salt []byte) []byte {
return argon2.IDKey(a.encryptionKey, salt, 1, 64*1024, 4, 32)
}
105 changes: 64 additions & 41 deletions internal/crypto/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,25 @@ import (

// Engine provides all functions to encrypt and decrypt data
type Engine interface {
EncryptOAuthToken(token *oauth2.Token) (string, error)
DecryptOAuthToken(encToken string) (oauth2.Token, error)
EncryptString(data string) (string, error)
DecryptString(encData string) (string, error)
// EncryptOAuthToken takes an OAuth2 token, serializes to JSON and encrypts it.
EncryptOAuthToken(token *oauth2.Token) (EncryptedData, error)
// DecryptOAuthToken takes an OAuth2 token encrypted using EncryptOAuthToken and decrypts it.
DecryptOAuthToken(encryptedToken EncryptedData) (oauth2.Token, error)
// EncryptString encrypts a string.
EncryptString(data string) (EncryptedData, error)
// DecryptString decrypts a string encrypted with EncryptString.
DecryptString(encryptedString EncryptedData) (string, error)
}

var (
// TODO: get rid of this when we allow per-secret salting.
legacySalt = []byte("somesalt")
// ErrDecrypt is returned when we cannot decrypt a secret.
ErrDecrypt = errors.New("unable to decrypt")
// ErrEncrypt is returned when we cannot encrypt a secret.
ErrEncrypt = errors.New("unable to encrypt")
)

type engine struct {
algorithm EncryptionAlgorithm
}
Expand All @@ -60,69 +73,79 @@ func NewEngine(key []byte) Engine {
return &engine{algorithm: newAlgorithm(key)}
}

// EncryptOAuthToken encrypts an oauth token
func (e *engine) EncryptOAuthToken(token *oauth2.Token) (string, error) {
// Convert token to JSON
func (e *engine) EncryptOAuthToken(token *oauth2.Token) (EncryptedData, error) {
// Convert token to JSON.
jsonData, err := json.Marshal(token)
if err != nil {
return "", fmt.Errorf("unable to marshal token to json: %w", err)
return EncryptedData{}, fmt.Errorf("unable to marshal token to json: %w", err)
}
encrypted, err := e.algorithm.Encrypt(jsonData)

// Encrypt the JSON.
encrypted, err := e.encrypt(jsonData)
if err != nil {
return "", fmt.Errorf("unable to encrypt token: %w", err)
return EncryptedData{}, fmt.Errorf("unable to encrypt token: %w", err)
}
return base64.StdEncoding.EncodeToString(encrypted), nil
return encrypted, nil
}

// DecryptOAuthToken decrypts an encrypted oauth token
func (e *engine) DecryptOAuthToken(encToken string) (oauth2.Token, error) {
var decryptedToken oauth2.Token
func (e *engine) DecryptOAuthToken(encryptedToken EncryptedData) (result oauth2.Token, err error) {
// Decrypt the token.
token, err := e.decrypt(encryptedToken)
if err != nil {
return result, err
}

// base64 decode the token
decodeToken, err := base64.StdEncoding.DecodeString(encToken)
// Deserialize to token struct.
err = json.Unmarshal(token, &result)
if err != nil {
return decryptedToken, err
return result, err
}
return result, nil
}

// decrypt the token
token, err := e.algorithm.Decrypt(decodeToken)
func (e *engine) EncryptString(data string) (EncryptedData, error) {
encrypted, err := e.encrypt([]byte(data))
if err != nil {
return decryptedToken, err
return EncryptedData{}, err
}
return encrypted, nil
}

// serialise token *oauth.Token
err = json.Unmarshal(token, &decryptedToken)
func (e *engine) DecryptString(encryptedString EncryptedData) (string, error) {
decrypted, err := e.decrypt(encryptedString)
if err != nil {
return decryptedToken, err
return "", fmt.Errorf("%w: %w", ErrDecrypt, err)
}
return decryptedToken, nil
return string(decrypted), nil
}

// EncryptString encrypts a string
func (e *engine) EncryptString(data string) (string, error) {
encrypted, err := e.algorithm.Encrypt([]byte(data))
func (e *engine) encrypt(data []byte) (EncryptedData, error) {
encrypted, err := e.algorithm.Encrypt(data, legacySalt)
if err != nil {
return "", err
return EncryptedData{}, err
}

return base64.StdEncoding.EncodeToString(encrypted), nil
encoded := base64.StdEncoding.EncodeToString(encrypted)
// TODO:
// 1. when we support more than one algorithm, remove hard-coding.
// 2. Allow salt to be randomly generated per secret.
// 3. Set key version.
return NewBackwardsCompatibleEncryptedData(encoded), nil
}

// DecryptString decrypts an encrypted string
func (e *engine) DecryptString(encData string) (string, error) {
var decrypted string

// base64 decode the string
decodeToken, err := base64.StdEncoding.DecodeString(encData)
if err != nil {
return decrypted, err
func (e *engine) decrypt(data EncryptedData) ([]byte, error) {
// TODO: Select algorithm based on Algorithm field when we support
// more than one algorithm.
if data.Algorithm != Aes256Cfb {
return nil, fmt.Errorf("%w: %s", ErrUnknownAlgorithm, data.Algorithm)
}

// decrypt the string
token, err := e.algorithm.Decrypt(decodeToken)
// base64 decode the string
encrypted, err := base64.StdEncoding.DecodeString(data.EncodedData)
if err != nil {
return decrypted, err
return nil, err
}

return string(token), nil
// decrypt the data
return e.algorithm.Decrypt(encrypted, data.Salt)
}
25 changes: 13 additions & 12 deletions internal/crypto/mock/engine.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit a231e84

Please sign in to comment.