Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use multiple GPUs for Pippenger's partition algorithm (PROOF-831) #127

Merged
merged 99 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
e0d60a6
add accessor
rnburn Apr 3, 2024
6ab3857
add stub for mx
rnburn Apr 3, 2024
361bc0a
add stub for partition product
rnburn Apr 4, 2024
4b68d75
add stub for partition product function
rnburn Apr 4, 2024
232c778
fill in partition product
rnburn Apr 4, 2024
b377956
fill in partition product
rnburn Apr 4, 2024
9313c1a
fill in partition product
rnburn Apr 4, 2024
02610a1
fix typo
rnburn Apr 4, 2024
1377c06
small tweak
rnburn Apr 4, 2024
5383574
fill in partition product
rnburn Apr 4, 2024
00c5f84
fill in kernel
rnburn Apr 4, 2024
46ed865
Merge branch 'main' of github.com:spaceandtimelabs/blitzar into pip-part
rnburn Apr 6, 2024
db43ccb
reformat
rnburn Apr 6, 2024
4ba6953
fill in partition product
rnburn Apr 8, 2024
99f5cb3
fill in partition product
rnburn Apr 8, 2024
2220f45
fill in partition product
rnburn Apr 8, 2024
604db05
rework partition product kernel
rnburn Apr 8, 2024
365b317
fill in partition product
rnburn Apr 9, 2024
9f2464b
fill in partition product
rnburn Apr 9, 2024
23b2f2b
fill in partition product
rnburn Apr 9, 2024
05700d7
fill in partition product
rnburn Apr 9, 2024
14c35ee
fill in partition product
rnburn Apr 9, 2024
a19ace5
fill in partition product
rnburn Apr 9, 2024
bbf2c7f
fill in partition product
rnburn Apr 9, 2024
84e7a3b
fill in partition product tests
rnburn Apr 10, 2024
9072a65
test partition product
rnburn Apr 10, 2024
54f18ef
reformat
rnburn Apr 10, 2024
9caf526
comment
rnburn Apr 10, 2024
e6d04b3
Merge branch 'main' of github.com:spaceandtimelabs/blitzar into pip-part
rnburn Apr 11, 2024
1842e61
add reduce component
rnburn Apr 11, 2024
94c6702
fill in multiexponentiation
rnburn Apr 11, 2024
1c7f96d
add stub for partition table accessor
rnburn Apr 11, 2024
99e27ad
fill in partition table accessor
rnburn Apr 11, 2024
5c354c5
make in memory accessor
rnburn Apr 11, 2024
7484d51
add accessor utility
rnburn Apr 11, 2024
4f891fa
fill in partition table accessor utility
rnburn Apr 11, 2024
a054edf
fill in testing
rnburn Apr 12, 2024
8d16865
fill in mx testing
rnburn Apr 12, 2024
6f64615
fill in multiexponentiation testing
rnburn Apr 12, 2024
0352299
fill in mx testing
rnburn Apr 12, 2024
6d3f78d
fill in tests
rnburn Apr 12, 2024
95eafed
fill in testing
rnburn Apr 12, 2024
974300c
fill in testing
rnburn Apr 12, 2024
32b9597
fill in testing
rnburn Apr 12, 2024
5576520
add stub for new benchmark
rnburn Apr 12, 2024
b00f7f6
fill in benchmark
rnburn Apr 12, 2024
c088d72
fill in benchmark
rnburn Apr 13, 2024
362d635
reformat
rnburn Apr 13, 2024
68cbecc
fill in benchmark
rnburn Apr 13, 2024
1d6bc37
fill in benchmark
rnburn Apr 13, 2024
dbd60fb
fill in benchmark
rnburn Apr 15, 2024
a4d3f41
fill in benchmark
rnburn Apr 15, 2024
70414c7
work on benchmark
rnburn Apr 15, 2024
0a7355a
fix loop
rnburn Apr 16, 2024
a1d7ecc
benchmark
rnburn Apr 16, 2024
df53182
reformat
rnburn Apr 17, 2024
2028965
Merge branch 'main' of github.com:spaceandtimelabs/blitzar into pip-part
rnburn Apr 17, 2024
feaccd4
add comments
rnburn Apr 17, 2024
451b22c
fill in testing
rnburn Apr 17, 2024
05c5001
add test
rnburn Apr 17, 2024
7ea297b
fill in benchmark
rnburn Apr 18, 2024
c159a04
fill in benchmark
rnburn Apr 18, 2024
25248aa
fill in benchmark
rnburn Apr 18, 2024
421870a
reformat
rnburn Apr 18, 2024
284e89e
fix deps
rnburn Apr 19, 2024
71cae2a
work on chunking
rnburn Apr 19, 2024
8408e11
work on chunk multiple support
rnburn Apr 19, 2024
9ac74ef
work on chunk multiple
rnburn Apr 19, 2024
b18a407
chunk multiple support
rnburn Apr 19, 2024
8e2f80a
chunking
rnburn Apr 19, 2024
dd021ae
Merge branch 'main' of github.com:spaceandtimelabs/blitzar into pip-part
rnburn Apr 19, 2024
3bb8ffc
work on chunk support
rnburn Apr 20, 2024
c8201c5
chunk support
rnburn Apr 20, 2024
36540b2
chunking
rnburn Apr 20, 2024
92a74fc
add stub for combination step
rnburn Apr 22, 2024
bcf8768
fill in combination
rnburn Apr 23, 2024
c4cff16
fill in combination
rnburn Apr 23, 2024
f6a08c3
fill in testing
rnburn Apr 23, 2024
036ef75
fill in combination
rnburn Apr 23, 2024
5c98919
add stub for partial combine
rnburn Apr 24, 2024
742d2c5
fill in combination
rnburn Apr 24, 2024
a33cf0a
merge
rnburn Apr 25, 2024
8008032
Merge branch 'main' of github.com:spaceandtimelabs/blitzar into pip-part
rnburn Apr 25, 2024
15dd27b
partial combinations
rnburn Apr 25, 2024
bf42bd0
work on partial combinations
rnburn Apr 25, 2024
5c2b0d4
fill in combine testing
rnburn Apr 25, 2024
f870a44
multi-gpu support
rnburn Apr 25, 2024
055eae3
fill in mx
rnburn Apr 26, 2024
20ff7e7
work on multiple device support
rnburn Apr 26, 2024
14a4ebd
add test case
rnburn Apr 26, 2024
b5f0674
fill in testing
rnburn Apr 26, 2024
3a32bb4
set up chunking options
rnburn Apr 29, 2024
981e814
comment
rnburn Apr 29, 2024
840a731
add logging
rnburn Apr 29, 2024
90641a3
add logging
rnburn Apr 29, 2024
b97c845
add logging
rnburn Apr 29, 2024
26950d3
Merge branch 'main' of github.com:spaceandtimelabs/blitzar into pip-part
rnburn Apr 30, 2024
d8a5002
add logging
rnburn Apr 30, 2024
245c703
minor tweaks
rnburn Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion sxt/multiexp/pippenger2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,27 @@ sxt_cc_component(
"//sxt/ristretto/random:element",
],
deps = [
":combination",
":partition_product",
":partition_table_accessor",
":reduce",
"//sxt/base/container:span",
"//sxt/base/container:span_utility",
"//sxt/base/curve:element",
"//sxt/execution/async:future",
"//sxt/base/device:memory_utility",
"//sxt/base/device:property",
"//sxt/base/device:state",
"//sxt/base/device:stream",
"//sxt/base/device:synchronization",
"//sxt/base/iterator:index_range_iterator",
"//sxt/base/iterator:index_range_utility",
"//sxt/base/log",
"//sxt/execution/async:coroutine",
"//sxt/execution/device:for_each",
"//sxt/memory/management:managed_array",
"//sxt/memory/resource:async_device_resource",
"//sxt/memory/resource:device_resource",
"//sxt/memory/resource:pinned_resource",
],
)

