From 09b1e06885b93e00bcfa74d3a62a578a6c0979c7 Mon Sep 17 00:00:00 2001 From: Victor Farazdagi Date: Thu, 28 Jan 2021 07:55:02 -0800 Subject: [PATCH] Max-cover: optimized implementation based on Bitlist64 (#8352) * Max-cover: optimized implementation based on Bitlist64 * gazelle * re-arrange overlaps check * minor comments * add Bitlists64WithMultipleBitSet * update benchmarks * gazelle Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com> --- shared/aggregation/BUILD.bazel | 6 +- shared/aggregation/maxcover.go | 89 ++++++++++++ shared/aggregation/maxcover_bench_test.go | 106 +++++++++++++++ shared/aggregation/maxcover_test.go | 150 +++++++++++++++++++++ shared/aggregation/testing/bitlistutils.go | 28 ++++ 5 files changed, 378 insertions(+), 1 deletion(-) create mode 100644 shared/aggregation/maxcover_bench_test.go diff --git a/shared/aggregation/BUILD.bazel b/shared/aggregation/BUILD.bazel index 4dbe290a63f..9a4ab2e425d 100644 --- a/shared/aggregation/BUILD.bazel +++ b/shared/aggregation/BUILD.bazel @@ -18,10 +18,14 @@ go_library( go_test( name = "go_default_test", - srcs = ["maxcover_test.go"], + srcs = [ + "maxcover_bench_test.go", + "maxcover_test.go", + ], embed = [":go_default_library"], deps = [ "//shared/aggregation/testing:go_default_library", + "//shared/params:go_default_library", "//shared/testutil/assert:go_default_library", "@com_github_prysmaticlabs_go_bitfield//:go_default_library", ], diff --git a/shared/aggregation/maxcover.go b/shared/aggregation/maxcover.go index 6ea9cae4aa7..2b0a92e8927 100644 --- a/shared/aggregation/maxcover.go +++ b/shared/aggregation/maxcover.go @@ -95,6 +95,82 @@ func (mc *MaxCoverProblem) Cover(k int, allowOverlaps bool) (*Aggregation, error return solution, nil } +// MaxCover finds the k-cover of Maximum Coverage problem. +func MaxCover(candidates []*bitfield.Bitlist64, k int, allowOverlaps bool) (selected, coverage *bitfield.Bitlist64, err error) { + if len(candidates) == 0 { + return nil, nil, errors.Wrap(ErrInvalidMaxCoverProblem, "cannot calculate set coverage") + } + if len(candidates) < k { + k = len(candidates) + } + + // Track usable candidates, and candidates selected for coverage as two bitlists. + selectedCandidates := bitfield.NewBitlist64(uint64(len(candidates))) + usableCandidates := bitfield.NewBitlist64(uint64(len(candidates))).Not() + + // Track bits covered so far as a bitlist. + coveredBits := bitfield.NewBitlist64(candidates[0].Len()) + remainingBits := union(candidates) + if remainingBits == nil { + return nil, nil, errors.Wrap(ErrInvalidMaxCoverProblem, "empty bitlists") + } + + attempts := 0 + tmpBitlist := bitfield.NewBitlist64(candidates[0].Len()) // Used as return param for NoAlloc*() methods. + indices := make([]int, usableCandidates.Count()) + for selectedCandidates.Count() < uint64(k) && usableCandidates.Count() > 0 { + // Safe-guard, each iteration should come with at least one candidate selected. + if attempts > k { + break + } + attempts += 1 + + // Greedy select the next best candidate (from usable ones) to cover the remaining bits maximally. + maxScore := uint64(0) + bestIdx := uint64(0) + indices = indices[0:usableCandidates.Count()] + usableCandidates.NoAllocBitIndices(indices) + for _, idx := range indices { + // Score is calculated by taking into account uncovered bits only. + score := uint64(0) + if candidates[idx].Len() == remainingBits.Len() { + score = candidates[idx].AndCount(remainingBits) + } + + // Filter out zero-score candidates. + if score == 0 { + usableCandidates.SetBitAt(uint64(idx), false) + continue + } + + // Filter out overlapping candidates (if overlapping is not allowed). + wrongLen := coveredBits.Len() != candidates[idx].Len() + overlaps := func(idx int) bool { + return !allowOverlaps && coveredBits.Overlaps(candidates[idx]) + } + if wrongLen || overlaps(idx) { + usableCandidates.SetBitAt(uint64(idx), false) + continue + } + + // Track the candidate with the best score. + if score > maxScore { + maxScore = score + bestIdx = uint64(idx) + } + } + // Process greedy selected candidate. + if maxScore > 0 { + coveredBits.NoAllocOr(candidates[bestIdx], coveredBits) + selectedCandidates.SetBitAt(bestIdx, true) + candidates[bestIdx].NoAllocNot(tmpBitlist) + remainingBits.NoAllocAnd(tmpBitlist, remainingBits) + usableCandidates.SetBitAt(bestIdx, false) + } + } + return selectedCandidates, coveredBits, nil +} + // score updates scores of candidates, taking into account the uncovered elements only. func (cl *MaxCoverCandidates) score(uncovered bitfield.Bitlist) *MaxCoverCandidates { for i := 0; i < len(*cl); i++ { @@ -152,6 +228,19 @@ func (cl *MaxCoverCandidates) union() bitfield.Bitlist { return ret } +func union(candidates []*bitfield.Bitlist64) *bitfield.Bitlist64 { + if len(candidates) == 0 || candidates[0].Len() == 0 { + return nil + } + ret := bitfield.NewBitlist64(candidates[0].Len()) + for _, bl := range candidates { + if ret.Len() == bl.Len() { + ret.NoAllocOr(bl, ret) + } + } + return ret +} + // String provides string representation of a candidate. func (c *MaxCoverCandidate) String() string { return fmt.Sprintf("{%v, %#b, s%d, %t}", c.key, c.bits, c.score, c.processed) diff --git a/shared/aggregation/maxcover_bench_test.go b/shared/aggregation/maxcover_bench_test.go new file mode 100644 index 00000000000..229f94c16be --- /dev/null +++ b/shared/aggregation/maxcover_bench_test.go @@ -0,0 +1,106 @@ +package aggregation + +import ( + "fmt" + "testing" + + "github.com/prysmaticlabs/go-bitfield" + aggtesting "github.com/prysmaticlabs/prysm/shared/aggregation/testing" + "github.com/prysmaticlabs/prysm/shared/params" +) + +func BenchmarkMaxCoverProblem_MaxCover(b *testing.B) { + bitlistLen := params.BeaconConfig().MaxValidatorsPerCommittee + tests := []struct { + numCandidates uint64 + numMarkedBits uint64 + allowOverlaps bool + }{ + { + numCandidates: 32, + numMarkedBits: 1, + }, + { + numCandidates: 128, + numMarkedBits: 1, + }, + { + numCandidates: 256, + numMarkedBits: 1, + }, + { + numCandidates: 32, + numMarkedBits: 8, + }, + { + numCandidates: 1024, + numMarkedBits: 8, + }, + { + numCandidates: 2048, + numMarkedBits: 8, + }, + { + numCandidates: 1024, + numMarkedBits: 32, + }, + { + numCandidates: 2048, + numMarkedBits: 32, + }, + { + numCandidates: 1024, + numMarkedBits: 128, + }, + { + numCandidates: 2048, + numMarkedBits: 128, + }, + { + numCandidates: 1024, + numMarkedBits: 512, + }, + { + numCandidates: 2048, + numMarkedBits: 512, + }, + } + for _, tt := range tests { + name := fmt.Sprintf("%d_attestations_with_%d_bit(s)_set", tt.numCandidates, tt.numMarkedBits) + b.Run(fmt.Sprintf("cur_%s", name), func(b *testing.B) { + b.StopTimer() + var bitlists []bitfield.Bitlist + if tt.numMarkedBits == 1 { + bitlists = aggtesting.BitlistsWithSingleBitSet(tt.numCandidates, bitlistLen) + } else { + bitlists = aggtesting.BitlistsWithMultipleBitSet(b, tt.numCandidates, bitlistLen, tt.numMarkedBits) + + } + b.StartTimer() + for i := 0; i < b.N; i++ { + candidates := make([]*MaxCoverCandidate, len(bitlists)) + for i := 0; i < len(bitlists); i++ { + candidates[i] = NewMaxCoverCandidate(i, &bitlists[i]) + } + mc := &MaxCoverProblem{Candidates: candidates} + _, err := mc.Cover(len(bitlists), tt.allowOverlaps) + _ = err + } + }) + b.Run(fmt.Sprintf("new_%s", name), func(b *testing.B) { + b.StopTimer() + var bitlists []*bitfield.Bitlist64 + if tt.numMarkedBits == 1 { + bitlists = aggtesting.Bitlists64WithSingleBitSet(tt.numCandidates, bitlistLen) + } else { + bitlists = aggtesting.Bitlists64WithMultipleBitSet(b, tt.numCandidates, bitlistLen, tt.numMarkedBits) + + } + b.StartTimer() + for i := 0; i < b.N; i++ { + _, _, err := MaxCover(bitlists, len(bitlists), tt.allowOverlaps) + _ = err + } + }) + } +} diff --git a/shared/aggregation/maxcover_test.go b/shared/aggregation/maxcover_test.go index d13c19f341c..4f2e0d5b41b 100644 --- a/shared/aggregation/maxcover_test.go +++ b/shared/aggregation/maxcover_test.go @@ -491,3 +491,153 @@ func TestMaxCover_MaxCoverProblem_Cover(t *testing.T) { }) } } + +func TestMaxCover_MaxCover(t *testing.T) { + problemSet := func() []*bitfield.Bitlist64 { + return []*bitfield.Bitlist64{ + bitfield.NewBitlist64From([]uint64{0b00000100}), + bitfield.NewBitlist64From([]uint64{0b00011011}), + bitfield.NewBitlist64From([]uint64{0b00011011}), + bitfield.NewBitlist64From([]uint64{0b00000001}), + bitfield.NewBitlist64From([]uint64{0b00011010}), + } + } + type args struct { + k int + candidates []*bitfield.Bitlist64 + allowOverlaps bool + } + type BitSetAggregation struct { + Coverage *bitfield.Bitlist64 + Keys []int + } + tests := []struct { + name string + args args + want *BitSetAggregation + wantedErr string + }{ + { + name: "nil problem", + args: args{}, + wantedErr: ErrInvalidMaxCoverProblem.Error(), + }, + { + name: "different bitlengths (pick first, combine with third)", + args: args{k: 3, candidates: []*bitfield.Bitlist64{ + bitfield.NewBitlist64From([]uint64{0b00000001, 0b11100000, 0b10000000}), + bitfield.NewBitlist64From([]uint64{0b00000000, 0b00011111}), + bitfield.NewBitlist64From([]uint64{0b00000110, 0b00000000, 0b01000000}), + }}, + want: &BitSetAggregation{ + Coverage: bitfield.NewBitlist64From([]uint64{0b00000111, 0b11100000, 0b11000000}), + Keys: []int{0, 2}, + }, + }, + { + name: "different bitlengths (pick first, no other combination)", + args: args{k: 3, candidates: []*bitfield.Bitlist64{ + bitfield.NewBitlist64From([]uint64{0b00000000, 0b00011111}), + bitfield.NewBitlist64From([]uint64{0b00000001, 0b11100000, 0b1}), + bitfield.NewBitlist64From([]uint64{0b00000110, 0b00000000, 0b1}), + }}, + want: &BitSetAggregation{ + Coverage: bitfield.NewBitlist64From([]uint64{0b00000000, 0b00011111}), + Keys: []int{0}, + }, + }, + { + name: "k=0", + args: args{k: 0, candidates: problemSet()}, + want: &BitSetAggregation{ + Coverage: bitfield.NewBitlist64From([]uint64{0b0}), + Keys: []int{}, + }, + }, + { + name: "k=1", + args: args{k: 1, candidates: problemSet()}, + want: &BitSetAggregation{ + Coverage: bitfield.NewBitlist64From([]uint64{0b0011011}), + Keys: []int{1}, + }, + }, + { + name: "k=2", + args: args{k: 2, candidates: problemSet()}, + want: &BitSetAggregation{ + Coverage: bitfield.NewBitlist64From([]uint64{0b0011111}), + Keys: []int{0, 1}, + }, + }, + { + name: "k=3", + args: args{k: 3, candidates: problemSet()}, + want: &BitSetAggregation{ + Coverage: bitfield.NewBitlist64From([]uint64{0b0011111}), + Keys: []int{0, 1}, + }, + }, + { + name: "k=5", + args: args{k: 5, candidates: problemSet()}, + want: &BitSetAggregation{ + Coverage: bitfield.NewBitlist64From([]uint64{0b0011111}), + Keys: []int{0, 1}, + }, + }, + { + name: "k=50", + args: args{k: 50, candidates: problemSet()}, + want: &BitSetAggregation{ + Coverage: bitfield.NewBitlist64From([]uint64{0b0011111}), + Keys: []int{0, 1}, + }, + }, + { + name: "suboptimal", // Greedy algorithm selects: 0, 2, 3, while 1,4,5 is optimal. + args: args{k: 3, candidates: []*bitfield.Bitlist64{ + bitfield.NewBitlist64From([]uint64{0b00000000, 0b00011111}), + bitfield.NewBitlist64From([]uint64{0b00000001, 0b11100000}), + bitfield.NewBitlist64From([]uint64{0b00000110, 0b00000000}), + bitfield.NewBitlist64From([]uint64{0b00110000, 0b01110000}), + bitfield.NewBitlist64From([]uint64{0b00000110, 0b10001100}), + bitfield.NewBitlist64From([]uint64{0b01001001, 0b00000011}), + }}, + want: &BitSetAggregation{ + Coverage: bitfield.NewBitlist64From([]uint64{0b00000111, 0b11111111}), + Keys: []int{0, 1, 2}, + }, + }, + { + name: "allow overlaps", + args: args{k: 5, allowOverlaps: true, candidates: []*bitfield.Bitlist64{ + bitfield.NewBitlist64From([]uint64{0b00000000, 0b00000001, 0b11111110}), + bitfield.NewBitlist64From([]uint64{0b00000000, 0b00001110, 0b00001110}), + bitfield.NewBitlist64From([]uint64{0b00000000, 0b01110000, 0b01110000}), + bitfield.NewBitlist64From([]uint64{0b00000111, 0b10000001, 0b10000000}), + bitfield.NewBitlist64From([]uint64{0b00000000, 0b00000110, 0b00000110}), + bitfield.NewBitlist64From([]uint64{0b00000000, 0b00000001, 0b01100010}), + bitfield.NewBitlist64From([]uint64{0b00001000, 0b00001000, 0b10000010}), + }}, + want: &BitSetAggregation{ + Coverage: bitfield.NewBitlist64From([]uint64{0b00001111, 0xff, 0b11111110}), + Keys: []int{0, 1, 2, 3, 6}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selectedCandidates, coverage, err := MaxCover(tt.args.candidates, tt.args.k, tt.args.allowOverlaps) + if tt.wantedErr != "" { + assert.ErrorContains(t, tt.wantedErr, err) + } else { + assert.NoError(t, err) + assert.DeepEqual(t, tt.want.Coverage, coverage) + selectedKeys := make([]int, selectedCandidates.Count()) + selectedCandidates.NoAllocBitIndices(selectedKeys) + assert.DeepEqual(t, tt.want.Keys, selectedKeys) + } + }) + } +} diff --git a/shared/aggregation/testing/bitlistutils.go b/shared/aggregation/testing/bitlistutils.go index fc8c2b79740..ac49ab92ce4 100644 --- a/shared/aggregation/testing/bitlistutils.go +++ b/shared/aggregation/testing/bitlistutils.go @@ -30,6 +30,17 @@ func BitlistsWithSingleBitSet(n, length uint64) []bitfield.Bitlist { return lists } +// Bitlists64WithSingleBitSet creates list of bitlists with a single bit set in each. +func Bitlists64WithSingleBitSet(n, length uint64) []*bitfield.Bitlist64 { + lists := make([]*bitfield.Bitlist64, n) + for i := uint64(0); i < n; i++ { + b := bitfield.NewBitlist64(length) + b.SetBitAt(i%length, true) + lists[i] = b + } + return lists +} + // BitlistsWithMultipleBitSet creates list of bitlists with random n bits set. func BitlistsWithMultipleBitSet(t testing.TB, n, length, count uint64) []bitfield.Bitlist { seed := timeutils.Now().UnixNano() @@ -47,6 +58,23 @@ func BitlistsWithMultipleBitSet(t testing.TB, n, length, count uint64) []bitfiel return lists } +// Bitlists64WithMultipleBitSet creates list of bitlists with random n bits set. +func Bitlists64WithMultipleBitSet(t testing.TB, n, length, count uint64) []*bitfield.Bitlist64 { + seed := timeutils.Now().UnixNano() + t.Logf("Bitlists64WithMultipleBitSet random seed: %v", seed) + rand.Seed(seed) + lists := make([]*bitfield.Bitlist64, n) + for i := uint64(0); i < n; i++ { + b := bitfield.NewBitlist64(length) + keys := rand.Perm(int(length)) + for _, key := range keys[:count] { + b.SetBitAt(uint64(key), true) + } + lists[i] = b + } + return lists +} + // MakeAttestationsFromBitlists creates list of bitlists from list of attestations. func MakeAttestationsFromBitlists(bl []bitfield.Bitlist) []*ethpb.Attestation { atts := make([]*ethpb.Attestation, len(bl))