Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mldsa/mldsa_native.S
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@
#undef crypto_sign_keypair
#undef crypto_sign_keypair_internal
#undef crypto_sign_open
#undef crypto_sign_pk_from_sk
#undef crypto_sign_signature
#undef crypto_sign_signature_extmu
#undef crypto_sign_signature_internal
Expand Down
1 change: 1 addition & 0 deletions mldsa/mldsa_native.c
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@
#undef crypto_sign_keypair
#undef crypto_sign_keypair_internal
#undef crypto_sign_open
#undef crypto_sign_pk_from_sk
#undef crypto_sign_signature
#undef crypto_sign_signature_extmu
#undef crypto_sign_signature_internal
Expand Down
19 changes: 19 additions & 0 deletions mldsa/mldsa_native.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,25 @@ size_t MLD_API_NAMESPACE(prepare_domain_separation_prefix)(
uint8_t prefix[MLD_DOMAIN_SEPARATION_MAX_BYTES], const uint8_t *ph,
size_t phlen, const uint8_t *ctx, size_t ctxlen, int hashalg);

/*************************************************
* Name: crypto_sign_pk_from_sk
*
* Description: Derives public key from secret key with validation.
* Checks that t0 and tr stored in sk match recomputed values.
*
* Arguments:
* - uint8_t pk[MLDSA_PUBLICKEYBYTES(MLD_CONFIG_API_PARAMETER_SET)]:
* output public key
* - const uint8_t sk[MLDSA_SECRETKEYBYTES(MLD_CONFIG_API_PARAMETER_SET)]:
* input secret key
*
* Returns 0 on success, -1 if validation fails (invalid secret key)
**************************************************/
MLD_API_MUST_CHECK_RETURN_VALUE
int MLD_API_NAMESPACE(pk_from_sk)(
uint8_t pk[MLDSA_PUBLICKEYBYTES(MLD_CONFIG_API_PARAMETER_SET)],
const uint8_t sk[MLDSA_SECRETKEYBYTES(MLD_CONFIG_API_PARAMETER_SET)]);

/****************************** SUPERCOP API *********************************/

