Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Naive aggregation for
SyncCommitteeContribution
(#9114)
* Copy CopySyncCommitteeContribution * Update BUILD.bazel * Add naive aggregation for sync contribution * Gazelle * Update deps.bzl Co-authored-by: Nishant Das <nishdas93@gmail.com>
- Loading branch information
1 parent
0347b4b
commit cd3a2e8
Showing
11 changed files
with
383 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
load("@io_bazel_rules_go//go:def.bzl", "go_test") | ||
load("@prysm//tools/go:def.bzl", "go_library") | ||
|
||
go_library( | ||
name = "go_default_library", | ||
srcs = [ | ||
"contribution.go", | ||
"naive.go", | ||
], | ||
importpath = "github.com/prysmaticlabs/prysm/shared/aggregation/sync_contribution", | ||
visibility = ["//visibility:public"], | ||
deps = [ | ||
"//proto/prysm/v2:go_default_library", | ||
"//shared/aggregation:go_default_library", | ||
"//shared/bls:go_default_library", | ||
"//shared/copyutil:go_default_library", | ||
"@com_github_pkg_errors//:go_default_library", | ||
"@com_github_sirupsen_logrus//:go_default_library", | ||
], | ||
) | ||
|
||
go_test( | ||
name = "go_default_test", | ||
srcs = ["naive_test.go"], | ||
embed = [":go_default_library"], | ||
deps = [ | ||
"//proto/prysm/v2:go_default_library", | ||
"//shared/aggregation:go_default_library", | ||
"//shared/aggregation/testing:go_default_library", | ||
"//shared/bls:go_default_library", | ||
"//shared/featureconfig:go_default_library", | ||
"//shared/testutil/assert:go_default_library", | ||
"//shared/testutil/require:go_default_library", | ||
"@com_github_prysmaticlabs_go_bitfield//:go_default_library", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
package sync_contribution | ||
|
||
import ( | ||
"github.com/pkg/errors" | ||
v2 "github.com/prysmaticlabs/prysm/proto/prysm/v2" | ||
"github.com/prysmaticlabs/prysm/shared/aggregation" | ||
"github.com/sirupsen/logrus" | ||
) | ||
|
||
const ( | ||
// NaiveAggregation is an aggregation strategy without any optimizations. | ||
NaiveAggregation SyncContributionAggregationStrategy = "naive" | ||
|
||
// MaxCoverAggregation is a strategy based on Maximum Coverage greedy algorithm. | ||
MaxCoverAggregation SyncContributionAggregationStrategy = "max_cover" | ||
) | ||
|
||
// SyncContributionAggregationStrategy defines SyncContribution aggregation strategy. | ||
type SyncContributionAggregationStrategy string | ||
|
||
var _ = logrus.WithField("prefix", "aggregation.sync_contribution") | ||
|
||
// ErrInvalidSyncContributionCount is returned when insufficient number | ||
// of sync contributions is provided for aggregation. | ||
var ErrInvalidSyncContributionCount = errors.New("invalid number of sync contributions") | ||
|
||
// Aggregate aggregates sync contributions. The minimal number of sync contributions is returned. | ||
// Aggregation occurs in-place i.e. contents of input array will be modified. Should you need to | ||
// preserve input sync contributions, clone them before aggregating. | ||
func Aggregate(cs []*v2.SyncCommitteeContribution) ([]*v2.SyncCommitteeContribution, error) { | ||
strategy := NaiveAggregation | ||
switch strategy { | ||
case "", NaiveAggregation: | ||
return naiveSyncContributionAggregation(cs) | ||
case MaxCoverAggregation: | ||
// TODO: Implement max cover aggregation for sync contributions. | ||
return nil, errors.New("no implemented") | ||
default: | ||
return nil, errors.Wrapf(aggregation.ErrInvalidStrategy, "%q", strategy) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
package sync_contribution | ||
|
||
import ( | ||
v2 "github.com/prysmaticlabs/prysm/proto/prysm/v2" | ||
"github.com/prysmaticlabs/prysm/shared/aggregation" | ||
"github.com/prysmaticlabs/prysm/shared/bls" | ||
"github.com/prysmaticlabs/prysm/shared/copyutil" | ||
) | ||
|
||
// naiveSyncContributionAggregation aggregates naively, without any complex algorithms or optimizations. | ||
// Note: this is currently a naive implementation to the order of O(mn^2). | ||
func naiveSyncContributionAggregation(contributions []*v2.SyncCommitteeContribution) ([]*v2.SyncCommitteeContribution, error) { | ||
if len(contributions) <= 1 { | ||
return contributions, nil | ||
} | ||
|
||
// Naive aggregation. O(n^2) time. | ||
for i, a := range contributions { | ||
if i >= len(contributions) { | ||
break | ||
} | ||
for j := i + 1; j < len(contributions); j++ { | ||
b := contributions[j] | ||
if a.AggregationBits.Len() == b.AggregationBits.Len() && !a.AggregationBits.Overlaps(b.AggregationBits) { | ||
var err error | ||
a, err = aggregate(a, b) | ||
if err != nil { | ||
return nil, err | ||
} | ||
// Delete b | ||
contributions = append(contributions[:j], contributions[j+1:]...) | ||
j-- | ||
contributions[i] = a | ||
} | ||
} | ||
} | ||
|
||
// Naive deduplication of identical contributions. O(n^2) time. | ||
for i, a := range contributions { | ||
for j := i + 1; j < len(contributions); j++ { | ||
b := contributions[j] | ||
|
||
if a.AggregationBits.Len() != b.AggregationBits.Len() { | ||
continue | ||
} | ||
|
||
if a.AggregationBits.Contains(b.AggregationBits) { | ||
// If b is fully contained in a, then b can be removed. | ||
contributions = append(contributions[:j], contributions[j+1:]...) | ||
j-- | ||
} else if b.AggregationBits.Contains(a.AggregationBits) { | ||
// if a is fully contained in b, then a can be removed. | ||
contributions = append(contributions[:i], contributions[i+1:]...) | ||
break // Stop the inner loop, advance a. | ||
} | ||
} | ||
} | ||
|
||
return contributions, nil | ||
} | ||
|
||
// aggregates pair of sync contributions c1 and c2 together. | ||
func aggregate(c1, c2 *v2.SyncCommitteeContribution) (*v2.SyncCommitteeContribution, error) { | ||
if c1.AggregationBits.Overlaps(c2.AggregationBits) { | ||
return nil, aggregation.ErrBitsOverlap | ||
} | ||
|
||
baseContribution := copyutil.CopySyncCommitteeContribution(c1) | ||
newContribution := copyutil.CopySyncCommitteeContribution(c2) | ||
if newContribution.AggregationBits.Count() > baseContribution.AggregationBits.Count() { | ||
baseContribution, newContribution = newContribution, baseContribution | ||
} | ||
|
||
if baseContribution.AggregationBits.Contains(newContribution.AggregationBits) { | ||
return baseContribution, nil | ||
} | ||
|
||
newBits := baseContribution.AggregationBits.Or(newContribution.AggregationBits) | ||
newSig, err := bls.SignatureFromBytes(newContribution.Signature) | ||
if err != nil { | ||
return nil, err | ||
} | ||
baseSig, err := bls.SignatureFromBytes(baseContribution.Signature) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
aggregatedSig := bls.AggregateSignatures([]bls.Signature{baseSig, newSig}) | ||
baseContribution.Signature = aggregatedSig.Marshal() | ||
baseContribution.AggregationBits = newBits | ||
|
||
return baseContribution, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
package sync_contribution | ||
|
||
import ( | ||
"fmt" | ||
"sort" | ||
"testing" | ||
|
||
"github.com/prysmaticlabs/go-bitfield" | ||
prysmv2 "github.com/prysmaticlabs/prysm/proto/prysm/v2" | ||
"github.com/prysmaticlabs/prysm/shared/aggregation" | ||
aggtesting "github.com/prysmaticlabs/prysm/shared/aggregation/testing" | ||
"github.com/prysmaticlabs/prysm/shared/bls" | ||
"github.com/prysmaticlabs/prysm/shared/featureconfig" | ||
"github.com/prysmaticlabs/prysm/shared/testutil/assert" | ||
"github.com/prysmaticlabs/prysm/shared/testutil/require" | ||
) | ||
|
||
func TestAggregateAttestations_aggregate(t *testing.T) { | ||
tests := []struct { | ||
a1 *prysmv2.SyncCommitteeContribution | ||
a2 *prysmv2.SyncCommitteeContribution | ||
want *prysmv2.SyncCommitteeContribution | ||
}{ | ||
{ | ||
a1: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x02}, Signature: bls.NewAggregateSignature().Marshal()}, | ||
a2: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x01}, Signature: bls.NewAggregateSignature().Marshal()}, | ||
want: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x03}}, | ||
}, | ||
{ | ||
a1: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x01}, Signature: bls.NewAggregateSignature().Marshal()}, | ||
a2: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x02}, Signature: bls.NewAggregateSignature().Marshal()}, | ||
want: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x03}}, | ||
}, | ||
} | ||
for _, tt := range tests { | ||
got, err := aggregate(tt.a1, tt.a2) | ||
require.NoError(t, err) | ||
require.DeepSSZEqual(t, tt.want.AggregationBits, got.AggregationBits) | ||
} | ||
} | ||
|
||
func TestAggregateAttestations_aggregate_OverlapFails(t *testing.T) { | ||
tests := []struct { | ||
a1 *prysmv2.SyncCommitteeContribution | ||
a2 *prysmv2.SyncCommitteeContribution | ||
}{ | ||
{ | ||
a1: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x1F}}, | ||
a2: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x11}}, | ||
}, | ||
{ | ||
a1: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0xFF, 0x85}}, | ||
a2: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x13, 0x8F}}, | ||
}, | ||
} | ||
for _, tt := range tests { | ||
_, err := aggregate(tt.a1, tt.a2) | ||
require.ErrorContains(t, aggregation.ErrBitsOverlap.Error(), err) | ||
} | ||
} | ||
|
||
func TestAggregateAttestations_Aggregate(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
inputs []bitfield.Bitvector128 | ||
want []bitfield.Bitvector128 | ||
}{ | ||
{ | ||
name: "empty list", | ||
inputs: []bitfield.Bitvector128{}, | ||
want: []bitfield.Bitvector128{}, | ||
}, | ||
{ | ||
name: "single attestation", | ||
inputs: []bitfield.Bitvector128{ | ||
{0b00000010}, | ||
}, | ||
want: []bitfield.Bitvector128{ | ||
{0b00000010}, | ||
}, | ||
}, | ||
{ | ||
name: "two attestations with no overlap", | ||
inputs: []bitfield.Bitvector128{ | ||
{0b00000001}, | ||
{0b00000010}, | ||
}, | ||
want: []bitfield.Bitvector128{ | ||
{0b00000011}, | ||
}, | ||
}, | ||
{ | ||
name: "two attestations with overlap", | ||
inputs: []bitfield.Bitvector128{ | ||
{0b00000101}, | ||
{0b00000110}, | ||
}, | ||
want: []bitfield.Bitvector128{ | ||
{0b00000101}, | ||
{0b00000110}, | ||
}, | ||
}, | ||
{ | ||
name: "some attestations overlap", | ||
inputs: []bitfield.Bitvector128{ | ||
{0b00001001}, | ||
{0b00010110}, | ||
{0b00001010}, | ||
{0b00110001}, | ||
}, | ||
want: []bitfield.Bitvector128{ | ||
{0b00111011}, | ||
{0b00011111}, | ||
}, | ||
}, | ||
{ | ||
name: "some attestations produce duplicates which are removed", | ||
inputs: []bitfield.Bitvector128{ | ||
{0b00000101}, | ||
{0b00000110}, | ||
{0b00001010}, | ||
{0b00001001}, | ||
}, | ||
want: []bitfield.Bitvector128{ | ||
{0b00001111}, // both 0&1 and 2&3 produce this bitlist | ||
}, | ||
}, | ||
{ | ||
name: "two attestations where one is fully contained within the other", | ||
inputs: []bitfield.Bitvector128{ | ||
{0b00000001}, | ||
{0b00000011}, | ||
}, | ||
want: []bitfield.Bitvector128{ | ||
{0b00000011}, | ||
}, | ||
}, | ||
{ | ||
name: "two attestations where one is fully contained within the other reversed", | ||
inputs: []bitfield.Bitvector128{ | ||
{0b00000011}, | ||
{0b00000001}, | ||
}, | ||
want: []bitfield.Bitvector128{ | ||
{0b00000011}, | ||
}, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
runner := func() { | ||
got, err := Aggregate(aggtesting.MakeSyncContributionsFromBitVector(tt.inputs)) | ||
require.NoError(t, err) | ||
sort.Slice(got, func(i, j int) bool { | ||
return got[i].AggregationBits.Bytes()[0] < got[j].AggregationBits.Bytes()[0] | ||
}) | ||
sort.Slice(tt.want, func(i, j int) bool { | ||
return tt.want[i].Bytes()[0] < tt.want[j].Bytes()[0] | ||
}) | ||
assert.Equal(t, len(tt.want), len(got)) | ||
for i, w := range tt.want { | ||
assert.DeepEqual(t, w.Bytes(), got[i].AggregationBits.Bytes()) | ||
} | ||
} | ||
t.Run(fmt.Sprintf("%s/%s", tt.name, NaiveAggregation), func(t *testing.T) { | ||
resetCfg := featureconfig.InitWithReset(&featureconfig.Flags{ | ||
AttestationAggregationStrategy: string(NaiveAggregation), | ||
}) | ||
defer resetCfg() | ||
runner() | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.