Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 61 additions & 47 deletions wolfcrypt/src/wc_mlkem.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
* Stores the matrix A during key generation for use in encapsulation when
* performing decapsulation.
* KyberKey is 8KB larger but decapsulation is significantly faster.
* Turn on when performing make key and decapsualtion with same object.
* Turn on when performing make key and decapsulation with same object.
*/

#include <wolfssl/wolfcrypt/libwolfssl_sources.h>
Expand Down Expand Up @@ -219,10 +219,10 @@ int wc_MlKemKey_Delete(MlKemKey* key, MlKemKey** key_p)
/**
* Initialize the Kyber key.
*
* @param [out] key Kyber key object to initialize.
* @param [in] type Type of key:
* WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024,
* KYBER512, KYBER768, KYBER1024.
* @param [out] key Kyber key object to initialize.
* @param [in] heap Dynamic memory hint.
* @param [in] devId Device Id.
* @return 0 on success.
Expand Down Expand Up @@ -292,7 +292,7 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
/* Cache heap pointer. */
key->heap = heap;
#ifdef WOLF_CRYPTO_CB
/* Cache device id - not used in for this algorithm yet. */
/* Cache device id - not used in this algorithm yet. */
key->devId = devId;
#endif
key->flags = 0;
Expand Down Expand Up @@ -353,17 +353,16 @@ int wc_MlKemKey_Free(MlKemKey* key)
* 4: return falsum
* > return an error indication if random bit generation failed
* 5: end if
* 6: (ek,dk) <- ML-KEM.KeyGen_Interal(d, z)
* 6: (ek,dk) <- ML-KEM.KeyGen_Internal(d, z)
* > run internal key generation algorithm
* &: return (ek,dk)
* 7: return (ek,dk)
*
* @param [in, out] key Kyber key object.
* @param [in] rng Random number generator.
* @return 0 on success.
* @return BAD_FUNC_ARG when key or rng is NULL.
* @return MEMORY_E when dynamic memory allocation failed.
* @return MEMORY_E when dynamic memory allocation failed.
* @return RNG_FAILURE_E when generating random numbers failed.
* @return RNG_FAILURE_E when generating random numbers failed.
* @return DRBG_CONT_FAILURE when random number generator health check fails.
*/
int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
Expand Down Expand Up @@ -405,13 +404,13 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
* FIPS 203 - Algorithm 16: ML-KEM.KeyGen_internal(d,z)
* Uses randomness to generate an encapsulation key and a corresponding
* decapsulation key.
* 1: (ek_PKE,dk_PKE) < K-PKE.KeyGen(d) > run key generation for K-PKE
* 1: (ek_PKE,dk_PKE) <- K-PKE.KeyGen(d) > run key generation for K-PKE
* ...
*
* FIPS 203 - Algorithm 13: K-PKE.KeyGen(d)
* Uses randomness to generate an encryption key and a corresponding decryption
* key.
* 1: (rho,sigma) <- G(d||k)A
* 1: (rho,sigma) <- G(d||k)
* > expand 32+1 bytes to two pseudorandom 32-byte seeds
* 2: N <- 0
* 3-7: generate matrix A_hat
Expand All @@ -420,7 +419,7 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
* 16-18: calculate t_hat from A_hat, s and e
* ...
*
* @param [in, out] key Kyber key ovject.
* @param [in, out] key Kyber key object.
* @param [in] rand Random data.
* @param [in] len Length of random data in bytes.
* @return 0 on success.
Expand Down Expand Up @@ -552,7 +551,7 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
#endif
#ifdef WOLFSSL_MLKEM_KYBER
{
/* Expand 32 bytes of random to 32. */
/* Expand 32 bytes of random to 64. */
ret = MLKEM_HASH_G(&key->hash, d, WC_ML_KEM_SYM_SZ, NULL, 0, buf);
}
#endif
Expand All @@ -562,7 +561,7 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
#ifndef WOLFSSL_NO_ML_KEM
{
buf[0] = k;
/* Expand 33 bytes of random to 32.
/* Expand 33 bytes of random to 64.
* Alg 13: Step 1: (rho,sigma) <- G(d||k)
*/
ret = MLKEM_HASH_G(&key->hash, d, WC_ML_KEM_SYM_SZ, buf, 1, buf);
Expand All @@ -572,9 +571,11 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
#ifdef WC_MLKEM_FAULT_HARDEN
if (ret == 0) {
XMEMCPY(sigma, buf + WC_ML_KEM_SYM_SZ, WC_ML_KEM_SYM_SZ);
/* Check that correct data was copied and pointer not changed. */
if (XMEMCMP(sigma, rho, WC_ML_KEM_SYM_SZ) == 0) {
ret = BAD_COND_E;
}
/* Check that rho is sigma - rho may have been modified. */
if (XMEMCMP(sigma, rho + WC_ML_KEM_SYM_SZ, WC_ML_KEM_SYM_SZ) != 0) {
ret = BAD_COND_E;
}
Expand Down Expand Up @@ -619,8 +620,8 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
if (ret == 0) {
/* Generate key pair from private vector and seeds.
* Alg 13: Steps 3-7: generate matrix A_hat
* Alg 13: 12-15: generate e
* Alg 13: 16-18: calculate t_hat from A_hat, s and e
* Alg 13: Steps 12-15: generate e
* Alg 13: Steps 16-18: calculate t_hat from A_hat, s and e
*/
ret = mlkem_keygen_seeds(s, t, &key->prf, e, k, rho, sigma);
}
Expand Down Expand Up @@ -715,17 +716,23 @@ int wc_MlKemKey_CipherTextSize(MlKemKey* key, word32* len)
* Size of a shared secret in bytes. Always KYBER_SS_SZ.
*
* @param [in] key Kyber key object. Not used.
* @param [out] Size of the shared secret created with a Kyber key.
* @param [out] len Size of the shared secret created with a Kyber key.
* @return 0 on success.
* @return 0 to indicate success.
* @return BAD_FUNC_ARG when len is NULL.
*/
int wc_MlKemKey_SharedSecretSize(MlKemKey* key, word32* len)
{
(void)key;
int ret = 0;

*len = WC_ML_KEM_SS_SZ;
if (len == NULL) {
ret = BAD_FUNC_ARG;
}
else {
*len = WC_ML_KEM_SS_SZ;
}

return 0;
(void)key;
return ret;
}

#if !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) || \
Expand All @@ -738,7 +745,7 @@ int wc_MlKemKey_SharedSecretSize(MlKemKey* key, word32* len)
* 1: N <- 0
* 2: t_hat <- ByteDecode_12(ek_PKE[0:384k])
* > run ByteDecode_12 k times to decode t_hat
* 3: rho <- ek_PKE[384k : 384K + 32]
* 3: rho <- ek_PKE[384k : 384k + 32]
* > extract 32-byte seed from ek_PKE
* 4-8: generate matrix A_hat
* 9-12: generate y
Expand Down Expand Up @@ -889,7 +896,7 @@ static int mlkemkey_encapsulate(MlKemKey* key, const byte* m, byte* r, byte* c)
}
if (ret == 0) {
/* Assign remaining allocated dynamic memory to pointers.
* y (v) | a (m) | mu (p) | e1 (p) | r2 (v) | u (v) | v (p)*/
* y (b) | a (m) | mu (p) | e1 (p) | e2 (v) | u (v) | v (p) */
u = e2 + MLKEM_N;
v = u + MLKEM_N * k;

Expand Down Expand Up @@ -1034,7 +1041,7 @@ static int wc_mlkemkey_check_h(MlKemKey* key)
* @param [out] k Shared secret generated.
* @param [in] rng Random number generator.
* @return 0 on success.
* @return BAD_FUNC_ARG when key, ct, ss or RNG is NULL.
* @return BAD_FUNC_ARG when key, c, k or rng is NULL.
* @return NOT_COMPILED_IN when key type is not supported.
* @return MEMORY_E when dynamic memory allocation failed.
*/
Expand Down Expand Up @@ -1075,7 +1082,7 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
* ciphertext.
* Step 1: (K,r) <- G(m||H(ek))
* > derive shared secret key K and randomness r
* Step 2: c <- K-PPKE.Encrypt(ek, m, r)
* Step 2: c <- K-PKE.Encrypt(ek, m, r)
* > encrypt m using K-PKE with randomness r
* Step 3: return (K,c)
*
Expand All @@ -1084,7 +1091,7 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
* @param [in] m Random bytes.
* @param [in] len Length of random bytes.
* @return 0 on success.
* @return BAD_FUNC_ARG when key, c, k or RNG is NULL.
* @return BAD_FUNC_ARG when key, c, k or m is NULL.
* @return BUFFER_E when len is not WC_ML_KEM_ENC_RAND_SZ.
* @return NOT_COMPILED_IN when key type is not supported.
* @return MEMORY_E when dynamic memory allocation failed.
Expand Down Expand Up @@ -1248,16 +1255,16 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c,
* FIPS 203, Algorithm 15: K-PKE.Decrypt(dk_PKE,c)
* Uses the decryption key to decrypt a ciphertext.
* 1: c1 <- c[0 : 32.d_u.k]
* 2: c2 <= c[32.d_u.k : 32(d_u.k + d_v)]
* 3: u' <= Decompress_d_u(ByteDecode_d_u(c1))
* 4: v' <= Decompress_d_v(ByteDecode_d_v(c2))
* 2: c2 <- c[32.d_u.k : 32(d_u.k + d_v)]
* 3: u' <- Decompress_d_u(ByteDecode_d_u(c1))
* 4: v' <- Decompress_d_v(ByteDecode_d_v(c2))
* ...
* 6: w <- v' - InvNTT(s_hat_trans o NTT(u'))
* 7: m <- ByteEncode_1(Compress_1(w))
* 8: return m
*
* @param [in] key Kyber key object.
* @param [out] m Message than was encapsulated.
* @param [out] m Message that was encapsulated.
* @param [in] c Cipher text.
* @return 0 on success.
* @return NOT_COMPILED_IN when key type is not supported.
Expand Down Expand Up @@ -1340,7 +1347,7 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
if (ret == 0) {
/* Step 1: c1 <- c[0 : 32.d_u.k] */
const byte* c1 = c;
/* Step 2: c2 <= c[32.d_u.k : 32(d_u.k + d_v)] */
/* Step 2: c2 <- c[32.d_u.k : 32(d_u.k + d_v)] */
const byte* c2 = c + compVecSz;

/* Assign allocated dynamic memory to pointers.
Expand All @@ -1350,25 +1357,25 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,

#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512)
if (k == WC_ML_KEM_512_K) {
/* Step 3: u' <= Decompress_d_u(ByteDecode_d_u(c1)) */
/* Step 3: u' <- Decompress_d_u(ByteDecode_d_u(c1)) */
mlkem_vec_decompress_10(u, c1, k);
/* Step 4: v' <= Decompress_d_v(ByteDecode_d_v(c2)) */
/* Step 4: v' <- Decompress_d_v(ByteDecode_d_v(c2)) */
mlkem_decompress_4(v, c2);
}
#endif
#if defined(WOLFSSL_KYBER768) || defined(WOLFSSL_WC_ML_KEM_768)
if (k == WC_ML_KEM_768_K) {
/* Step 3: u' <= Decompress_d_u(ByteDecode_d_u(c1)) */
/* Step 3: u' <- Decompress_d_u(ByteDecode_d_u(c1)) */
mlkem_vec_decompress_10(u, c1, k);
/* Step 4: v' <= Decompress_d_v(ByteDecode_d_v(c2)) */
/* Step 4: v' <- Decompress_d_v(ByteDecode_d_v(c2)) */
mlkem_decompress_4(v, c2);
}
#endif
#if defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
if (k == WC_ML_KEM_1024_K) {
/* Step 3: u' <= Decompress_d_u(ByteDecode_d_u(c1)) */
/* Step 3: u' <- Decompress_d_u(ByteDecode_d_u(c1)) */
mlkem_vec_decompress_11(u, c1);
/* Step 4: v' <= Decompress_d_v(ByteDecode_d_v(c2)) */
/* Step 4: v' <- Decompress_d_v(ByteDecode_d_v(c2)) */
mlkem_decompress_5(v, c2);
}
#endif
Expand Down Expand Up @@ -1408,19 +1415,19 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
* ...
* 1: dk_PKE <- dk[0 : 384k]
* > extract (from KEM decaps key) the PKE decryption key
* 2: ek_PKE <- dk[384k : 768l + 32]
* 2: ek_PKE <- dk[384k : 768k + 32]
* > extract PKE encryption key
* 3: h <- dk[768K + 32 : 768k + 64]
* 3: h <- dk[768k + 32 : 768k + 64]
* > extract hash of PKE encryption key
* 4: z <- dk[768K + 64 : 768k + 96]
* 4: z <- dk[768k + 64 : 768k + 96]
* > extract implicit rejection value
* 5: m' <- K-PKE.Decrypt(dk_PKE, c) > decrypt ciphertext
* 6: (K', r') <- G(m'||h)
* 7: K_bar <- J(z||c)
* 8: c' <- K-PKE.Encrypt(ek_PKE, m', r')
* > re-encrypt using the derived randomness r'
* 9: if c != c' then
* 10: K' <= K_bar
* 10: K' <- K_bar
* > if ciphertexts do not match, "implicitly reject"
* 11: end if
* 12: return K'
Expand All @@ -1430,7 +1437,7 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
* @param [in] ct Cipher text.
* @param [in] len Length of cipher text.
* @return 0 on success.
* @return BAD_FUNC_ARG when key, ss or cr are NULL.
* @return BAD_FUNC_ARG when key, ss or ct are NULL.
* @return NOT_COMPILED_IN when key type is not supported.
* @return BUFFER_E when len is not the length of cipher text for the key type.
* @return MEMORY_E when dynamic memory allocation failed.
Expand Down Expand Up @@ -1588,7 +1595,7 @@ int wc_MlKemKey_Decapsulate(MlKemKey* key, unsigned char* ss,
/**
* Get the public key and public seed from bytes.
*
* FIPS 203, Algorithm 14 K-PKE.Encrypt(ek_PKE, m, r)
* FIPS 203, Algorithm 14: K-PKE.Encrypt(ek_PKE, m, r)
* ...
* 2: t <- ByteDecode_12(ek_PKE[0 : 384k])
* 3: rho <- ek_PKE[384k : 384k + 32]
Expand Down Expand Up @@ -1624,16 +1631,16 @@ static void mlkemkey_decode_public(sword16* pub, byte* pubSeed, const byte* p,
* FIPS 203, Algorithm 18: ML-KEM.Decaps_internal(dk, c)
* 1: dk_PKE <- dk[0 : 384k]
* > extract (from KEM decaps key) the PKE decryption key
* 2: ek_PKE <- dk[384k : 768l + 32]
* 2: ek_PKE <- dk[384k : 768k + 32]
* > extract PKE encryption key
* 3: h <- dk[768K + 32 : 768k + 64]
* 3: h <- dk[768k + 32 : 768k + 64]
* > extract hash of PKE encryption key
* 4: z <- dk[768K + 64 : 768k + 96]
* 4: z <- dk[768k + 64 : 768k + 96]
* > extract implicit rejection value
*
* FIPS 203, Algorithm 15: K-PKE.Decrypt(dk_PKE, c)
* ...
* 5: s_hat <= ByteDecode_12(dk_PKE)
* 5: s_hat <- ByteDecode_12(dk_PKE)
* ...
*
* @param [in, out] key Kyber key object.
Expand Down Expand Up @@ -1729,14 +1736,21 @@ int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in,
mlkemkey_decode_public(key->pub, key->pubSeed, p, k);
/* Compute the hash of the public key. */
ret = MLKEM_HASH_H(&key->hash, p, pubLen, key->h);
p += pubLen;
if (ret != 0) {
ForceZero(key->priv, k * MLKEM_N);
}
}

if (ret == 0) {
p += pubLen;
/* Compare computed public key hash with stored hash */
if (XMEMCMP(key->h, p, WC_ML_KEM_SYM_SZ) != 0)
if (XMEMCMP(key->h, p, WC_ML_KEM_SYM_SZ) != 0) {
ForceZero(key->priv, k * MLKEM_N);
ret = MLKEM_PUB_HASH_E;
}
}

if (ret == 0) {
/* Copy the hash of the encoded public key that is after public key. */
XMEMCPY(key->h, p, sizeof(key->h));
p += WC_ML_KEM_SYM_SZ;
Expand Down
Loading
Loading