Skip to content

Commit

Permalink
feat: don't use pinned memory for cpu backend (PROOF-831) (#132)
Browse files Browse the repository at this point in the history
* refactor interface

* rework cpu_backend

* customize partition table accessor

* reformat
  • Loading branch information
rnburn committed May 23, 2024
1 parent 25788e0 commit c7e9b40
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 29 deletions.
6 changes: 2 additions & 4 deletions sxt/cbindings/backend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@ load(

sxt_cc_component(
name = "computational_backend",
impl_deps = [
"//sxt/cbindings/base:curve_id_utility",
"//sxt/multiexp/pippenger2:in_memory_partition_table_accessor_utility",
],
with_test = False,
deps = [
"//sxt/base/container:span",
Expand Down Expand Up @@ -51,6 +47,7 @@ sxt_cc_component(
"//sxt/ristretto/operation:compression",
"//sxt/multiexp/base:exponent_sequence",
"//sxt/multiexp/curve:multiexponentiation",
"//sxt/multiexp/pippenger2:in_memory_partition_table_accessor_utility",
"//sxt/multiexp/pippenger2:multiexponentiation",
"//sxt/seqcommit/generator:precomputed_generators",
"//sxt/proof/inner_product:proof_descriptor",
Expand Down Expand Up @@ -90,6 +87,7 @@ sxt_cc_component(
"//sxt/execution/async:future",
"//sxt/execution/schedule:scheduler",
"//sxt/memory/management:managed_array",
"//sxt/multiexp/pippenger2:in_memory_partition_table_accessor_utility",
"//sxt/multiexp/pippenger2:multiexponentiation",
"//sxt/ristretto/type:compressed_element",
"//sxt/ristretto/operation:compression",
Expand Down
20 changes: 0 additions & 20 deletions sxt/cbindings/backend/computational_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,3 @@
* limitations under the License.
*/
#include "sxt/cbindings/backend/computational_backend.h"

#include "sxt/cbindings/base/curve_id_utility.h"
#include "sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.h"

namespace sxt::cbnbck {
//--------------------------------------------------------------------------------------------------
// make_partition_table_accessor
//--------------------------------------------------------------------------------------------------
std::unique_ptr<mtxpp2::partition_table_accessor_base>
computational_backend::make_partition_table_accessor(cbnb::curve_id_t curve_id,
const void* generators,
unsigned n) const noexcept {
std::unique_ptr<mtxpp2::partition_table_accessor_base> res;
cbnb::switch_curve_type(curve_id, [&]<class T>(std::type_identity<T>) noexcept {
res = mtxpp2::make_in_memory_partition_table_accessor<T>(
basct::cspan<T>{static_cast<const T*>(generators), n});
});
return res;
}
} // namespace sxt::cbnbck
2 changes: 1 addition & 1 deletion sxt/cbindings/backend/computational_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class computational_backend {

virtual std::unique_ptr<mtxpp2::partition_table_accessor_base>
make_partition_table_accessor(cbnb::curve_id_t curve_id, const void* generators,
unsigned n) const noexcept;
unsigned n) const noexcept = 0;

virtual void fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id,
const mtxpp2::partition_table_accessor_base& accessor,
Expand Down
15 changes: 15 additions & 0 deletions sxt/cbindings/backend/cpu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "sxt/memory/management/managed_array.h"
#include "sxt/multiexp/base/exponent_sequence.h"
#include "sxt/multiexp/curve/multiexponentiation.h"
#include "sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.h"
#include "sxt/multiexp/pippenger2/multiexponentiation.h"
#include "sxt/proof/inner_product/cpu_driver.h"
#include "sxt/proof/inner_product/proof_computation.h"
Expand Down Expand Up @@ -121,6 +122,20 @@ bool cpu_backend::verify_inner_product(prft::transcript& transcript,
.value();
}

//--------------------------------------------------------------------------------------------------
// make_partition_table_accessor
//--------------------------------------------------------------------------------------------------
std::unique_ptr<mtxpp2::partition_table_accessor_base>
cpu_backend::make_partition_table_accessor(cbnb::curve_id_t curve_id, const void* generators,
unsigned n) const noexcept {
std::unique_ptr<mtxpp2::partition_table_accessor_base> res;
cbnb::switch_curve_type(curve_id, [&]<class T>(std::type_identity<T>) noexcept {
res = mtxpp2::make_in_memory_partition_table_accessor<T>(
basct::cspan<T>{static_cast<const T*>(generators), n}, basm::alloc_t{});
});
return res;
}

//--------------------------------------------------------------------------------------------------
// fixed_multiexponentiation
//--------------------------------------------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions sxt/cbindings/backend/cpu_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class cpu_backend final : public computational_backend {
basct::cspan<rstt::compressed_element> r_vector,
const s25t::element& ap_value) const noexcept override;

std::unique_ptr<mtxpp2::partition_table_accessor_base>
make_partition_table_accessor(cbnb::curve_id_t curve_id, const void* generators,
unsigned n) const noexcept override;

void fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id,
const mtxpp2::partition_table_accessor_base& accessor,
unsigned element_num_bytes, unsigned num_outputs, unsigned n,
Expand Down
15 changes: 15 additions & 0 deletions sxt/cbindings/backend/gpu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "sxt/memory/management/managed_array.h"
#include "sxt/multiexp/base/exponent_sequence.h"
#include "sxt/multiexp/curve/multiexponentiation.h"
#include "sxt/multiexp/pippenger2/in_memory_partition_table_accessor_utility.h"
#include "sxt/multiexp/pippenger2/multiexponentiation.h"
#include "sxt/proof/inner_product/gpu_driver.h"
#include "sxt/proof/inner_product/proof_computation.h"
Expand Down Expand Up @@ -158,6 +159,20 @@ bool gpu_backend::verify_inner_product(prft::transcript& transcript,
return fut.value();
}

//--------------------------------------------------------------------------------------------------
// make_partition_table_accessor
//--------------------------------------------------------------------------------------------------
std::unique_ptr<mtxpp2::partition_table_accessor_base>
gpu_backend::make_partition_table_accessor(cbnb::curve_id_t curve_id, const void* generators,
unsigned n) const noexcept {
std::unique_ptr<mtxpp2::partition_table_accessor_base> res;
cbnb::switch_curve_type(curve_id, [&]<class T>(std::type_identity<T>) noexcept {
res = mtxpp2::make_in_memory_partition_table_accessor<T>(
basct::cspan<T>{static_cast<const T*>(generators), n});
});
return res;
}

//--------------------------------------------------------------------------------------------------
// fixed_multiexponentiation
//--------------------------------------------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions sxt/cbindings/backend/gpu_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class gpu_backend final : public computational_backend {
basct::cspan<rstt::compressed_element> r_vector,
const s25t::element& ap_value) const noexcept override;

std::unique_ptr<mtxpp2::partition_table_accessor_base>
make_partition_table_accessor(cbnb::curve_id_t curve_id, const void* generators,
unsigned n) const noexcept override;

void fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id,
const mtxpp2::partition_table_accessor_base& accessor,
unsigned element_num_bytes, unsigned num_outputs, unsigned n,
Expand Down
1 change: 1 addition & 0 deletions sxt/multiexp/pippenger2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ sxt_cc_component(
":partition_table",
"//sxt/base/container:span",
"//sxt/base/curve:element",
"//sxt/base/memory:alloc",
"//sxt/base/num:divide_up",
"//sxt/memory/resource:pinned_resource",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "sxt/base/container/span.h"
#include "sxt/base/curve/element.h"
#include "sxt/base/memory/alloc.h"
#include "sxt/base/num/divide_up.h"
#include "sxt/memory/resource/pinned_resource.h"
#include "sxt/multiexp/pippenger2/in_memory_partition_table_accessor.h"
Expand All @@ -31,8 +32,8 @@ namespace sxt::mtxpp2 {
// make_in_memory_partition_table_accessor
//--------------------------------------------------------------------------------------------------
template <bascrv::element T>
std::unique_ptr<partition_table_accessor<T>>
make_in_memory_partition_table_accessor(basct::cspan<T> generators) noexcept {
std::unique_ptr<partition_table_accessor<T>> make_in_memory_partition_table_accessor(
basct::cspan<T> generators, basm::alloc_t alloc = memr::get_pinned_resource()) noexcept {
auto n = generators.size();
std::vector<T> generators_data;
auto num_partitions = basn::divide_up(n, size_t{16});
Expand All @@ -43,8 +44,7 @@ make_in_memory_partition_table_accessor(basct::cspan<T> generators) noexcept {
std::fill(iter, generators_data.end(), T::identity());
generators = generators_data;
}
memmg::managed_array<T> sums{partition_table_size_v * num_partitions,
memr::get_pinned_resource()};
memmg::managed_array<T> sums{partition_table_size_v * num_partitions, alloc};
compute_partition_table<T>(sums, generators);
return std::make_unique<in_memory_partition_table_accessor<T>>(std::move(sums));
}
Expand Down

0 comments on commit c7e9b40

Please sign in to comment.