Skip to content

Commit

Permalink
Max-cover: optimized implementation based on Bitlist64 (#8352)
Browse files Browse the repository at this point in the history
* Max-cover: optimized implementation based on Bitlist64

* gazelle

* re-arrange overlaps check

* minor comments

* add Bitlists64WithMultipleBitSet

* update benchmarks

* gazelle

Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
  • Loading branch information
farazdagi and prylabs-bulldozer[bot] committed Jan 28, 2021
1 parent d9c451d commit 09b1e06
Show file tree
Hide file tree
Showing 5 changed files with 378 additions and 1 deletion.
6 changes: 5 additions & 1 deletion shared/aggregation/BUILD.bazel
Expand Up @@ -18,10 +18,14 @@ go_library(

go_test(
name = "go_default_test",
srcs = ["maxcover_test.go"],
srcs = [
"maxcover_bench_test.go",
"maxcover_test.go",
],
embed = [":go_default_library"],
deps = [
"//shared/aggregation/testing:go_default_library",
"//shared/params:go_default_library",
"//shared/testutil/assert:go_default_library",
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
],
Expand Down
89 changes: 89 additions & 0 deletions shared/aggregation/maxcover.go
Expand Up @@ -95,6 +95,82 @@ func (mc *MaxCoverProblem) Cover(k int, allowOverlaps bool) (*Aggregation, error
return solution, nil
}

// MaxCover finds the k-cover of Maximum Coverage problem.
func MaxCover(candidates []*bitfield.Bitlist64, k int, allowOverlaps bool) (selected, coverage *bitfield.Bitlist64, err error) {
if len(candidates) == 0 {
return nil, nil, errors.Wrap(ErrInvalidMaxCoverProblem, "cannot calculate set coverage")
}
if len(candidates) < k {
k = len(candidates)
}

// Track usable candidates, and candidates selected for coverage as two bitlists.
selectedCandidates := bitfield.NewBitlist64(uint64(len(candidates)))
usableCandidates := bitfield.NewBitlist64(uint64(len(candidates))).Not()

// Track bits covered so far as a bitlist.
coveredBits := bitfield.NewBitlist64(candidates[0].Len())
remainingBits := union(candidates)
if remainingBits == nil {
return nil, nil, errors.Wrap(ErrInvalidMaxCoverProblem, "empty bitlists")
}

attempts := 0
tmpBitlist := bitfield.NewBitlist64(candidates[0].Len()) // Used as return param for NoAlloc*() methods.
indices := make([]int, usableCandidates.Count())
for selectedCandidates.Count() < uint64(k) && usableCandidates.Count() > 0 {
// Safe-guard, each iteration should come with at least one candidate selected.
if attempts > k {
break
}
attempts += 1

// Greedy select the next best candidate (from usable ones) to cover the remaining bits maximally.
maxScore := uint64(0)
bestIdx := uint64(0)
indices = indices[0:usableCandidates.Count()]
usableCandidates.NoAllocBitIndices(indices)
for _, idx := range indices {
// Score is calculated by taking into account uncovered bits only.
score := uint64(0)
if candidates[idx].Len() == remainingBits.Len() {
score = candidates[idx].AndCount(remainingBits)
}

// Filter out zero-score candidates.
if score == 0 {
usableCandidates.SetBitAt(uint64(idx), false)
continue
}

// Filter out overlapping candidates (if overlapping is not allowed).
wrongLen := coveredBits.Len() != candidates[idx].Len()
overlaps := func(idx int) bool {
return !allowOverlaps && coveredBits.Overlaps(candidates[idx])
}
if wrongLen || overlaps(idx) {
usableCandidates.SetBitAt(uint64(idx), false)
continue
}

// Track the candidate with the best score.
if score > maxScore {
maxScore = score
bestIdx = uint64(idx)
}
}
// Process greedy selected candidate.
if maxScore > 0 {
coveredBits.NoAllocOr(candidates[bestIdx], coveredBits)
selectedCandidates.SetBitAt(bestIdx, true)
candidates[bestIdx].NoAllocNot(tmpBitlist)
remainingBits.NoAllocAnd(tmpBitlist, remainingBits)
usableCandidates.SetBitAt(bestIdx, false)
}
}
return selectedCandidates, coveredBits, nil
}

// score updates scores of candidates, taking into account the uncovered elements only.
func (cl *MaxCoverCandidates) score(uncovered bitfield.Bitlist) *MaxCoverCandidates {
for i := 0; i < len(*cl); i++ {
Expand Down Expand Up @@ -152,6 +228,19 @@ func (cl *MaxCoverCandidates) union() bitfield.Bitlist {
return ret
}

func union(candidates []*bitfield.Bitlist64) *bitfield.Bitlist64 {
if len(candidates) == 0 || candidates[0].Len() == 0 {
return nil
}
ret := bitfield.NewBitlist64(candidates[0].Len())
for _, bl := range candidates {
if ret.Len() == bl.Len() {
ret.NoAllocOr(bl, ret)
}
}
return ret
}

// String provides string representation of a candidate.
func (c *MaxCoverCandidate) String() string {
return fmt.Sprintf("{%v, %#b, s%d, %t}", c.key, c.bits, c.score, c.processed)
Expand Down
106 changes: 106 additions & 0 deletions shared/aggregation/maxcover_bench_test.go
@@ -0,0 +1,106 @@
package aggregation

import (
"fmt"
"testing"

"github.com/prysmaticlabs/go-bitfield"
aggtesting "github.com/prysmaticlabs/prysm/shared/aggregation/testing"
"github.com/prysmaticlabs/prysm/shared/params"
)

func BenchmarkMaxCoverProblem_MaxCover(b *testing.B) {
bitlistLen := params.BeaconConfig().MaxValidatorsPerCommittee
tests := []struct {
numCandidates uint64
numMarkedBits uint64
allowOverlaps bool
}{
{
numCandidates: 32,
numMarkedBits: 1,
},
{
numCandidates: 128,
numMarkedBits: 1,
},
{
numCandidates: 256,
numMarkedBits: 1,
},
{
numCandidates: 32,
numMarkedBits: 8,
},
{
numCandidates: 1024,
numMarkedBits: 8,
},
{
numCandidates: 2048,
numMarkedBits: 8,
},
{
numCandidates: 1024,
numMarkedBits: 32,
},
{
numCandidates: 2048,
numMarkedBits: 32,
},
{
numCandidates: 1024,
numMarkedBits: 128,
},
{
numCandidates: 2048,
numMarkedBits: 128,
},
{
numCandidates: 1024,
numMarkedBits: 512,
},
{
numCandidates: 2048,
numMarkedBits: 512,
},
}
for _, tt := range tests {
name := fmt.Sprintf("%d_attestations_with_%d_bit(s)_set", tt.numCandidates, tt.numMarkedBits)
b.Run(fmt.Sprintf("cur_%s", name), func(b *testing.B) {
b.StopTimer()
var bitlists []bitfield.Bitlist
if tt.numMarkedBits == 1 {
bitlists = aggtesting.BitlistsWithSingleBitSet(tt.numCandidates, bitlistLen)
} else {
bitlists = aggtesting.BitlistsWithMultipleBitSet(b, tt.numCandidates, bitlistLen, tt.numMarkedBits)

}
b.StartTimer()
for i := 0; i < b.N; i++ {
candidates := make([]*MaxCoverCandidate, len(bitlists))
for i := 0; i < len(bitlists); i++ {
candidates[i] = NewMaxCoverCandidate(i, &bitlists[i])
}
mc := &MaxCoverProblem{Candidates: candidates}
_, err := mc.Cover(len(bitlists), tt.allowOverlaps)
_ = err
}
})
b.Run(fmt.Sprintf("new_%s", name), func(b *testing.B) {
b.StopTimer()
var bitlists []*bitfield.Bitlist64
if tt.numMarkedBits == 1 {
bitlists = aggtesting.Bitlists64WithSingleBitSet(tt.numCandidates, bitlistLen)
} else {
bitlists = aggtesting.Bitlists64WithMultipleBitSet(b, tt.numCandidates, bitlistLen, tt.numMarkedBits)

}
b.StartTimer()
for i := 0; i < b.N; i++ {
_, _, err := MaxCover(bitlists, len(bitlists), tt.allowOverlaps)
_ = err
}
})
}
}
150 changes: 150 additions & 0 deletions shared/aggregation/maxcover_test.go
Expand Up @@ -491,3 +491,153 @@ func TestMaxCover_MaxCoverProblem_Cover(t *testing.T) {
})
}
}

func TestMaxCover_MaxCover(t *testing.T) {
problemSet := func() []*bitfield.Bitlist64 {
return []*bitfield.Bitlist64{
bitfield.NewBitlist64From([]uint64{0b00000100}),
bitfield.NewBitlist64From([]uint64{0b00011011}),
bitfield.NewBitlist64From([]uint64{0b00011011}),
bitfield.NewBitlist64From([]uint64{0b00000001}),
bitfield.NewBitlist64From([]uint64{0b00011010}),
}
}
type args struct {
k int
candidates []*bitfield.Bitlist64
allowOverlaps bool
}
type BitSetAggregation struct {
Coverage *bitfield.Bitlist64
Keys []int
}
tests := []struct {
name string
args args
want *BitSetAggregation
wantedErr string
}{
{
name: "nil problem",
args: args{},
wantedErr: ErrInvalidMaxCoverProblem.Error(),
},
{
name: "different bitlengths (pick first, combine with third)",
args: args{k: 3, candidates: []*bitfield.Bitlist64{
bitfield.NewBitlist64From([]uint64{0b00000001, 0b11100000, 0b10000000}),
bitfield.NewBitlist64From([]uint64{0b00000000, 0b00011111}),
bitfield.NewBitlist64From([]uint64{0b00000110, 0b00000000, 0b01000000}),
}},
want: &BitSetAggregation{
Coverage: bitfield.NewBitlist64From([]uint64{0b00000111, 0b11100000, 0b11000000}),
Keys: []int{0, 2},
},
},
{
name: "different bitlengths (pick first, no other combination)",
args: args{k: 3, candidates: []*bitfield.Bitlist64{
bitfield.NewBitlist64From([]uint64{0b00000000, 0b00011111}),
bitfield.NewBitlist64From([]uint64{0b00000001, 0b11100000, 0b1}),
bitfield.NewBitlist64From([]uint64{0b00000110, 0b00000000, 0b1}),
}},
want: &BitSetAggregation{
Coverage: bitfield.NewBitlist64From([]uint64{0b00000000, 0b00011111}),
Keys: []int{0},
},
},
{
name: "k=0",
args: args{k: 0, candidates: problemSet()},
want: &BitSetAggregation{
Coverage: bitfield.NewBitlist64From([]uint64{0b0}),
Keys: []int{},
},
},
{
name: "k=1",
args: args{k: 1, candidates: problemSet()},
want: &BitSetAggregation{
Coverage: bitfield.NewBitlist64From([]uint64{0b0011011}),
Keys: []int{1},
},
},
{
name: "k=2",
args: args{k: 2, candidates: problemSet()},
want: &BitSetAggregation{
Coverage: bitfield.NewBitlist64From([]uint64{0b0011111}),
Keys: []int{0, 1},
},
},
{
name: "k=3",
args: args{k: 3, candidates: problemSet()},
want: &BitSetAggregation{
Coverage: bitfield.NewBitlist64From([]uint64{0b0011111}),
Keys: []int{0, 1},
},
},
{
name: "k=5",
args: args{k: 5, candidates: problemSet()},
want: &BitSetAggregation{
Coverage: bitfield.NewBitlist64From([]uint64{0b0011111}),
Keys: []int{0, 1},
},
},
{
name: "k=50",
args: args{k: 50, candidates: problemSet()},
want: &BitSetAggregation{
Coverage: bitfield.NewBitlist64From([]uint64{0b0011111}),
Keys: []int{0, 1},
},
},
{
name: "suboptimal", // Greedy algorithm selects: 0, 2, 3, while 1,4,5 is optimal.
args: args{k: 3, candidates: []*bitfield.Bitlist64{
bitfield.NewBitlist64From([]uint64{0b00000000, 0b00011111}),
bitfield.NewBitlist64From([]uint64{0b00000001, 0b11100000}),
bitfield.NewBitlist64From([]uint64{0b00000110, 0b00000000}),
bitfield.NewBitlist64From([]uint64{0b00110000, 0b01110000}),
bitfield.NewBitlist64From([]uint64{0b00000110, 0b10001100}),
bitfield.NewBitlist64From([]uint64{0b01001001, 0b00000011}),
}},
want: &BitSetAggregation{
Coverage: bitfield.NewBitlist64From([]uint64{0b00000111, 0b11111111}),
Keys: []int{0, 1, 2},
},
},
{
name: "allow overlaps",
args: args{k: 5, allowOverlaps: true, candidates: []*bitfield.Bitlist64{
bitfield.NewBitlist64From([]uint64{0b00000000, 0b00000001, 0b11111110}),
bitfield.NewBitlist64From([]uint64{0b00000000, 0b00001110, 0b00001110}),
bitfield.NewBitlist64From([]uint64{0b00000000, 0b01110000, 0b01110000}),
bitfield.NewBitlist64From([]uint64{0b00000111, 0b10000001, 0b10000000}),
bitfield.NewBitlist64From([]uint64{0b00000000, 0b00000110, 0b00000110}),
bitfield.NewBitlist64From([]uint64{0b00000000, 0b00000001, 0b01100010}),
bitfield.NewBitlist64From([]uint64{0b00001000, 0b00001000, 0b10000010}),
}},
want: &BitSetAggregation{
Coverage: bitfield.NewBitlist64From([]uint64{0b00001111, 0xff, 0b11111110}),
Keys: []int{0, 1, 2, 3, 6},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
selectedCandidates, coverage, err := MaxCover(tt.args.candidates, tt.args.k, tt.args.allowOverlaps)
if tt.wantedErr != "" {
assert.ErrorContains(t, tt.wantedErr, err)
} else {
assert.NoError(t, err)
assert.DeepEqual(t, tt.want.Coverage, coverage)
selectedKeys := make([]int, selectedCandidates.Count())
selectedCandidates.NoAllocBitIndices(selectedKeys)
assert.DeepEqual(t, tt.want.Keys, selectedKeys)
}
})
}
}

0 comments on commit 09b1e06

Please sign in to comment.