Skip to content

Commit

Permalink
feat: add host support for multiexponentiations using precomputed par…
Browse files Browse the repository at this point in the history
…tition method (PROOF-831) (#131)

* add more assertions

* work on host version of mx code

* rename

* work on host version

* work on host version of partition product

* add tests

* fill in host reduction

* work on host-side code

* rename

* add stub for host version

* fill in host version of multiexponentiation

* test host multiexponentiation

* impl cpu_backend

* test cpu backend

* reformat

* rename

* fixes

* rework partition table accessor

* add more documentation

* remove unused header

* add constant

* clean up
  • Loading branch information
rnburn committed May 22, 2024
1 parent 8a789ed commit 25788e0
Show file tree
Hide file tree
Showing 23 changed files with 294 additions and 80 deletions.
8 changes: 4 additions & 4 deletions benchmark/multi_exp_pip/benchmark.m.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,17 @@ int main(int argc, char* argv[]) {

// discard initial run
{
auto fut =
mtxpp2::multiexponentiate<c21t::element_p3>(res, *accessor, element_num_bytes, exponents);
auto fut = mtxpp2::async_multiexponentiate<c21t::element_p3>(res, *accessor, element_num_bytes,
exponents);
xens::get_scheduler().run();
}

// run benchmark
double times = 0;
for (unsigned i = 0; i < num_samples; ++i) {
auto t1 = std::chrono::steady_clock::now();
auto fut =
mtxpp2::multiexponentiate<c21t::element_p3>(res, *accessor, element_num_bytes, exponents);
auto fut = mtxpp2::async_multiexponentiate<c21t::element_p3>(res, *accessor, element_num_bytes,
exponents);
xens::get_scheduler().run();
auto t2 = std::chrono::steady_clock::now();
auto elapse = std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1);
Expand Down
1 change: 1 addition & 0 deletions cbindings/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ sxt_cc_component(
"//sxt/cbindings/base:multiexp_handle",
],
test_deps = [
":backend",
"//sxt/base/test:unit_test",
"//sxt/curve21/operation:add",
"//sxt/curve21/operation:double",
Expand Down
6 changes: 1 addition & 5 deletions cbindings/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,7 @@ cbnbck::computational_backend* get_backend() noexcept {
//--------------------------------------------------------------------------------------------------
// reset_backend_for_testing
//--------------------------------------------------------------------------------------------------
void reset_backend_for_testing() noexcept {
SXT_RELEASE_ASSERT(backend != nullptr);

backend = nullptr;
}
void reset_backend_for_testing() noexcept { backend = nullptr; }
} // namespace sxt::cbn

//--------------------------------------------------------------------------------------------------
Expand Down
26 changes: 21 additions & 5 deletions cbindings/fixed_pedersen.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <vector>

#include "cbindings/backend.h"
#include "sxt/base/test/unit_test.h"
#include "sxt/curve21/operation/add.h"
#include "sxt/curve21/operation/double.h"
Expand Down Expand Up @@ -45,13 +46,28 @@ TEST_CASE("we can compute multi-exponentiations with a fixed set of generators")
0x456_c21,
};

const sxt_config config = {SXT_GPU_BACKEND, 0};
REQUIRE(sxt_init(&config) == 0);
SECTION("we can compute a multiexponentiation with the gpu backend") {
cbn::reset_backend_for_testing();
const sxt_config config = {SXT_GPU_BACKEND, 0};
REQUIRE(sxt_init(&config) == 0);

wrapped_handle h{generators.data(), 2};
REQUIRE(h.h != nullptr);
wrapped_handle h{generators.data(), 2};
REQUIRE(h.h != nullptr);

uint8_t scalars[] = {1, 0, 0, 2};
c21t::element_p3 res;
sxt_fixed_multiexponentiation(&res, h.h, 2, 1, 2, scalars);
REQUIRE(res == generators[0] + 2 * 256 * generators[1]);
}

SECTION("we can compute a multiexponentiation with the cpu backend") {
cbn::reset_backend_for_testing();
const sxt_config config = {SXT_CPU_BACKEND, 0};
REQUIRE(sxt_init(&config) == 0);

wrapped_handle h{generators.data(), 2};
REQUIRE(h.h != nullptr);

SECTION("we can compute a multiexponentiation") {
uint8_t scalars[] = {1, 0, 0, 2};
c21t::element_p3 res;
sxt_fixed_multiexponentiation(&res, h.h, 2, 1, 2, scalars);
Expand Down
2 changes: 2 additions & 0 deletions sxt/cbindings/backend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ sxt_cc_component(
name = "cpu_backend",
impl_deps = [
"//sxt/base/error:panic",
"//sxt/cbindings/base:curve_id_utility",
"//sxt/proof/transcript:transcript",
"//sxt/scalar25/type:element",
"//sxt/curve_bng1/operation:add",
Expand All @@ -89,6 +90,7 @@ sxt_cc_component(
"//sxt/execution/async:future",
"//sxt/execution/schedule:scheduler",
"//sxt/memory/management:managed_array",
"//sxt/multiexp/pippenger2:multiexponentiation",
"//sxt/ristretto/type:compressed_element",
"//sxt/ristretto/operation:compression",
"//sxt/multiexp/base:exponent_sequence",
Expand Down
17 changes: 10 additions & 7 deletions sxt/cbindings/backend/cpu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "sxt/base/error/assert.h"
#include "sxt/base/error/panic.h"
#include "sxt/cbindings/base/curve_id_utility.h"
#include "sxt/curve21/operation/add.h"
#include "sxt/curve21/operation/double.h"
#include "sxt/curve21/operation/neg.h"
Expand All @@ -40,6 +41,7 @@
#include "sxt/memory/management/managed_array.h"
#include "sxt/multiexp/base/exponent_sequence.h"
#include "sxt/multiexp/curve/multiexponentiation.h"
#include "sxt/multiexp/pippenger2/multiexponentiation.h"
#include "sxt/proof/inner_product/cpu_driver.h"
#include "sxt/proof/inner_product/proof_computation.h"
#include "sxt/proof/inner_product/proof_descriptor.h"
Expand Down Expand Up @@ -126,13 +128,14 @@ void cpu_backend::fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id
const mtxpp2::partition_table_accessor_base& accessor,
unsigned element_num_bytes, unsigned num_outputs,
unsigned n, const uint8_t* scalars) const noexcept {
(void)res;
(void)curve_id;
(void)accessor;
(void)num_outputs;
(void)n;
(void)scalars;
baser::panic("not implemented yet");
cbnb::switch_curve_type(curve_id, [&]<class T>(std::type_identity<T>) noexcept {
basct::span<T> res_span{static_cast<T*>(res), num_outputs};
basct::cspan<uint8_t> scalars_span{scalars, element_num_bytes * num_outputs * n};
mtxpp2::multiexponentiate<T>(res_span,
static_cast<const mtxpp2::partition_table_accessor<T>&>(accessor),
element_num_bytes, scalars_span);
xens::get_scheduler().run();
});
}

//--------------------------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion sxt/cbindings/backend/gpu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ void gpu_backend::fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id
cbnb::switch_curve_type(curve_id, [&]<class T>(std::type_identity<T>) noexcept {
basct::span<T> res_span{static_cast<T*>(res), num_outputs};
basct::cspan<uint8_t> scalars_span{scalars, element_num_bytes * num_outputs * n};
auto fut = mtxpp2::multiexponentiate<T>(
auto fut = mtxpp2::async_multiexponentiate<T>(
res_span, static_cast<const mtxpp2::partition_table_accessor<T>&>(accessor),
element_num_bytes, scalars_span);
xens::get_scheduler().run();
Expand Down
8 changes: 8 additions & 0 deletions sxt/multiexp/pippenger2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ load(
"sxt_cc_component",
)

sxt_cc_component(
name = "constants",
with_test = False,
)

sxt_cc_component(
name = "combination",
test_deps = [
Expand Down Expand Up @@ -41,6 +46,7 @@ sxt_cc_component(
"//sxt/memory/resource:managed_device_resource",
],
deps = [
":constants",
"//sxt/algorithm/iteration:for_each",
"//sxt/base/bit:iteration",
"//sxt/base/bit:permutation",
Expand Down Expand Up @@ -69,6 +75,7 @@ sxt_cc_component(
"//sxt/memory/resource:managed_device_resource",
],
deps = [
":constants",
"//sxt/algorithm/iteration:for_each",
"//sxt/base/bit:iteration",
"//sxt/base/bit:permutation",
Expand Down Expand Up @@ -108,6 +115,7 @@ sxt_cc_component(
"//sxt/memory/resource:device_resource",
],
deps = [
":constants",
":partition_table_accessor",
"//sxt/base/container:span_utility",
"//sxt/base/device:memory_utility",
Expand Down
17 changes: 17 additions & 0 deletions sxt/multiexp/pippenger2/constants.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/multiexp/pippenger2/constants.h"
24 changes: 24 additions & 0 deletions sxt/multiexp/pippenger2/constants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2024-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

namespace sxt::mtxpp2 {
//--------------------------------------------------------------------------------------------------
// partition_table_size_v
//--------------------------------------------------------------------------------------------------
constexpr unsigned partition_table_size_v = 1u << 16u;
} // namespace sxt::mtxpp2
25 changes: 14 additions & 11 deletions sxt/multiexp/pippenger2/in_memory_partition_table_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "sxt/base/error/panic.h"
#include "sxt/memory/management/managed_array.h"
#include "sxt/memory/resource/pinned_resource.h"
#include "sxt/multiexp/pippenger2/constants.h"
#include "sxt/multiexp/pippenger2/partition_table_accessor.h"

namespace sxt::mtxpp2 {
Expand All @@ -35,6 +36,7 @@ namespace sxt::mtxpp2 {
//--------------------------------------------------------------------------------------------------
template <bascrv::element T>
class in_memory_partition_table_accessor final : public partition_table_accessor<T> {

public:
explicit in_memory_partition_table_accessor(std::string_view filename) noexcept
: table_{memr::get_pinned_resource()} {
Expand All @@ -54,17 +56,18 @@ class in_memory_partition_table_accessor final : public partition_table_accessor
explicit in_memory_partition_table_accessor(memmg::managed_array<T>&& table) noexcept
: table_{std::move(table)} {}

void async_copy_precomputed_sums_to_device(basct::span<T> dest, bast::raw_stream_t stream,
unsigned first) const noexcept override {
static unsigned num_entries = 1u << 16u;
SXT_DEBUG_ASSERT(
// clang-format off
table_.size() >= dest.size() + first * num_entries &&
basdv::is_active_device_pointer(dest.data())
// clang-format on
);
basdv::async_copy_host_to_device(dest, basct::subspan(table_, first * num_entries, dest.size()),
stream);
void async_copy_to_device(basct::span<T> dest, bast::raw_stream_t stream,
unsigned first) const noexcept override {
SXT_RELEASE_ASSERT(table_.size() >= dest.size() + first * partition_table_size_v);
SXT_DEBUG_ASSERT(basdv::is_active_device_pointer(dest.data()));
basdv::async_copy_host_to_device(
dest, basct::subspan(table_, first * partition_table_size_v, dest.size()), stream);
}

basct::cspan<T> host_view(std::pmr::polymorphic_allocator<> /*alloc*/, unsigned first,
unsigned size) const noexcept override {
SXT_RELEASE_ASSERT(table_.size() >= size + first * partition_table_size_v);
return basct::subspan(table_, first * partition_table_size_v, size);
}

void write_to_file(std::string_view filename) const noexcept override {
Expand Down
23 changes: 18 additions & 5 deletions sxt/multiexp/pippenger2/in_memory_partition_table_accessor.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
#include "sxt/multiexp/pippenger2/in_memory_partition_table_accessor.h"

#include <memory_resource>
#include <vector>

#include "sxt/base/curve/example_element.h"
Expand All @@ -41,7 +42,7 @@ TEST_CASE("we can provide access to precomputed partition sums stored on disk")
temp_file.stream().close();
in_memory_partition_table_accessor<E> accessor{temp_file.name()};
memmg::managed_array<E> v_dev{1, memr::get_device_resource()};
accessor.async_copy_precomputed_sums_to_device(v_dev, stream, 0);
accessor.async_copy_to_device(v_dev, stream, 0);
std::vector<E> v(1);
basdv::async_copy_device_to_host(v, v_dev, stream);
basdv::synchronize_stream(stream);
Expand All @@ -50,22 +51,34 @@ TEST_CASE("we can provide access to precomputed partition sums stored on disk")
}

SECTION("we can access a elements with offset") {
std::vector<E> data((1u << 16u) * 2);
std::vector<E> data(partition_table_size_v * 2);
data[1u << 16u] = 12u;
temp_file.stream().write(reinterpret_cast<const char*>(data.data()), sizeof(E) * data.size());
temp_file.stream().close();
in_memory_partition_table_accessor<E> accessor{temp_file.name()};
memmg::managed_array<E> v_dev{1, memr::get_device_resource()};
accessor.async_copy_precomputed_sums_to_device(v_dev, stream, 1);
accessor.async_copy_to_device(v_dev, stream, 1);
std::vector<E> v(1);
basdv::async_copy_device_to_host(v, v_dev, stream);
basdv::synchronize_stream(stream);
std::vector<E> expected = {12u};
REQUIRE(v == expected);
}

SECTION("we can access elements from the host") {
std::vector<E> data(partition_table_size_v * 2);
data[partition_table_size_v] = 12;
temp_file.stream().write(reinterpret_cast<const char*>(data.data()), sizeof(E) * data.size());
temp_file.stream().close();
in_memory_partition_table_accessor<E> accessor{temp_file.name()};
std::pmr::monotonic_buffer_resource alloc;
auto v = accessor.host_view(&alloc, 1, 1);
REQUIRE(v.size() == 1);
REQUIRE(v[0] == 12);
}

SECTION("we can write an accessor to a file") {
memmg::managed_array<E> data((1u << 16u) * 2);
memmg::managed_array<E> data(partition_table_size_v * 2);
unsigned cnt = 0;
for (auto& val : data) {
val = cnt++;
Expand All @@ -75,7 +88,7 @@ TEST_CASE("we can provide access to precomputed partition sums stored on disk")
accessor.write_to_file(temp_file.name());
in_memory_partition_table_accessor<E> accessor_p{temp_file.name()};
memmg::managed_array<E> data_dev{data.size(), memr::get_device_resource()};
accessor_p.async_copy_precomputed_sums_to_device(data_dev, stream, 0);
accessor_p.async_copy_to_device(data_dev, stream, 0);
memmg::managed_array<E> data_p(data.size());
basdv::async_copy_device_to_host(data_p, data_dev, stream);
basdv::synchronize_stream(stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ make_in_memory_partition_table_accessor(basct::cspan<T> generators) noexcept {
std::fill(iter, generators_data.end(), T::identity());
generators = generators_data;
}
auto num_entries = 1u << 16u;
memmg::managed_array<T> sums{num_entries * num_partitions, memr::get_pinned_resource()};
memmg::managed_array<T> sums{partition_table_size_v * num_partitions,
memr::get_pinned_resource()};
compute_partition_table<T>(sums, generators);
return std::make_unique<in_memory_partition_table_accessor<T>>(std::move(sums));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ TEST_CASE("we can create a partition table accessor from given generators") {

SECTION("we can create an accessor from a single generators") {
auto accessor = make_in_memory_partition_table_accessor<E>(basct::subspan(generators, 0, 1));
accessor->async_copy_precomputed_sums_to_device(table_dev, stream, 0);
accessor->async_copy_to_device(table_dev, stream, 0);
basdv::async_copy_device_to_host(table, table_dev, stream);
basdv::synchronize_stream(stream);
REQUIRE(table[1] == generators[0]);
Expand All @@ -53,7 +53,7 @@ TEST_CASE("we can create a partition table accessor from given generators") {

SECTION("we can create an accessor from multiple generators") {
auto accessor = make_in_memory_partition_table_accessor<E>(generators);
accessor->async_copy_precomputed_sums_to_device(table_dev, stream, 0);
accessor->async_copy_to_device(table_dev, stream, 0);
basdv::async_copy_device_to_host(table, table_dev, stream);
basdv::synchronize_stream(stream);
REQUIRE(table[1] == generators[0]);
Expand Down

0 comments on commit 25788e0

Please sign in to comment.