Skip to content
Merged

sqrt #94

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
348 changes: 178 additions & 170 deletions .gas-snapshot

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions src/concrete/DecimalFloat.sol
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ contract DecimalFloat {
return a.pow(b, LibDecimalFloat.LOG_TABLES_ADDRESS);
}

/// Exposes `LibDecimalFloat.sqrt` for offchain use.
/// @param a The float to take the square root of.
/// @return The square root of the float.
function sqrt(Float a) external view returns (Float) {
return a.sqrt(LibDecimalFloat.LOG_TABLES_ADDRESS);
}

/// Exposes `LibDecimalFloat.min` for offchain use.
/// @param a The first float to compare.
/// @param b The second float to compare.
Expand Down
24 changes: 24 additions & 0 deletions src/lib/LibDecimalFloat.sol
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ library LibDecimalFloat {
/// A one valued float.
Float constant FLOAT_ONE = Float.wrap(bytes32(uint256(1)));

/// A half valued float.
// slither-disable-next-line too-many-digits
Float constant FLOAT_HALF =
Float.wrap(bytes32(uint256(0xffffffff00000000000000000000000000000000000000000000000000000005)));

/// A two valued float.
Float constant FLOAT_TWO = Float.wrap(bytes32(uint256(0x02)));

/// Largest possible positive value.
/// type(int224).max, type(int32).max
Float constant FLOAT_MAX_POSITIVE_VALUE =
Expand Down Expand Up @@ -691,6 +699,22 @@ library LibDecimalFloat {
return c;
}

/// sqrt a = a ^ 0.5
///
/// Due to the inaccuracies of log10 and power10, this is not perfectly
/// accurate, a round trip like sqrt(x)^2 will typically be within half a
/// percent or less of the original value, but this can vary depending on
/// the input values.
///
/// Doesn't lose precision due to the exponent, for a wide range of
/// exponents.
/// @param a The float to take the square root of.
/// @param tablesDataContract The address of the contract containing the
/// logarithm tables.
function sqrt(Float a, address tablesDataContract) internal view returns (Float) {
return pow(a, FLOAT_HALF, tablesDataContract);
}

/// Returns the minimum of two values.
/// Convenience for `a < b ? a : b`.
/// @param a The first float to compare.
Expand Down
24 changes: 12 additions & 12 deletions test/src/concrete/DecimalFloat.pow.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ import {DecimalFloat} from "src/concrete/DecimalFloat.sol";
contract DecimalFloatPowTest is LogTest {
using LibDecimalFloat for Float;

function powExternal(Float a, Float b) external returns (Float) {
return a.pow(b, logTables());
function powExternal(Float a, Float b) external view returns (Float) {
return a.pow(b, LibDecimalFloat.LOG_TABLES_ADDRESS);
}

// function testPowDeployed(Float a, Float b) external {
// DecimalFloat deployed = new DecimalFloat();
function testPowDeployed(Float a, Float b) external {
DecimalFloat deployed = new DecimalFloat();

// try this.powExternal(a, b) returns (Float c) {
// Float deployedC = deployed.pow(a, b);
try this.powExternal(a, b) returns (Float c) {
Float deployedC = deployed.pow(a, b);

// assertEq(Float.unwrap(c), Float.unwrap(deployedC));
// } catch (bytes memory err) {
// vm.expectRevert(err);
// deployed.pow(a, b);
// }
// }
assertEq(Float.unwrap(c), Float.unwrap(deployedC));
} catch (bytes memory err) {
vm.expectRevert(err);
deployed.pow(a, b);
}
}
}
27 changes: 27 additions & 0 deletions test/src/concrete/DecimalFloat.sqrt.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-License-Identifier: CAL
pragma solidity =0.8.25;

import {LibDecimalFloat, Float} from "src/lib/LibDecimalFloat.sol";
import {LogTest} from "test/abstract/LogTest.sol";
import {DecimalFloat} from "src/concrete/DecimalFloat.sol";

contract DecimalFloatSqrtTest is LogTest {
using LibDecimalFloat for Float;

function sqrtExternal(Float a) external view returns (Float) {
return a.sqrt(LibDecimalFloat.LOG_TABLES_ADDRESS);
}

function testSqrtDeployed(Float a) external {
DecimalFloat deployed = new DecimalFloat();

try this.sqrtExternal(a) returns (Float c) {
Float deployedC = deployed.sqrt(a);

assertEq(Float.unwrap(c), Float.unwrap(deployedC));
} catch (bytes memory err) {
vm.expectRevert(err);
deployed.sqrt(a);
}
}
}
12 changes: 12 additions & 0 deletions test/src/lib/LibDecimalFloat.constants.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,16 @@ contract LibDecimalFloatConstantsTest is Test {
Float expected = LibDecimalFloat.packLossless(1, 0);
assertEq(Float.unwrap(one), Float.unwrap(expected));
}

function testFloatHalf() external pure {
Float half = LibDecimalFloat.FLOAT_HALF;
Float expected = LibDecimalFloat.packLossless(5, -1);
assertEq(Float.unwrap(half), Float.unwrap(expected));
}

function testFloatTwo() external pure {
Float two = LibDecimalFloat.FLOAT_TWO;
Float expected = LibDecimalFloat.packLossless(2, 0);
assertEq(Float.unwrap(two), Float.unwrap(expected));
}
}
2 changes: 1 addition & 1 deletion test/src/lib/LibDecimalFloat.pow.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ contract LibDecimalFloatPowTest is LogTest {

Float roundTrip = c.pow(b.inv(), tables);

Float diff = a.div(roundTrip).sub(LibDecimalFloat.packLossless(1, 0)).abs();
Float diff = a.div(roundTrip).sub(LibDecimalFloat.FLOAT_ONE).abs();

assertTrue(!diff.gt(diffLimit()), "diff");
}
Expand Down
92 changes: 92 additions & 0 deletions test/src/lib/LibDecimalFloat.sqrt.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// SPDX-License-Identifier: CAL
pragma solidity =0.8.25;

import {LogTest} from "../../abstract/LogTest.sol";

import {LibDecimalFloat, Float} from "src/lib/LibDecimalFloat.sol";
import {ZeroNegativePower, Log10Negative} from "src/error/ErrDecimalFloat.sol";
import {LibDecimalFloatImplementation} from "src/lib/implementation/LibDecimalFloatImplementation.sol";
import {console2} from "forge-std/Test.sol";

contract LibDecimalFloatSqrtTest is LogTest {
using LibDecimalFloat for Float;

function diffLimit() internal pure returns (Float) {
return LibDecimalFloat.packLossless(94, -3);
}

function sqrtExternal(Float a, address tables) external view returns (Float) {
return a.sqrt(tables);
}

function checkSqrt(
int256 signedCoefficient,
int256 exponent,
int256 expectedSignedCoefficient,
int256 expectedExponent
) internal {
Float a = LibDecimalFloat.packLossless(signedCoefficient, exponent);
address tables = logTables();
uint256 beforeGas = gasleft();
Float c = a.sqrt(tables);
uint256 afterGas = gasleft();
console2.log("Gas used:", beforeGas - afterGas);
console2.logInt(signedCoefficient);
console2.logInt(exponent);
(int256 actualSignedCoefficient, int256 actualExponent) = c.unpack();
assertEq(actualSignedCoefficient, expectedSignedCoefficient, "signedCoefficient");
assertEq(actualExponent, expectedExponent, "exponent");
}

function checkRoundTrip(int256 signedCoefficient, int256 exponent) internal {
Float a = LibDecimalFloat.packLossless(signedCoefficient, exponent);
address tables = logTables();
Float c = a.sqrt(tables);
Float roundTrip = c.pow(LibDecimalFloat.FLOAT_TWO, tables);

Float diff = a.div(roundTrip).sub(LibDecimalFloat.FLOAT_ONE).abs();

assertTrue(diff.lte(diffLimit()), "Round trip sqrt diff too high");
}

function testSqrt() external {
checkSqrt(0, 0, 0, 0);
checkSqrt(2, 0, 1415, -3);
checkSqrt(4, 0, 2e3, -3);
checkSqrt(16, 0, 399950000000000000000000000000000000000000, -41);
}

function testSqrtNegative(Float a) external {
// We can't simply minus 0 to get a negative base.
vm.assume(!a.isZero());

if (a.gt(LibDecimalFloat.FLOAT_ZERO)) {
a = a.minus();
}

address tables = logTables();

(int256 signedCoefficient, int256 exponent) = a.unpack();
(int256 signedCoefficientNormalized, int256 exponentNormalized) =
LibDecimalFloatImplementation.normalize(signedCoefficient, exponent);
vm.expectRevert(abi.encodeWithSelector(Log10Negative.selector, signedCoefficientNormalized, exponentNormalized));
this.sqrtExternal(a, tables);
}

function testSqrtRoundTrip() external {
checkRoundTrip(2, 0);
checkRoundTrip(4, 0);
checkRoundTrip(16, 0);
checkRoundTrip(25, 0);
checkRoundTrip(100, 0);
checkRoundTrip(10000, 0);
checkRoundTrip(1000000, 0);
checkRoundTrip(100000000, 0);
}

function testRoundTripFuzz(int224 signedCoefficient, int32 exponent) external {
signedCoefficient = int224(bound(signedCoefficient, 1, type(int224).max));
exponent = int32(bound(exponent, type(int16).min, type(int16).max));
checkRoundTrip(signedCoefficient, exponent);
}
}
Loading