diff --git a/wolfcrypt/src/wc_mlkem.c b/wolfcrypt/src/wc_mlkem.c index ccdd671256..7a3b04338a 100644 --- a/wolfcrypt/src/wc_mlkem.c +++ b/wolfcrypt/src/wc_mlkem.c @@ -28,30 +28,30 @@ * post-quantum-cryptography-standardization/round-3-submissions */ -/* Possible Kyber options: +/* Possible ML-KEM options: * - * WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM Default: OFF + * WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM Default: OFF * Uses less dynamic memory to perform key generation. * Has a small performance trade-off. * Only usable with C implementation. * - * WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM Default: OFF + * WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM Default: OFF * Uses less dynamic memory to perform encapsulation. * Affects decapsulation too as encapsulation called. * Has a small performance trade-off. * Only usable with C implementation. * - * WOLFSSL_MLKEM_NO_MAKE_KEY Default: OFF + * WOLFSSL_MLKEM_NO_MAKE_KEY Default: OFF * Disable the make key or key generation API. * Reduces the code size. * Turn on when only doing encapsulation. * - * WOLFSSL_MLKEM_NO_ENCAPSULATE Default: OFF + * WOLFSSL_MLKEM_NO_ENCAPSULATE Default: OFF * Disable the encapsulation API. * Reduces the code size. * Turn on when doing make key/decapsulation. * - * WOLFSSL_MLKEM_NO_DECAPSULATE Default: OFF + * WOLFSSL_MLKEM_NO_DECAPSULATE Default: OFF * Disable the decapsulation API. * Reduces the code size. * Turn on when only doing encapsulation. @@ -59,7 +59,7 @@ * WOLFSSL_MLKEM_CACHE_A Default: OFF * Stores the matrix A during key generation for use in encapsulation when * performing decapsulation. - * KyberKey is 8KB larger but decapsulation is significantly faster. + * MlKemKey is 8KB larger but decapsulation is significantly faster. * Turn on when performing make key and decapsulation with same object. * * WOLFSSL_MLKEM_DYNAMIC_KEYS Default: OFF @@ -282,6 +282,8 @@ static int mlkemkey_alloc_pub(MlKemKey* key, unsigned int k) */ static int mlkemkey_alloc_a(MlKemKey* key, unsigned int k) { + int ret = 0; + if (key->a != NULL) { XFREE(key->a, key->heap, DYNAMIC_TYPE_TMP_BUFFER); key->a = NULL; @@ -289,9 +291,10 @@ static int mlkemkey_alloc_a(MlKemKey* key, unsigned int k) key->a = (sword16*)XMALLOC(k * k * MLKEM_N * sizeof(sword16), key->heap, DYNAMIC_TYPE_TMP_BUFFER); if (key->a == NULL) { - return MEMORY_E; + ret = MEMORY_E; } - return 0; + + return ret; } #endif /* WOLFSSL_MLKEM_CACHE_A */ #endif /* WOLFSSL_MLKEM_DYNAMIC_KEYS */ @@ -304,19 +307,20 @@ static int mlkemkey_alloc_a(MlKemKey* key, unsigned int k) * * Allocates and initializes a ML-KEM key object. * - * @param [in] type Type of key: - * WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024, - * KYBER512, KYBER768, KYBER1024. - * @param [in] heap Dynamic memory hint. - * @param [in] devId Device Id. - * @return Pointer to new MlKemKey object, or NULL on failure. + * @param [in] type Type of key: + * WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024, + * KYBER512, KYBER768, KYBER1024. + * @param [in] heap Dynamic memory hint. + * @param [in] devId Device Id. + * @return Pointer to new MlKemKey object on success. + * @return NULL on failure. */ - MlKemKey* wc_MlKemKey_New(int type, void* heap, int devId) { int ret; - MlKemKey* key = (MlKemKey*)XMALLOC(sizeof(MlKemKey), heap, - DYNAMIC_TYPE_TMP_BUFFER); + MlKemKey* key; + + key = (MlKemKey*)XMALLOC(sizeof(MlKemKey), heap, DYNAMIC_TYPE_TMP_BUFFER); if (key != NULL) { ret = wc_MlKemKey_Init(key, type, heap, devId); if (ret != 0) { @@ -333,31 +337,36 @@ MlKemKey* wc_MlKemKey_New(int type, void* heap, int devId) * * Frees resources associated with a ML-KEM key object and sets pointer to NULL. * - * @param [in] key ML-KEM key object to delete. - * @param [in, out] key_p Pointer to key pointer to set to NULL. + * @param [in] key ML-KEM key object to delete. + * @param [in, out] key_p Pointer to key pointer to set to NULL. * @return 0 on success. * @return BAD_FUNC_ARG when key is NULL. */ - int wc_MlKemKey_Delete(MlKemKey* key, MlKemKey** key_p) { - void* heap; - if (key == NULL) - return BAD_FUNC_ARG; - heap = key->heap; - wc_MlKemKey_Free(key); - XFREE(key, heap, DYNAMIC_TYPE_TMP_BUFFER); - if (key_p != NULL) - *key_p = NULL; + int ret = 0; - return 0; + if (key == NULL) { + ret = BAD_FUNC_ARG; + } + else { + void* heap = key->heap; + + wc_MlKemKey_Free(key); + XFREE(key, heap, DYNAMIC_TYPE_TMP_BUFFER); + if (key_p != NULL) { + *key_p = NULL; + } + } + + return ret; } #endif /* !WC_NO_CONSTRUCTORS */ /** - * Initialize the Kyber key. + * Initialize the ML-KEM key. * - * @param [out] key Kyber key object to initialize. + * @param [out] key ML-KEM key object to initialize. * @param [in] type Type of key: * WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024, * KYBER512, KYBER768, KYBER1024. @@ -381,19 +390,19 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId) #ifndef WOLFSSL_NO_ML_KEM case WC_ML_KEM_512: #ifndef WOLFSSL_WC_ML_KEM_512 - /* Code not compiled in for Kyber-512. */ + /* Code not compiled in for ML-KEM-512. */ ret = NOT_COMPILED_IN; #endif break; case WC_ML_KEM_768: #ifndef WOLFSSL_WC_ML_KEM_768 - /* Code not compiled in for Kyber-768. */ + /* Code not compiled in for ML-KEM-768. */ ret = NOT_COMPILED_IN; #endif break; case WC_ML_KEM_1024: #ifndef WOLFSSL_WC_ML_KEM_1024 - /* Code not compiled in for Kyber-1024. */ + /* Code not compiled in for ML-KEM-1024. */ ret = NOT_COMPILED_IN; #endif break; @@ -468,22 +477,42 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId) } #ifdef WOLF_PRIVATE_KEY_ID +/** + * Initialize the ML-KEM key with an id. + * + * @param [out] key ML-KEM key object to initialize. + * @param [in] type Type of key: + * WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024, + * KYBER512, KYBER768, KYBER1024. + * @param [in] id Identifier of key. + * @param [in] len Length of key identifier in bytes. + * @param [in] heap Dynamic memory hint. + * @param [in] devId Device Id. + * @return 0 on success. + * @return BAD_FUNC_ARG when key is NULL, id is NULL but len is not zero, or + * type is unrecognized. + * @return BUFFER_E when len is out of range. + * @return NOT_COMPILED_IN when key type is not supported. + */ int wc_MlKemKey_Init_Id(MlKemKey* key, int type, const unsigned char* id, int len, void* heap, int devId) { int ret = 0; - if (key == NULL || (id == NULL && len != 0)) { + /* Validate parameters. */ + if ((key == NULL) || (id == NULL && len != 0)) { ret = BAD_FUNC_ARG; } - if (ret == 0 && (len < 0 || len > MLKEM_MAX_ID_LEN)) { + if ((ret == 0) && ((len < 0) || (len > MLKEM_MAX_ID_LEN))) { ret = BUFFER_E; } if (ret == 0) { + /* Initialize key. */ ret = wc_MlKemKey_Init(key, type, heap, devId); } - if (ret == 0 && id != NULL && len != 0) { + if ((ret == 0) && (id != NULL) && (len != 0)) { + /* Store key identifier. */ XMEMCPY(key->id, id, (size_t)len); key->idLen = len; } @@ -491,16 +520,33 @@ int wc_MlKemKey_Init_Id(MlKemKey* key, int type, const unsigned char* id, return ret; } +/** + * Initialize the ML-KEM key with a label. + * + * @param [out] key ML-KEM key object to initialize. + * @param [in] type Type of key: + * WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024, + * KYBER512, KYBER768, KYBER1024. + * @param [in] label Label of key. Must be a null-terminated string. + * @param [in] heap Dynamic memory hint. + * @param [in] devId Device Id. + * @return 0 on success. + * @return BAD_FUNC_ARG when key or label is NULL, or type is unrecognized. + * @return BUFFER_E when label is too small or big. + * @return NOT_COMPILED_IN when key type is not supported. + */ int wc_MlKemKey_Init_Label(MlKemKey* key, int type, const char* label, void* heap, int devId) { int ret = 0; int labelLen = 0; - if (key == NULL || label == NULL) { + /* Validate parameters. */ + if ((key == NULL) || (label == NULL)) { ret = BAD_FUNC_ARG; } if (ret == 0) { + /* Validate label length. */ labelLen = (int)XSTRLEN(label); if ((labelLen == 0) || (labelLen > MLKEM_MAX_LABEL_LEN)) { ret = BUFFER_E; @@ -508,10 +554,11 @@ int wc_MlKemKey_Init_Label(MlKemKey* key, int type, const char* label, } if (ret == 0) { + /* Initialize key. */ ret = wc_MlKemKey_Init(key, type, heap, devId); } if (ret == 0) { - /* The string in key->label is not necessarily null-terminated. + /* Don't save string in key->label with null terminator. * Use key->labelLen to get the length if required. */ XMEMCPY(key->label, label, (size_t)labelLen); key->labelLen = labelLen; @@ -522,9 +569,9 @@ int wc_MlKemKey_Init_Label(MlKemKey* key, int type, const char* label, #endif /** - * Free the Kyber key object. + * Free the ML-KEM key object. * - * @param [in, out] key Kyber key object to dispose of. + * @param [in, out] key ML-KEM key object to dispose of. * @return 0 on success. */ int wc_MlKemKey_Free(MlKemKey* key) @@ -533,9 +580,7 @@ int wc_MlKemKey_Free(MlKemKey* key) #if defined(WOLF_CRYPTO_CB) && defined(WOLF_CRYPTO_CB_FREE) if (key->devId != INVALID_DEVID) { (void)wc_CryptoCb_Free(key->devId, WC_ALGO_TYPE_PK, - WC_PK_TYPE_PQC_KEM_KEYGEN, - WC_PQC_KEM_TYPE_KYBER, - (void*)key); + WC_PK_TYPE_PQC_KEM_KEYGEN, WC_PQC_KEM_TYPE_KYBER, (void*)key); /* always continue to software cleanup */ } #endif @@ -567,6 +612,9 @@ int wc_MlKemKey_Free(MlKemKey* key) ForceZero(key->priv, sizeof(key->priv)); #endif ForceZero(key->z, sizeof(key->z)); + + /* Clear flags as values are no longer set. */ + key->flags = 0; } return 0; @@ -576,7 +624,7 @@ int wc_MlKemKey_Free(MlKemKey* key) #ifndef WOLFSSL_MLKEM_NO_MAKE_KEY /** - * Make a Kyber key object using a random number generator. + * Make a ML-KEM key object using a random number generator. * * FIPS 203 - Algorithm 19: ML-KEM.KeyGen() * Generates an encapsulation key and a corresponding decapsulation key. @@ -590,13 +638,17 @@ int wc_MlKemKey_Free(MlKemKey* key) * > run internal key generation algorithm * 7: return (ek,dk) * - * @param [in, out] key Kyber key object. + * @param [in, out] key ML-KEM key object. * @param [in] rng Random number generator. * @return 0 on success. * @return BAD_FUNC_ARG when key or rng is NULL. * @return MEMORY_E when dynamic memory allocation failed. * @return RNG_FAILURE_E when generating random numbers failed. * @return DRBG_CONT_FAILURE when random number generator health check fails. + * @return ML_KEM_PCT_E when pairwise consistency test fails. FIPS only. + * @return BAD_COND_E when fault attack detected. + * @return NOT_COMPILED_IN when no random number generator is compiled in or + * key type is not supported. */ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng) { @@ -615,8 +667,8 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng) #else if (ret == 0) { #endif - ret = wc_CryptoCb_MakePqcKemKey(rng, WC_PQC_KEM_TYPE_KYBER, - key->type, key); + ret = wc_CryptoCb_MakePqcKemKey(rng, WC_PQC_KEM_TYPE_KYBER, key->type, + key); if (ret != WC_NO_ERR_TRACE(CRYPTOCB_UNAVAILABLE)) return ret; /* fall-through when unavailable */ @@ -637,7 +689,7 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng) * Step 6. run internal key generation algorithm * Step 7. public and private key are stored in key */ - ret = wc_KyberKey_MakeKeyWithRandom(key, rand, sizeof(rand)); + ret = wc_MlKemKey_MakeKeyWithRandom(key, rand, sizeof(rand)); } #ifdef HAVE_FIPS @@ -697,7 +749,7 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng) } /** - * Make a Kyber key object using random data. + * Make a ML-KEM key object using random data. * * FIPS 203 - Algorithm 16: ML-KEM.KeyGen_internal(d,z) * Uses randomness to generate an encapsulation key and a corresponding @@ -717,7 +769,7 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng) * 16-18: calculate t_hat from A_hat, s and e * ... * - * @param [in, out] key Kyber key object. + * @param [in, out] key ML-KEM key object. * @param [in] rand Random data. * @param [in] len Length of random data in bytes. * @return 0 on success. @@ -725,6 +777,7 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng) * @return BUFFER_E when length is not WC_ML_KEM_MAKEKEY_RAND_SZ. * @return NOT_COMPILED_IN when key type is not supported. * @return MEMORY_E when dynamic memory allocation failed. + * @return BAD_COND_E when fault attack detected. */ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand, int len) @@ -846,11 +899,12 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand, #ifdef WC_MLKEM_FAULT_HARDEN if (ret == 0) { XMEMCPY(sigma, buf + WC_ML_KEM_SYM_SZ, WC_ML_KEM_SYM_SZ); - /* Check that correct data was copied and pointer not changed. */ + /* Check that correct data was copied and pointer was not faulted. */ if (XMEMCMP(sigma, rho, WC_ML_KEM_SYM_SZ) == 0) { ret = BAD_COND_E; } - /* Check that rho is sigma - rho may have been modified. */ + /* Check that sigma is after rho - rho pointer may have been modified. + */ if (XMEMCMP(sigma, rho + WC_ML_KEM_SYM_SZ, WC_ML_KEM_SYM_SZ) != 0) { ret = BAD_COND_E; } @@ -928,7 +982,7 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand, /** * Get the size in bytes of cipher text for key. * - * @param [in] key Kyber key object. + * @param [in] key ML-KEM key object. * @param [out] len Length of cipher text in bytes. * @return 0 on success. * @return BAD_FUNC_ARG when key or len is NULL. @@ -991,10 +1045,10 @@ int wc_MlKemKey_CipherTextSize(MlKemKey* key, word32* len) } /** - * Size of a shared secret in bytes. Always KYBER_SS_SZ. + * Size of a shared secret in bytes. Always WC_ML_KEM_SS_SZ. * - * @param [in] key Kyber key object. Not used. - * @param [out] len Size of the shared secret created with a Kyber key. + * @param [in] key ML-KEM key object. Not used. + * @param [out] len Size of the shared secret created with a ML-KEM key. * @return 0 on success. * @return BAD_FUNC_ARG when len is NULL. */ @@ -1037,7 +1091,7 @@ int wc_MlKemKey_SharedSecretSize(MlKemKey* key, word32* len) * 23: c_2 <- ByteEncode_d_v(Compress_d_v(v)) * 24: return c <- (c_1||c_2) * - * @param [in] key Kyber key object. + * @param [in] key ML-KEM key object. * @param [in] m Random bytes. * @param [in] r Seed to feed to PRF when generating y, e1 and e2. * @param [out] c Calculated cipher text. @@ -1270,7 +1324,7 @@ static int wc_mlkemkey_check_h(MlKemKey* key) #endif /* Determine how big an encoded public key will be. */ - ret = wc_KyberKey_PublicKeySize(key, &pubKeyLen); + ret = wc_MlKemKey_PublicKeySize(key, &pubKeyLen); if (ret == 0) { #ifndef WOLFSSL_NO_MALLOC /* Allocate dynamic memory for encoded public key. */ @@ -1283,15 +1337,15 @@ static int wc_mlkemkey_check_h(MlKemKey* key) if (ret == 0) { #endif /* Encode public key - h is hash of encoded public key. */ - ret = wc_KyberKey_EncodePublicKey(key, pubKey, pubKeyLen); + ret = wc_MlKemKey_EncodePublicKey(key, pubKey, pubKeyLen); } #ifndef WOLFSSL_NO_MALLOC /* Dispose of encoded public key. */ XFREE(pubKey, key->heap, DYNAMIC_TYPE_TMP_BUFFER); - #endif + #endif } if ((ret == 0) && ((key->flags & MLKEM_FLAG_H_SET) == 0)) { - /* Implementation issue if h not cached and flag set. */ + /* Implementation issue if h not cached and flag not set. */ ret = BAD_STATE_E; } @@ -1314,16 +1368,17 @@ static int wc_mlkemkey_check_h(MlKemKey* key) * > run internal encapsulation algorithm * 6: return (K,c) * - * @param [in] key Kyber key object. - * @param [out] c Cipher text. - * @param [out] k Shared secret generated. + * @param [in] key ML-KEM key object. + * @param [out] ct Cipher text. + * @param [out] ss Shared secret generated. * @param [in] rng Random number generator. * @return 0 on success. - * @return BAD_FUNC_ARG when key, c, k or rng is NULL. + * @return BAD_FUNC_ARG when key, ct, ss or rng is NULL. + * @return BAD_STATE_E when public key not set. * @return NOT_COMPILED_IN when key type is not supported. * @return MEMORY_E when dynamic memory allocation failed. */ -int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k, +int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* ct, unsigned char* ss, WC_RNG* rng) { #ifndef WC_NO_RNG @@ -1334,9 +1389,13 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k, #endif /* Validate parameters. */ - if ((key == NULL) || (c == NULL) || (k == NULL) || (rng == NULL)) { + if ((key == NULL) || (ct == NULL) || (ss == NULL) || (rng == NULL)) { ret = BAD_FUNC_ARG; } + /* Check the public key has been set. */ + else if ((key->flags & MLKEM_FLAG_PUB_SET) == 0) { + ret = BAD_STATE_E; + } #ifdef WOLF_CRYPTO_CB if (ret == 0) { @@ -1347,8 +1406,8 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k, #else if (ret == 0) { #endif - ret = wc_CryptoCb_PqcEncapsulate(c, ctlen, k, KYBER_SS_SZ, rng, - WC_PQC_KEM_TYPE_KYBER, key); + ret = wc_CryptoCb_PqcEncapsulate(ct, ctlen, ss, WC_ML_KEM_SS_SZ, rng, + WC_PQC_KEM_TYPE_KYBER, key); if (ret != WC_NO_ERR_TRACE(CRYPTOCB_UNAVAILABLE)) return ret; /* fall-through when unavailable */ @@ -1367,15 +1426,15 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k, /* Encapsulate with the random. * Step 5: run internal encapsulation algorithm */ - ret = wc_KyberKey_EncapsulateWithRandom(key, c, k, m, sizeof(m)); + ret = wc_MlKemKey_EncapsulateWithRandom(key, ct, ss, m, sizeof(m)); } /* Step 3: return ret != 0 on falsum or internal key generation failure. */ return ret; #else (void)key; - (void)c; - (void)k; + (void)ct; + (void)ss; (void)rng; return NOT_COMPILED_IN; #endif /* WC_NO_RNG */ @@ -1393,35 +1452,41 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k, * > encrypt m using K-PKE with randomness r * Step 3: return (K,c) * - * @param [out] c Cipher text. - * @param [out] k Shared secret generated. - * @param [in] m Random bytes. - * @param [in] len Length of random bytes. + * @param [in] key ML-KEM key object. + * @param [out] ct Cipher text. + * @param [out] ss Shared secret generated. + * @param [in] rand Random bytes. + * @param [in] len Length of random bytes. * @return 0 on success. - * @return BAD_FUNC_ARG when key, c, k or m is NULL. + * @return BAD_FUNC_ARG when key, ct, ss or rand is NULL. * @return BUFFER_E when len is not WC_ML_KEM_ENC_RAND_SZ. + * @return BAD_STATE_E when public key not set. * @return NOT_COMPILED_IN when key type is not supported. * @return MEMORY_E when dynamic memory allocation failed. */ -int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c, - unsigned char* k, const unsigned char* m, int len) +int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* ct, + unsigned char* ss, const unsigned char* rand, int len) { #ifdef WOLFSSL_MLKEM_KYBER - byte msg[KYBER_SYM_SZ]; + byte msg[WC_ML_KEM_SYM_SZ]; #endif - byte kr[2 * KYBER_SYM_SZ + 1]; + byte kr[2 * WC_ML_KEM_SYM_SZ + 1]; int ret = 0; #ifdef WOLFSSL_MLKEM_KYBER unsigned int cSz = 0; #endif /* Validate parameters. */ - if ((key == NULL) || (c == NULL) || (k == NULL) || (m == NULL)) { + if ((key == NULL) || (ct == NULL) || (ss == NULL) || (rand == NULL)) { ret = BAD_FUNC_ARG; } if ((ret == 0) && (len != WC_ML_KEM_ENC_RAND_SZ)) { ret = BUFFER_E; } + /* Check the public key has been set. */ + if ((ret == 0) && ((key->flags & MLKEM_FLAG_PUB_SET) == 0)) { + ret = BAD_STATE_E; + } #ifdef WOLFSSL_MLKEM_KYBER if (ret == 0) { @@ -1473,7 +1538,7 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c, #endif { /* Hash random to anonymize as seed data. */ - ret = MLKEM_HASH_H(&key->hash, m, WC_ML_KEM_SYM_SZ, msg); + ret = MLKEM_HASH_H(&key->hash, rand, WC_ML_KEM_SYM_SZ, msg); } } #endif @@ -1494,7 +1559,7 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c, #ifndef WOLFSSL_NO_ML_KEM { /* Step 1: (K,r) <- G(m||H(ek)) */ - ret = MLKEM_HASH_G(&key->hash, m, WC_ML_KEM_SYM_SZ, key->h, + ret = MLKEM_HASH_G(&key->hash, rand, WC_ML_KEM_SYM_SZ, key->h, WC_ML_KEM_SYM_SZ, kr); } #endif @@ -1507,7 +1572,7 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c, #endif #ifdef WOLFSSL_MLKEM_KYBER { - ret = mlkemkey_encapsulate(key, msg, kr + WC_ML_KEM_SYM_SZ, c); + ret = mlkemkey_encapsulate(key, msg, kr + WC_ML_KEM_SYM_SZ, ct); } #endif #if defined(WOLFSSL_MLKEM_KYBER) && !defined(WOLFSSL_NO_ML_KEM) @@ -1516,7 +1581,7 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c, #ifndef WOLFSSL_NO_ML_KEM { /* Step 2: c <- K-PKE.Encrypt(ek,m,r) */ - ret = mlkemkey_encapsulate(key, m, kr + WC_ML_KEM_SYM_SZ, c); + ret = mlkemkey_encapsulate(key, rand, kr + WC_ML_KEM_SYM_SZ, ct); } #endif } @@ -1528,11 +1593,11 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c, { if (ret == 0) { /* Hash the cipher text after the seed. */ - ret = MLKEM_HASH_H(&key->hash, c, cSz, kr + WC_ML_KEM_SYM_SZ); + ret = MLKEM_HASH_H(&key->hash, ct, cSz, kr + WC_ML_KEM_SYM_SZ); } if (ret == 0) { /* Derive the secret from the seed and hash of cipher text. */ - ret = MLKEM_KDF(kr, 2 * WC_ML_KEM_SYM_SZ, k, WC_ML_KEM_SS_SZ); + ret = MLKEM_KDF(kr, 2 * WC_ML_KEM_SYM_SZ, ss, WC_ML_KEM_SS_SZ); } } #endif @@ -1543,7 +1608,7 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c, { if (ret == 0) { /* return (K,c) */ - XMEMCPY(k, kr, WC_ML_KEM_SS_SZ); + XMEMCPY(ss, kr, WC_ML_KEM_SS_SZ); } } #endif @@ -1570,7 +1635,7 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c, * 7: m <- ByteEncode_1(Compress_1(w)) * 8: return m * - * @param [in] key Kyber key object. + * @param [in] key ML-KEM key object. * @param [out] m Message that was encapsulated. * @param [in] c Cipher text. * @return 0 on success. @@ -1739,12 +1804,13 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m, * 11: end if * 12: return K' * - * @param [in] key Kyber key object. + * @param [in] key ML-KEM key object. * @param [out] ss Shared secret. * @param [in] ct Cipher text. * @param [in] len Length of cipher text. * @return 0 on success. * @return BAD_FUNC_ARG when key, ss or ct are NULL. + * @return BAD_STATE_E when private key is not set. * @return NOT_COMPILED_IN when key type is not supported. * @return BUFFER_E when len is not the length of cipher text for the key type. * @return MEMORY_E when dynamic memory allocation failed. @@ -1827,8 +1893,8 @@ int wc_MlKemKey_Decapsulate(MlKemKey* key, unsigned char* ss, #else if (ret == 0) { #endif - ret = wc_CryptoCb_PqcDecapsulate(ct, ctSz, ss, KYBER_SS_SZ, - WC_PQC_KEM_TYPE_KYBER, key); + ret = wc_CryptoCb_PqcDecapsulate(ct, ctSz, ss, WC_ML_KEM_SS_SZ, + WC_PQC_KEM_TYPE_KYBER, key); if (ret != WC_NO_ERR_TRACE(CRYPTOCB_UNAVAILABLE)) return ret; /* fall-through when unavailable */ @@ -1968,13 +2034,16 @@ static void mlkemkey_decode_public(sword16* pub, byte* pubSeed, const byte* p, * 5: s_hat <- ByteDecode_12(dk_PKE) * ... * - * @param [in, out] key Kyber key object. + * @param [in, out] key ML-KEM key object. * @param [in] in Buffer holding encoded key. * @param [in] len Length of data in buffer. * @return 0 on success. * @return BAD_FUNC_ARG when key or in is NULL. * @return NOT_COMPILED_IN when key type is not supported. * @return BUFFER_E when len is not the correct size. + * @return PUBLIC_KEY_E when public key data doesn't match parameters. + * @return MLKEM_PUB_HASH_E when public key hash doesn't match stored hash. + * @return MEMORY_E when dynamic memory allocation failed. */ int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in, word32 len) @@ -2067,6 +2136,12 @@ int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in, /* Decode the public key that is after the private key. */ mlkemkey_decode_public(key->pub, key->pubSeed, p, k); + ret = mlkem_check_public(key->pub, (int)k); + if (ret != 0) { + ForceZero(key->priv, k * MLKEM_N * sizeof(sword16)); + } + } + if (ret == 0) { /* Compute the hash of the public key. */ ret = MLKEM_HASH_H(&key->hash, p, pubLen, key->h); if (ret != 0) { @@ -2102,13 +2177,15 @@ int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in, * * Public vector | Public Seed * - * @param [in, out] key Kyber key object. + * @param [in, out] key ML-KEM key object. * @param [in] in Buffer holding encoded key. * @param [in] len Length of data in buffer. * @return 0 on success. * @return BAD_FUNC_ARG when key or in is NULL. * @return NOT_COMPILED_IN when key type is not supported. * @return BUFFER_E when len is not the correct size. + * @return PUBLIC_KEY_E when public key data doesn't match parameters. + * @return MEMORY_E when dynamic memory allocation failed. */ int wc_MlKemKey_DecodePublicKey(MlKemKey* key, const unsigned char* in, word32 len) @@ -2182,6 +2259,7 @@ int wc_MlKemKey_DecodePublicKey(MlKemKey* key, const unsigned char* in, } #endif if (ret == 0) { + /* Decode public key and check public key matches parameters. */ mlkemkey_decode_public(key->pub, key->pubSeed, p, k); ret = mlkem_check_public(key->pub, (int)k); } @@ -2200,7 +2278,7 @@ int wc_MlKemKey_DecodePublicKey(MlKemKey* key, const unsigned char* in, /** * Get the size in bytes of encoded private key for the key. * - * @param [in] key Kyber key object. + * @param [in] key ML-KEM key object. * @param [out] len Length of encoded private key in bytes. * @return 0 on success. * @return BAD_FUNC_ARG when key or len is NULL. @@ -2266,7 +2344,7 @@ int wc_MlKemKey_PrivateKeySize(MlKemKey* key, word32* len) /** * Get the size in bytes of encoded public key for the key. * - * @param [in] key Kyber key object. + * @param [in] key ML-KEM key object. * @param [out] len Length of encoded public key in bytes. * @return 0 on success. * @return BAD_FUNC_ARG when key or len is NULL. @@ -2343,12 +2421,12 @@ int wc_MlKemKey_PublicKeySize(MlKemKey* key, word32* len) * 20: dk_PKE <- ByteEncode_12(s_hat) * ... * - * @param [in] key Kyber key object. + * @param [in] key ML-KEM key object. * @param [out] out Buffer to hold data. * @param [in] len Size of buffer in bytes. * @return 0 on success. - * @return BAD_FUNC_ARG when key or out is NULL or private/public key not - * available. + * @return BAD_FUNC_ARG when key or out is NULL. + * @return BAD_STATE_E when private/public key not available. * @return NOT_COMPILED_IN when key type is not supported. */ int wc_MlKemKey_EncodePrivateKey(MlKemKey* key, unsigned char* out, word32 len) @@ -2364,7 +2442,7 @@ int wc_MlKemKey_EncodePrivateKey(MlKemKey* key, unsigned char* out, word32 len) } if ((ret == 0) && ((key->flags & MLKEM_FLAG_BOTH_SET) != MLKEM_FLAG_BOTH_SET)) { - ret = BAD_FUNC_ARG; + ret = BAD_STATE_E; } if (ret == 0) { @@ -2431,17 +2509,11 @@ int wc_MlKemKey_EncodePrivateKey(MlKemKey* key, unsigned char* out, word32 len) mlkem_to_bytes(p, key->priv, (int)k); p += WC_ML_KEM_POLY_SIZE * k; - /* Encode public key. */ - ret = wc_KyberKey_EncodePublicKey(key, p, pubLen); + /* Encode public key - calculates hash of public key. */ + ret = wc_MlKemKey_EncodePublicKey(key, p, pubLen); p += pubLen; } - /* Ensure hash of public key is available. */ - if ((ret == 0) && ((key->flags & MLKEM_FLAG_H_SET) == 0)) { - ret = MLKEM_HASH_H(&key->hash, p - pubLen, pubLen, key->h); - } if (ret == 0) { - /* Public hash is available. */ - key->flags |= MLKEM_FLAG_H_SET; /* Append public hash. */ XMEMCPY(p, key->h, sizeof(key->h)); p += WC_ML_KEM_SYM_SZ; @@ -2466,11 +2538,12 @@ int wc_MlKemKey_EncodePrivateKey(MlKemKey* key, unsigned char* out, word32 len) * 19: ek_PKE <- ByteEncode_12(t_hat)||rho * ... * - * @param [in] key Kyber key object. + * @param [in] key ML-KEM key object. * @param [out] out Buffer to hold data. * @param [in] len Size of buffer in bytes. * @return 0 on success. - * @return BAD_FUNC_ARG when key or out is NULL or public key not available. + * @return BAD_FUNC_ARG when key or out is NULL. + * @return BAD_STATE_E when public key not available. * @return NOT_COMPILED_IN when key type is not supported. */ int wc_MlKemKey_EncodePublicKey(MlKemKey* key, unsigned char* out, word32 len) @@ -2485,7 +2558,7 @@ int wc_MlKemKey_EncodePublicKey(MlKemKey* key, unsigned char* out, word32 len) } if ((ret == 0) && ((key->flags & MLKEM_FLAG_PUB_SET) != MLKEM_FLAG_PUB_SET)) { - ret = BAD_FUNC_ARG; + ret = BAD_STATE_E; } if (ret == 0) { diff --git a/wolfcrypt/src/wc_mlkem_poly.c b/wolfcrypt/src/wc_mlkem_poly.c index 1df80823cc..1a0f3e2213 100644 --- a/wolfcrypt/src/wc_mlkem_poly.c +++ b/wolfcrypt/src/wc_mlkem_poly.c @@ -32,9 +32,9 @@ * polynomials. */ -/* Possible Kyber options: +/* Possible ML-KEM options: * - * WOLFSSL_HAVE_MLKEM Default: OFF + * WOLFSSL_HAVE_MLKEM Default: OFF * Enables this code, wolfSSL implementation, to be built. * * WOLFSSL_WC_ML_KEM_512 Default: OFF @@ -112,7 +112,7 @@ static cpuid_flags_t cpuid_flags = WC_CPUID_INITIALIZER; #define MLKEM_Q_HALF (MLKEM_Q / 2) -/* q^-1 mod 2^16 (inverse of 3329 mod 16384) */ +/* q^-1 mod 2^16 (inverse of 3329 mod 65536) */ #define MLKEM_QINV 62209 /* Used in Barrett Reduction: @@ -1062,7 +1062,7 @@ static void mlkem_basemul(sword16* r, const sword16* a, const sword16* b, * 1: for (i <- 0; i < 128; i++) * 2: (h_hat[2i],h_hat[2i+1]) <- * BaseCaseMultiply(f_hat[2i],f_hat[2i+1],g_hat[2i],g_hat[2i+1], - * zetas^(BitRev_7(i)+1) + * zetas^(BitRev_7(i)+1)) * 3: end for * 4: return h_hat * @@ -1115,7 +1115,7 @@ static void mlkem_basemul_mont(sword16* r, const sword16* a, const sword16* b) * 1: for (i <- 0; i < 128; i++) * 2: (h_hat[2i],h_hat[2i+1]) <- * BaseCaseMultiply(f_hat[2i],f_hat[2i+1],g_hat[2i],g_hat[2i+1], - * zetas^(BitRev_7(i)+1) + * zetas^(BitRev_7(i)+1)) * 3: end for * 4: return h_hat * Add h_hat to r. @@ -1237,7 +1237,7 @@ static void mlkem_pointwise_acc_mont(sword16* r, const sword16* a, /******************************************************************************/ -/* Initialize Kyber implementation. +/* Initialize ML-KEM implementation. */ void mlkem_init(void) { @@ -1285,7 +1285,7 @@ void mlkem_keygen(sword16* s, sword16* t, sword16* e, const sword16* a, int k) /* Multiply a by private into public polynomial. * Step 18: ... A_hat o s_hat ... */ mlkem_pointwise_acc_mont(t + i * MLKEM_N, a + i * k * MLKEM_N, s, - k); + (unsigned int)k); /* Convert public polynomial to Montgomery form. * Step 18: ... MontRed(A_hat o s_hat) ... */ mlkem_to_mont_sqrdmlsh(t + i * MLKEM_N); @@ -1312,7 +1312,7 @@ void mlkem_keygen(sword16* s, sword16* t, sword16* e, const sword16* a, int k) /* Multiply a by private into public polynomial. * Step 18: ... A_hat o s_hat ... */ mlkem_pointwise_acc_mont(t + i * MLKEM_N, a + i * k * MLKEM_N, s, - k); + (unsigned int)k); /* Convert public polynomial to Montgomery form. * Step 18: ... MontRed(A_hat o s_hat) ... */ mlkem_to_mont(t + i * MLKEM_N); @@ -1349,7 +1349,7 @@ void mlkem_keygen(sword16* s, sword16* t, sword16* e, const sword16* a, int k) * @param [in] m Message polynomial. * @param [in] k Number of polynomials in vector. */ -void mlkem_encapsulate(const sword16* t, sword16* u , sword16* v, +void mlkem_encapsulate(const sword16* t, sword16* u, sword16* v, const sword16* a, sword16* y, const sword16* e1, const sword16* e2, const sword16* m, int k) { @@ -1364,25 +1364,25 @@ void mlkem_encapsulate(const sword16* t, sword16* u , sword16* v, } /* For each polynomial in the vectors. - * Step 19: u <- InvNTT(A_hat_trans o y_hat) + e_1) */ + * Step 19: u <- InvNTT(A_hat_trans o y_hat) + e_1 */ for (i = 0; i < k; ++i) { /* Multiply at by y into u polynomial. * Step 19: ... A_hat_trans o y_hat ... */ mlkem_pointwise_acc_mont(u + i * MLKEM_N, a + i * k * MLKEM_N, y, - k); - /* Inverse transform u polynomial. + (unsigned int)k); + /* Inverse transform u polynomial. * Step 19: ... InvNTT(A_hat_trans o y_hat) ... */ - mlkem_invntt_sqrdmlsh(u + i * MLKEM_N); - /* Add errors to u and reduce. - * Step 19: u <- InvNTT(A_hat_trans o y_hat) + e_1) */ - mlkem_add_reduce(u + i * MLKEM_N, e1 + i * MLKEM_N); + mlkem_invntt_sqrdmlsh(u + i * MLKEM_N); + /* Add errors to u and reduce. + * Step 19: u <- InvNTT(A_hat_trans o y_hat) + e_1 */ + mlkem_add_reduce(u + i * MLKEM_N, e1 + i * MLKEM_N); } /* Multiply public key by y into v polynomial. * Step 21: ... t_hat_trans o y_hat ... */ - mlkem_pointwise_acc_mont(v, t, y, k); + mlkem_pointwise_acc_mont(v, t, y, (unsigned int)k); /* Inverse transform v. - * Step 22: ... InvNTT(t_hat_trans o y_hat) ... */ + * Step 21: ... InvNTT(t_hat_trans o y_hat) ... */ mlkem_invntt_sqrdmlsh(v); } else @@ -1400,8 +1400,8 @@ void mlkem_encapsulate(const sword16* t, sword16* u , sword16* v, /* Multiply at by y into u polynomial. * Step 19: ... A_hat_trans o y_hat ... */ mlkem_pointwise_acc_mont(u + i * MLKEM_N, a + i * k * MLKEM_N, y, - k); - /* Inverse transform u polynomial. + (unsigned int)k); + /* Inverse transform u polynomial. * Step 19: ... InvNTT(A_hat_trans o y_hat) ... */ mlkem_invntt(u + i * MLKEM_N); /* Add errors to u and reduce. @@ -1411,9 +1411,9 @@ void mlkem_encapsulate(const sword16* t, sword16* u , sword16* v, /* Multiply public key by y into v polynomial. * Step 21: ... t_hat_trans o y_hat ... */ - mlkem_pointwise_acc_mont(v, t, y, k); + mlkem_pointwise_acc_mont(v, t, y, (unsigned int)k); /* Inverse transform v. - * Step 22: ... InvNTT(t_hat_trans o y_hat) ... */ + * Step 21: ... InvNTT(t_hat_trans o y_hat) ... */ mlkem_invntt(v); } /* Add errors and message to v and reduce. @@ -1452,7 +1452,7 @@ void mlkem_decapsulate(const sword16* s, sword16* w, sword16* u, /* Multiply private key by u into w polynomial. * Step 6: ... s_hat_trans o NTT(u') */ - mlkem_pointwise_acc_mont(w, s, u, k); + mlkem_pointwise_acc_mont(w, s, u, (unsigned int)k); /* Inverse transform w. * Step 6: ... InvNTT(s_hat_trans o NTT(u')) */ mlkem_invntt_sqrdmlsh(w); @@ -1468,7 +1468,7 @@ void mlkem_decapsulate(const sword16* s, sword16* w, sword16* u, /* Multiply private key by u into w polynomial. * Step 6: ... s_hat_trans o NTT(u') */ - mlkem_pointwise_acc_mont(w, s, u, k); + mlkem_pointwise_acc_mont(w, s, u, (unsigned int)k); /* Inverse transform w. * Step 6: ... InvNTT(s_hat_trans o NTT(u')) */ mlkem_invntt(w); @@ -1863,7 +1863,7 @@ static void mlkem_keygen_c(sword16* s, sword16* t, sword16* e, const sword16* a, void mlkem_keygen(sword16* s, sword16* t, sword16* e, const sword16* a, int k) { #ifdef USE_INTEL_SPEEDUP - if ((IS_INTEL_AVX2(cpuid_flags)) && (SAVE_VECTOR_REGISTERS2() == 0)) { + if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) { /* Alg 13: Steps 16-18 */ mlkem_keygen_avx2(s, t, e, a, k); RESTORE_VECTOR_REGISTERS(); @@ -1898,7 +1898,11 @@ void mlkem_keygen(sword16* s, sword16* t, sword16* e, const sword16* a, int k) * @param [in] tv Temporary vector of polynomials. * @param [in] k Number of polynomials in vector. * @param [in] rho Random seed to generate matrix A from. - * @param [in] sigma Random seed to generate noise from. + * @param [in, out] sigma Random seed to generate noise from. + * @return 0 on success. + * @return MEMORY_E when dynamic memory allocation fails. Only possible when + * WOLFSSL_SMALL_STACK is defined. + * @return Other negative value when a hash error occurred. */ int mlkem_keygen_seeds(sword16* s, sword16* t, MLKEM_PRF_T* prf, sword16* tv, int k, byte* rho, byte* sigma) @@ -2087,7 +2091,11 @@ void mlkem_encapsulate(const sword16* pub, sword16* u, sword16* v, * @param [in] k Number of polynomials in vector. * @param [in] msg Message to encapsulate. * @param [in] seed Random seed to generate matrix A from. - * @param [in] coins Random seed to generate noise from. + * @param [in, out] coins Random seed to generate noise from. + * @return 0 on success. + * @return MEMORY_E when dynamic memory allocation fails. Only possible when + * WOLFSSL_SMALL_STACK is defined. + * @return Other negative value when a hash error occurred. */ int mlkem_encapsulate_seeds(const sword16* pub, MLKEM_PRF_T* prf, sword16* u, sword16* tp, sword16* y, int k, const byte* msg, byte* seed, byte* coins) @@ -2283,7 +2291,7 @@ void mlkem_decapsulate(const sword16* s, sword16* w, sword16* u, * @param [in] transposed Whether A or A^T is generated. * @return 0 on success. * @return MEMORY_E when dynamic memory allocation fails. Only possible when - * WOLFSSL_SMALL_STACK is defined. + * WOLFSSL_SMALL_STACK is defined. */ static int mlkem_gen_matrix_k2_avx2(sword16* a, byte* seed, int transposed) { @@ -2395,7 +2403,7 @@ static int mlkem_gen_matrix_k2_avx2(sword16* a, byte* seed, int transposed) * @param [in] transposed Whether A or A^T is generated. * @return 0 on success. * @return MEMORY_E when dynamic memory allocation fails. Only possible when - * WOLFSSL_SMALL_STACK is defined. + * WOLFSSL_SMALL_STACK is defined. */ static int mlkem_gen_matrix_k3_avx2(sword16* a, byte* seed, int transposed) { @@ -2553,7 +2561,7 @@ static int mlkem_gen_matrix_k3_avx2(sword16* a, byte* seed, int transposed) * @param [in] transposed Whether A or A^T is generated. * @return 0 on success. * @return MEMORY_E when dynamic memory allocation fails. Only possible when - * WOLFSSL_SMALL_STACK is defined. + * WOLFSSL_SMALL_STACK is defined. */ static int mlkem_gen_matrix_k4_avx2(sword16* a, byte* seed, int transposed) { @@ -2665,8 +2673,6 @@ static int mlkem_gen_matrix_k4_avx2(sword16* a, byte* seed, int transposed) * @param [in] seed Bytes to seed XOF generation. * @param [in] transposed Whether A or A^T is generated. * @return 0 on success. - * @return MEMORY_E when dynamic memory allocation fails. Only possible when - * WOLFSSL_SMALL_STACK is defined. */ static int mlkem_gen_matrix_k2_aarch64(sword16* a, byte* seed, int transposed) { @@ -2739,8 +2745,6 @@ static int mlkem_gen_matrix_k2_aarch64(sword16* a, byte* seed, int transposed) * @param [in] seed Bytes to seed XOF generation. * @param [in] transposed Whether A or A^T is generated. * @return 0 on success. - * @return MEMORY_E when dynamic memory allocation fails. Only possible when - * WOLFSSL_SMALL_STACK is defined. */ static int mlkem_gen_matrix_k3_aarch64(sword16* a, byte* seed, int transposed) { @@ -2805,8 +2809,6 @@ static int mlkem_gen_matrix_k3_aarch64(sword16* a, byte* seed, int transposed) * @param [in] seed Bytes to seed XOF generation. * @param [in] transposed Whether A or A^T is generated. * @return 0 on success. - * @return MEMORY_E when dynamic memory allocation fails. Only possible when - * WOLFSSL_SMALL_STACK is defined. */ static int mlkem_gen_matrix_k4_aarch64(sword16* a, byte* seed, int transposed) { @@ -2891,7 +2893,7 @@ static int mlkem_gen_matrix_k4_aarch64(sword16* a, byte* seed, int transposed) * @param [in] len Length of data to absorb in bytes. * @return 0 on success always. */ -static int mlkem_xof_absorb(wc_Shake* shake128, byte* seed, int len) +static int mlkem_xof_absorb(wc_Shake* shake128, const byte* seed, int len) { int ret; @@ -2992,7 +2994,7 @@ int mlkem_hash512(wc_Sha3* hash, const byte* data1, word32 data1Len, /* Process first block of data. */ ret = wc_Sha3_512_Update(hash, data1, data1Len); /* Check if there is a second block of data. */ - if ((ret == 0) && (data2Len > 0)) { + if ((ret == 0) && (data2 != NULL) && (data2Len > 0)) { /* Process second block of data. */ ret = wc_Sha3_512_Update(hash, data2, data2Len); } @@ -3125,7 +3127,7 @@ static int mlkem_prf(wc_Shake* shake256, byte* out, unsigned int outLen, * @param [in] outLen Number of bytes to derive. * @return 0 on success always. */ -int mlkem_kdf(byte* seed, int seedLen, byte* out, int outLen) +int mlkem_kdf(const byte* seed, int seedLen, byte* out, int outLen) { word64 state[25]; word32 len64 = seedLen / 8; @@ -3163,7 +3165,7 @@ int mlkem_kdf(byte* seed, int seedLen, byte* out, int outLen) * @param [in] outLen Number of bytes to derive. * @return 0 on success always. */ -int mlkem_kdf(byte* seed, int seedLen, byte* out, int outLen) +int mlkem_kdf(const byte* seed, int seedLen, byte* out, int outLen) { word64 state[25]; word32 len64 = seedLen / 8; @@ -3184,41 +3186,41 @@ int mlkem_kdf(byte* seed, int seedLen, byte* out, int outLen) #ifndef WOLFSSL_NO_ML_KEM /* Derive the secret from z and cipher text. * - * @param [in, out] shake256 SHAKE-256 object. - * @param [in] z Implicit rejection value. - * @param [in] ct Cipher text. - * @param [in] ctSz Length of cipher text in bytes. - * @param [out] ss Shared secret. + * @param [in, out] prf SHAKE-256 object. + * @param [in] z Implicit rejection value. + * @param [in] ct Cipher text. + * @param [in] ctSz Length of cipher text in bytes. + * @param [out] ss Shared secret. * @return 0 on success. * @return MEMORY_E when dynamic memory allocation failed. * @return Other negative value when a hash error occurred. */ -int mlkem_derive_secret(wc_Shake* shake256, const byte* z, const byte* ct, +int mlkem_derive_secret(wc_Shake* prf, const byte* z, const byte* ct, word32 ctSz, byte* ss) { int ret; #ifdef USE_INTEL_SPEEDUP - XMEMCPY(shake256->t, z, WC_ML_KEM_SYM_SZ); - XMEMCPY(shake256->t + WC_ML_KEM_SYM_SZ, ct, + XMEMCPY(prf->t, z, WC_ML_KEM_SYM_SZ); + XMEMCPY(prf->t + WC_ML_KEM_SYM_SZ, ct, WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ); - shake256->i = WC_ML_KEM_SYM_SZ + WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ; + prf->i = WC_ML_KEM_SYM_SZ + WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ; ct += WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ; ctSz -= WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ; - ret = wc_Shake256_Update(shake256, ct, ctSz); + ret = wc_Shake256_Update(prf, ct, ctSz); if (ret == 0) { - ret = wc_Shake256_Final(shake256, ss, WC_ML_KEM_SS_SZ); + ret = wc_Shake256_Final(prf, ss, WC_ML_KEM_SS_SZ); } #else - ret = wc_InitShake256(shake256, NULL, INVALID_DEVID); + ret = wc_InitShake256(prf, NULL, INVALID_DEVID); if (ret == 0) { - ret = wc_Shake256_Update(shake256, z, WC_ML_KEM_SYM_SZ); + ret = wc_Shake256_Update(prf, z, WC_ML_KEM_SYM_SZ); } if (ret == 0) { - ret = wc_Shake256_Update(shake256, ct, ctSz); + ret = wc_Shake256_Update(prf, ct, ctSz); } if (ret == 0) { - ret = wc_Shake256_Final(shake256, ss, WC_ML_KEM_SS_SZ); + ret = wc_Shake256_Final(prf, ss, WC_ML_KEM_SS_SZ); } #endif @@ -3427,7 +3429,7 @@ static unsigned int mlkem_rej_uniform_c(sword16* p, unsigned int len, * @param [in] transposed Whether A or A^T is generated. * @return 0 on success. * @return MEMORY_E when dynamic memory allocation fails. Only possible when - * WOLFSSL_SMALL_STACK is defined. + * WOLFSSL_SMALL_STACK is defined. */ static int mlkem_gen_matrix_c(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed, int transposed) @@ -3530,7 +3532,7 @@ static int mlkem_gen_matrix_c(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed, * @param [in] transposed Whether A or A^T is generated. * @return 0 on success. * @return MEMORY_E when dynamic memory allocation fails. Only possible when - * WOLFSSL_SMALL_STACK is defined. + * WOLFSSL_SMALL_STACK is defined. */ int mlkem_gen_matrix(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed, int transposed) @@ -3634,7 +3636,7 @@ int mlkem_gen_matrix(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed, * @param [in] transposed Whether A or A^T is generated. * @return 0 on success. * @return MEMORY_E when dynamic memory allocation fails. Only possible when - * WOLFSSL_SMALL_STACK is defined. + * WOLFSSL_SMALL_STACK is defined. */ static int mlkem_gen_matrix_i(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed, int i, int transposed) @@ -3729,7 +3731,7 @@ static int mlkem_gen_matrix_i(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed, * * @param [in] d Value containing sequential 2 bit values. * @param [in] i Start index of the two values in 2 bits each. - * @return Difference of the two values with range 0..2. + * @return Difference of the two values with range -2..2. */ #define ETA2_SUB(d, i) \ (sword16)(((sword16)(((d) >> ((i) * 4 + 0)) & 0x3)) - \ @@ -3845,7 +3847,7 @@ static void mlkem_cbd_eta2(sword16* p, const byte* r) * * @param [in] d Value containing sequential 3 bit values. * @param [in] i Start index of the two values in 3 bits each. - * @return Difference of the two values with range 0..3. + * @return Difference of the two values with range -3..3. */ #define ETA3_SUB(d, i) \ (sword16)(((sword16)(((d) >> ((i) * 6 + 0)) & 0x7)) - \ @@ -4220,6 +4222,8 @@ static void mlkem_get_noise_x4_eta3_avx2(byte* rand, byte* seed) * @param [out] poly Polynomial. * @param [in, out] seed Seed to use when calculating random. * @return 0 on success. + * @return MEMORY_E when dynamic memory allocation fails. Only possible when + * WOLFSSL_SMALL_STACK is defined. */ static int mlkem_get_noise_k2_avx2(MLKEM_PRF_T* prf, sword16* vec1, sword16* vec2, sword16* poly, byte* seed) @@ -4559,7 +4563,7 @@ static int mlkem_get_noise_k4_aarch64(sword16* vec1, sword16* vec2, * @param [out] vec2 Second Vector of polynomials. * @param [in] eta2 Size of noise/error integers with second vector. * @param [out] poly Polynomial. - * @param [in] seed Seed to use when calculating random. + * @param [in, out] seed Seed to use when calculating random. * @return 0 on success. */ static int mlkem_get_noise_c(MLKEM_PRF_T* prf, int k, sword16* vec1, int eta1, @@ -4598,7 +4602,7 @@ static int mlkem_get_noise_c(MLKEM_PRF_T* prf, int k, sword16* vec1, int eta1, return ret; } -#endif /* __aarch64__ && WOLFSSL_ARMASM */ +#endif /* !(__aarch64__ && WOLFSSL_ARMASM) */ /* Get the noise/error by calculating random bytes and sampling to a binomial * distribution. @@ -4697,7 +4701,7 @@ int mlkem_get_noise(MLKEM_PRF_T* prf, int k, sword16* vec1, sword16* vec2, * @param [in, out] prf Pseudo-random function object. * @param [in] k Number of polynomials in vector. * @param [out] vec2 Second Vector of polynomials. - * @param [in] seed Seed to use when calculating random. + * @param [in, out] seed Seed to use when calculating random. * @param [in] i Index of vector to generate. * @param [in] make Indicates generation is for making a key. * @return 0 on success. @@ -5147,8 +5151,8 @@ static void mlkem_vec_compress_11_c(byte* r, sword16* v) * * FIPS 203, Section 4.2.1, Compression and decompression * - * @param [out] r Array of bytes. - * @param [in] v Vector of polynomials. + * @param [out] r Array of bytes. + * @param [in, out] v Vector of polynomials. */ void mlkem_vec_compress_11(byte* r, sword16* v) { @@ -5839,7 +5843,7 @@ void mlkem_from_msg(sword16* p, const byte* msg) * * Uses div operator that may be slow. * - * FIPS 203, Algorithm 6: ByteEncode_d(F) + * FIPS 203, Algorithm 5: ByteEncode_d(F) * * @param [in, out] m Message. * @param [in] p Polynomial. @@ -5862,7 +5866,7 @@ void mlkem_from_msg(sword16* p, const byte* msg) * * Uses mul instead of div. * - * FIPS 203, Algorithm 6: ByteEncode_d(F) + * FIPS 203, Algorithm 5: ByteEncode_d(F) * * @param [in, out] m Message. * @param [in] p Polynomial. @@ -5877,7 +5881,7 @@ void mlkem_from_msg(sword16* p, const byte* msg) /* Convert polynomial to message. * - * FIPS 203, Algorithm 6: ByteEncode_d(F) + * FIPS 203, Algorithm 5: ByteEncode_d(F) * * @param [out] msg Message as a byte array. * @param [in, out] p Polynomial. @@ -5913,7 +5917,7 @@ static void mlkem_to_msg_c(byte* msg, sword16* p) /* Convert polynomial to message. * - * FIPS 203, Algorithm 6: ByteEncode_d(F) + * FIPS 203, Algorithm 5: ByteEncode_d(F) * * @param [out] msg Message as a byte array. * @param [in, out] p Polynomial. @@ -5952,7 +5956,7 @@ void mlkem_from_msg(sword16* p, const byte* msg) #ifndef WOLFSSL_MLKEM_NO_DECAPSULATE /* Convert polynomial to message. * - * FIPS 203, Algorithm 6: ByteEncode_d(F) + * FIPS 203, Algorithm 5: ByteEncode_d(F) * * @param [out] msg Message as a byte array. * @param [in, out] p Polynomial. @@ -6031,7 +6035,7 @@ void mlkem_from_bytes(sword16* p, const byte* b, int k) * Consecutive 12 bits hold each coefficient of polynomial. * Used in encoding private and public keys. * - * FIPS 203, Algorithm 6: ByteEncode_d(F) + * FIPS 203, Algorithm 5: ByteEncode_d(F) * * @param [out] b Array of bytes. * @param [in, out] p Polynomial. @@ -6064,7 +6068,7 @@ static void mlkem_to_bytes_c(byte* b, sword16* p, int k) * Consecutive 12 bits hold each coefficient of polynomial. * Used in encoding private and public keys. * - * FIPS 203, Algorithm 6: ByteEncode_d(F) + * FIPS 203, Algorithm 5: ByteEncode_d(F) * * @param [out] b Array of bytes. * @param [in, out] p Polynomial. @@ -6094,18 +6098,18 @@ void mlkem_to_bytes(byte* b, sword16* p, int k) /** * Check the public key values are smaller than the modulus. * - * @param [in] pub Public key - vector. - * @param [in] k Number of polynomials in vector. + * @param [in] p Public key - vector. + * @param [in] k Number of polynomials in vector. * @return 0 when all values are in range. * @return PUBLIC_KEY_E when at least one value is out of range. */ -int mlkem_check_public(sword16* pub, int k) +int mlkem_check_public(const sword16* p, int k) { int ret = 0; int i; for (i = 0; i < k * MLKEM_N; i++) { - if (pub[i] >= MLKEM_Q) { + if (p[i] >= MLKEM_Q) { ret = PUBLIC_KEY_E; break; } diff --git a/wolfssl/wolfcrypt/wc_mlkem.h b/wolfssl/wolfcrypt/wc_mlkem.h index b8a0efa1d9..4d02a6252a 100644 --- a/wolfssl/wolfcrypt/wc_mlkem.h +++ b/wolfssl/wolfcrypt/wc_mlkem.h @@ -422,7 +422,7 @@ typedef struct MlKemKey { WOLFSSL_API MlKemKey* wc_MlKemKey_New(int type, void* heap, int devId); -WOLFSSL_API int wc_MlKemKey_Delete(MlKemKey* key, MlKemKey** key_p); +WOLFSSL_API int wc_MlKemKey_Delete(MlKemKey* key, MlKemKey** key_p); WOLFSSL_API int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId); @@ -522,11 +522,9 @@ int mlkem_get_noise(MLKEM_PRF_T* prf, int kp, sword16* vec1, sword16* vec2, #if defined(USE_INTEL_SPEEDUP) || \ (defined(WOLFSSL_ARMASM) && defined(__aarch64__)) WOLFSSL_LOCAL -int mlkem_kdf(byte* seed, int seedLen, byte* out, int outLen); +int mlkem_kdf(const byte* seed, int seedLen, byte* out, int outLen); #endif WOLFSSL_LOCAL -void mlkem_hash_init(MLKEM_HASH_T* hash); -WOLFSSL_LOCAL int mlkem_hash_new(MLKEM_HASH_T* hash, void* heap, int devId); WOLFSSL_LOCAL void mlkem_hash_free(MLKEM_HASH_T* hash); @@ -578,7 +576,7 @@ void mlkem_from_bytes(sword16* p, const byte* b, int k); WOLFSSL_LOCAL void mlkem_to_bytes(byte* b, sword16* p, int k); WOLFSSL_LOCAL -int mlkem_check_public(sword16* p, int k); +int mlkem_check_public(const sword16* p, int k); #ifdef USE_INTEL_SPEEDUP WOLFSSL_LOCAL @@ -601,10 +599,13 @@ unsigned int mlkem_rej_uniform_avx2(sword16* p, unsigned int len, const byte* r, WOLFSSL_LOCAL void mlkem_redistribute_21_rand_avx2(const word64* s, byte* r0, byte* r1, byte* r2, byte* r3); +WOLFSSL_LOCAL void mlkem_redistribute_17_rand_avx2(const word64* s, byte* r0, byte* r1, byte* r2, byte* r3); +WOLFSSL_LOCAL void mlkem_redistribute_16_rand_avx2(const word64* s, byte* r0, byte* r1, byte* r2, byte* r3); +WOLFSSL_LOCAL void mlkem_redistribute_8_rand_avx2(const word64* s, byte* r0, byte* r1, byte* r2, byte* r3);