Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix #154 and #155 by inverting the padding check logic and being more…
… rigorous.
  • Loading branch information
dfaranha committed Aug 1, 2020
1 parent 965ff62 commit 76c9a1f
Showing 1 changed file with 117 additions and 117 deletions.
234 changes: 117 additions & 117 deletions src/cp/relic_cp_rsa.c
Expand Up @@ -129,7 +129,7 @@
*/
static int pad_basic(bn_t m, int *p_len, int m_len, int k_len, int operation) {
uint8_t pad = 0;
int result = RLC_OK;
int result = RLC_ERR;
bn_t t;

RLC_TRY {
Expand All @@ -146,27 +146,27 @@ static int pad_basic(bn_t m, int *p_len, int m_len, int k_len, int operation) {
bn_add_dig(m, m, RSA_PAD);
/* Make room for the real message. */
bn_lsh(m, m, m_len * 8);
result = RLC_OK;
break;
case RSA_DEC:
case RSA_VER:
case RSA_VER_HASH:
/* EB = 00 | FF | D. */
m_len = k_len - 1;
bn_rsh(t, m, 8 * m_len);
if (!bn_is_zero(t)) {
result = RLC_ERR;
}
*p_len = 1;
do {
(*p_len)++;
m_len--;
bn_rsh(t, m, 8 * m_len);
pad = (uint8_t)t->dp[0];
} while (pad == 0 && m_len > 0);
if (pad != RSA_PAD) {
result = RLC_ERR;
if (bn_is_zero(t)) {
*p_len = 1;
do {
(*p_len)++;
m_len--;
bn_rsh(t, m, 8 * m_len);
pad = (uint8_t)t->dp[0];
} while (pad == 0 && m_len > 0);
if (pad == RSA_PAD) {
result = RLC_OK;
}
bn_mod_2b(m, m, (k_len - *p_len) * 8);
}
bn_mod_2b(m, m, (k_len - *p_len) * 8);
break;
}
}
Expand Down Expand Up @@ -251,7 +251,7 @@ static uint8_t *hash_id(int md, int *len) {
*/
static int pad_pkcs1(bn_t m, int *p_len, int m_len, int k_len, int operation) {
uint8_t *id, pad = 0;
int len, result = RLC_OK;
int len, result = RLC_ERR;
bn_t t;

bn_null(t);
Expand All @@ -278,29 +278,28 @@ static int pad_pkcs1(bn_t m, int *p_len, int m_len, int k_len, int operation) {
bn_add_dig(m, m, 0);
/* Make room for the real message. */
bn_lsh(m, m, m_len * 8);
result = RLC_OK;
break;
case RSA_DEC:
m_len = k_len - 1;
bn_rsh(t, m, 8 * m_len);
if (!bn_is_zero(t)) {
result = RLC_ERR;
}

*p_len = m_len;
m_len--;
bn_rsh(t, m, 8 * m_len);
pad = (uint8_t)t->dp[0];
if (pad != RSA_PUB) {
result = RLC_ERR;
}
do {
if (bn_is_zero(t)) {
*p_len = m_len;
m_len--;
bn_rsh(t, m, 8 * m_len);
pad = (uint8_t)t->dp[0];
} while (pad != 0 && m_len > 0);
/* Remove padding and trailing zero. */
*p_len -= (m_len - 1);
bn_mod_2b(m, m, (k_len - *p_len) * 8);
if (pad == RSA_PUB) {
do {
m_len--;
bn_rsh(t, m, 8 * m_len);
pad = (uint8_t)t->dp[0];
} while (pad != 0 && m_len > 0);
/* Remove padding and trailing zero. */
*p_len -= (m_len - 1);
bn_mod_2b(m, m, (k_len - *p_len) * 8);
result = (m_len > 0 ? RLC_OK : RLC_ERR);
}
}
break;
case RSA_SIG:
/* EB = 00 | 01 | PS | 00 | D. */
Expand All @@ -321,6 +320,7 @@ static int pad_pkcs1(bn_t m, int *p_len, int m_len, int k_len, int operation) {
bn_add(m, m, t);
/* Make room for the real message. */
bn_lsh(m, m, m_len * 8);
result = RLC_OK;
break;
case RSA_SIG_HASH:
/* EB = 00 | 01 | PS | 00 | D. */
Expand All @@ -337,65 +337,65 @@ static int pad_pkcs1(bn_t m, int *p_len, int m_len, int k_len, int operation) {
bn_add_dig(m, m, 0);
/* Make room for the real message. */
bn_lsh(m, m, m_len * 8);
result = RLC_OK;
break;
case RSA_VER:
m_len = k_len - 1;
bn_rsh(t, m, 8 * m_len);
if (!bn_is_zero(t)) {
result = RLC_ERR;
}
m_len--;
bn_rsh(t, m, 8 * m_len);
pad = (uint8_t)t->dp[0];
if (pad != RSA_PRV) {
result = RLC_ERR;
}
do {
if (bn_is_zero(t)) {
m_len--;
bn_rsh(t, m, 8 * m_len);
pad = (uint8_t)t->dp[0];
} while (pad != 0 && m_len > 0);
if (m_len == 0) {
result = RLC_ERR;
}
/* Remove padding and trailing zero. */
id = hash_id(MD_MAP, &len);
m_len -= len;

bn_rsh(t, m, m_len * 8);
int r = 0;
for (int i = 0; i < len; i++) {
pad = (uint8_t)t->dp[0];
r |= pad - id[len - i - 1];
bn_rsh(t, t, 8);
if (pad == RSA_PRV) {
int counter = 0;
do {
counter++;
m_len--;
bn_rsh(t, m, 8 * m_len);
pad = (uint8_t)t->dp[0];
} while (pad == RSA_PAD && m_len > 0);
/* Remove padding and trailing zero. */
id = hash_id(MD_MAP, &len);
m_len -= len;

bn_rsh(t, m, m_len * 8);
int r = 0;
for (int i = 0; i < len; i++) {
pad = (uint8_t)t->dp[0];
r |= pad ^ id[len - i - 1];
bn_rsh(t, t, 8);
}
*p_len = k_len - m_len;
bn_mod_2b(m, m, m_len * 8);
if (r && m_len > 0 && counter >= 8) {
result = RLC_OK;
}
}
}
*p_len = k_len - m_len;
bn_mod_2b(m, m, m_len * 8);
result = (r == 0 ? RLC_OK : RLC_ERR);
break;
case RSA_VER_HASH:
m_len = k_len - 1;
bn_rsh(t, m, 8 * m_len);
if (!bn_is_zero(t)) {
result = RLC_ERR;
}
m_len--;
bn_rsh(t, m, 8 * m_len);
pad = (uint8_t)t->dp[0];
if (pad != RSA_PRV) {
result = RLC_ERR;
}
do {
if (bn_is_zero(t)) {
m_len--;
bn_rsh(t, m, 8 * m_len);
pad = (uint8_t)t->dp[0];
} while (pad != 0 && m_len > 0);
if (m_len == 0) {
result = RLC_ERR;
if (pad == RSA_PRV) {
int counter = 0;
do {
counter++;
m_len--;
bn_rsh(t, m, 8 * m_len);
pad = (uint8_t)t->dp[0];
} while (pad == RSA_PAD && m_len > 0);
/* Remove padding and trailing zero. */
*p_len = k_len - m_len;
bn_mod_2b(m, m, m_len * 8);
if (m_len > 0 && counter >= 8) {
result = RLC_OK;
}
}
}
/* Remove padding and trailing zero. */
*p_len = k_len - m_len;
bn_mod_2b(m, m, m_len * 8);
break;
}
}
Expand Down Expand Up @@ -426,7 +426,7 @@ static int pad_pkcs2(bn_t m, int *p_len, int m_len, int k_len, int operation) {
uint8_t pad, h1[RLC_MD_LEN], h2[RLC_MD_LEN];
/* MSVC does not allow dynamic stack arrays */
uint8_t *mask = RLC_ALLOCA(uint8_t, k_len);
int result = RLC_OK;
int result = RLC_ERR;
bn_t t;

bn_null(t);
Expand All @@ -445,6 +445,7 @@ static int pad_pkcs2(bn_t m, int *p_len, int m_len, int k_len, int operation) {
bn_add_dig(m, m, 0x01);
/* Make room for the real message. */
bn_lsh(m, m, m_len * 8);
result = RLC_OK;
break;
case RSA_ENC_FIN:
/* EB = 00 | maskedSeed | maskedDB. */
Expand All @@ -463,47 +464,44 @@ static int pad_pkcs2(bn_t m, int *p_len, int m_len, int k_len, int operation) {
bn_lsh(t, t, 8 * (k_len - RLC_MD_LEN - 1));
bn_add(t, t, m);
bn_copy(m, t);
result = RLC_OK;
break;
case RSA_DEC:
m_len = k_len - 1;
bn_rsh(t, m, 8 * m_len);
if (!bn_is_zero(t)) {
result = RLC_ERR;
}
m_len -= RLC_MD_LEN;
bn_rsh(t, m, 8 * m_len);
bn_write_bin(h1, RLC_MD_LEN, t);
bn_mod_2b(m, m, 8 * m_len);
bn_write_bin(mask, m_len, m);
md_mgf(h2, RLC_MD_LEN, mask, m_len);
for (int i = 0; i < RLC_MD_LEN; i++) {
h1[i] ^= h2[i];
}
md_mgf(mask, k_len - RLC_MD_LEN - 1, h1, RLC_MD_LEN);
bn_read_bin(t, mask, k_len - RLC_MD_LEN - 1);
for (int i = 0; i < t->used; i++) {
m->dp[i] ^= t->dp[i];
}
m_len -= RLC_MD_LEN;
bn_rsh(t, m, 8 * m_len);
bn_write_bin(h2, RLC_MD_LEN, t);
md_map(h1, NULL, 0);
pad = 0;
for (int i = 0; i < RLC_MD_LEN; i++) {
pad |= h1[i] - h2[i];
}
if (result == RLC_OK) {
result = (pad ? RLC_ERR : RLC_OK);
}
bn_mod_2b(m, m, 8 * m_len);
*p_len = bn_size_bin(m);
(*p_len)--;
bn_rsh(t, m, *p_len * 8);
if (bn_cmp_dig(t, 1) != RLC_EQ) {
result = RLC_ERR;
if (bn_is_zero(t)) {
m_len -= RLC_MD_LEN;
bn_rsh(t, m, 8 * m_len);
bn_write_bin(h1, RLC_MD_LEN, t);
bn_mod_2b(m, m, 8 * m_len);
bn_write_bin(mask, m_len, m);
md_mgf(h2, RLC_MD_LEN, mask, m_len);
for (int i = 0; i < RLC_MD_LEN; i++) {
h1[i] ^= h2[i];
}
md_mgf(mask, k_len - RLC_MD_LEN - 1, h1, RLC_MD_LEN);
bn_read_bin(t, mask, k_len - RLC_MD_LEN - 1);
for (int i = 0; i < t->used; i++) {
m->dp[i] ^= t->dp[i];
}
m_len -= RLC_MD_LEN;
bn_rsh(t, m, 8 * m_len);
bn_write_bin(h2, RLC_MD_LEN, t);
md_map(h1, NULL, 0);
pad = 0;
for (int i = 0; i < RLC_MD_LEN; i++) {
pad |= h1[i] ^ h2[i];
}
bn_mod_2b(m, m, 8 * m_len);
*p_len = bn_size_bin(m);
(*p_len)--;
bn_rsh(t, m, *p_len * 8);
if (pad == 0 && bn_cmp_dig(t, 1) == RLC_EQ) {
result = RLC_OK;
}
bn_mod_2b(m, m, *p_len * 8);
*p_len = k_len - *p_len;
}
bn_mod_2b(m, m, *p_len * 8);
*p_len = k_len - *p_len;
break;
case RSA_SIG:
case RSA_SIG_HASH:
Expand All @@ -512,6 +510,7 @@ static int pad_pkcs2(bn_t m, int *p_len, int m_len, int k_len, int operation) {
bn_lsh(m, m, 64);
/* Make room for the real message. */
bn_lsh(m, m, RLC_MD_LEN * 8);
result = RLC_OK;
break;
case RSA_SIG_FIN:
memset(mask, 0, 8);
Expand All @@ -529,16 +528,17 @@ static int pad_pkcs2(bn_t m, int *p_len, int m_len, int k_len, int operation) {
for (int i = m_len - 1; i < 8 * k_len; i++) {
bn_set_bit(m, i, 0);
}
result = RLC_OK;
break;
case RSA_VER:
case RSA_VER_HASH:
bn_mod_2b(t, m, 8);
if (bn_cmp_dig(t, RSA_PSS) != RLC_EQ) {
result = RLC_ERR;
} else {
pad = (uint8_t)t->dp[0];
if (pad == RSA_PSS) {
int r = 1;
for (int i = m_len; i < 8 * k_len; i++) {
if (bn_get_bit(m, i) != 0) {
result = RLC_ERR;
r = 0;
}
}
bn_rsh(m, m, 8);
Expand All @@ -555,8 +555,8 @@ static int pad_pkcs2(bn_t m, int *p_len, int m_len, int k_len, int operation) {
for (int i = m_len - 1; i < 8 * k_len; i++) {
bn_set_bit(m, i - ((RLC_MD_LEN + 1) * 8), 0);
}
if (!bn_is_zero(m)) {
result = RLC_ERR;
if (r == 1 && bn_is_zero(m)) {
result = RLC_OK;
}
bn_read_bin(m, h2, RLC_MD_LEN);
*p_len = k_len - RLC_MD_LEN;
Expand Down

0 comments on commit 76c9a1f

Please sign in to comment.