diff --git a/crates/precompile/src/bn128.rs b/crates/precompile/src/bn128.rs index 49d6d3f654..ce08da7994 100644 --- a/crates/precompile/src/bn128.rs +++ b/crates/precompile/src/bn128.rs @@ -3,31 +3,22 @@ use crate::{ Address, Error, Precompile, PrecompileResult, PrecompileWithAddress, }; use bn::{AffineG1, AffineG2, Fq, Fq2, Group, Gt, G1, G2}; -use revm_primitives::Bytes; pub mod add { use super::*; const ADDRESS: Address = crate::u64_to_address(6); + pub const ISTANBUL_ADD_GAS_COST: u64 = 150; pub const ISTANBUL: PrecompileWithAddress = PrecompileWithAddress( ADDRESS, - Precompile::Standard(|input, gas_limit| { - if 150 > gas_limit { - return Err(Error::OutOfGas); - } - Ok((150, super::run_add(input)?)) - }), + Precompile::Standard(|input, gas_limit| run_add(input, ISTANBUL_ADD_GAS_COST, gas_limit)), ); + pub const BYZANTIUM_ADD_GAS_COST: u64 = 500; pub const BYZANTIUM: PrecompileWithAddress = PrecompileWithAddress( ADDRESS, - Precompile::Standard(|input, gas_limit| { - if 500 > gas_limit { - return Err(Error::OutOfGas); - } - Ok((500, super::run_add(input)?)) - }), + Precompile::Standard(|input, gas_limit| run_add(input, BYZANTIUM_ADD_GAS_COST, gas_limit)), ); } @@ -36,24 +27,16 @@ pub mod mul { const ADDRESS: Address = crate::u64_to_address(7); + pub const ISTANBUL_MUL_GAS_COST: u64 = 6_000; pub const ISTANBUL: PrecompileWithAddress = PrecompileWithAddress( ADDRESS, - Precompile::Standard(|input, gas_limit| { - if 6_000 > gas_limit { - return Err(Error::OutOfGas); - } - Ok((6_000, super::run_mul(input)?)) - }), + Precompile::Standard(|input, gas_limit| run_mul(input, ISTANBUL_MUL_GAS_COST, gas_limit)), ); + pub const BYZANTIUM_MUL_GAS_COST: u64 = 40_000; pub const BYZANTIUM: PrecompileWithAddress = PrecompileWithAddress( ADDRESS, - Precompile::Standard(|input, gas_limit| { - if 40_000 > gas_limit { - return Err(Error::OutOfGas); - } - Ok((40_000, super::run_mul(input)?)) - }), + Precompile::Standard(|input, gas_limit| run_mul(input, BYZANTIUM_MUL_GAS_COST, gas_limit)), ); } @@ -67,7 +50,7 @@ pub mod pair { pub const ISTANBUL: PrecompileWithAddress = PrecompileWithAddress( ADDRESS, Precompile::Standard(|input, gas_limit| { - super::run_pair( + run_pair( input, ISTANBUL_PAIR_PER_POINT, ISTANBUL_PAIR_BASE, @@ -81,7 +64,7 @@ pub mod pair { pub const BYZANTIUM: PrecompileWithAddress = PrecompileWithAddress( ADDRESS, Precompile::Standard(|input, gas_limit| { - super::run_pair( + run_pair( input, BYZANTIUM_PAIR_PER_POINT, BYZANTIUM_PAIR_BASE, @@ -137,7 +120,11 @@ pub fn new_g1_point(px: Fq, py: Fq) -> Result { } } -pub fn run_add(input: &[u8]) -> Result { +pub fn run_add(input: &[u8], gas_cost: u64, gas_limit: u64) -> PrecompileResult { + if gas_cost > gas_limit { + return Err(Error::OutOfGas); + } + let input = right_pad::(input); let p1 = read_point(&input[..64])?; @@ -148,10 +135,14 @@ pub fn run_add(input: &[u8]) -> Result { sum.x().to_big_endian(&mut output[..32]).unwrap(); sum.y().to_big_endian(&mut output[32..]).unwrap(); } - Ok(output.into()) + Ok((gas_cost, output.into())) } -pub fn run_mul(input: &[u8]) -> Result { +pub fn run_mul(input: &[u8], gas_cost: u64, gas_limit: u64) -> PrecompileResult { + if gas_cost > gas_limit { + return Err(Error::OutOfGas); + } + let input = right_pad::(input); let p = read_point(&input[..64])?; @@ -164,7 +155,7 @@ pub fn run_mul(input: &[u8]) -> Result { mul.x().to_big_endian(&mut output[..32]).unwrap(); mul.y().to_big_endian(&mut output[32..]).unwrap(); } - Ok(output.into()) + Ok((gas_cost, output.into())) } pub fn run_pair( @@ -223,10 +214,12 @@ pub fn run_pair( Ok((gas_used, bool_to_bytes32(success))) } -/* #[cfg(test)] mod tests { - use crate::test_utils::new_context; + use crate::bn128::add::BYZANTIUM_ADD_GAS_COST; + use crate::bn128::mul::BYZANTIUM_MUL_GAS_COST; + use crate::bn128::pair::{BYZANTIUM_PAIR_BASE, BYZANTIUM_PAIR_PER_POINT}; + use revm_primitives::hex; use super::*; @@ -247,9 +240,7 @@ mod tests { ) .unwrap(); - let res = Bn128Add::::run(&input, 500, &new_context(), false) - .unwrap() - .output; + let (_, res) = run_add(&input, BYZANTIUM_ADD_GAS_COST, 500).unwrap(); assert_eq!(res, expected); // zero sum test @@ -268,9 +259,7 @@ mod tests { ) .unwrap(); - let res = Bn128Add::::run(&input, 500, &new_context(), false) - .unwrap() - .output; + let (_, res) = run_add(&input, BYZANTIUM_ADD_GAS_COST, 500).unwrap(); assert_eq!(res, expected); // out of gas test @@ -282,8 +271,10 @@ mod tests { 0000000000000000000000000000000000000000000000000000000000000000", ) .unwrap(); - let res = Bn128Add::::run(&input, 499, &new_context(), false); - assert!(matches!(res, Err(Return::OutOfGas))); + + let res = run_add(&input, BYZANTIUM_ADD_GAS_COST, 499); + println!("{:?}", res); + assert!(matches!(res, Err(Error::OutOfGas))); // no input test let input = [0u8; 0]; @@ -294,9 +285,7 @@ mod tests { ) .unwrap(); - let res = Bn128Add::::run(&input, 500, &new_context(), false) - .unwrap() - .output; + let (_, res) = run_add(&input, BYZANTIUM_ADD_GAS_COST, 500).unwrap(); assert_eq!(res, expected); // point not on curve fail @@ -309,11 +298,8 @@ mod tests { ) .unwrap(); - let res = Bn128Add::::run(&input, 500, &new_context(), false); - assert!(matches!( - res, - Err(Return::Other(Cow::Borrowed("ERR_BN128_INVALID_POINT"))) - )); + let res = run_add(&input, BYZANTIUM_ADD_GAS_COST, 500); + assert!(matches!(res, Err(Error::Bn128AffineGFailedToCreate))); } #[test] @@ -332,9 +318,7 @@ mod tests { ) .unwrap(); - let res = Bn128Mul::::run(&input, 40_000, &new_context(), false) - .unwrap() - .output; + let (_, res) = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 40_000).unwrap(); assert_eq!(res, expected); // out of gas test @@ -345,8 +329,9 @@ mod tests { 0200000000000000000000000000000000000000000000000000000000000000", ) .unwrap(); - let res = Bn128Mul::::run(&input, 39_999, &new_context(), false); - assert!(matches!(res, Err(Return::OutOfGas))); + + let res = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 39_999); + assert!(matches!(res, Err(Error::OutOfGas))); // zero multiplication test let input = hex::decode( @@ -363,9 +348,7 @@ mod tests { ) .unwrap(); - let res = Bn128Mul::::run(&input, 40_000, &new_context(), false) - .unwrap() - .output; + let (_, res) = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 40_000).unwrap(); assert_eq!(res, expected); // no input test @@ -377,9 +360,7 @@ mod tests { ) .unwrap(); - let res = Bn128Mul::::run(&input, 40_000, &new_context(), false) - .unwrap() - .output; + let (_, res) = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 40_000).unwrap(); assert_eq!(res, expected); // point not on curve fail @@ -391,11 +372,8 @@ mod tests { ) .unwrap(); - let res = Bn128Mul::::run(&input, 40_000, &new_context(), false); - assert!(matches!( - res, - Err(Return::Other(Cow::Borrowed("ERR_BN128_INVALID_POINT"))) - )); + let res = run_mul(&input, BYZANTIUM_MUL_GAS_COST, 40_000); + assert!(matches!(res, Err(Error::Bn128AffineGFailedToCreate))); } #[test] @@ -420,9 +398,13 @@ mod tests { hex::decode("0000000000000000000000000000000000000000000000000000000000000001") .unwrap(); - let res = Bn128Pair::::run(&input, 260_000, &new_context(), false) - .unwrap() - .output; + let (_, res) = run_pair( + &input, + BYZANTIUM_PAIR_PER_POINT, + BYZANTIUM_PAIR_BASE, + 260_000, + ) + .unwrap(); assert_eq!(res, expected); // out of gas test @@ -442,8 +424,14 @@ mod tests { 12c85ea5db8c6deb4aab71808dcb408fe3d1e7690c43d37b4ce6cc0166fa7daa", ) .unwrap(); - let res = Bn128Pair::::run(&input, 259_999, &new_context(), false); - assert!(matches!(res, Err(Return::OutOfGas))); + + let res = run_pair( + &input, + BYZANTIUM_PAIR_PER_POINT, + BYZANTIUM_PAIR_BASE, + 259_999, + ); + assert!(matches!(res, Err(Error::OutOfGas))); // no input test let input = [0u8; 0]; @@ -451,9 +439,13 @@ mod tests { hex::decode("0000000000000000000000000000000000000000000000000000000000000001") .unwrap(); - let res = Bn128Pair::::run(&input, 260_000, &new_context(), false) - .unwrap() - .output; + let (_, res) = run_pair( + &input, + BYZANTIUM_PAIR_PER_POINT, + BYZANTIUM_PAIR_BASE, + 260_000, + ) + .unwrap(); assert_eq!(res, expected); // point not on curve fail @@ -468,11 +460,13 @@ mod tests { ) .unwrap(); - let res = Bn128Pair::::run(&input, 260_000, &new_context(), false); - assert!(matches!( - res, - Err(Return::Other(Cow::Borrowed("ERR_BN128_INVALID_A"))) - )); + let res = run_pair( + &input, + BYZANTIUM_PAIR_PER_POINT, + BYZANTIUM_PAIR_BASE, + 260_000, + ); + assert!(matches!(res, Err(Error::Bn128AffineGFailedToCreate))); // invalid input length let input = hex::decode( @@ -484,11 +478,12 @@ mod tests { ) .unwrap(); - let res = Bn128Pair::::run(&input, 260_000, &new_context(), false); - assert!(matches!( - res, - Err(Return::Other(Cow::Borrowed("ERR_BN128_INVALID_LEN",))) - )); + let res = run_pair( + &input, + BYZANTIUM_PAIR_PER_POINT, + BYZANTIUM_PAIR_BASE, + 260_000, + ); + assert!(matches!(res, Err(Error::Bn128PairLength))); } } -*/