Skip to content

Commit

Permalink
feat: add c bindings for multiexponentiations using partition method …
Browse files Browse the repository at this point in the history
…(PROOF-831) (#130)

* add stub for multiexponentiation c function

* fill in c multiexponentiation functions

* extend backend

* tweak api

* fill in mx

* fill in mx

* fill cbindings

* tweak api

* fill in gpu backend

* fill in test

* reformat

* document c api

* tweak comment format
  • Loading branch information
rnburn committed May 10, 2024
1 parent 92659f4 commit 8a789ed
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 5 deletions.
1 change: 1 addition & 0 deletions cbindings/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ sxt_cc_component(
"//sxt/curve21/operation:add",
"//sxt/curve21/operation:double",
"//sxt/curve21/operation:neg",
"//sxt/curve21/operation:overload",
"//sxt/curve21/type:element_p3",
"//sxt/curve21/type:literal",
],
Expand Down
32 changes: 32 additions & 0 deletions cbindings/blitzar_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,38 @@ struct sxt_multiexp_handle* sxt_multiexp_handle_new(unsigned curve_id, const voi
*/
void sxt_multiexp_handle_free(struct sxt_multiexp_handle* handle);

/**
* Compute a multiexponentiation using a handle to pre-specified generators.
*
* On completion `res` contains an array of size `num_outputs` for the multiexponentiation
* of the given `scalars` array.
*
* `scalars` specifies a contiguous multi-dimension `num_outputs` by `n` array laid out in
* column-major order. An entry in the array specifies the `element_num_bytes` bytes of a
* particular scalar.
*
* For example, if `g_1, g_2, ..., g_n` are the generators associated with `handle` and
*
* ```text
* s_11, s_12, ..., s_1n
* s_21, s_22, ..., s_2n
* ```
*
* is the scalar array (laid out in memory as `s_11, s_21, s_12, s_22, ..., s_1n, s_2n`), then `res`
* will contain the two values
*
* ```text
* res[0] = g1^s11 g2^s12 ... gn^s1n
* res[1] = g1^s21 g2^s22 ... gn^s2n
* ```
*
* Note: `res` must match the generator type of the curve. See `sxt_multiexp_handle_new` for
* the types.
*/
void sxt_fixed_multiexponentiation(void* res, const struct sxt_multiexp_handle* handle,
unsigned element_num_bytes, unsigned num_outputs, unsigned n,
const uint8_t* scalars);

#ifdef __cplusplus
} // extern "C"
#endif
12 changes: 12 additions & 0 deletions cbindings/fixed_pedersen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,15 @@ struct sxt_multiexp_handle* sxt_multiexp_handle_new(unsigned curve_id, const voi
void sxt_multiexp_handle_free(struct sxt_multiexp_handle* handle) {
delete reinterpret_cast<cbnb::multiexp_handle*>(handle);
}

//--------------------------------------------------------------------------------------------------
// sxt_fixed_multiexponentiation
//--------------------------------------------------------------------------------------------------
void sxt_fixed_multiexponentiation(void* res, const struct sxt_multiexp_handle* handle,
unsigned element_num_bytes, unsigned num_outputs, unsigned n,
const uint8_t* scalars) {
auto backend = cbn::get_backend();
auto h = reinterpret_cast<const cbnb::multiexp_handle*>(handle);
backend->fixed_multiexponentiation(res, h->curve_id, *h->partition_table_accessor,
element_num_bytes, num_outputs, n, scalars);
}
25 changes: 20 additions & 5 deletions cbindings/fixed_pedersen.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,39 @@
#include "sxt/curve21/operation/add.h"
#include "sxt/curve21/operation/double.h"
#include "sxt/curve21/operation/neg.h"
#include "sxt/curve21/operation/overload.h"
#include "sxt/curve21/type/element_p3.h"
#include "sxt/curve21/type/literal.h"

using namespace sxt;
using sxt::c21t::operator""_c21;

struct wrapped_handle {
wrapped_handle(const c21t::element_p3* generators, unsigned n) noexcept {
h = sxt_multiexp_handle_new(SXT_CURVE_RISTRETTO255, static_cast<const void*>(generators), n);
}

~wrapped_handle() noexcept { sxt_multiexp_handle_free(h); }

sxt_multiexp_handle* h;
};

TEST_CASE("we can compute multi-exponentiations with a fixed set of generators") {
std::vector<c21t::element_p3> generators = {
0x123_c21,
0x456_c21,
};

const sxt_config config = {SXT_GPU_BACKEND, 0};
REQUIRE(sxt_init(&config) == 0);

SECTION("we can create and free a handle") {
auto h =
sxt_multiexp_handle_new(SXT_CURVE_RISTRETTO255, static_cast<void*>(generators.data()), 1);
REQUIRE(h != nullptr);
sxt_multiexp_handle_free(h);
wrapped_handle h{generators.data(), 2};
REQUIRE(h.h != nullptr);

SECTION("we can compute a multiexponentiation") {
uint8_t scalars[] = {1, 0, 0, 2};
c21t::element_p3 res;
sxt_fixed_multiexponentiation(&res, h.h, 2, 1, 2, scalars);
REQUIRE(res == generators[0] + 2 * 256 * generators[1]);
}
}
3 changes: 3 additions & 0 deletions sxt/cbindings/backend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ sxt_cc_component(
"//sxt/base/error:assert",
"//sxt/proof/transcript:transcript",
"//sxt/scalar25/type:element",
"//sxt/cbindings/base:curve_id_utility",
"//sxt/curve_bng1/operation:add",
"//sxt/curve_bng1/operation:double",
"//sxt/curve_bng1/operation:neg",
Expand All @@ -50,6 +51,7 @@ sxt_cc_component(
"//sxt/ristretto/operation:compression",
"//sxt/multiexp/base:exponent_sequence",
"//sxt/multiexp/curve:multiexponentiation",
"//sxt/multiexp/pippenger2:multiexponentiation",
"//sxt/seqcommit/generator:precomputed_generators",
"//sxt/proof/inner_product:proof_descriptor",
"//sxt/proof/inner_product:proof_computation",
Expand All @@ -65,6 +67,7 @@ sxt_cc_component(
sxt_cc_component(
name = "cpu_backend",
impl_deps = [
"//sxt/base/error:panic",
"//sxt/proof/transcript:transcript",
"//sxt/scalar25/type:element",
"//sxt/curve_bng1/operation:add",
Expand Down
5 changes: 5 additions & 0 deletions sxt/cbindings/backend/computational_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,10 @@ class computational_backend {
virtual std::unique_ptr<mtxpp2::partition_table_accessor_base>
make_partition_table_accessor(cbnb::curve_id_t curve_id, const void* generators,
unsigned n) const noexcept;

virtual void fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id,
const mtxpp2::partition_table_accessor_base& accessor,
unsigned element_num_bytes, unsigned num_outputs,
unsigned n, const uint8_t* scalars) const noexcept = 0;
};
} // namespace sxt::cbnbck
17 changes: 17 additions & 0 deletions sxt/cbindings/backend/cpu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <vector>

#include "sxt/base/error/assert.h"
#include "sxt/base/error/panic.h"
#include "sxt/curve21/operation/add.h"
#include "sxt/curve21/operation/double.h"
#include "sxt/curve21/operation/neg.h"
Expand Down Expand Up @@ -118,6 +119,22 @@ bool cpu_backend::verify_inner_product(prft::transcript& transcript,
.value();
}

//--------------------------------------------------------------------------------------------------
// fixed_multiexponentiation
//--------------------------------------------------------------------------------------------------
void cpu_backend::fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id,
const mtxpp2::partition_table_accessor_base& accessor,
unsigned element_num_bytes, unsigned num_outputs,
unsigned n, const uint8_t* scalars) const noexcept {
(void)res;
(void)curve_id;
(void)accessor;
(void)num_outputs;
(void)n;
(void)scalars;
baser::panic("not implemented yet");
}

//--------------------------------------------------------------------------------------------------
// get_cpu_backend
//--------------------------------------------------------------------------------------------------
Expand Down
5 changes: 5 additions & 0 deletions sxt/cbindings/backend/cpu_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class cpu_backend final : public computational_backend {
basct::cspan<rstt::compressed_element> l_vector,
basct::cspan<rstt::compressed_element> r_vector,
const s25t::element& ap_value) const noexcept override;

void fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id,
const mtxpp2::partition_table_accessor_base& accessor,
unsigned element_num_bytes, unsigned num_outputs, unsigned n,
const uint8_t* scalars) const noexcept override;
};

//--------------------------------------------------------------------------------------------------
Expand Down
19 changes: 19 additions & 0 deletions sxt/cbindings/backend/gpu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <vector>

#include "sxt/base/error/assert.h"
#include "sxt/cbindings/base/curve_id_utility.h"
#include "sxt/curve21/operation/add.h"
#include "sxt/curve21/operation/double.h"
#include "sxt/curve21/operation/neg.h"
Expand All @@ -40,6 +41,7 @@
#include "sxt/memory/management/managed_array.h"
#include "sxt/multiexp/base/exponent_sequence.h"
#include "sxt/multiexp/curve/multiexponentiation.h"
#include "sxt/multiexp/pippenger2/multiexponentiation.h"
#include "sxt/proof/inner_product/gpu_driver.h"
#include "sxt/proof/inner_product/proof_computation.h"
#include "sxt/proof/inner_product/proof_descriptor.h"
Expand Down Expand Up @@ -156,6 +158,23 @@ bool gpu_backend::verify_inner_product(prft::transcript& transcript,
return fut.value();
}

//--------------------------------------------------------------------------------------------------
// fixed_multiexponentiation
//--------------------------------------------------------------------------------------------------
void gpu_backend::fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id,
const mtxpp2::partition_table_accessor_base& accessor,
unsigned element_num_bytes, unsigned num_outputs,
unsigned n, const uint8_t* scalars) const noexcept {
cbnb::switch_curve_type(curve_id, [&]<class T>(std::type_identity<T>) noexcept {
basct::span<T> res_span{static_cast<T*>(res), num_outputs};
basct::cspan<uint8_t> scalars_span{scalars, element_num_bytes * num_outputs * n};
auto fut = mtxpp2::multiexponentiate<T>(
res_span, static_cast<const mtxpp2::partition_table_accessor<T>&>(accessor),
element_num_bytes, scalars_span);
xens::get_scheduler().run();
});
}

//--------------------------------------------------------------------------------------------------
// get_gpu_backend
//--------------------------------------------------------------------------------------------------
Expand Down
5 changes: 5 additions & 0 deletions sxt/cbindings/backend/gpu_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ class gpu_backend final : public computational_backend {
basct::cspan<rstt::compressed_element> l_vector,
basct::cspan<rstt::compressed_element> r_vector,
const s25t::element& ap_value) const noexcept override;

void fixed_multiexponentiation(void* res, cbnb::curve_id_t curve_id,
const mtxpp2::partition_table_accessor_base& accessor,
unsigned element_num_bytes, unsigned num_outputs, unsigned n,
const uint8_t* scalars) const noexcept override;
};

//--------------------------------------------------------------------------------------------------
Expand Down

0 comments on commit 8a789ed

Please sign in to comment.