Skip to content

Commit

Permalink
Merge pull request #3961 from randombit/jack/mp-shift-refactor
Browse files Browse the repository at this point in the history
Various word/mp related cleanup
  • Loading branch information
randombit committed Apr 1, 2024
2 parents 63e47f8 + b5b2ba3 commit aa12651
Show file tree
Hide file tree
Showing 13 changed files with 75 additions and 74 deletions.
16 changes: 4 additions & 12 deletions src/lib/math/bigint/big_ops2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,17 +258,12 @@ word BigInt::operator%=(word mod) {
* Left Shift Operator
*/
BigInt& BigInt::operator<<=(size_t shift) {
const size_t shift_words = shift / BOTAN_MP_WORD_BITS;
const size_t shift_bits = shift % BOTAN_MP_WORD_BITS;
const size_t size = sig_words();

const size_t bits_free = top_bits_free();

const size_t new_size = size + shift_words + (bits_free < shift_bits);
const size_t sw = sig_words();
const size_t new_size = sw + (shift + BOTAN_MP_WORD_BITS - 1) / BOTAN_MP_WORD_BITS;

m_data.grow_to(new_size);

bigint_shl1(m_data.mutable_data(), new_size, size, shift_words, shift_bits);
bigint_shl1(m_data.mutable_data(), new_size, sw, shift);

return (*this);
}
Expand All @@ -277,10 +272,7 @@ BigInt& BigInt::operator<<=(size_t shift) {
* Right Shift Operator
*/
BigInt& BigInt::operator>>=(size_t shift) {
const size_t shift_words = shift / BOTAN_MP_WORD_BITS;
const size_t shift_bits = shift % BOTAN_MP_WORD_BITS;

bigint_shr1(m_data.mutable_data(), m_data.size(), shift_words, shift_bits);
bigint_shr1(m_data.mutable_data(), m_data.size(), shift);

if(is_negative() && is_zero()) {
set_sign(Positive);
Expand Down
10 changes: 4 additions & 6 deletions src/lib/math/bigint/big_ops3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,11 @@ word operator%(const BigInt& n, word mod) {
* Left Shift Operator
*/
BigInt operator<<(const BigInt& x, size_t shift) {
const size_t shift_words = shift / BOTAN_MP_WORD_BITS, shift_bits = shift % BOTAN_MP_WORD_BITS;

const size_t x_sw = x.sig_words();

BigInt y = BigInt::with_capacity(x_sw + shift_words + (shift_bits ? 1 : 0));
bigint_shl2(y.mutable_data(), x.data(), x_sw, shift_words, shift_bits);
const size_t new_size = x_sw + (shift + BOTAN_MP_WORD_BITS - 1) / BOTAN_MP_WORD_BITS;
BigInt y = BigInt::with_capacity(new_size);
bigint_shl2(y.mutable_data(), x.data(), x_sw, shift);
y.set_sign(x.sign());
return y;
}
Expand All @@ -187,15 +186,14 @@ BigInt operator<<(const BigInt& x, size_t shift) {
*/
BigInt operator>>(const BigInt& x, size_t shift) {
const size_t shift_words = shift / BOTAN_MP_WORD_BITS;
const size_t shift_bits = shift % BOTAN_MP_WORD_BITS;
const size_t x_sw = x.sig_words();

if(shift_words >= x_sw) {
return BigInt::zero();
}

BigInt y = BigInt::with_capacity(x_sw - shift_words);
bigint_shr2(y.mutable_data(), x.data(), x_sw, shift_words, shift_bits);
bigint_shr2(y.mutable_data(), x.data(), x_sw, shift);

if(x.is_negative() && y.is_zero()) {
y.set_sign(BigInt::Positive);
Expand Down
4 changes: 2 additions & 2 deletions src/lib/math/bigint/bigint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,14 @@ void BigInt::ct_cond_add(bool predicate, const BigInt& value) {
void BigInt::ct_shift_left(size_t shift) {
auto shl_bit = [](const BigInt& a, BigInt& result) {
BOTAN_DEBUG_ASSERT(a.size() + 1 == result.size());
bigint_shl2(result.mutable_data(), a.data(), a.size(), 0, 1);
bigint_shl2(result.mutable_data(), a.data(), a.size(), 1);
// shl2 may have shifted a bit into the next word, which must be dropped
clear_mem(result.mutable_data() + result.size() - 1, 1);
};

auto shl_word = [](const BigInt& a, BigInt& result) {
// the most significant word is not copied, aka. shifted out
bigint_shl2(result.mutable_data(), a.data(), a.size() - 1 /* ignore msw */, 1, 0);
bigint_shl2(result.mutable_data(), a.data(), a.size() - 1 /* ignore msw */, BOTAN_MP_WORD_BITS);
// we left-shifted by a full word, the least significant word must be zero'ed
clear_mem(result.mutable_data(), 1);
};
Expand Down
2 changes: 1 addition & 1 deletion src/lib/math/bigint/divide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ void vartime_divide(const BigInt& x, const BigInt& y_arg, BigInt& q_out, BigInt&

word qjt = bigint_divop_vartime(x_j0, x_j1, y_t0);

qjt = CT::Mask<word>::is_equal(x_j0, y_t0).select(MP_WORD_MAX, qjt);
qjt = CT::Mask<word>::is_equal(x_j0, y_t0).select(WordInfo<word>::max, qjt);

// Per HAC 14.23, this operation is required at most twice
qjt -= division_check(qjt, y_t0, y_t1, x_j0, x_j1, x_j2);
Expand Down
28 changes: 17 additions & 11 deletions src/lib/math/mp/mp_asmi.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,30 @@ template <typename T>
concept WordType = (std::same_as<T, uint32_t> || std::same_as<T, uint64_t>);

template <WordType W>
struct DwordType {};
struct WordInfo {};

template <>
struct DwordType<uint32_t> {
struct WordInfo<uint32_t> {
public:
typedef uint64_t type;
static const bool is_native = true;
static const constexpr size_t bits = 32;
static const constexpr uint32_t max = 0xFFFFFFFF;

typedef uint64_t dword;
static const constexpr bool dword_is_native = true;
};

template <>
struct DwordType<uint64_t> {
struct WordInfo<uint64_t> {
public:
static const constexpr size_t bits = 64;
static const constexpr uint64_t max = 0xFFFFFFFFFFFFFFFF;

#if defined(BOTAN_TARGET_HAS_NATIVE_UINT128)
typedef uint128_t type;
static const bool is_native = true;
typedef uint128_t dword;
static const constexpr bool dword_is_native = true;
#else
typedef donna128 type;
static const bool is_native = false;
typedef donna128 dword;
static const constexpr bool dword_is_native = false;
#endif
};

Expand Down Expand Up @@ -90,7 +96,7 @@ inline constexpr auto word_madd2(W a, W b, W* c) -> W {
}
#endif

typedef typename DwordType<W>::type dword;
typedef typename WordInfo<W>::dword dword;
const dword s = dword(a) * b + *c;
*c = static_cast<W>(s >> (sizeof(W) * 8));
return static_cast<W>(s);
Expand Down Expand Up @@ -139,7 +145,7 @@ inline constexpr auto word_madd3(W a, W b, W c, W* d) -> W {
}
#endif

typedef typename DwordType<W>::type dword;
typedef typename WordInfo<W>::dword dword;
const dword s = dword(a) * b + c + *d;
*d = static_cast<W>(s >> (sizeof(W) * 8));
return static_cast<W>(s);
Expand Down
43 changes: 25 additions & 18 deletions src/lib/math/mp/mp_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

namespace Botan {

const word MP_WORD_MAX = ~static_cast<word>(0);

/*
* If cond == 0, does nothing.
* If cond > 0, swaps x[0:size] with y[0:size]
Expand Down Expand Up @@ -396,12 +394,15 @@ inline constexpr auto bigint_sub_abs(W z[], const W x[], const W y[], size_t N,
* Shift Operations
*/
template <WordType W>
inline constexpr void bigint_shl1(W x[], size_t x_size, size_t x_words, size_t word_shift, size_t bit_shift) {
inline constexpr void bigint_shl1(W x[], size_t x_size, size_t x_words, size_t shift) {
const size_t word_shift = shift / WordInfo<W>::bits;
const size_t bit_shift = shift % WordInfo<W>::bits;

copy_mem(x + word_shift, x, x_words);
clear_mem(x, word_shift);

const auto carry_mask = CT::Mask<W>::expand(bit_shift);
const W carry_shift = carry_mask.if_set_return(sizeof(W) * 8 - bit_shift);
const W carry_shift = carry_mask.if_set_return(WordInfo<W>::bits - bit_shift);

W carry = 0;
for(size_t i = word_shift; i != x_size; ++i) {
Expand All @@ -412,7 +413,10 @@ inline constexpr void bigint_shl1(W x[], size_t x_size, size_t x_words, size_t w
}

template <WordType W>
inline constexpr void bigint_shr1(W x[], size_t x_size, size_t word_shift, size_t bit_shift) {
inline constexpr void bigint_shr1(W x[], size_t x_size, size_t shift) {
const size_t word_shift = shift / WordInfo<W>::bits;
const size_t bit_shift = shift % WordInfo<W>::bits;

const size_t top = x_size >= word_shift ? (x_size - word_shift) : 0;

if(top > 0) {
Expand All @@ -421,7 +425,7 @@ inline constexpr void bigint_shr1(W x[], size_t x_size, size_t word_shift, size_
clear_mem(x + top, std::min(word_shift, x_size));

const auto carry_mask = CT::Mask<W>::expand(bit_shift);
const W carry_shift = carry_mask.if_set_return(sizeof(W) * 8 - bit_shift);
const W carry_shift = carry_mask.if_set_return(WordInfo<W>::bits - bit_shift);

W carry = 0;

Expand All @@ -433,11 +437,14 @@ inline constexpr void bigint_shr1(W x[], size_t x_size, size_t word_shift, size_
}

template <WordType W>
inline constexpr void bigint_shl2(W y[], const W x[], size_t x_size, size_t word_shift, size_t bit_shift) {
inline constexpr void bigint_shl2(W y[], const W x[], size_t x_size, size_t shift) {
const size_t word_shift = shift / WordInfo<W>::bits;
const size_t bit_shift = shift % WordInfo<W>::bits;

copy_mem(y + word_shift, x, x_size);

const auto carry_mask = CT::Mask<W>::expand(bit_shift);
const W carry_shift = carry_mask.if_set_return(sizeof(W) * 8 - bit_shift);
const W carry_shift = carry_mask.if_set_return(WordInfo<W>::bits - bit_shift);

W carry = 0;
for(size_t i = word_shift; i != x_size + word_shift + 1; ++i) {
Expand All @@ -448,15 +455,17 @@ inline constexpr void bigint_shl2(W y[], const W x[], size_t x_size, size_t word
}

template <WordType W>
inline constexpr void bigint_shr2(W y[], const W x[], size_t x_size, size_t word_shift, size_t bit_shift) {
inline constexpr void bigint_shr2(W y[], const W x[], size_t x_size, size_t shift) {
const size_t word_shift = shift / WordInfo<W>::bits;
const size_t bit_shift = shift % WordInfo<W>::bits;
const size_t new_size = x_size < word_shift ? 0 : (x_size - word_shift);

if(new_size > 0) {
copy_mem(y, x + word_shift, new_size);
}

const auto carry_mask = CT::Mask<W>::expand(bit_shift);
const W carry_shift = carry_mask.if_set_return(sizeof(W) * 8 - bit_shift);
const W carry_shift = carry_mask.if_set_return(WordInfo<W>::bits - bit_shift);

W carry = 0;
for(size_t i = new_size; i > 0; --i) {
Expand Down Expand Up @@ -701,22 +710,20 @@ inline constexpr auto bigint_divop_vartime(W n1, W n0, W d) -> W {
throw Invalid_Argument("bigint_divop_vartime divide by zero");
}

constexpr const size_t W_bits = sizeof(W) * 8;

if constexpr(DwordType<W>::is_native) {
typename DwordType<W>::type n = n1;
n <<= W_bits;
if constexpr(WordInfo<W>::dword_is_native) {
typename WordInfo<W>::dword n = n1;
n <<= WordInfo<W>::bits;
n |= n0;
return static_cast<W>(n / d);
} else {
W high = n1 % d;
W quotient = 0;

for(size_t i = 0; i != W_bits; ++i) {
const W high_top_bit = high >> (W_bits - 1);
for(size_t i = 0; i != WordInfo<W>::bits; ++i) {
const W high_top_bit = high >> (WordInfo<W>::bits - 1);

high <<= 1;
high |= (n0 >> (W_bits - 1 - i)) & 1;
high |= (n0 >> (WordInfo<W>::bits - 1 - i)) & 1;
quotient <<= 1;

if(high_top_bit || high >= d) {
Expand Down
6 changes: 3 additions & 3 deletions src/lib/math/numbertheory/mod_inv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ BigInt inverse_mod_odd_modulus(const BigInt& n, const BigInt& mod) {
// compute (mod + 1) / 2 which [because mod is odd] is equal to
// (mod / 2) + 1
copy_mem(mp1o2, mod.data(), std::min(mod.size(), mod_words));
bigint_shr1(mp1o2, mod_words, 0, 1);
bigint_shr1(mp1o2, mod_words, 1);
word carry = bigint_add2_nc(mp1o2, mod_words, u_w, 1);
BOTAN_ASSERT_NOMSG(carry == 0);

Expand All @@ -81,7 +81,7 @@ BigInt inverse_mod_odd_modulus(const BigInt& n, const BigInt& mod) {
bigint_cnd_swap(underflow, u_w, v_w, mod_words);

// a >>= 1
bigint_shr1(a_w, mod_words, 0, 1);
bigint_shr1(a_w, mod_words, 1);

//if(odd_a) u -= v;
word borrow = bigint_cnd_sub(odd_a, u_w, v_w, mod_words);
Expand All @@ -92,7 +92,7 @@ BigInt inverse_mod_odd_modulus(const BigInt& n, const BigInt& mod) {
const word odd_u = u_w[0] & 1;

// u >>= 1
bigint_shr1(u_w, mod_words, 0, 1);
bigint_shr1(u_w, mod_words, 1);

//if(odd_u) u += mp1o2;
bigint_cnd_add(odd_u, u_w, mp1o2, mod_words);
Expand Down
6 changes: 3 additions & 3 deletions src/lib/math/numbertheory/monty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ word monty_inverse(word a) {
word b = 1;
word r = 0;

for(size_t i = 0; i != BOTAN_MP_WORD_BITS; ++i) {
for(size_t i = 0; i != WordInfo<word>::bits; ++i) {
const word bi = b % 2;
r >>= 1;
r += bi << (BOTAN_MP_WORD_BITS - 1);
r += bi << (WordInfo<word>::bits - 1);

b -= a * bi;
b >>= 1;
}

// Now invert in addition space
r = (MP_WORD_MAX - r) + 1;
r = (WordInfo<word>::max - r) + 1;

return r;
}
Expand Down
7 changes: 4 additions & 3 deletions src/lib/math/numbertheory/nistp_redc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void redc_p521(BigInt& x, secure_vector<word>& ws) {
}

clear_mem(ws.data(), ws.size());
bigint_shr2(ws.data(), x.data(), std::min(x.size(), 2 * p_words), p_full_words, p_top_bits);
bigint_shr2(ws.data(), x.data(), std::min(x.size(), 2 * p_words), 521);

x.mask_bits(521);
x.grow_to(p_words);
Expand All @@ -81,11 +81,12 @@ void redc_p521(BigInt& x, secure_vector<word>& ws) {
*/
const auto bit_522_set = CT::Mask<word>::expand(top_word >> p_top_bits);

word and_512 = MP_WORD_MAX;
const word max = WordInfo<word>::max;
word and_512 = max;
for(size_t i = 0; i != p_full_words; ++i) {
and_512 &= x.word_at(i);
}
const auto all_512_low_bits_set = CT::Mask<word>::is_equal(and_512, MP_WORD_MAX);
const auto all_512_low_bits_set = CT::Mask<word>::is_equal(and_512, max);
const auto has_p521_top_word = CT::Mask<word>::is_equal(top_word, 0x1FF);
const auto is_p521 = all_512_low_bits_set & has_p521_top_word;

Expand Down
4 changes: 2 additions & 2 deletions src/lib/math/numbertheory/numthry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,11 @@ BigInt gcd(const BigInt& a, const BigInt& b) {
factors_of_two += (u_is_even & v_is_even).if_set_return(1);

// remove one factor of 2, if u is even
bigint_shr2(tmp.mutable_data(), u.data(), sz, 0, 1);
bigint_shr2(tmp.mutable_data(), u.data(), sz, 1);
u.ct_cond_assign(u_is_even.as_bool(), tmp);

// remove one factor of 2, if v is even
bigint_shr2(tmp.mutable_data(), v.data(), sz, 0, 1);
bigint_shr2(tmp.mutable_data(), v.data(), sz, 1);
v.ct_cond_assign(v_is_even.as_bool(), tmp);
}

Expand Down
5 changes: 1 addition & 4 deletions src/lib/pubkey/curve448/curve448_scalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,8 @@ auto div_mod_2_446(std::span<const word, S> x) {
// Clear the two most significant bits
r[Scalar448::WORDS - 1] &= ~(word(0b11) << (sizeof(word) * 8 - 2));

constexpr size_t word_shift = 446 / (sizeof(word) * 8);
constexpr size_t bit_shift = 446 % (sizeof(word) * 8);

std::array<word, S - Scalar448::WORDS + 1> q;
bigint_shr2(q.data(), x.data(), x.size(), word_shift, bit_shift);
bigint_shr2(q.data(), x.data(), x.size(), 446);

return std::make_pair(q, r);
}
Expand Down
6 changes: 2 additions & 4 deletions src/lib/utils/mem_pool/mem_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,12 @@ class BitMap final {
private:
#if defined(BOTAN_ENABLE_DEBUG_ASSERTS)
using bitmask_type = uint8_t;

enum { BITMASK_BITS = 8 };
#else
using bitmask_type = word;

enum { BITMASK_BITS = BOTAN_MP_WORD_BITS };
#endif

static const size_t BITMASK_BITS = sizeof(bitmask_type) * 8;

size_t m_len;
bitmask_type m_main_mask;
bitmask_type m_last_mask;
Expand Down
Loading

0 comments on commit aa12651

Please sign in to comment.