Skip to content

Commit

Permalink
Beacon cleanups (#5148)
Browse files Browse the repository at this point in the history
Some clean-ups in beacon code:
- removed redundant:
  -  nodeID (can be derived from edSigner)
  -  and metricsRegistry (unused
- use `codec.MustEncode` where an error lead to shutdown anyway
- avoid costly eager marshaling for log purposes

:bulb: It's recommended to review commit-by-commit as changes are separated.
  • Loading branch information
poszu committed Oct 12, 2023
1 parent c699075 commit 6338651
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 158 deletions.
110 changes: 24 additions & 86 deletions beacon/beacon.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@ package beacon

import (
"context"
"encoding/hex"
"errors"
"fmt"
"math/big"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/ALTree/bigfloat"
"github.com/prometheus/client_golang/prometheus"
"github.com/spacemeshos/fixed"
"golang.org/x/sync/errgroup"

Expand Down Expand Up @@ -105,7 +102,6 @@ func withNonceFetcher(nf nonceFetcher) Opt {

// New returns a new ProtocolDriver.
func New(
nodeID types.NodeID,
publisher pubsub.Publisher,
edSigner *signing.EdSigner,
edVerifier *signing.EdVerifier,
Expand All @@ -118,7 +114,6 @@ func New(
ctx: context.Background(),
logger: log.NewNop(),
config: DefaultConfig(),
nodeID: nodeID,
publisher: publisher,
edSigner: edSigner,
edVerifier: edVerifier,
Expand Down Expand Up @@ -166,7 +161,6 @@ type ProtocolDriver struct {
startOnce sync.Once

config Config
nodeID types.NodeID
sync system.SyncStateProvider
publisher pubsub.Publisher
edSigner *signing.EdSigner
Expand Down Expand Up @@ -207,7 +201,6 @@ type ProtocolDriver struct {

// metrics
metricsCollector *metrics.BeaconMetricsCollector
metricsRegistry *prometheus.Registry
}

// SetSyncState updates sync state provider. Must be executed only once.
Expand All @@ -218,13 +211,6 @@ func (pd *ProtocolDriver) SetSyncState(sync system.SyncStateProvider) {
pd.sync = sync
}

// for testing.
func (pd *ProtocolDriver) setMetricsRegistry(registry *prometheus.Registry) {
pd.mu.Lock()
defer pd.mu.Unlock()
pd.metricsRegistry = registry
}

// Start starts listening for layers and outputs.
func (pd *ProtocolDriver) Start(ctx context.Context) {
pd.startOnce.Do(func() {
Expand All @@ -237,7 +223,7 @@ func (pd *ProtocolDriver) Start(ctx context.Context) {
pd.logger.Info("beacon protocol disabled")
return
}
pd.logger.With().Info("starting beacon protocol", log.String("config", fmt.Sprintf("%+v", pd.config)))
pd.logger.With().Info("starting beacon protocol", log.Any("config", pd.config))
pd.setProposalTimeForNextEpoch()
pd.eg.Go(func() error {
pd.listenEpochs(ctx)
Expand Down Expand Up @@ -586,7 +572,7 @@ func (pd *ProtocolDriver) initEpochStateIfNotPresent(logger log.Log, epoch types
log.Bool("malicious", malicious),
log.Stringer("smesher", header.NodeID))
}
if header.NodeID == pd.nodeID {
if header.NodeID == pd.edSigner.NodeID() {
active = true
}
return nil
Expand All @@ -600,7 +586,7 @@ func (pd *ProtocolDriver) initEpochStateIfNotPresent(logger log.Log, epoch types
}

if active {
nnc, err := pd.nonceFetcher.VRFNonce(pd.nodeID, epoch)
nnc, err := pd.nonceFetcher.VRFNonce(pd.edSigner.NodeID(), epoch)
if err != nil {
logger.With().Error("failed to get own VRF nonce", log.Err(err))
return nil, fmt.Errorf("get own VRF nonce: %w", err)
Expand Down Expand Up @@ -740,7 +726,7 @@ func (pd *ProtocolDriver) onNewEpoch(ctx context.Context, epoch types.EpochID) e
func (pd *ProtocolDriver) runProtocol(ctx context.Context, epoch types.EpochID, st *state) {
ctx = log.WithNewSessionID(ctx)
targetEpoch := epoch + 1
logger := pd.logger.WithContext(ctx).WithFields(epoch, log.Uint32("target_epoch", uint32(targetEpoch)))
logger := pd.logger.WithContext(ctx).WithFields(epoch, log.FieldNamed("target_epoch", targetEpoch))

pd.setBeginProtocol(ctx)
defer pd.setEndProtocol(ctx)
Expand Down Expand Up @@ -776,17 +762,14 @@ func (pd *ProtocolDriver) runProtocol(ctx context.Context, epoch types.EpochID,

func calcBeacon(logger log.Log, set proposalSet) types.Beacon {
allProposals := set.sort()
allHexes := make([]string, len(allProposals))
for i, h := range allProposals {
allHexes[i] = hex.EncodeToString(h[:])
}

// Beacon should appear to have the same entropy as the initial proposals, hence cropping it
// to the same size as the proposal
beacon := types.BytesToBeacon(allProposals.hash().Bytes())
logger.With().Info("calculated beacon",
beacon,
log.Int("num_hashes", len(allProposals)),
log.String("proposals", strings.Join(allHexes, ", ")),
log.Array("proposals", allProposals),
)
return beacon
}
Expand All @@ -795,8 +778,7 @@ func (pd *ProtocolDriver) runProposalPhase(ctx context.Context, epoch types.Epoc
logger := pd.logger.WithContext(ctx).WithFields(epoch)
logger.Info("starting beacon proposal phase")

var cancel func()
ctx, cancel = context.WithTimeout(ctx, pd.config.ProposalDuration)
ctx, cancel := context.WithTimeout(ctx, pd.config.ProposalDuration)
defer cancel()

if st.nonce != nil {
Expand Down Expand Up @@ -835,28 +817,18 @@ func (pd *ProtocolDriver) sendProposal(ctx context.Context, epoch types.EpochID,
proposal := ProposalFromVrf(vrfSig)
m := ProposalMessage{
EpochID: epoch,
NodeID: pd.nodeID,
NodeID: pd.edSigner.NodeID(),
VRFSignature: vrfSig,
}

if invalid == pd.classifyProposal(logger, m, atx.Received, time.Now(), checker) {
logger.With().Debug("own proposal doesn't pass threshold",
log.String("proposal", hex.EncodeToString(proposal[:])),
)
logger.With().Debug("own proposal doesn't pass threshold", log.Inline(proposal))
return
}

logger.With().Debug("own proposal passes threshold",
log.String("proposal", hex.EncodeToString(proposal[:])),
)

serialized, err := codec.Encode(&m)
if err != nil {
logger.With().Fatal("failed to encode beacon proposal", log.Err(err))
}

pd.sendToGossip(ctx, pubsub.BeaconProposalProtocol, serialized)
logger.With().Info("beacon proposal sent", log.String("proposal", hex.EncodeToString(proposal[:])))
logger.With().Debug("own proposal passes threshold", log.Inline(proposal))
pd.sendToGossip(ctx, pubsub.BeaconProposalProtocol, codec.MustEncode(&m))
logger.With().Info("beacon proposal sent", log.Inline(proposal))
}

// runConsensusPhase runs K voting rounds and returns result from last weak coin round.
Expand Down Expand Up @@ -975,28 +947,17 @@ func (pd *ProtocolDriver) genFirstRoundMsgBody(epoch types.EpochID) (FirstVoting
func (pd *ProtocolDriver) sendFirstRoundVote(ctx context.Context, epoch types.EpochID) error {
mb, err := pd.genFirstRoundMsgBody(epoch)
if err != nil {
return err
}

encoded, err := codec.Encode(&mb)
if err != nil {
pd.logger.With().Fatal("failed to serialize message for signing", log.Err(err))
return fmt.Errorf("getting first round message: %w", err)
}
sig := pd.edSigner.Sign(signing.BEACON_FIRST_MSG, encoded)

m := FirstVotingMessage{
FirstVotingMessageBody: mb,
SmesherID: pd.edSigner.NodeID(),
Signature: sig,
Signature: pd.edSigner.Sign(signing.BEACON_FIRST_MSG, codec.MustEncode(&mb)),
}

pd.logger.WithContext(ctx).With().Debug("sending first round vote", epoch, types.FirstRound)
serialized, err := codec.Encode(&m)
if err != nil {
pd.logger.With().Fatal("failed to serialize message for gossip", log.Err(err))
}

pd.sendToGossip(ctx, pubsub.BeaconFirstVotesProtocol, serialized)
pd.sendToGossip(ctx, pubsub.BeaconFirstVotesProtocol, codec.MustEncode(&m))
return nil
}

Expand All @@ -1015,7 +976,7 @@ func (pd *ProtocolDriver) getFirstRoundVote(epoch types.EpochID, nodeID types.No
func (pd *ProtocolDriver) sendFollowingVote(ctx context.Context, epoch types.EpochID, round types.RoundID, ownCurrentRoundVotes allVotes) error {
firstRoundVotes, err := pd.getFirstRoundVote(epoch, pd.edSigner.NodeID())
if err != nil {
return fmt.Errorf("get own first round votes %v: %w", pd.edSigner.NodeID().String(), err)
return fmt.Errorf("get own first round votes %s: %w", pd.edSigner.NodeID(), err)
}

bitVector := encodeVotes(ownCurrentRoundVotes, firstRoundVotes)
Expand All @@ -1025,26 +986,14 @@ func (pd *ProtocolDriver) sendFollowingVote(ctx context.Context, epoch types.Epo
VotesBitVector: bitVector,
}

encoded, err := codec.Encode(&mb)
if err != nil {
pd.logger.With().Fatal("failed to serialize message for signing", log.Err(err))
}
sig := pd.edSigner.Sign(signing.BEACON_FOLLOWUP_MSG, encoded)

m := FollowingVotingMessage{
FollowingVotingMessageBody: mb,
SmesherID: pd.edSigner.NodeID(),
Signature: sig,
Signature: pd.edSigner.Sign(signing.BEACON_FOLLOWUP_MSG, codec.MustEncode(&mb)),
}

pd.logger.WithContext(ctx).With().Debug("sending following round vote", epoch, round)

serialized, err := codec.Encode(&m)
if err != nil {
pd.logger.With().Fatal("failed to serialize message for gossip", log.Err(err))
}

pd.sendToGossip(ctx, pubsub.BeaconFollowingVotesProtocol, serialized)
pd.sendToGossip(ctx, pubsub.BeaconFollowingVotesProtocol, codec.MustEncode(&m))
return nil
}

Expand All @@ -1064,8 +1013,9 @@ func createProposalChecker(logger log.Log, conf Config, numEarlyATXs, numATXs in
logger.With().Info("created proposal checker with ATX threshold",
log.Int("num_early_atxs", numEarlyATXs),
log.Int("num_atxs", numATXs),
log.String("threshold", high.String()),
log.String("threshold_strict", low.String()))
log.Stringer("threshold", high),
log.Stringer("threshold_strict", low),
)
return &proposalChecker{threshold: high, thresholdStrict: low}
}

Expand Down Expand Up @@ -1136,11 +1086,7 @@ func buildSignedProposal(ctx context.Context, logger log.Log, signer vrfSigner,
p := buildProposal(logger, epoch, nonce)
vrfSig := signer.Sign(p)
proposal := ProposalFromVrf(vrfSig)
logger.WithContext(ctx).With().Debug("calculated beacon proposal",
epoch,
nonce,
log.String("proposal", hex.EncodeToString(proposal[:])),
)
logger.WithContext(ctx).With().Debug("calculated beacon proposal", epoch, nonce, log.Inline(proposal))
return vrfSig
}

Expand All @@ -1150,22 +1096,14 @@ func buildProposal(logger log.Log, epoch types.EpochID, nonce types.VRFPostIndex
Nonce: nonce,
Epoch: epoch,
}

b, err := codec.Encode(message)
if err != nil {
logger.With().Fatal("failed to serialize proposal", log.Err(err))
}
return b
return codec.MustEncode(message)
}

func (pd *ProtocolDriver) sendToGossip(ctx context.Context, protocol string, serialized []byte) {
// NOTE(dshulyak) moved to goroutine because self-broadcast is applied synchronously
pd.eg.Go(func() error {
if err := pd.publisher.Publish(ctx, protocol, serialized); err != nil {
pd.logger.With().Error("failed to broadcast",
log.String("protocol", protocol),
log.Err(err),
)
pd.logger.With().Error("failed to broadcast", log.String("protocol", protocol), log.Err(err))
}
return nil
})
Expand Down
17 changes: 8 additions & 9 deletions beacon/beacon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,13 @@ func newTestDriver(tb testing.TB, cfg Config, p pubsub.Publisher) *testProtocolD
tpd.mNonceFetcher.EXPECT().VRFNonce(gomock.Any(), gomock.Any()).AnyTimes().Return(types.VRFPostIndex(1), nil)

tpd.cdb = datastore.NewCachedDB(sql.InMemory(), lg)
tpd.ProtocolDriver = New(minerID, p, edSgn, edVerify, tpd.mVerifier, tpd.cdb, tpd.mClock,
tpd.ProtocolDriver = New(p, edSgn, edVerify, tpd.mVerifier, tpd.cdb, tpd.mClock,
WithConfig(cfg),
WithLogger(lg),
withWeakCoin(coinValueMock(tb, true)),
withNonceFetcher(tpd.mNonceFetcher),
)
tpd.ProtocolDriver.SetSyncState(tpd.mSync)
tpd.ProtocolDriver.setMetricsRegistry(prometheus.NewPedanticRegistry())
return tpd
}

Expand Down Expand Up @@ -159,11 +158,11 @@ func TestBeacon_MultipleNodes(t *testing.T) {
for _, node := range testNodes {
switch protocol {
case pubsub.BeaconProposalProtocol:
require.NoError(t, node.HandleProposal(ctx, p2p.Peer(node.nodeID.ShortString()), data))
require.NoError(t, node.HandleProposal(ctx, p2p.Peer(node.edSigner.NodeID().ShortString()), data))
case pubsub.BeaconFirstVotesProtocol:
require.NoError(t, node.HandleFirstVotes(ctx, p2p.Peer(node.nodeID.ShortString()), data))
require.NoError(t, node.HandleFirstVotes(ctx, p2p.Peer(node.edSigner.NodeID().ShortString()), data))
case pubsub.BeaconFollowingVotesProtocol:
require.NoError(t, node.HandleFollowingVotes(ctx, p2p.Peer(node.nodeID.ShortString()), data))
require.NoError(t, node.HandleFollowingVotes(ctx, p2p.Peer(node.edSigner.NodeID().ShortString()), data))
case pubsub.BeaconWeakCoinProtocol:
}
}
Expand Down Expand Up @@ -228,11 +227,11 @@ func TestBeacon_MultipleNodes_OnlyOneHonest(t *testing.T) {
for _, node := range testNodes {
switch protocol {
case pubsub.BeaconProposalProtocol:
require.NoError(t, node.HandleProposal(ctx, p2p.Peer(node.nodeID.ShortString()), data))
require.NoError(t, node.HandleProposal(ctx, p2p.Peer(node.edSigner.NodeID().ShortString()), data))
case pubsub.BeaconFirstVotesProtocol:
require.NoError(t, node.HandleFirstVotes(ctx, p2p.Peer(node.nodeID.ShortString()), data))
require.NoError(t, node.HandleFirstVotes(ctx, p2p.Peer(node.edSigner.NodeID().ShortString()), data))
case pubsub.BeaconFollowingVotesProtocol:
require.NoError(t, node.HandleFollowingVotes(ctx, p2p.Peer(node.nodeID.ShortString()), data))
require.NoError(t, node.HandleFollowingVotes(ctx, p2p.Peer(node.edSigner.NodeID().ShortString()), data))
case pubsub.BeaconWeakCoinProtocol:
}
}
Expand Down Expand Up @@ -264,7 +263,7 @@ func TestBeacon_MultipleNodes_OnlyOneHonest(t *testing.T) {
for _, db := range dbs {
createATX(t, db, atxPublishLid, node.edSigner, 1, time.Now().Add(-1*time.Second))
if i != 0 {
require.NoError(t, identities.SetMalicious(db, node.nodeID, []byte("bad"), time.Now()))
require.NoError(t, identities.SetMalicious(db, node.edSigner.NodeID(), []byte("bad"), time.Now()))
}
}
}
Expand Down

0 comments on commit 6338651

Please sign in to comment.