#if !defined(MLD_CONFIG_API_NO_SUPERCOP)
Expand Down
42 changes: 42 additions & 0 deletions mldsa/src/ct.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,48 @@ __contract__(
#if !defined(__ASSEMBLER__)
#include <string.h>

/*************************************************
* Name: mld_ct_memcmp
*
* Description: Compare two arrays for equality in constant time.
*
* Arguments: const void *a: pointer to first byte array
* const void *b: pointer to second byte array
* size_t len: length of the byte arrays
*
* Returns 0 if the byte arrays are equal, a non-zero value otherwise
**************************************************/
static MLD_INLINE uint8_t mld_ct_memcmp(const void *a, const void *b,
const size_t len)
__contract__(
requires(len <= UINT16_MAX)
requires(memory_no_alias(a, len))
requires(memory_no_alias(b, len))
ensures((return_value == 0) == forall(i, 0, len, (((const uint8_t *)a)[i] == ((const uint8_t *)b)[i])))
)
{
const uint8_t *a_bytes = (const uint8_t *)a;
const uint8_t *b_bytes = (const uint8_t *)b;
uint8_t r = 0, s = 0;
unsigned i;

for (i = 0; i < len; i++)
__loop__(
invariant(i <= len)
invariant((r == 0) == (forall(k, 0, i, (a_bytes[k] == b_bytes[k])))))
{
r |= a_bytes[i] ^ b_bytes[i];
/* s is useless, but prevents the loop from being aborted once r=0xff. */
s ^= a_bytes[i] ^ b_bytes[i];
}

/*
* XOR twice with s, separated by a value barrier, to prevent the compile
* from dropping the s computation in the loop.
*/
return (uint8_t)((mld_value_barrier_u32((uint32_t)r) ^ s) ^ s);
}

/*************************************************
* Name: mld_zeroize
*
Expand Down
181 changes: 146 additions & 35 deletions mldsa/src/sign.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <string.h>

#include "cbmc.h"
#include "ct.h"
#include "debug.h"
#include "packing.h"
#include "poly.h"
Expand All @@ -48,6 +49,8 @@
#define mld_H MLD_ADD_PARAM_SET(mld_H)
#define mld_attempt_signature_generation \
MLD_ADD_PARAM_SET(mld_attempt_signature_generation)
#define mld_compute_t0_t1_tr_from_sk_components \
MLD_ADD_PARAM_SET(mld_compute_t0_t1_tr_from_sk_components)
/* End of parameter set namespacing */


Expand Down Expand Up @@ -174,6 +177,85 @@ __contract__(
#endif /* !MLD_CONFIG_SERIAL_FIPS202_ONLY */
}

/*************************************************
* Name: mld_compute_t0_t1_tr_from_sk_components
*
* Description: Computes t0, t1, tr, and pk from secret key components
* rho, s1, s2. This is the shared computation used by
* both keygen and generating the public key from the
* secret key.
*
* Arguments: - mld_polyveck *t0: output t0
* - mld_polyveck *t1: output t1
* - uint8_t tr[MLDSA_TRBYTES]: output tr
* - uint8_t pk[CRYPTO_PUBLICKEYBYTES]: output public key
* - const uint8_t rho[MLDSA_SEEDBYTES]: input rho
* - const mld_polyvecl *s1: input s1
* - const mld_polyveck *s2: input s2
**************************************************/
static void mld_compute_t0_t1_tr_from_sk_components(
mld_polyveck *t0, mld_polyveck *t1, uint8_t tr[MLDSA_TRBYTES],
uint8_t pk[CRYPTO_PUBLICKEYBYTES], const uint8_t rho[MLDSA_SEEDBYTES],
const mld_polyvecl *s1, const mld_polyveck *s2)
__contract__(
requires(memory_no_alias(t0, sizeof(mld_polyveck)))
requires(memory_no_alias(t1, sizeof(mld_polyveck)))
requires(memory_no_alias(tr, MLDSA_TRBYTES))
requires(memory_no_alias(pk, CRYPTO_PUBLICKEYBYTES))
requires(memory_no_alias(rho, MLDSA_SEEDBYTES))
requires(memory_no_alias(s1, sizeof(mld_polyvecl)))
requires(memory_no_alias(s2, sizeof(mld_polyveck)))
requires(forall(l0, 0, MLDSA_L, array_bound(s1->vec[l0].coeffs, 0, MLDSA_N, MLD_POLYETA_UNPACK_LOWER_BOUND, MLDSA_ETA + 1)))
requires(forall(k0, 0, MLDSA_K, array_bound(s2->vec[k0].coeffs, 0, MLDSA_N, MLD_POLYETA_UNPACK_LOWER_BOUND, MLDSA_ETA + 1)))
assigns(memory_slice(t0, sizeof(mld_polyveck)))
assigns(memory_slice(t1, sizeof(mld_polyveck)))
assigns(memory_slice(tr, MLDSA_TRBYTES))
assigns(memory_slice(pk, CRYPTO_PUBLICKEYBYTES))
ensures(forall(k1, 0, MLDSA_K, array_bound(t0->vec[k1].coeffs, 0, MLDSA_N, -(1<<(MLDSA_D-1)) + 1, (1<<(MLDSA_D-1)) + 1)))
ensures(forall(k2, 0, MLDSA_K, array_bound(t1->vec[k2].coeffs, 0, MLDSA_N, 0, 1 << 10)))
)
{
mld_polyvecl mat[MLDSA_K], s1hat;
mld_polyveck t;

/* Expand matrix */
mld_polyvec_matrix_expand(mat, rho);

/* Matrix-vector multiplication */
s1hat = *s1;
mld_polyvecl_ntt(&s1hat);
mld_polyvec_matrix_pointwise_montgomery(&t, mat, &s1hat);
mld_polyveck_reduce(&t);
mld_polyveck_invntt_tomont(&t);

/* Add error vector s2 */
mld_polyveck_add(&t, s2);

/* Reference: The following reduction is not present in the reference
* implementation. Omitting this reduction requires the output of
* the invntt to be small enough such that the addition of s2 does
* not result in absolute values >= MLDSA_Q. While our C, x86_64,
* and AArch64 invntt implementations produce small enough
* values for this to work out, it complicates the bounds
* reasoning. We instead add an additional reduction, and can
* consequently, relax the bounds requirements for the invntt.
*/
mld_polyveck_reduce(&t);

/* Decompose to get t1, t0 */
mld_polyveck_caddq(&t);
mld_polyveck_power2round(t1, t0, &t);

/* Pack public key and compute tr */
mld_pack_pk(pk, rho, t1);
mld_shake256(tr, MLDSA_TRBYTES, pk, CRYPTO_PUBLICKEYBYTES);

/* @[FIPS204, Section 3.6.3] Destruction of intermediate values. */
mld_zeroize(mat, sizeof(mat));
mld_zeroize(&s1hat, sizeof(s1hat));
mld_zeroize(&t, sizeof(t));
}

MLD_MUST_CHECK_RETURN_VALUE
MLD_EXTERNAL_API
int crypto_sign_keypair_internal(uint8_t pk[CRYPTO_PUBLICKEYBYTES],
Expand All @@ -184,9 +266,8 @@ int crypto_sign_keypair_internal(uint8_t pk[CRYPTO_PUBLICKEYBYTES],
MLD_ALIGN uint8_t inbuf[MLDSA_SEEDBYTES + 2];
MLD_ALIGN uint8_t tr[MLDSA_TRBYTES];
const uint8_t *rho, *rhoprime, *key;
mld_polyvecl mat[MLDSA_K];
mld_polyvecl s1, s1hat;
mld_polyveck s2, t2, t1, t0;
mld_polyvecl s1;
mld_polyveck s2, t1, t0;

/* Get randomness for rho, rhoprime and key */
mld_memcpy(inbuf, seed, MLDSA_SEEDBYTES);
Expand All @@ -200,50 +281,23 @@ int crypto_sign_keypair_internal(uint8_t pk[CRYPTO_PUBLICKEYBYTES],

/* Constant time: rho is part of the public key and, hence, public. */
MLD_CT_TESTING_DECLASSIFY(rho, MLDSA_SEEDBYTES);
/* Expand matrix */
mld_polyvec_matrix_expand(mat, rho);
mld_sample_s1_s2(&s1, &s2, rhoprime);

/* Matrix-vector multiplication */
s1hat = s1;
mld_polyvecl_ntt(&s1hat);
mld_polyvec_matrix_pointwise_montgomery(&t1, mat, &s1hat);
mld_polyveck_reduce(&t1);
mld_polyveck_invntt_tomont(&t1);

/* Add error vector s2 */
mld_polyveck_add(&t1, &s2);

/* Reference: The following reduction is not present in the reference
* implementation. Omitting this reduction requires the output of
* the invntt to be small enough such that the addition of s2 does
* not result in absolute values >= MLDSA_Q. While our C, x86_64,
* and AArch64 invntt implementations produce small enough
* values for this to work out, it complicates the bounds
* reasoning. We instead add an additional reduction, and can
* consequently, relax the bounds requirements for the invntt.
*/
mld_polyveck_reduce(&t1);
/* Sample s1 and s2 */
mld_sample_s1_s2(&s1, &s2, rhoprime);

/* Extract t1 and write public key */
mld_polyveck_caddq(&t1);
mld_polyveck_power2round(&t2, &t0, &t1);
mld_pack_pk(pk, rho, &t2);
/* Compute t0, t1, tr, and pk from rho, s1, s2 */
mld_compute_t0_t1_tr_from_sk_components(&t0, &t1, tr, pk, rho, &s1, &s2);

/* Compute H(rho, t1) and write secret key */
mld_shake256(tr, MLDSA_TRBYTES, pk, CRYPTO_PUBLICKEYBYTES);
/* Pack secret key */
mld_pack_sk(sk, rho, tr, key, &t0, &s1, &s2);

/* @[FIPS204, Section 3.6.3] Destruction of intermediate values. */
mld_zeroize(seedbuf, sizeof(seedbuf));
mld_zeroize(inbuf, sizeof(inbuf));
mld_zeroize(tr, sizeof(tr));
mld_zeroize(mat, sizeof(mat));
mld_zeroize(&s1, sizeof(s1));
mld_zeroize(&s1hat, sizeof(s1hat));
mld_zeroize(&s2, sizeof(s2));
mld_zeroize(&t1, sizeof(t1));
mld_zeroize(&t2, sizeof(t2));
mld_zeroize(&t0, sizeof(t0));

/* Constant time: pk is the public key, inherently public data */
Expand Down Expand Up @@ -1131,6 +1185,62 @@ size_t mld_prepare_domain_separation_prefix(
return 2 + ctxlen + MLD_PRE_HASH_OID_LEN + phlen;
}

MLD_EXTERNAL_API
int crypto_sign_pk_from_sk(uint8_t pk[CRYPTO_PUBLICKEYBYTES],
const uint8_t sk[CRYPTO_SECRETKEYBYTES])
{
MLD_ALIGN uint8_t rho[MLDSA_SEEDBYTES];
MLD_ALIGN uint8_t tr[MLDSA_TRBYTES];
MLD_ALIGN uint8_t tr_computed[MLDSA_TRBYTES];
MLD_ALIGN uint8_t key[MLDSA_SEEDBYTES];
mld_polyvecl s1;
mld_polyveck s2, t0, t0_computed, t1;
int res;

/* Unpack secret key */
mld_unpack_sk(rho, tr, key, &t0, &s1, &s2, sk);

/* Recompute t0, t1, tr, and pk from rho, s1, s2 */
mld_compute_t0_t1_tr_from_sk_components(&t0_computed, &t1, tr_computed, pk,
rho, &s1, &s2);

/* Declassify public key */
MLD_CT_TESTING_DECLASSIFY(pk, CRYPTO_PUBLICKEYBYTES);
Comment on lines +1207 to +1208
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This declassification could (should?) be moved to after the validation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually... if the function fails, pk should probably be 0'ed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, what would you like here? Move the declassification, remove it, and 0 instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In cleanup, I'd suggest to zeroize pk if res is not 0, and then unconditionally declassify it.


/* Validate t0 using constant-time comparison */
res = mld_ct_memcmp(&t0, &t0_computed, sizeof(mld_polyveck));
/* Declassify comparison result */
MLD_CT_TESTING_DECLASSIFY(&res, sizeof(res));
if (res != 0)
{
res = -1;
goto cleanup;
}

/* Validate tr using constant-time comparison */
res = mld_ct_memcmp(tr, tr_computed, MLDSA_TRBYTES);
/* Declassify comparison result */
MLD_CT_TESTING_DECLASSIFY(&res, sizeof(res));
if (res != 0)
{
res = -1;
goto cleanup;
}
Comment on lines +1211 to +1228
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: This can be streamlined and further hardened

res0 = mld_ct_memcmp(&t0, &t0_computed, sizeof(mld_polyveck));
res1 = mld_ct_memcmp(tr, tr_computed, MLDSA_TRBYTES);
res = mld_value_barrier_u8(res0 | res1);

/* Declassify the final result of the validity check. */
MLD_CT_TESTING_DECLASSIFY(&res, sizeof(res));
if (res != 0)
{
  mld_zeroize(pk, CRYPTO_PUBLICKEYBYTES);
}

cleanup:
   ...

Also, it needs to be documented that the function leaks the validity of the sk.


res = 0;

cleanup:
/* @[FIPS204, Section 3.6.3] Destruction of intermediate values. */
mld_zeroize(&s1, sizeof(s1));
mld_zeroize(&s2, sizeof(s2));
mld_zeroize(&t0, sizeof(t0));
mld_zeroize(&t0_computed, sizeof(t0_computed));
mld_zeroize(key, sizeof(key));
mld_zeroize(tr_computed, sizeof(tr_computed));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tr is not zeroized here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And what about t1?


return res;
}

/* To facilitate single-compilation-unit (SCU) builds, undefine all macros.
* Don't modify by hand -- this is auto-generated by scripts/autogen. */
#undef mld_check_pct
Expand All @@ -1139,5 +1249,6 @@ size_t mld_prepare_domain_separation_prefix(
#undef mld_get_hash_oid
#undef mld_H
#undef mld_attempt_signature_generation
#undef mld_compute_t0_t1_tr_from_sk_components
#undef NONCE_UB
#undef MLD_PRE_HASH_OID_LEN
22 changes: 22 additions & 0 deletions mldsa/src/sign.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
MLD_NAMESPACE_KL(verify_pre_hash_shake256)
#define mld_prepare_domain_separation_prefix \
MLD_NAMESPACE_KL(prepare_domain_separation_prefix)
#define crypto_sign_pk_from_sk MLD_NAMESPACE_KL(pk_from_sk)

/*************************************************
* Hash algorithm constants for domain separation
Expand Down Expand Up @@ -686,4 +687,25 @@ __contract__(
ensures(return_value <= MLD_DOMAIN_SEPARATION_MAX_BYTES)
);

/*************************************************
* Name: crypto_sign_pk_from_sk
*
* Description: Derives public key from secret key with validation.
* Checks that t0 and tr stored in sk match recomputed values.
*
* Arguments: - uint8_t pk[CRYPTO_PUBLICKEYBYTES]: output public key
* - const uint8_t sk[CRYPTO_SECRETKEYBYTES]: input secret key
*
* Returns 0 on success, -1 if validation fails (corrupted secret key)
**************************************************/
MLD_MUST_CHECK_RETURN_VALUE
MLD_EXTERNAL_API
int crypto_sign_pk_from_sk(uint8_t pk[CRYPTO_PUBLICKEYBYTES],
const uint8_t sk[CRYPTO_SECRETKEYBYTES])
__contract__(
requires(memory_no_alias(pk, CRYPTO_PUBLICKEYBYTES))
requires(memory_no_alias(sk, CRYPTO_SECRETKEYBYTES))
assigns(memory_slice(pk, CRYPTO_PUBLICKEYBYTES))
ensures(return_value == 0 || return_value == -1)
);
#endif /* !MLD_SIGN_H */
Loading