From 25788e04d90c4af926c1b8771008738b159cc282 Mon Sep 17 00:00:00 2001 From: Ryan Burn Date: Wed, 22 May 2024 10:54:35 -0700 Subject: [PATCH] feat: add host support for multiexponentiations using precomputed partition method (PROOF-831) (#131) * add more assertions * work on host version of mx code * rename * work on host version * work on host version of partition product * add tests * fill in host reduction * work on host-side code * rename * add stub for host version * fill in host version of multiexponentiation * test host multiexponentiation * impl cpu_backend * test cpu backend * reformat * rename * fixes * rework partition table accessor * add more documentation * remove unused header * add constant * clean up --- benchmark/multi_exp_pip/benchmark.m.cc | 8 +-- cbindings/BUILD | 1 + cbindings/backend.cc | 6 +-- cbindings/fixed_pedersen.t.cc | 26 +++++++-- sxt/cbindings/backend/BUILD | 2 + sxt/cbindings/backend/cpu_backend.cc | 17 +++--- sxt/cbindings/backend/gpu_backend.cc | 2 +- sxt/multiexp/pippenger2/BUILD | 8 +++ sxt/multiexp/pippenger2/constants.cc | 17 ++++++ sxt/multiexp/pippenger2/constants.h | 24 +++++++++ .../in_memory_partition_table_accessor.h | 25 +++++---- .../in_memory_partition_table_accessor.t.cc | 23 ++++++-- ..._memory_partition_table_accessor_utility.h | 4 +- ...mory_partition_table_accessor_utility.t.cc | 4 +- sxt/multiexp/pippenger2/multiexponentiation.h | 40 +++++++++++--- .../pippenger2/multiexponentiation.t.cc | 26 ++++++--- sxt/multiexp/pippenger2/partition_product.h | 53 ++++++++++++++++--- .../pippenger2/partition_product.t.cc | 25 ++++++--- sxt/multiexp/pippenger2/partition_table.h | 6 +-- sxt/multiexp/pippenger2/partition_table.t.cc | 6 +-- .../pippenger2/partition_table_accessor.h | 30 ++++++++++- sxt/multiexp/pippenger2/reduce.h | 11 ++++ sxt/multiexp/pippenger2/reduce.t.cc | 10 ++++ 23 files changed, 294 insertions(+), 80 deletions(-) create mode 100644 sxt/multiexp/pippenger2/constants.cc create mode 100644 sxt/multiexp/pippenger2/constants.h diff --git a/benchmark/multi_exp_pip/benchmark.m.cc b/benchmark/multi_exp_pip/benchmark.m.cc index ef9945f8..fed5f8c6 100644 --- a/benchmark/multi_exp_pip/benchmark.m.cc +++ b/benchmark/multi_exp_pip/benchmark.m.cc @@ -129,8 +129,8 @@ int main(int argc, char* argv[]) { // discard initial run { - auto fut = - mtxpp2::multiexponentiate(res, *accessor, element_num_bytes, exponents); + auto fut = mtxpp2::async_multiexponentiate(res, *accessor, element_num_bytes, + exponents); xens::get_scheduler().run(); } @@ -138,8 +138,8 @@ int main(int argc, char* argv[]) { double times = 0; for (unsigned i = 0; i < num_samples; ++i) { auto t1 = std::chrono::steady_clock::now(); - auto fut = - mtxpp2::multiexponentiate(res, *accessor, element_num_bytes, exponents); + auto fut = mtxpp2::async_multiexponentiate(res, *accessor, element_num_bytes, + exponents); xens::get_scheduler().run(); auto t2 = std::chrono::steady_clock::now(); auto elapse = std::chrono::duration_cast(t2 - t1); diff --git a/cbindings/BUILD b/cbindings/BUILD index 7e8f3c1d..2243d780 100644 --- a/cbindings/BUILD +++ b/cbindings/BUILD @@ -155,6 +155,7 @@ sxt_cc_component( "//sxt/cbindings/base:multiexp_handle", ], test_deps = [ + ":backend", "//sxt/base/test:unit_test", "//sxt/curve21/operation:add", "//sxt/curve21/operation:double", diff --git a/cbindings/backend.cc b/cbindings/backend.cc index 43997112..179f107a 100644 --- a/cbindings/backend.cc +++ b/cbindings/backend.cc @@ -77,11 +77,7 @@ cbnbck::computational_backend* get_backend() noexcept { //-------------------------------------------------------------------------------------------------- // reset_backend_for_testing //-------------------------------------------------------------------------------------------------- -void reset_backend_for_testing() noexcept { - SXT_RELEASE_ASSERT(backend != nullptr); - - backend = nullptr; -} +void reset_backend_for_testing() noexcept { backend = nullptr; } } // namespace sxt::cbn //-------------------------------------------------------------------------------------------------- diff --git a/cbindings/fixed_pedersen.t.cc b/cbindings/fixed_pedersen.t.cc index b4866bdf..4fabfc86 100644 --- a/cbindings/fixed_pedersen.t.cc +++ b/cbindings/fixed_pedersen.t.cc @@ -18,6 +18,7 @@ #include +#include "cbindings/backend.h" #include "sxt/base/test/unit_test.h" #include "sxt/curve21/operation/add.h" #include "sxt/curve21/operation/double.h" @@ -45,13 +46,28 @@ TEST_CASE("we can compute multi-exponentiations with a fixed set of generators") 0x456_c21, }; - const sxt_config config = {SXT_GPU_BACKEND, 0}; - REQUIRE(sxt_init(&config) == 0); + SECTION("we can compute a multiexponentiation with the gpu backend") { + cbn::reset_backend_for_testing(); + const sxt_config config = {SXT_GPU_BACKEND, 0}; + REQUIRE(sxt_init(&config) == 0); - wrapped_handle h{generators.data(), 2}; - REQUIRE(h.h != nullptr); + wrapped_handle h{generators.data(), 2}; + REQUIRE(h.h != nullptr); + + uint8_t scalars[] = {1, 0, 0, 2}; + c21t::element_p3 res; + sxt_fixed_multiexponentiation(&res, h.h, 2, 1, 2, scalars); + REQUIRE(res == generators[0] + 2 * 256 * generators[1]); + } + + SECTION("we can compute a multiexponentiation with the cpu backend") { + cbn::reset_backend_for_testing(); + const sxt_config config = {SXT_CPU_BACKEND, 0}; + REQUIRE(sxt_init(&config) == 0); + + wrapped_handle h{generators.data(), 2}; + REQUIRE(h.h != nullptr); - SECTION("we can compute a multiexponentiation") { uint8_t scalars[] = {1, 0, 0, 2}; c21t::element_p3 res; sxt_fixed_multiexponentiation(&res, h.h, 2, 1, 2, scalars); diff --git a/sxt/cbindings/backend/BUILD b/sxt/cbindings/backend/BUILD index 14ec1fa7..5967ed58 100644 --- a/sxt/cbindings/backend/BUILD +++ b/sxt/cbindings/backend/BUILD @@ -68,6 +68,7 @@ sxt_cc_component( name = "cpu_backend", impl_deps = [ "//sxt/base/error:panic", + "//sxt/cbindings/base:curve_id_utility", "//sxt/proof/transcript:transcript", "//sxt/scalar25/type:element", "//sxt/curve_bng1/operation:add", @@ -89,6 +90,7 @@ sxt_cc_component( "//sxt/execution/async:future", "//sxt/execution/schedule:scheduler", "//sxt/memory/management:managed_array", + "//sxt/multiexp/pippenger2:multiexponentiation", "//sxt/ristretto/type:compressed_element", "//sxt/ristretto/operation:compression", "//sxt/multiexp/base:exponent_sequence", diff --git a/sxt/cbindings/backend/cpu_backend.cc b/sxt/cbindings/backend/cpu_backend.cc index 0aff551c..a966ac12 100644 --- a/sxt/cbindings/backend/cpu_backend.cc +++ b/sxt/cbindings/backend/cpu_backend.cc @@ -21,6 +21,7 @@ #include "sxt/base/error/assert.h" #include "sxt/base/error/panic.h" +#include "sxt/cbindings/base/curve_id_utility.h" #include "sxt/curve21/operation/add.h" #include "sxt/curve21/operation/double.h" #include "sxt/curve21/operation/neg.h" @@ -40,6 +41,7 @@ #include "sxt/memory/management/managed_array.h" #include "sxt/multiexp/base/exponent_sequence.h" #include "sxt/multiexp/curve/multiexponentiation.h" +#include "sxt/multiexp/pippenger2/multiexponentiation.h" #include "sxt/proof/inner_product/cpu_driver.h" #include "sxt/proof/inner_product/proof_computation.h" #include "sxt/proof/inner_product/proof_descriptor.h" @@ -126,13 +128,14 @@ void cpu_backend::fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id const mtxpp2::partition_table_accessor_base& accessor, unsigned element_num_bytes, unsigned num_outputs, unsigned n, const uint8_t* scalars) const noexcept { - (void)res; - (void)curve_id; - (void)accessor; - (void)num_outputs; - (void)n; - (void)scalars; - baser::panic("not implemented yet"); + cbnb::switch_curve_type(curve_id, [&](std::type_identity) noexcept { + basct::span res_span{static_cast(res), num_outputs}; + basct::cspan scalars_span{scalars, element_num_bytes * num_outputs * n}; + mtxpp2::multiexponentiate(res_span, + static_cast&>(accessor), + element_num_bytes, scalars_span); + xens::get_scheduler().run(); + }); } //-------------------------------------------------------------------------------------------------- diff --git a/sxt/cbindings/backend/gpu_backend.cc b/sxt/cbindings/backend/gpu_backend.cc index d50b2e96..eb8ef447 100644 --- a/sxt/cbindings/backend/gpu_backend.cc +++ b/sxt/cbindings/backend/gpu_backend.cc @@ -168,7 +168,7 @@ void gpu_backend::fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id cbnb::switch_curve_type(curve_id, [&](std::type_identity) noexcept { basct::span res_span{static_cast(res), num_outputs}; basct::cspan scalars_span{scalars, element_num_bytes * num_outputs * n}; - auto fut = mtxpp2::multiexponentiate( + auto fut = mtxpp2::async_multiexponentiate( res_span, static_cast&>(accessor), element_num_bytes, scalars_span); xens::get_scheduler().run(); diff --git a/sxt/multiexp/pippenger2/BUILD b/sxt/multiexp/pippenger2/BUILD index fe13dffd..9411b56b 100644 --- a/sxt/multiexp/pippenger2/BUILD +++ b/sxt/multiexp/pippenger2/BUILD @@ -3,6 +3,11 @@ load( "sxt_cc_component", ) +sxt_cc_component( + name = "constants", + with_test = False, +) + sxt_cc_component( name = "combination", test_deps = [ @@ -41,6 +46,7 @@ sxt_cc_component( "//sxt/memory/resource:managed_device_resource", ], deps = [ + ":constants", "//sxt/algorithm/iteration:for_each", "//sxt/base/bit:iteration", "//sxt/base/bit:permutation", @@ -69,6 +75,7 @@ sxt_cc_component( "//sxt/memory/resource:managed_device_resource", ], deps = [ + ":constants", "//sxt/algorithm/iteration:for_each", "//sxt/base/bit:iteration", "//sxt/base/bit:permutation", @@ -108,6 +115,7 @@ sxt_cc_component( "//sxt/memory/resource:device_resource", ], deps = [ + ":constants", ":partition_table_accessor", "//sxt/base/container:span_utility", "//sxt/base/device:memory_utility", diff --git a/sxt/multiexp/pippenger2/constants.cc b/sxt/multiexp/pippenger2/constants.cc new file mode 100644 index 00000000..9a0f8895 --- /dev/null +++ b/sxt/multiexp/pippenger2/constants.cc @@ -0,0 +1,17 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2024-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sxt/multiexp/pippenger2/constants.h" diff --git a/sxt/multiexp/pippenger2/constants.h b/sxt/multiexp/pippenger2/constants.h new file mode 100644 index 00000000..99c9e246 --- /dev/null +++ b/sxt/multiexp/pippenger2/constants.h @@ -0,0 +1,24 @@ +/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. + * + * Copyright 2024-present Space and Time Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace sxt::mtxpp2 { +//-------------------------------------------------------------------------------------------------- +// partition_table_size_v +//-------------------------------------------------------------------------------------------------- +constexpr unsigned partition_table_size_v = 1u << 16u; +} // namespace sxt::mtxpp2 diff --git a/sxt/multiexp/pippenger2/in_memory_partition_table_accessor.h b/sxt/multiexp/pippenger2/in_memory_partition_table_accessor.h index f72eb2b1..1283ff8d 100644 --- a/sxt/multiexp/pippenger2/in_memory_partition_table_accessor.h +++ b/sxt/multiexp/pippenger2/in_memory_partition_table_accessor.h @@ -27,6 +27,7 @@ #include "sxt/base/error/panic.h" #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/pinned_resource.h" +#include "sxt/multiexp/pippenger2/constants.h" #include "sxt/multiexp/pippenger2/partition_table_accessor.h" namespace sxt::mtxpp2 { @@ -35,6 +36,7 @@ namespace sxt::mtxpp2 { //-------------------------------------------------------------------------------------------------- template class in_memory_partition_table_accessor final : public partition_table_accessor { + public: explicit in_memory_partition_table_accessor(std::string_view filename) noexcept : table_{memr::get_pinned_resource()} { @@ -54,17 +56,18 @@ class in_memory_partition_table_accessor final : public partition_table_accessor explicit in_memory_partition_table_accessor(memmg::managed_array&& table) noexcept : table_{std::move(table)} {} - void async_copy_precomputed_sums_to_device(basct::span dest, bast::raw_stream_t stream, - unsigned first) const noexcept override { - static unsigned num_entries = 1u << 16u; - SXT_DEBUG_ASSERT( - // clang-format off - table_.size() >= dest.size() + first * num_entries && - basdv::is_active_device_pointer(dest.data()) - // clang-format on - ); - basdv::async_copy_host_to_device(dest, basct::subspan(table_, first * num_entries, dest.size()), - stream); + void async_copy_to_device(basct::span dest, bast::raw_stream_t stream, + unsigned first) const noexcept override { + SXT_RELEASE_ASSERT(table_.size() >= dest.size() + first * partition_table_size_v); + SXT_DEBUG_ASSERT(basdv::is_active_device_pointer(dest.data())); + basdv::async_copy_host_to_device( + dest, basct::subspan(table_, first * partition_table_size_v, dest.size()), stream); + } + + basct::cspan host_view(std::pmr::polymorphic_allocator<> /*alloc*/, unsigned first, + unsigned size) const noexcept override { + SXT_RELEASE_ASSERT(table_.size() >= size + first * partition_table_size_v); + return basct::subspan(table_, first * partition_table_size_v, size); } void write_to_file(std::string_view filename) const noexcept override { diff --git a/sxt/multiexp/pippenger2/in_memory_partition_table_accessor.t.cc b/sxt/multiexp/pippenger2/in_memory_partition_table_accessor.t.cc index fe4d728b..8e0ad10e 100644 --- a/sxt/multiexp/pippenger2/in_memory_partition_table_accessor.t.cc +++ b/sxt/multiexp/pippenger2/in_memory_partition_table_accessor.t.cc @@ -16,6 +16,7 @@ */ #include "sxt/multiexp/pippenger2/in_memory_partition_table_accessor.h" +#include #include #include "sxt/base/curve/example_element.h" @@ -41,7 +42,7 @@ TEST_CASE("we can provide access to precomputed partition sums stored on disk") temp_file.stream().close(); in_memory_partition_table_accessor accessor{temp_file.name()}; memmg::managed_array v_dev{1, memr::get_device_resource()}; - accessor.async_copy_precomputed_sums_to_device(v_dev, stream, 0); + accessor.async_copy_to_device(v_dev, stream, 0); std::vector v(1); basdv::async_copy_device_to_host(v, v_dev, stream); basdv::synchronize_stream(stream); @@ -50,13 +51,13 @@ TEST_CASE("we can provide access to precomputed partition sums stored on disk") } SECTION("we can access a elements with offset") { - std::vector data((1u << 16u) * 2); + std::vector data(partition_table_size_v * 2); data[1u << 16u] = 12u; temp_file.stream().write(reinterpret_cast(data.data()), sizeof(E) * data.size()); temp_file.stream().close(); in_memory_partition_table_accessor accessor{temp_file.name()}; memmg::managed_array v_dev{1, memr::get_device_resource()}; - accessor.async_copy_precomputed_sums_to_device(v_dev, stream, 1); + accessor.async_copy_to_device(v_dev, stream, 1); std::vector v(1); basdv::async_copy_device_to_host(v, v_dev, stream); basdv::synchronize_stream(stream); @@ -64,8 +65,20 @@ TEST_CASE("we can provide access to precomputed partition sums stored on disk") REQUIRE(v == expected); } + SECTION("we can access elements from the host") { + std::vector data(partition_table_size_v * 2); + data[partition_table_size_v] = 12; + temp_file.stream().write(reinterpret_cast(data.data()), sizeof(E) * data.size()); + temp_file.stream().close(); + in_memory_partition_table_accessor accessor{temp_file.name()}; + std::pmr::monotonic_buffer_resource alloc; + auto v = accessor.host_view(&alloc, 1, 1); + REQUIRE(v.size() == 1); + REQUIRE(v[0] == 12); + } + SECTION("we can write an accessor to a file") { - memmg::managed_array data((1u << 16u) * 2); + memmg::managed_array data(partition_table_size_v * 2); unsigned cnt = 0; for (auto& val : data) { val = cnt++; @@ -75,7 +88,7 @@ TEST_CASE("we can provide access to precomputed partition sums stored on disk") accessor.write_to_file(temp_file.name()); in_memory_partition_table_accessor accessor_p{temp_file.name()}; memmg::managed_array data_dev{data.size(), memr::get_device_resource()}; - accessor_p.async_copy_precomputed_sums_to_device(data_dev, stream, 0); + accessor_p.async_copy_to_device(data_dev, stream, 0); memmg::managed_array data_p(data.size()); basdv::async_copy_device_to_host(data_p, data_dev, stream); basdv::synchronize_stream(stream); diff --git a/sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.h b/sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.h index b2d976fe..f7c10650 100644 --- a/sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.h +++ b/sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.h @@ -43,8 +43,8 @@ make_in_memory_partition_table_accessor(basct::cspan generators) noexcept { std::fill(iter, generators_data.end(), T::identity()); generators = generators_data; } - auto num_entries = 1u << 16u; - memmg::managed_array sums{num_entries * num_partitions, memr::get_pinned_resource()}; + memmg::managed_array sums{partition_table_size_v * num_partitions, + memr::get_pinned_resource()}; compute_partition_table(sums, generators); return std::make_unique>(std::move(sums)); } diff --git a/sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.t.cc b/sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.t.cc index 2663b265..6622cea9 100644 --- a/sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.t.cc +++ b/sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.t.cc @@ -44,7 +44,7 @@ TEST_CASE("we can create a partition table accessor from given generators") { SECTION("we can create an accessor from a single generators") { auto accessor = make_in_memory_partition_table_accessor(basct::subspan(generators, 0, 1)); - accessor->async_copy_precomputed_sums_to_device(table_dev, stream, 0); + accessor->async_copy_to_device(table_dev, stream, 0); basdv::async_copy_device_to_host(table, table_dev, stream); basdv::synchronize_stream(stream); REQUIRE(table[1] == generators[0]); @@ -53,7 +53,7 @@ TEST_CASE("we can create a partition table accessor from given generators") { SECTION("we can create an accessor from multiple generators") { auto accessor = make_in_memory_partition_table_accessor(generators); - accessor->async_copy_precomputed_sums_to_device(table_dev, stream, 0); + accessor->async_copy_to_device(table_dev, stream, 0); basdv::async_copy_device_to_host(table, table_dev, stream); basdv::synchronize_stream(stream); REQUIRE(table[1] == generators[0]); diff --git a/sxt/multiexp/pippenger2/multiexponentiation.h b/sxt/multiexp/pippenger2/multiexponentiation.h index b551d40b..f4f0d834 100644 --- a/sxt/multiexp/pippenger2/multiexponentiation.h +++ b/sxt/multiexp/pippenger2/multiexponentiation.h @@ -70,7 +70,7 @@ multiexponentiate_no_chunks(basct::span res, const partition_table_accessor products(num_products, memr::get_device_resource()); - co_await partition_product(products, accessor, scalars, 0); + co_await async_partition_product(products, accessor, scalars, 0); // reduce products basl::info("reducing {} products to {} outputs", num_products, num_products); @@ -155,7 +155,7 @@ xena::future<> multiexponentiate_impl(basct::span res, memmg::managed_array products_dev{num_products, memr::get_device_resource()}; auto scalars_slice = scalars.subspan(num_outputs * element_num_bytes * rng.a(), rng.size() * num_outputs * element_num_bytes); - co_await partition_product(products_dev, accessor, scalars_slice, rng.a()); + co_await async_partition_product(products_dev, accessor, scalars_slice, rng.a()); basdv::stream stream; basdv::async_copy_device_to_host( basct::subspan(products, num_products * chunk_index, num_products), products_dev, @@ -180,7 +180,7 @@ xena::future<> multiexponentiate_impl(basct::span res, } //-------------------------------------------------------------------------------------------------- -// multiexponentiate +// async_multiexponentiate //-------------------------------------------------------------------------------------------------- /** * Compute a multi-exponentiation using an accessor to precompute sums of partition groups. @@ -189,11 +189,39 @@ xena::future<> multiexponentiate_impl(basct::span res, * https://cacr.uwaterloo.ca/techreports/2010/cacr2010-26.pdf */ template -xena::future<> multiexponentiate(basct::span res, const partition_table_accessor& accessor, - unsigned element_num_bytes, - basct::cspan scalars) noexcept { +xena::future<> +async_multiexponentiate(basct::span res, const partition_table_accessor& accessor, + unsigned element_num_bytes, basct::cspan scalars) noexcept { multiexponentiate_options options; options.split_factor = static_cast(basdv::get_num_devices()); return multiexponentiate_impl(res, accessor, element_num_bytes, scalars, options); } + +//-------------------------------------------------------------------------------------------------- +// multiexponentiate +//-------------------------------------------------------------------------------------------------- +/** + * Host version of async_multiexponentiate. + */ +template +void multiexponentiate(basct::span res, const partition_table_accessor& accessor, + unsigned element_num_bytes, basct::cspan scalars) noexcept { + auto num_outputs = res.size(); + auto n = scalars.size() / (num_outputs * element_num_bytes); + auto num_products = num_outputs * element_num_bytes * 8u; + SXT_DEBUG_ASSERT( + // clang-format off + scalars.size() % (num_outputs * element_num_bytes) == 0 + // clang-format on + ); + + // compute bitwise products + basl::info("computing {} bitwise multiexponentiation products of length {}", num_products, n); + memmg::managed_array products(num_products); + partition_product(products, accessor, scalars, 0); + + // reduce products + basl::info("reducing {} products to {} outputs", num_products, num_products); + reduce_products(res, products); +} } // namespace sxt::mtxpp2 diff --git a/sxt/multiexp/pippenger2/multiexponentiation.t.cc b/sxt/multiexp/pippenger2/multiexponentiation.t.cc index 8efb9ae2..692e5006 100644 --- a/sxt/multiexp/pippenger2/multiexponentiation.t.cc +++ b/sxt/multiexp/pippenger2/multiexponentiation.t.cc @@ -50,7 +50,7 @@ TEST_CASE("we can compute multiexponentiations using a precomputed table of part std::vector res(1); SECTION("we can compute a multiexponentiation with a zero scalar") { - auto fut = multiexponentiate(res, *accessor, 1, scalars); + auto fut = async_multiexponentiate(res, *accessor, 1, scalars); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(res[0] == E::identity()); @@ -58,7 +58,7 @@ TEST_CASE("we can compute multiexponentiations using a precomputed table of part SECTION("we can compute a multiexponentiation multiexponentiation with a scalar of one") { scalars[0] = 1; - auto fut = multiexponentiate(res, *accessor, 1, scalars); + auto fut = async_multiexponentiate(res, *accessor, 1, scalars); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(res[0] == generators[0]); @@ -66,7 +66,7 @@ TEST_CASE("we can compute multiexponentiations using a precomputed table of part SECTION("we can compute a multiexponentiation with a scalar of two") { scalars[0] = 2; - auto fut = multiexponentiate(res, *accessor, 1, scalars); + auto fut = async_multiexponentiate(res, *accessor, 1, scalars); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(res[0] == 2u * generators[0].value); @@ -74,7 +74,7 @@ TEST_CASE("we can compute multiexponentiations using a precomputed table of part SECTION("we can compute a multiexponentiation with a scalar of three") { scalars[0] = 3; - auto fut = multiexponentiate(res, *accessor, 1, scalars); + auto fut = async_multiexponentiate(res, *accessor, 1, scalars); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(res[0] == 3u * generators[0].value); @@ -84,7 +84,7 @@ TEST_CASE("we can compute multiexponentiations using a precomputed table of part scalars.resize(2); scalars[0] = 1; scalars[1] = 1; - auto fut = multiexponentiate(res, *accessor, 1, scalars); + auto fut = async_multiexponentiate(res, *accessor, 1, scalars); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(res[0] == generators[0].value + generators[1].value); @@ -94,7 +94,7 @@ TEST_CASE("we can compute multiexponentiations using a precomputed table of part scalars.resize(17); scalars[0] = 1; scalars[16] = 1; - auto fut = multiexponentiate(res, *accessor, 1, scalars); + auto fut = async_multiexponentiate(res, *accessor, 1, scalars); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(res[0] == generators[0].value + generators[16].value); @@ -105,13 +105,23 @@ TEST_CASE("we can compute multiexponentiations using a precomputed table of part scalars.resize(2); scalars[0] = 1u; scalars[1] = 2u; - auto fut = multiexponentiate(res, *accessor, 1, scalars); + auto fut = async_multiexponentiate(res, *accessor, 1, scalars); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(res[0] == generators[0].value); REQUIRE(res[1] == 2u * generators[0].value); } + SECTION("we can compute a multiexponentiation on the host") { + res.resize(2); + scalars.resize(2); + scalars[0] = 1u; + scalars[1] = 2u; + multiexponentiate(res, *accessor, 1, scalars); + REQUIRE(res[0] == generators[0].value); + REQUIRE(res[1] == 2u * generators[0].value); + } + SECTION("we can split a multi-exponentiation") { multiexponentiate_options options{ .split_factor = 2, @@ -161,7 +171,7 @@ TEST_CASE("we can compute multiexponentiations with curve-21") { SECTION("we can compute a multiexponentiation multiexponentiation with a scalar of one") { scalars[0] = 1; - auto fut = multiexponentiate(res, *accessor, 1, scalars); + auto fut = async_multiexponentiate(res, *accessor, 1, scalars); xens::get_scheduler().run(); REQUIRE(fut.ready()); REQUIRE(res[0] == generators[0]); diff --git a/sxt/multiexp/pippenger2/partition_product.h b/sxt/multiexp/pippenger2/partition_product.h index 5152e7d5..58d71f21 100644 --- a/sxt/multiexp/pippenger2/partition_product.h +++ b/sxt/multiexp/pippenger2/partition_product.h @@ -18,6 +18,7 @@ #include #include +#include #include "sxt/algorithm/iteration/for_each.h" #include "sxt/base/container/span.h" @@ -32,6 +33,7 @@ #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/async_device_resource.h" #include "sxt/memory/resource/device_resource.h" +#include "sxt/multiexp/pippenger2/constants.h" #include "sxt/multiexp/pippenger2/partition_table_accessor.h" namespace sxt::mtxpp2 { @@ -90,21 +92,26 @@ partition_product_kernel(T* __restrict__ products, const T* __restrict__ partiti } //-------------------------------------------------------------------------------------------------- -// partition_product +// async_partition_product //-------------------------------------------------------------------------------------------------- /** * Compute the multiproduct for the bits of an array of scalars using an accessor to * precomputed sums for each group of generators. */ template -xena::future<> partition_product(basct::span products, - const partition_table_accessor& accessor, - basct::cspan scalars, unsigned offset) noexcept { +xena::future<> async_partition_product(basct::span products, + const partition_table_accessor& accessor, + basct::cspan scalars, unsigned offset) noexcept { auto num_products = products.size(); auto n = static_cast(scalars.size() * 8u / num_products); auto num_partitions = basn::divide_up(n, 16u); - auto num_table_entries = 1u << 16u; - SXT_DEBUG_ASSERT(offset % 16u == 0); + SXT_DEBUG_ASSERT( + // clang-format off + offset % 16u == 0 && + basdv::is_active_device_pointer(products.data()) && + basdv::is_host_pointer(scalars.data()) + // clang-format on + ); // scalars_dev memmg::managed_array scalars_dev{scalars.size(), memr::get_device_resource()}; @@ -117,8 +124,8 @@ xena::future<> partition_product(basct::span products, // partition_table basdv::stream stream; memr::async_device_resource resource{stream}; - memmg::managed_array partition_table{num_partitions * num_table_entries, &resource}; - accessor.async_copy_precomputed_sums_to_device(partition_table, stream, offset / 16u); + memmg::managed_array partition_table{num_partitions * partition_table_size_v, &resource}; + accessor.async_copy_to_device(partition_table, stream, offset / 16u); co_await std::move(scalars_fut); // product @@ -139,4 +146,34 @@ xena::future<> partition_product(basct::span products, algi::launch_for_each_kernel(stream, f, num_products); co_await xendv::await_stream(stream); } + +//-------------------------------------------------------------------------------------------------- +// partition_product +//-------------------------------------------------------------------------------------------------- +/** + * Compute the multiproduct for the bits of an array of scalars using an accessor to + * precomputed sums for each group of generators. + */ +template +void partition_product(basct::span products, const partition_table_accessor& accessor, + basct::cspan scalars, unsigned offset) noexcept { + auto num_products = products.size(); + auto n = static_cast(scalars.size() * 8u / num_products); + SXT_DEBUG_ASSERT( + // clang-format off + offset % 16u == 0 + // clang-format on + ); + std::pmr::monotonic_buffer_resource alloc; + + auto partition_table = + accessor.host_view(&alloc, offset, basn::divide_up(n, 16u) * partition_table_size_v); + + for (unsigned product_index = 0; product_index < num_products; ++product_index) { + auto byte_index = product_index / 8u; + auto bit_offset = product_index % 8u; + partition_product_kernel(products.data(), partition_table.data(), scalars.data(), byte_index, + bit_offset, num_products, n); + } +} } // namespace sxt::mtxpp2 diff --git a/sxt/multiexp/pippenger2/partition_product.t.cc b/sxt/multiexp/pippenger2/partition_product.t.cc index 044f9325..48503948 100644 --- a/sxt/multiexp/pippenger2/partition_product.t.cc +++ b/sxt/multiexp/pippenger2/partition_product.t.cc @@ -25,6 +25,7 @@ #include "sxt/execution/schedule/scheduler.h" #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/managed_device_resource.h" +#include "sxt/multiexp/pippenger2/constants.h" #include "sxt/multiexp/pippenger2/in_memory_partition_table_accessor.h" using namespace sxt; @@ -66,7 +67,6 @@ TEST_CASE("we can compute the index used to lookup the precomputed sum for a par } TEST_CASE("we can compute the product of partitions") { - constexpr auto num_entries = 1u << 16u; using E = bascrv::element97; memmg::managed_array products{8, memr::get_managed_device_resource()}; std::vector scalars(1); @@ -88,7 +88,7 @@ TEST_CASE("we can compute the product of partitions") { SECTION("we handle a product with a single scalar") { scalars[0] = 1; - auto fut = partition_product(products, accessor, scalars, 0); + auto fut = async_partition_product(products, accessor, scalars, 0); xens::get_scheduler().run(); REQUIRE(fut.ready()); basdv::synchronize_device(); @@ -98,17 +98,17 @@ TEST_CASE("we can compute the product of partitions") { SECTION("we handle a product with an offset") { scalars[0] = 1; - auto fut = partition_product(products, accessor, scalars, 16); + auto fut = async_partition_product(products, accessor, scalars, 16); xens::get_scheduler().run(); REQUIRE(fut.ready()); basdv::synchronize_device(); - expected[0] = partition_table[num_entries + 1]; + expected[0] = partition_table[partition_table_size_v + 1]; REQUIRE(products == expected); } SECTION("we handle a product with two scalars") { scalars = {1u, 3u}; - auto fut = partition_product(products, accessor, scalars, 0); + auto fut = async_partition_product(products, accessor, scalars, 0); xens::get_scheduler().run(); REQUIRE(fut.ready()); basdv::synchronize_device(); @@ -121,7 +121,7 @@ TEST_CASE("we can compute the product of partitions") { scalars.resize(16); scalars[0] = 1u; scalars[15] = 1u; - auto fut = partition_product(products, accessor, scalars, 0); + auto fut = async_partition_product(products, accessor, scalars, 0); xens::get_scheduler().run(); REQUIRE(fut.ready()); basdv::synchronize_device(); @@ -133,11 +133,20 @@ TEST_CASE("we can compute the product of partitions") { scalars.resize(32); scalars[0] = 1u; scalars[16] = 1u; - auto fut = partition_product(products, accessor, scalars, 0); + auto fut = async_partition_product(products, accessor, scalars, 0); xens::get_scheduler().run(); REQUIRE(fut.ready()); basdv::synchronize_device(); - expected[0] = partition_table[1].value + partition_table[num_entries + 1].value; + expected[0] = partition_table[1].value + partition_table[partition_table_size_v + 1].value; + REQUIRE(products == expected); + } + + SECTION("we can compute products on the host") { + scalars.resize(32); + scalars[0] = 1u; + scalars[16] = 1u; + partition_product(products, accessor, scalars, 0); + expected[0] = partition_table[1].value + partition_table[partition_table_size_v + 1].value; REQUIRE(products == expected); } } diff --git a/sxt/multiexp/pippenger2/partition_table.h b/sxt/multiexp/pippenger2/partition_table.h index cfa310a9..ddbb66e1 100644 --- a/sxt/multiexp/pippenger2/partition_table.h +++ b/sxt/multiexp/pippenger2/partition_table.h @@ -23,6 +23,7 @@ #include "sxt/base/curve/element.h" #include "sxt/base/error/assert.h" #include "sxt/base/macro/cuda_callable.h" +#include "sxt/multiexp/pippenger2/constants.h" namespace sxt::mtxpp2 { //-------------------------------------------------------------------------------------------------- @@ -70,16 +71,15 @@ CUDA_CALLABLE void compute_partition_table_slice(T* __restrict__ sums, */ template void compute_partition_table(basct::span sums, basct::cspan generators) noexcept { - auto num_entries = 1u << 16u; SXT_DEBUG_ASSERT( // clang-format off - sums.size() == num_entries * generators.size() / 16u && + sums.size() == partition_table_size_v * generators.size() / 16u && generators.size() % 16 == 0 // clang-format on ); auto n = generators.size() / 16u; for (unsigned i = 0; i < n; ++i) { - auto sums_slice = sums.subspan(i * num_entries, num_entries); + auto sums_slice = sums.subspan(i * partition_table_size_v, partition_table_size_v); auto generators_slice = generators.subspan(i * 16u, 16u); compute_partition_table_slice(sums_slice.data(), generators_slice.data()); } diff --git a/sxt/multiexp/pippenger2/partition_table.t.cc b/sxt/multiexp/pippenger2/partition_table.t.cc index d0303748..9c860179 100644 --- a/sxt/multiexp/pippenger2/partition_table.t.cc +++ b/sxt/multiexp/pippenger2/partition_table.t.cc @@ -21,6 +21,7 @@ #include "sxt/base/bit/iteration.h" #include "sxt/base/curve/example_element.h" #include "sxt/base/test/unit_test.h" +#include "sxt/multiexp/pippenger2/constants.h" using namespace sxt; using namespace sxt::mtxpp2; @@ -44,13 +45,12 @@ TEST_CASE("we can compute a slice of the partition table") { TEST_CASE("we can compute the full partition table") { using E = bascrv::element97; auto n = 2u; - auto num_entries = 1u << 16u; - std::vector sums(num_entries * n); + std::vector sums(partition_table_size_v * n); std::vector generators(16u * n); for (unsigned i = 0; i < generators.size(); ++i) { generators[i] = i + 1u; } compute_partition_table(sums, generators); REQUIRE(sums[1] == generators[0]); - REQUIRE(sums[num_entries + 1] == generators[16]); + REQUIRE(sums[partition_table_size_v + 1] == generators[16]); } diff --git a/sxt/multiexp/pippenger2/partition_table_accessor.h b/sxt/multiexp/pippenger2/partition_table_accessor.h index 645fd4eb..2bb5a20d 100644 --- a/sxt/multiexp/pippenger2/partition_table_accessor.h +++ b/sxt/multiexp/pippenger2/partition_table_accessor.h @@ -16,6 +16,7 @@ */ #pragma once +#include #include #include "sxt/base/container/span.h" @@ -27,10 +28,35 @@ namespace sxt::mtxpp2 { //-------------------------------------------------------------------------------------------------- // partition_table_accessor //-------------------------------------------------------------------------------------------------- +/** + * Support accessing precomputed sums for groups of 16 generators. + * + * For example, if there are 32 generators + * + * g0, ..., g15, g16, ..., g31 + * + * an accessor will contain two tables each of 2^16 entries with all the sums of + * generators g0 to g15 and all the sums of generators g16 to g31, respectively. + */ template class partition_table_accessor : public partition_table_accessor_base { public: - virtual void async_copy_precomputed_sums_to_device(basct::span dest, bast::raw_stream_t stream, - unsigned first) const noexcept = 0; + /** + * Asynchronously copy precomputed sums of partitions to device. + * + * `first` specifies the partition group offset to use. + */ + virtual void async_copy_to_device(basct::span dest, bast::raw_stream_t stream, + unsigned first) const noexcept = 0; + + /** + * Make a view into precomputed sums of partitions available to host memory. + * + * `first` specifies the partition group offset to use. If memory needs to be allocated + * to make the view available, it will be allocated using alloc. Make sure that alloc uses + * a resource that frees memory upon destruction (e.g. std::pmr::monotonic_buffer_resource). + */ + virtual basct::cspan host_view(std::pmr::polymorphic_allocator<> alloc, unsigned first, + unsigned size) const noexcept = 0; virtual void write_to_file(std::string_view filename) const noexcept = 0; }; diff --git a/sxt/multiexp/pippenger2/reduce.h b/sxt/multiexp/pippenger2/reduce.h index eccf2ed2..9ef73f25 100644 --- a/sxt/multiexp/pippenger2/reduce.h +++ b/sxt/multiexp/pippenger2/reduce.h @@ -69,4 +69,15 @@ void reduce_products(basct::span reductions, bast::raw_stream_t stream, }; algi::launch_for_each_kernel(stream, f, num_outputs); } + +template +void reduce_products(basct::span reductions, basct::cspan products) noexcept { + auto num_outputs = reductions.size(); + auto reduction_size = products.size() / reductions.size(); + SXT_DEBUG_ASSERT(products.size() == reduction_size * num_outputs); + for (unsigned output_index = 0; output_index < num_outputs; ++output_index) { + reduce_output(reductions.data() + output_index, products.data() + output_index * reduction_size, + reduction_size); + } +} } // namespace sxt::mtxpp2 diff --git a/sxt/multiexp/pippenger2/reduce.t.cc b/sxt/multiexp/pippenger2/reduce.t.cc index 81987f02..150f99ae 100644 --- a/sxt/multiexp/pippenger2/reduce.t.cc +++ b/sxt/multiexp/pippenger2/reduce.t.cc @@ -55,4 +55,14 @@ TEST_CASE("we can reduce products") { expected = {123u + 2u * 456u}; REQUIRE(outputs == expected); } + + SECTION("we can reduce products on the host") { + outputs.resize(1); + products.resize(2); + products[0] = 123u; + products[1] = 456u; + reduce_products(outputs, products); + expected = {123u + 2u * 456u}; + REQUIRE(outputs == expected); + } }