Skip to content

Commit

Permalink
Fix bug in Poly1305 bigval_final_reduce().
Browse files Browse the repository at this point in the history
Mark Wooding pointed out that my comment in make1305.py was completely
wrong, and that the stated strategy for reducing a value mod 2^130-5
would not in fact completely reduce all inputs in the range - for the
most obvious reason, namely that the numbers between 2^130-5 and 2^130
would never have anything subtracted at all.

Implemented a replacement strategy which my tests suggest will do the
right thing for all numbers in the expected range that are anywhere
near an integer multiple of the modulus.
  • Loading branch information
sgtatham committed Apr 8, 2017
1 parent 61f668a commit d2653e7
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 91 deletions.
20 changes: 12 additions & 8 deletions contrib/make1305.py
Expand Up @@ -338,16 +338,20 @@ def gen_mul(target):
\n""" % target.text()

def gen_final_reduce(target):
# We take our input number n, and compute k = n + 5*(n >> 130).
# Then k >> 130 is precisely the multiple of p that needs to be
# subtracted from n to reduce it to strictly less than p.
# Given our input number n, n >> 130 is usually precisely the
# multiple of p that needs to be subtracted from n to reduce it to
# strictly less than p, but it might be too low by 1 (but not more
# than 1, given the range of our input is nowhere near the square
# of the modulus). So we add another 5, which will push a carry
# into the 130th bit if and only if that has happened, and then
# use that to decide whether to subtract one more copy of p.

a = target.bigval_input("n", 133)
a1 = a.extract_bits(130, 130)
k = a + target.const(5) * a1
q = k.extract_bits(130)
adjusted = a + target.const(5) * q
ret = adjusted.extract_bits(0, 130)
q = a.extract_bits(130)
adjusted = a.extract_bits(0, 130) + target.const(5) * q
final_subtract = (adjusted + target.const(5)).extract_bits(130)
adjusted2 = adjusted + target.const(5) * final_subtract
ret = adjusted2.extract_bits(0, 130)
target.write_bigval("n", ret)
return """\
static void bigval_final_reduce(bigval *n)
Expand Down
188 changes: 105 additions & 83 deletions sshccp.c
Expand Up @@ -440,9 +440,10 @@ static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b)

static void bigval_final_reduce(bigval *n)
{
BignumInt v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v12, v13, v14, v15;
BignumInt v16, v17, v18, v19, v20, v21, v22, v24, v25, v26, v27, v28, v29;
BignumInt v30, v31, v32, v33;
BignumInt v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v13, v14, v15;
BignumInt v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28;
BignumInt v29, v30, v31, v32, v34, v35, v36, v37, v38, v39, v40, v41, v42;
BignumInt v43;
BignumCarry carry;

v0 = n->w[0];
Expand All @@ -455,45 +456,55 @@ static void bigval_final_reduce(bigval *n)
v7 = n->w[7];
v8 = n->w[8];
v9 = (v8) >> 2;
v10 = 5 * v9;
BignumADC(v12, carry, v0, v10, 0);
(void)v12;
BignumADC(v13, carry, v1, 0, carry);
(void)v13;
BignumADC(v14, carry, v2, 0, carry);
(void)v14;
BignumADC(v15, carry, v3, 0, carry);
(void)v15;
BignumADC(v16, carry, v4, 0, carry);
(void)v16;
BignumADC(v17, carry, v5, 0, carry);
(void)v17;
BignumADC(v18, carry, v6, 0, carry);
(void)v18;
BignumADC(v19, carry, v7, 0, carry);
(void)v19;
v20 = v8 + 0 + carry;
v21 = (v20) >> 2;
v22 = 5 * v21;
BignumADC(v24, carry, v0, v22, 0);
BignumADC(v25, carry, v1, 0, carry);
BignumADC(v26, carry, v2, 0, carry);
BignumADC(v27, carry, v3, 0, carry);
BignumADC(v28, carry, v4, 0, carry);
BignumADC(v29, carry, v5, 0, carry);
BignumADC(v30, carry, v6, 0, carry);
BignumADC(v31, carry, v7, 0, carry);
v32 = v8 + 0 + carry;
v33 = (v32) & ((((BignumInt)1) << 2)-1);
n->w[0] = v24;
n->w[1] = v25;
n->w[2] = v26;
n->w[3] = v27;
n->w[4] = v28;
n->w[5] = v29;
n->w[6] = v30;
n->w[7] = v31;
n->w[8] = v33;
v10 = (v8) & ((((BignumInt)1) << 2)-1);
v11 = 5 * v9;
BignumADC(v13, carry, v0, v11, 0);
BignumADC(v14, carry, v1, 0, carry);
BignumADC(v15, carry, v2, 0, carry);
BignumADC(v16, carry, v3, 0, carry);
BignumADC(v17, carry, v4, 0, carry);
BignumADC(v18, carry, v5, 0, carry);
BignumADC(v19, carry, v6, 0, carry);
BignumADC(v20, carry, v7, 0, carry);
v21 = v10 + 0 + carry;
BignumADC(v22, carry, v13, 5, 0);
(void)v22;
BignumADC(v23, carry, v14, 0, carry);
(void)v23;
BignumADC(v24, carry, v15, 0, carry);
(void)v24;
BignumADC(v25, carry, v16, 0, carry);
(void)v25;
BignumADC(v26, carry, v17, 0, carry);
(void)v26;
BignumADC(v27, carry, v18, 0, carry);
(void)v27;
BignumADC(v28, carry, v19, 0, carry);
(void)v28;
BignumADC(v29, carry, v20, 0, carry);
(void)v29;
v30 = v21 + 0 + carry;
v31 = (v30) >> 2;
v32 = 5 * v31;
BignumADC(v34, carry, v13, v32, 0);
BignumADC(v35, carry, v14, 0, carry);
BignumADC(v36, carry, v15, 0, carry);
BignumADC(v37, carry, v16, 0, carry);
BignumADC(v38, carry, v17, 0, carry);
BignumADC(v39, carry, v18, 0, carry);
BignumADC(v40, carry, v19, 0, carry);
BignumADC(v41, carry, v20, 0, carry);
v42 = v21 + 0 + carry;
v43 = (v42) & ((((BignumInt)1) << 2)-1);
n->w[0] = v34;
n->w[1] = v35;
n->w[2] = v36;
n->w[3] = v37;
n->w[4] = v38;
n->w[5] = v39;
n->w[6] = v40;
n->w[7] = v41;
n->w[8] = v43;
}