Expand Down
154 changes: 143 additions & 11 deletions sxt/multiexp/pippenger2/multiexponentiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,50 @@
*/
#pragma once

#include <iterator>

#include "sxt/base/container/span.h"
#include "sxt/base/container/span_utility.h"
#include "sxt/base/curve/element.h"
#include "sxt/base/device/memory_utility.h"
#include "sxt/base/device/property.h"
#include "sxt/base/device/state.h"
#include "sxt/base/device/stream.h"
#include "sxt/base/error/assert.h"
#include "sxt/execution/async/future.h"
#include "sxt/base/iterator/index_range_iterator.h"
#include "sxt/base/iterator/index_range_utility.h"
#include "sxt/base/log/log.h"
#include "sxt/execution/async/coroutine.h"
#include "sxt/execution/device/for_each.h"
#include "sxt/execution/device/synchronization.h"
#include "sxt/memory/management/managed_array.h"
#include "sxt/memory/resource/async_device_resource.h"
#include "sxt/memory/resource/device_resource.h"
#include "sxt/memory/resource/pinned_resource.h"
#include "sxt/multiexp/pippenger2/combination.h"
#include "sxt/multiexp/pippenger2/partition_product.h"
#include "sxt/multiexp/pippenger2/partition_table_accessor.h"
#include "sxt/multiexp/pippenger2/reduce.h"

