Skip to content

Commit

Permalink
Merge pull request #346 from z0r0z/patch-1
Browse files Browse the repository at this point in the history
more efficient sqrt
  • Loading branch information
matthewlilley committed May 9, 2022
2 parents 0f85efd + 00a89a7 commit 8144e4a
Showing 1 changed file with 45 additions and 37 deletions.
82 changes: 45 additions & 37 deletions contracts/libraries/TridentMath.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,49 @@ library TridentMath {
/// @dev Modified from Solmate (https://github.com/Rari-Capital/solmate/blob/main/src/utils/FixedPointMathLib.sol)
function sqrt(uint256 x) internal pure returns (uint256 z) {
assembly {
// Start off with z at 1.
z := 1

// Used below to help find a nearby power of 2.
// This segment is to get a reasonable initial estimate for the Babylonian method.
// If the initial estimate is bad, the number of correct bits increases ~linearly
// each iteration instead of ~quadratically.
// The idea is to get z*z*y within a small factor of x.
// More iterations here gets y in a tighter range. Currently, we will have
// y in [256, 256*2^16). We ensure y>= 256 so that the relative difference
// between y and y+1 is small. If x < 256 this is not possible, but those cases
// are easy enough to verify exhaustively.
z := 181 // The 'correct' value is 1, but this saves a multiply later
let y := x

// Find the lowest power of 2 that is at least sqrt(x).
if iszero(lt(y, 0x100000000000000000000000000000000)) {
y := shr(128, y) // Like dividing by 2 ** 128.
z := shl(64, z) // Like multiplying by 2 ** 64.
}
if iszero(lt(y, 0x10000000000000000)) {
y := shr(64, y) // Like dividing by 2 ** 64.
z := shl(32, z) // Like multiplying by 2 ** 32.
}
if iszero(lt(y, 0x100000000)) {
y := shr(32, y) // Like dividing by 2 ** 32.
z := shl(16, z) // Like multiplying by 2 ** 16.
// Note that we check y>= 2^(k + 8) but shift right by k bits each branch,
// this is to ensure that if x >= 256, then y >= 256.
if iszero(lt(y, 0x10000000000000000000000000000000000)) {
y := shr(128, y)
z := shl(64, z)
}
if iszero(lt(y, 0x10000)) {
y := shr(16, y) // Like dividing by 2 ** 16.
z := shl(8, z) // Like multiplying by 2 ** 8.
if iszero(lt(y, 0x1000000000000000000)) {
y := shr(64, y)
z := shl(32, z)
}
if iszero(lt(y, 0x100)) {
y := shr(8, y) // Like dividing by 2 ** 8.
z := shl(4, z) // Like multiplying by 2 ** 4.
if iszero(lt(y, 0x10000000000)) {
y := shr(32, y)
z := shl(16, z)
}
if iszero(lt(y, 0x10)) {
y := shr(4, y) // Like dividing by 2 ** 4.
z := shl(2, z) // Like multiplying by 2 ** 2.
}
if iszero(lt(y, 0x8)) {
// Equivalent to 2 ** z.
z := shl(1, z)
if iszero(lt(y, 0x1000000)) {
y := shr(16, y)
z := shl(8, z)
}
// Now, z*z*y <= x < z*z*(y+1), and y <= 2^(16+8),
// and either y >= 256, or x < 256.
// Correctness can be checked exhaustively for x < 256, so we assume y >= 256.
// Then z*sqrt(y) is within sqrt(257)/sqrt(256) of sqrt(x), or about 20bps.

// For s in the range [1/256, 256], the estimate f(s) = (181/1024) * (s+1)
// is in the range (1/2.84 * sqrt(s), 2.84 * sqrt(s)), with largest error when s=1
// and when s = 256 or 1/256. Since y is in [256, 256*2^16), let a = y/65536, so
// that a is in [1/256, 256). Then we can estimate sqrt(y) as
// sqrt(65536) * 181/1024 * (a + 1) = 181/4 * (y + 65536)/65536 = 181 * (y + 65536)/2^18
// There is no overflow risk here since y < 2^136 after the first branch above.
z := shr(18, mul(z, add(y, 65536))) // A multiply is saved from the initial z := 181

// Shifting right by 1 is like dividing by 2.
// Given the worst case multiplicative error of 2.84 above, 7 iterations should be enough.
// Possibly with a quadratic/cubic polynomial above we could get 4-6.
z := shr(1, add(z, div(x, z)))
z := shr(1, add(z, div(x, z)))
z := shr(1, add(z, div(x, z)))
Expand All @@ -52,12 +58,14 @@ library TridentMath {
z := shr(1, add(z, div(x, z)))
z := shr(1, add(z, div(x, z)))

// Compute a rounded down version of z.
let zRoundDown := div(x, z)

// If zRoundDown is smaller, use it.
if lt(zRoundDown, z) {
z := zRoundDown
// See https://en.wikipedia.org/wiki/Integer_square_root#Using_only_integer_division.
// If x+1 is a perfect square, the Babylonian method cycles between
// floor(sqrt(x)) and ceil(sqrt(x)). This check ensures we return floor.
// Since this case is rare, we choose to save gas on the assignment and
// repeat division in the rare case.
// If you don't care whether floor or ceil is returned, you can skip this.
if lt(div(x, z), z) {
z := div(x, z)
}
}
}
Expand Down

0 comments on commit 8144e4a

Please sign in to comment.