#elif BIGNUM_INT_BITS == 32
Expand Down Expand Up @@ -604,8 +615,8 @@ static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b)

static void bigval_final_reduce(bigval *n)
{
BignumInt v0, v1, v2, v3, v4, v5, v6, v8, v9, v10, v11, v12, v13, v14;
BignumInt v16, v17, v18, v19, v20, v21;
BignumInt v0, v1, v2, v3, v4, v5, v6, v7, v9, v10, v11, v12, v13, v14;
BignumInt v15, v16, v17, v18, v19, v20, v22, v23, v24, v25, v26, v27;
BignumCarry carry;

v0 = n->w[0];
Expand All @@ -614,29 +625,35 @@ static void bigval_final_reduce(bigval *n)
v3 = n->w[3];
v4 = n->w[4];
v5 = (v4) >> 2;
v6 = 5 * v5;
BignumADC(v8, carry, v0, v6, 0);
(void)v8;
BignumADC(v9, carry, v1, 0, carry);
(void)v9;
BignumADC(v10, carry, v2, 0, carry);
(void)v10;
BignumADC(v11, carry, v3, 0, carry);
(void)v11;
v12 = v4 + 0 + carry;
v13 = (v12) >> 2;
v14 = 5 * v13;
BignumADC(v16, carry, v0, v14, 0);
BignumADC(v17, carry, v1, 0, carry);
BignumADC(v18, carry, v2, 0, carry);
BignumADC(v19, carry, v3, 0, carry);
v20 = v4 + 0 + carry;
v21 = (v20) & ((((BignumInt)1) << 2)-1);
n->w[0] = v16;
n->w[1] = v17;
n->w[2] = v18;
n->w[3] = v19;
n->w[4] = v21;
v6 = (v4) & ((((BignumInt)1) << 2)-1);
v7 = 5 * v5;
BignumADC(v9, carry, v0, v7, 0);
BignumADC(v10, carry, v1, 0, carry);
BignumADC(v11, carry, v2, 0, carry);
BignumADC(v12, carry, v3, 0, carry);
v13 = v6 + 0 + carry;
BignumADC(v14, carry, v9, 5, 0);
(void)v14;
BignumADC(v15, carry, v10, 0, carry);
(void)v15;
BignumADC(v16, carry, v11, 0, carry);
(void)v16;
BignumADC(v17, carry, v12, 0, carry);
(void)v17;
v18 = v13 + 0 + carry;
v19 = (v18) >> 2;
v20 = 5 * v19;
BignumADC(v22, carry, v9, v20, 0);
BignumADC(v23, carry, v10, 0, carry);
BignumADC(v24, carry, v11, 0, carry);
BignumADC(v25, carry, v12, 0, carry);
v26 = v13 + 0 + carry;
v27 = (v26) & ((((BignumInt)1) << 2)-1);
n->w[0] = v22;
n->w[1] = v23;
n->w[2] = v24;
n->w[3] = v25;
n->w[4] = v27;
}

#elif BIGNUM_INT_BITS == 64
Expand Down Expand Up @@ -705,28 +722,33 @@ static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b)

static void bigval_final_reduce(bigval *n)
{
BignumInt v0, v1, v2, v3, v4, v6, v7, v8, v9, v10, v12, v13, v14, v15;
BignumInt v0, v1, v2, v3, v4, v5, v7, v8, v9, v10, v11, v12, v13, v14;
BignumInt v16, v17, v18, v19;
BignumCarry carry;

v0 = n->w[0];
v1 = n->w[1];
v2 = n->w[2];
v3 = (v2) >> 2;
v4 = 5 * v3;
BignumADC(v6, carry, v0, v4, 0);
(void)v6;
BignumADC(v7, carry, v1, 0, carry);
(void)v7;
v8 = v2 + 0 + carry;
v9 = (v8) >> 2;
v10 = 5 * v9;
BignumADC(v12, carry, v0, v10, 0);
BignumADC(v13, carry, v1, 0, carry);
v14 = v2 + 0 + carry;
v15 = (v14) & ((((BignumInt)1) << 2)-1);
n->w[0] = v12;
n->w[1] = v13;
n->w[2] = v15;
v4 = (v2) & ((((BignumInt)1) << 2)-1);
v5 = 5 * v3;
BignumADC(v7, carry, v0, v5, 0);
BignumADC(v8, carry, v1, 0, carry);
v9 = v4 + 0 + carry;
BignumADC(v10, carry, v7, 5, 0);
(void)v10;
BignumADC(v11, carry, v8, 0, carry);
(void)v11;
v12 = v9 + 0 + carry;
v13 = (v12) >> 2;
v14 = 5 * v13;
BignumADC(v16, carry, v7, v14, 0);
BignumADC(v17, carry, v8, 0, carry);
v18 = v9 + 0 + carry;
v19 = (v18) & ((((BignumInt)1) << 2)-1);
n->w[0] = v16;
n->w[1] = v17;
n->w[2] = v19;
}

#else
Expand Down

0 comments on commit d2653e7

Please sign in to comment.