namespace sxt::mtxpp2 {
//--------------------------------------------------------------------------------------------------
// multiexponentiate
// multiexponentiate_options
//--------------------------------------------------------------------------------------------------
struct multiexponentiate_options {
unsigned split_factor = 1;
unsigned min_chunk_size = 64;
unsigned max_chunk_size = 1024;
};

//--------------------------------------------------------------------------------------------------
// multiexponentiate_no_chunks
//--------------------------------------------------------------------------------------------------
/**
* Compute a multi-exponentiation using an accessor to precompute sums of partition groups.
*
* This implements the partition part of Pipenger's algorithm. See Algorithm 7 of
* https://cacr.uwaterloo.ca/techreports/2010/cacr2010-26.pdf
*/
template <bascrv::element T>
xena::future<> multiexponentiate(basct::span<T> res, const partition_table_accessor<T>& accessor,
unsigned element_num_bytes,
basct::cspan<uint8_t> scalars) noexcept {
xena::future<>
multiexponentiate_no_chunks(basct::span<T> res, const partition_table_accessor<T>& accessor,
unsigned element_num_bytes, basct::cspan<uint8_t> scalars) noexcept {
auto num_outputs = res.size();
auto n = scalars.size() / (num_outputs * element_num_bytes);
auto num_products = num_outputs * element_num_bytes * 8u;
SXT_DEBUG_ASSERT(
// clang-format off
Expand All @@ -50,10 +68,12 @@ xena::future<> multiexponentiate(basct::span<T> res, const partition_table_acces
);

// compute bitwise products
basl::info("computing {} bitwise multiexponentiation products of length {}", num_products, n);
memmg::managed_array<T> products(num_products, memr::get_device_resource());
co_await partition_product<T>(products, accessor, scalars, 0);

// reduce products
basl::info("reducing {} products to {} outputs", num_products, num_products);
basdv::stream stream;
memr::async_device_resource resource{stream};
memmg::managed_array<T> res_dev{num_outputs, &resource};
Expand All @@ -64,4 +84,116 @@ xena::future<> multiexponentiate(basct::span<T> res, const partition_table_acces
basdv::async_copy_device_to_host(res, res_dev, stream);
co_await xendv::await_stream(stream);
}

//--------------------------------------------------------------------------------------------------
// complete_multiexponentiation
//--------------------------------------------------------------------------------------------------
template <bascrv::element T>
xena::future<> complete_multiexponentiation(basct::span<T> res, unsigned element_num_bytes,
basct::cspan<T> partial_products, unsigned num_products,
unsigned offset) noexcept {
auto num_outputs_slice = res.size();
auto num_products_slice = num_outputs_slice * element_num_bytes * 8u;

// combine the partial results
memmg::managed_array<T> products_slice{num_products_slice, memr::get_device_resource()};
co_await combine_partial<T>(products_slice, partial_products, num_products, offset);

// reduce the products
basdv::stream stream;
memr::async_device_resource resource{stream};
memmg::managed_array<T> res_dev{num_outputs_slice, &resource};
reduce_products<T>(res_dev, stream, products_slice);
products_slice.reset();

// copy result
basdv::async_copy_device_to_host(res, res_dev, stream);
co_await xendv::await_stream(stream);
}

//--------------------------------------------------------------------------------------------------
// multiexponentiate_impl
//--------------------------------------------------------------------------------------------------
template <bascrv::element T>
xena::future<> multiexponentiate_impl(basct::span<T> res,
const partition_table_accessor<T>& accessor,
unsigned element_num_bytes, basct::cspan<uint8_t> scalars,
const multiexponentiate_options& options) noexcept {
auto num_outputs = res.size();
auto n = scalars.size() / (num_outputs * element_num_bytes);
auto num_products = num_outputs * element_num_bytes * 8u;
SXT_DEBUG_ASSERT(
// clang-format off
scalars.size() % (num_outputs * element_num_bytes) == 0
// clang-format on
);

// compute bitwise products
//
// We split the work by groups of generators so that a single chunk will process
// all the outputs for those generators. This minimizes the amount of host->device
// copying we need to do for the table of precomputed sums.
auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, n}
.chunk_multiple(16)
.min_chunk_size(options.min_chunk_size)
.max_chunk_size(options.max_chunk_size),
options.split_factor);
auto num_chunks = std::distance(chunk_first, chunk_last);
if (num_chunks == 1) {
multiexponentiate_no_chunks(res, accessor, element_num_bytes, scalars);
co_return;
}

memmg::managed_array<T> products{num_products * num_chunks, memr::get_pinned_resource()};
size_t chunk_index = 0;
basl::info("computing {} bitwise multiexponentiation products of length {} using {} chunks",
num_products, n, num_chunks);
co_await xendv::concurrent_for_each(
chunk_first, chunk_last, [&](const basit::index_range& rng) noexcept -> xena::future<> {
basl::info("computing {} multiproducts for generators [{}, {}] on device {}", num_products,
rng.a(), rng.b(), basdv::get_device());
memmg::managed_array<T> products_dev{num_products, memr::get_device_resource()};
auto scalars_slice = scalars.subspan(num_outputs * element_num_bytes * rng.a(),
rng.size() * num_outputs * element_num_bytes);
co_await partition_product<T>(products_dev, accessor, scalars_slice, rng.a());
basdv::stream stream;
basdv::async_copy_device_to_host(
basct::subspan(products, num_products * chunk_index, num_products), products_dev,
stream);
++chunk_index;
co_await xendv::await_stream(stream);
});

// complete the multi-exponentiation by splitting the remaining work by output
auto [output_first, output_last] =
basit::split(basit::index_range{0, num_outputs}, options.split_factor);
basl::info("reducing products for {} outputs using {} chunks", num_outputs,
std::distance(output_first, output_last));
co_await xendv::concurrent_for_each(
output_first, output_last, [&](const basit::index_range& rng) noexcept -> xena::future<> {
basl::info("reducing products for outputs [{}, {}] on device {}", rng.a(), rng.b(),
basdv::get_device());
co_await complete_multiexponentiation<T>(res.subspan(rng.a(), rng.size()),
element_num_bytes, products, num_products,
rng.a() * element_num_bytes * 8u);
});
}

//--------------------------------------------------------------------------------------------------
// multiexponentiate
//--------------------------------------------------------------------------------------------------
/**
* Compute a multi-exponentiation using an accessor to precompute sums of partition groups.
*
* This implements the partition part of Pipenger's algorithm. See Algorithm 7 of
* https://cacr.uwaterloo.ca/techreports/2010/cacr2010-26.pdf
*/
template <bascrv::element T>
xena::future<> multiexponentiate(basct::span<T> res, const partition_table_accessor<T>& accessor,
unsigned element_num_bytes,
basct::cspan<uint8_t> scalars) noexcept {
multiexponentiate_options options;
options.split_factor = static_cast<unsigned>(basdv::get_num_devices());
return multiexponentiate_impl(res, accessor, element_num_bytes, scalars, options);
}
} // namespace sxt::mtxpp2
32 changes: 32 additions & 0 deletions sxt/multiexp/pippenger2/multiexponentiation.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,38 @@ TEST_CASE("we can compute multiexponentiations using a precomputed table of part
REQUIRE(res[0] == generators[0].value);
REQUIRE(res[1] == 2u * generators[0].value);
}

SECTION("we can split a multi-exponentiation") {
multiexponentiate_options options{
.split_factor = 2,
.min_chunk_size = 16u,
};
scalars.resize(32);
scalars[0] = 1;
scalars[16] = 1;
auto fut = multiexponentiate_impl<E>(res, *accessor, 1, scalars, options);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == generators[0].value + generators[16].value);
}

SECTION("we can split a multi-exponentiation with more than one output") {
multiexponentiate_options options{
.split_factor = 2,
.min_chunk_size = 16u,
};
scalars.resize(64);
scalars[0] = 1;
scalars[1] = 2;
scalars[32] = 3;
scalars[33] = 4;
res.resize(2);
auto fut = multiexponentiate_impl<E>(res, *accessor, 1, scalars, options);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == generators[0].value + 3 * generators[16].value);
REQUIRE(res[1] == 2 * generators[0].value + 4 * generators[16].value);
}
}

TEST_CASE("we can compute multiexponentiations with curve-21") {
Expand Down
Loading