diff --git a/proposals/handler.go b/proposals/handler.go index 956751c1b7..ab50a1de43 100644 --- a/proposals/handler.go +++ b/proposals/handler.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "sync" "time" lru "github.com/hashicorp/golang-lru/v2" @@ -48,16 +49,18 @@ type Handler struct { logger log.Log cfg Config - db *sql.Database - atxsdata *atxsdata.Data - activeSets *lru.Cache[types.Hash32, uint64] - edVerifier *signing.EdVerifier - publisher pubsub.Publisher - fetcher system.Fetcher - mesh meshProvider - validator eligibilityValidator - tortoise tortoiseProvider - clock layerClock + db *sql.Database + atxsdata *atxsdata.Data + activeSets *lru.Cache[types.Hash32, uint64] + edVerifier *signing.EdVerifier + publisher pubsub.Publisher + fetcher system.Fetcher + mesh meshProvider + validator eligibilityValidator + tortoise tortoiseProvider + weightCalcLock sync.Mutex + pendingWeightCalc map[types.Hash32][]chan uint64 + clock layerClock proposals proposalsConsumer } @@ -123,18 +126,19 @@ func NewHandler( panic(err) } b := &Handler{ - logger: log.NewNop(), - cfg: defaultConfig(), - db: db, - atxsdata: atxsdata, - proposals: proposals, - activeSets: activeSets, - edVerifier: edVerifier, - publisher: p, - fetcher: f, - mesh: m, - tortoise: tortoise, - clock: clock, + logger: log.NewNop(), + cfg: defaultConfig(), + db: db, + atxsdata: atxsdata, + proposals: proposals, + activeSets: activeSets, + edVerifier: edVerifier, + publisher: p, + fetcher: f, + mesh: m, + tortoise: tortoise, + pendingWeightCalc: make(map[types.Hash32][]chan uint64), + clock: clock, } for _, opt := range opts { opt(b) @@ -519,6 +523,88 @@ func (h *Handler) checkBallotSyntacticValidity( return decoded, nil } +func (h *Handler) getActiveSetWeight(ctx context.Context, id types.Hash32) (uint64, error) { + h.weightCalcLock.Lock() + totalWeight, exists := h.activeSets.Get(id) + if exists { + h.weightCalcLock.Unlock() + return totalWeight, nil + } + + var ch chan uint64 + chs, exists := h.pendingWeightCalc[id] + if exists { + // The calculation is running or the activeset is being fetched, + // subscribe. + // Avoid any blocking on the channel by making it buffered, also so that + // we don't have to wait on it in case the context is canceled + ch = make(chan uint64, 1) + h.pendingWeightCalc[id] = append(chs, ch) + } else { + // mark calculation as running + h.pendingWeightCalc[id] = nil + } + h.weightCalcLock.Unlock() + + if exists { + // need to wait for the calculation which is already running to finish + select { + case <-ctx.Done(): + return 0, ctx.Err() + case totalWeight, ok := <-ch: + if !ok { + // Channel closed, fetch / calculation failed. + // The actual error will be logged by the initiator of the + // initial fetch / calculation, let's not make an + // impression it happened multiple times and use a simpler + // message + return totalWeight, errors.New("error getting activeset weight") + } + return totalWeight, nil + } + } + + success := false + defer func() { + h.weightCalcLock.Lock() + // this is guaranteed not to block b/c each channel is buffered + for _, ch := range h.pendingWeightCalc[id] { + if success { + ch <- totalWeight + } + close(ch) + } + delete(h.pendingWeightCalc, id) + h.weightCalcLock.Unlock() + }() + + if err := h.fetcher.GetActiveSet(ctx, id); err != nil { + return 0, err + } + set, err := activesets.Get(h.db, id) + if err != nil { + return 0, err + } + if len(set.Set) == 0 { + return 0, fmt.Errorf("%w: empty active set", pubsub.ErrValidationReject) + } + + computed, used := h.atxsdata.WeightForSet(set.Epoch, set.Set) + for i := range used { + if !used[i] { + return 0, fmt.Errorf( + "missing atx %s in active set", + set.Set[i].ShortString(), + ) + } + } + totalWeight = computed + h.activeSets.Add(id, totalWeight) + success = true // totalWeight will be sent to the subscribers + + return totalWeight, nil +} + func (h *Handler) checkBallotDataIntegrity(ctx context.Context, b *types.Ballot) (uint64, error) { //nolint:nestif if b.RefBallot == types.EmptyBallotID { @@ -534,36 +620,9 @@ func (h *Handler) checkBallotDataIntegrity(ctx context.Context, b *types.Ballot) epoch-- // download activesets in the previous epoch too } if b.Layer.GetEpoch() >= epoch { - var exists bool - totalWeight, exists := h.activeSets.Get(b.EpochData.ActiveSetHash) - if !exists { - if err := h.fetcher.GetActiveSet(ctx, b.EpochData.ActiveSetHash); err != nil { - return 0, err - } - set, err := activesets.Get(h.db, b.EpochData.ActiveSetHash) - if err != nil { - return 0, err - } - if len(set.Set) == 0 { - return 0, fmt.Errorf( - "%w: empty active set ballot %s", - pubsub.ErrValidationReject, - b.ID().String(), - ) - } - - computed, used := h.atxsdata.WeightForSet(set.Epoch, set.Set) - for i := range used { - if !used[i] { - return 0, fmt.Errorf( - "missing atx %s in active set ballot %s", - set.Set[i].ShortString(), - b.ID().String(), - ) - } - } - totalWeight = computed - h.activeSets.Add(b.EpochData.ActiveSetHash, totalWeight) + totalWeight, err := h.getActiveSetWeight(ctx, b.EpochData.ActiveSetHash) + if err != nil { + return 0, fmt.Errorf("ballot %s: %w", b.ID().String(), err) } return totalWeight, nil } diff --git a/proposals/handler_test.go b/proposals/handler_test.go index 6b9792a202..36915ead62 100644 --- a/proposals/handler_test.go +++ b/proposals/handler_test.go @@ -13,6 +13,7 @@ import ( "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "golang.org/x/sync/errgroup" "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/codec" @@ -1388,9 +1389,9 @@ func TestHandleActiveSet(t *testing.T) { } } -func gproposal(signer *signing.EdSigner, atxid types.ATXID, +func gproposal(t *testing.T, signer *signing.EdSigner, atxid types.ATXID, layer types.LayerID, edata *types.EpochData, -) types.Proposal { +) *types.Proposal { p := types.Proposal{} p.Layer = layer p.AtxID = atxid @@ -1402,39 +1403,78 @@ func gproposal(signer *signing.EdSigner, atxid types.ATXID, if edata != nil { p.SetBeacon(edata.Beacon) } - return p + require.NoError(t, p.Initialize()) + return &p } func TestHandleSyncedProposalActiveSet(t *testing.T) { + // TBD: test concurrent fetches + // TBD: test failures + // TBD: test cancellations + signer, err := signing.NewEdSigner() require.NoError(t, err) - set := types.ATXIDList{{1}, {2}} lid := types.LayerID(20) - good := gproposal(signer, types.ATXID{1}, lid, &types.EpochData{ - ActiveSetHash: set.Hash(), - Beacon: types.Beacon{1}, - }) - require.NoError(t, good.Initialize()) + sets := []types.ATXIDList{ + {{1}, {2}, {3}}, + {{1}, {2}, {4}, {5}}, + {{2}, {4}, {5}, {6}, {7}}, + {{2}, {4}, {5}, {6}, {7}, {8}, {9}}, + } th := createTestHandler(t) + p := make([]*types.Proposal, 8) + for n := range p { + p[n] = gproposal(t, signer, types.ATXID{byte(n + 1)}, lid, &types.EpochData{ + ActiveSetHash: sets[n/2].Hash(), + Beacon: types.Beacon{1}, + }) + th.mconsumer.EXPECT().IsKnown(p[n].Layer, p[n].ID()).AnyTimes() + } pid := p2p.Peer("any") th.mclock.EXPECT().CurrentLayer().Return(lid).AnyTimes() th.mm.EXPECT().ProcessedLayer().Return(lid - 2).AnyTimes() - th.mclock.EXPECT().LayerToTime(gomock.Any()) + th.mclock.EXPECT().LayerToTime(gomock.Any()).AnyTimes() th.mf.EXPECT().RegisterPeerHashes(pid, gomock.Any()).AnyTimes() - th.mf.EXPECT().GetActiveSet(gomock.Any(), set.Hash()).DoAndReturn( - func(_ context.Context, got types.Hash32) error { - require.NoError(t, activesets.Add(th.db, got, &types.EpochActiveSet{ - Epoch: lid.GetEpoch(), - Set: set, - })) - for _, id := range set { - th.atxsdata.AddAtx(lid.GetEpoch(), id, &atxsdata.ATX{Node: types.NodeID{1}}) - } - return nil - }, - ) + type asReq struct { + id types.Hash32 + err error + } + startChs := make(map[types.Hash32]chan struct{}) + asCh := make(chan asReq, 1) // buffered, no wait + for _, set := range sets { + th.mf.EXPECT().GetActiveSet(gomock.Any(), set.Hash()).DoAndReturn( + func(ctx context.Context, got types.Hash32) error { + startCh, found := startChs[got] + if found { + select { + case <-ctx.Done(): + return ctx.Err() + case startCh <- struct{}{}: + } + } + var req asReq + select { + case <-ctx.Done(): + return ctx.Err() + case req = <-asCh: + } + require.Equal(t, got, req.id) + if req.err != nil { + return req.err + } + require.NoError(t, activesets.Add(th.db, got, &types.EpochActiveSet{ + Epoch: lid.GetEpoch(), + Set: set, + })) + for _, id := range set { + th.atxsdata.AddAtx(lid.GetEpoch(), id, &atxsdata.ATX{Node: types.NodeID{1}}) + } + return nil + }, + ).AnyTimes() + } th.mf.EXPECT().GetAtxs(gomock.Any(), gomock.Any()).AnyTimes() th.mf.EXPECT().GetBallots(gomock.Any(), gomock.Any()).AnyTimes() th.mockSet.decodeAnyBallots() @@ -1442,13 +1482,115 @@ func TestHandleSyncedProposalActiveSet(t *testing.T) { th.mm.EXPECT().AddBallot(gomock.Any(), gomock.Any()).AnyTimes() th.mm.EXPECT().AddTXsFromProposal(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - th.mconsumer.EXPECT().IsKnown(good.Layer, good.ID()) - th.mconsumer.EXPECT().OnProposal(gomock.Eq(&good)) - err = th.HandleSyncedProposal(context.Background(), good.ID().AsHash32(), pid, codec.MustEncode(&good)) - require.NoError(t, err) + t.Run("non-concurrent fetching of ActiveSets", func(t *testing.T) { + asCh <- asReq{id: sets[0].Hash()} + th.mconsumer.EXPECT().OnProposal(gomock.Eq(p[0])) + err = th.HandleSyncedProposal(context.Background(), p[0].ID().AsHash32(), pid, codec.MustEncode(p[0])) + require.NoError(t, err) + + th.mconsumer.EXPECT().OnProposal(gomock.Eq(p[1])) + err = th.HandleSyncedProposal(context.Background(), p[1].ID().AsHash32(), pid, codec.MustEncode(p[1])) + require.NoError(t, err) + }) + + t.Run("concurrent fetching of ActiveSets", func(t *testing.T) { + startCh := make(chan struct{}) + startChs[sets[1].Hash()] = startCh + var eg errgroup.Group + th.mconsumer.EXPECT().OnProposal(gomock.Eq(p[2])) + eg.Go(func() error { + // blocks till we send smth on asCh + return th.HandleSyncedProposal(context.Background(), p[2].ID().AsHash32(), pid, codec.MustEncode(p[2])) + }) + <-startCh + // at this point, the fetcher for the activeset is started, but blocked + th.mconsumer.EXPECT().OnProposal(gomock.Eq(p[3])) + eg.Go(func() error { + // need to fetch the same ActiveSet + return th.HandleSyncedProposal(context.Background(), p[3].ID().AsHash32(), pid, codec.MustEncode(p[3])) + }) + + asCh <- asReq{id: sets[1].Hash()} // unblock fetching of the ActiveSet + require.NoError(t, eg.Wait()) + }) + + t.Run("ActiveSet fetch failure and refetch", func(t *testing.T) { + startCh := make(chan struct{}) + startChs[sets[2].Hash()] = startCh + var eg errgroup.Group + eg.Go(func() error { + require.Error(t, th.HandleSyncedProposal(context.Background(), + p[4].ID().AsHash32(), pid, codec.MustEncode(p[4]))) + return nil + }) + <-startCh + // at this point, the fetcher for the activeset is started, but blocked + eg.Go(func() error { + require.Error(t, th.HandleSyncedProposal(context.Background(), + p[5].ID().AsHash32(), pid, codec.MustEncode(p[5]))) + return nil + }) + + // wait till the 2nd request has subscribed + require.Eventually(t, func() bool { + th.weightCalcLock.Lock() + defer th.weightCalcLock.Unlock() + return len(th.pendingWeightCalc[sets[2].Hash()]) != 0 + }, 10*time.Second, 10*time.Millisecond) + asCh <- asReq{id: sets[2].Hash(), err: errors.New("foobar")} + require.NoError(t, eg.Wait()) + + // refetch after failure + th.mconsumer.EXPECT().OnProposal(gomock.Eq(p[5])) + var eg1 errgroup.Group + eg1.Go(func() error { + return th.HandleSyncedProposal(context.Background(), + p[5].ID().AsHash32(), pid, codec.MustEncode(p[5])) + }) + <-startCh + asCh <- asReq{id: sets[2].Hash()} + require.NoError(t, eg1.Wait()) + }) + + t.Run("ActiveSet fetch cancel and refetch", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + startCh := make(chan struct{}) + startChs[sets[3].Hash()] = startCh + var eg errgroup.Group + eg.Go(func() error { + require.Error(t, th.HandleSyncedProposal(ctx, p[6].ID().AsHash32(), pid, codec.MustEncode(p[6]))) + return nil + }) + <-startCh + // at this point, the fetcher for the activeset is started, but blocked + eg.Go(func() error { + require.Error(t, th.HandleSyncedProposal(ctx, p[7].ID().AsHash32(), pid, codec.MustEncode(p[7]))) + return nil + }) + + // wait till the 2nd request has subscribed + require.Eventually(t, func() bool { + th.weightCalcLock.Lock() + defer th.weightCalcLock.Unlock() + return len(th.pendingWeightCalc[sets[3].Hash()]) != 0 + }, 10*time.Second, 10*time.Millisecond) + cancel() + require.NoError(t, eg.Wait()) + + // refetch after cancel + th.mconsumer.EXPECT().OnProposal(gomock.Eq(p[7])) + var eg1 errgroup.Group + eg1.Go(func() error { + return th.HandleSyncedProposal(context.Background(), + p[7].ID().AsHash32(), pid, codec.MustEncode(p[7])) + }) + <-startCh + asCh <- asReq{id: sets[3].Hash()} + require.NoError(t, eg1.Wait()) + }) } -func TestHandler_SettingBallotBeacon(t *testing.T) { +func TestHandler_SettingBallotBeacon(t *testing.T) { // t.Run("non-refferential", func(t *testing.T) { t.Parallel() th := createTestHandler(t)