diff --git a/Makefile b/Makefile index 4198474b829..d94176b7820 100644 --- a/Makefile +++ b/Makefile @@ -102,7 +102,7 @@ clear-test-cache: .PHONY: clear-test-cache test: get-libs - @$(ULIMIT) CGO_LDFLAGS="$(CGO_TEST_LDFLAGS)" gotestsum -- -race -timeout 5m -p 1 $(UNIT_TESTS) + @$(ULIMIT) CGO_LDFLAGS="$(CGO_TEST_LDFLAGS)" gotestsum -- -race -timeout 8m -p 1 $(UNIT_TESTS) .PHONY: test generate: get-libs diff --git a/activation/activation.go b/activation/activation.go index 8d9838c2176..ca9437da16b 100644 --- a/activation/activation.go +++ b/activation/activation.go @@ -776,14 +776,14 @@ func (b *Builder) createAtx( NiPosts: []wire.NiPostsV2{ { Membership: wire.MerkleProofV2{ - Nodes: nipostState.Membership.Nodes, - LeafIndices: []uint64{nipostState.Membership.LeafIndex}, + Nodes: nipostState.Membership.Nodes, }, Challenge: types.Hash32(nipostState.NIPost.PostMetadata.Challenge), Posts: []wire.SubPostV2{ { - Post: *wire.PostToWireV1(nipostState.Post), - NumUnits: nipostState.NumUnits, + Post: *wire.PostToWireV1(nipostState.Post), + NumUnits: nipostState.NumUnits, + MembershipLeafIndex: nipostState.Membership.LeafIndex, }, }, }, diff --git a/activation/e2e/atx_merge_test.go b/activation/e2e/atx_merge_test.go new file mode 100644 index 00000000000..a0cf900b48b --- /dev/null +++ b/activation/e2e/atx_merge_test.go @@ -0,0 +1,554 @@ +package activation_test + +import ( + "context" + "encoding/hex" + "fmt" + "net/url" + "slices" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/spacemeshos/merkle-tree" + "github.com/spacemeshos/poet/registration" + "github.com/spacemeshos/poet/shared" + "github.com/spacemeshos/post/verifying" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "go.uber.org/zap" + "go.uber.org/zap/zaptest" + "golang.org/x/sync/errgroup" + + "github.com/spacemeshos/go-spacemesh/activation" + "github.com/spacemeshos/go-spacemesh/activation/wire" + "github.com/spacemeshos/go-spacemesh/api/grpcserver" + "github.com/spacemeshos/go-spacemesh/atxsdata" + "github.com/spacemeshos/go-spacemesh/codec" + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/datastore" + "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/system" + smocks "github.com/spacemeshos/go-spacemesh/system/mocks" + "github.com/spacemeshos/go-spacemesh/timesync" +) + +func constructMerkleProof(t *testing.T, members []types.Hash32, ids map[uint64]bool) wire.MerkleProofV2 { + t.Helper() + + tree, err := merkle.NewTreeBuilder(). + WithLeavesToProve(ids). + WithHashFunc(shared.HashMembershipTreeNode). + Build() + require.NoError(t, err) + for _, member := range members { + require.NoError(t, tree.AddLeaf(member[:])) + } + nodes := tree.Proof() + nodesH32 := make([]types.Hash32, 0, len(nodes)) + for _, n := range nodes { + nodesH32 = append(nodesH32, types.BytesToHash(n)) + } + return wire.MerkleProofV2{Nodes: nodesH32} +} + +type nipostData struct { + previous types.ATXID + *nipost.NIPostState +} + +func buildNipost( + nb *activation.NIPostBuilder, + sig *signing.EdSigner, + publish types.EpochID, + previous, positioning types.ATXID, +) (nipostData, error) { + challenge := wire.NIPostChallengeV2{ + PublishEpoch: publish, + PrevATXID: previous, + PositioningATXID: positioning, + } + nipost, err := nb.BuildNIPost(context.Background(), sig, challenge.PublishEpoch, challenge.Hash()) + nb.ResetState(sig.NodeID()) + return nipostData{previous, nipost}, err +} + +func createInitialAtx( + publish types.EpochID, + commitment, pos types.ATXID, + nipost *nipost.NIPostState, + initial *types.Post, +) *wire.ActivationTxV2 { + return &wire.ActivationTxV2{ + PublishEpoch: publish, + PositioningATX: pos, + Initial: &wire.InitialAtxPartsV2{ + CommitmentATX: commitment, + Post: *wire.PostToWireV1(initial), + }, + VRFNonce: uint64(nipost.VRFNonce), + NiPosts: []wire.NiPostsV2{ + { + Membership: wire.MerkleProofV2{ + Nodes: nipost.Membership.Nodes, + }, + Challenge: types.Hash32(nipost.PostMetadata.Challenge), + Posts: []wire.SubPostV2{ + { + Post: *wire.PostToWireV1(nipost.Post), + NumUnits: nipost.NumUnits, + MembershipLeafIndex: nipost.Membership.LeafIndex, + }, + }, + }, + }, + } +} + +func createSoloAtx(publish types.EpochID, prev, pos types.ATXID, nipost *nipost.NIPostState) *wire.ActivationTxV2 { + return &wire.ActivationTxV2{ + PublishEpoch: publish, + PreviousATXs: []types.ATXID{prev}, + PositioningATX: pos, + VRFNonce: uint64(nipost.VRFNonce), + NiPosts: []wire.NiPostsV2{ + { + Membership: wire.MerkleProofV2{ + Nodes: nipost.Membership.Nodes, + }, + Challenge: types.Hash32(nipost.PostMetadata.Challenge), + Posts: []wire.SubPostV2{ + { + Post: *wire.PostToWireV1(nipost.Post), + NumUnits: nipost.NumUnits, + MembershipLeafIndex: nipost.Membership.LeafIndex, + }, + }, + }, + }, + } +} + +func createMerged( + niposts []nipostData, + publish types.EpochID, + marriage, positioning types.ATXID, + previous []types.ATXID, + membership wire.MerkleProofV2, +) *wire.ActivationTxV2 { + atx := &wire.ActivationTxV2{ + PublishEpoch: publish, + PreviousATXs: previous, + MarriageATX: &marriage, + PositioningATX: positioning, + NiPosts: []wire.NiPostsV2{ + { + Membership: membership, + Challenge: types.Hash32(niposts[0].PostMetadata.Challenge), + }, + }, + } + // Append PoSTs for all IDs + for i, nipost := range niposts { + idx := slices.IndexFunc(previous, func(a types.ATXID) bool { return a == nipost.previous }) + if idx == -1 { + panic(fmt.Sprintf("previous ATX %s not found in %s", nipost.previous, previous)) + } + atx.NiPosts[0].Posts = append(atx.NiPosts[0].Posts, wire.SubPostV2{ + MarriageIndex: uint32(i), + PrevATXIndex: uint32(idx), + MembershipLeafIndex: nipost.Membership.LeafIndex, + Post: *wire.PostToWireV1(nipost.Post), + NumUnits: nipost.NumUnits, + }) + } + return atx +} + +func signers(t *testing.T, keysHex []string) []*signing.EdSigner { + t.Helper() + + var keys [][]byte + for _, k := range keysHex { + key, err := hex.DecodeString(k) + require.NoError(t, err) + keys = append(keys, key) + } + + signers := []*signing.EdSigner{} + for _, key := range keys { + sig, err := signing.NewEdSigner(signing.WithPrivateKey(key)) + require.NoError(t, err) + signers = append(signers, sig) + } + return signers +} + +var units = [2]uint32{2, 3} + +// Keys were preselected to give IDs whose VRF nonces satisfy the combined storage requirement for the above `units`. +// +//nolint:lll +var singerKeys = [2]string{ + "1f2b77052ecc193038156d5c32f08d449742e7dda81fa172f8ac90839d34c76935a5d9365d1317c3002838126409e138321c57a5651d758485336c1e7e5af101", + "6f385445a53d8af57874acd2dd98023858df7aa62f0b6e91ffdd51198036e2c331d2a7c55ba1e29312ac71dd419b4edc019b6406960cfc8ffb3d7550dde2ca1b", +} + +func Test_MarryAndMerge(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + signers := signers(t, singerKeys[:]) + + var totalNumUnits uint32 + var nonces [2]uint64 + + logger := zaptest.NewLogger(t) + goldenATX := types.ATXID{2, 3, 4} + coinbase := types.Address{1, 2, 3, 4, 5, 6, 7} + cfg := activation.DefaultPostConfig() + db := sql.InMemory() + cdb := datastore.NewCachedDB(db, logger) + localDB := localsql.InMemory() + + syncer := activation.NewMocksyncer(ctrl) + syncer.EXPECT().RegisterForATXSynced().DoAndReturn(func() <-chan struct{} { + synced := make(chan struct{}) + close(synced) + return synced + }).AnyTimes() + + svc := grpcserver.NewPostService(logger) + svc.AllowConnections(true) + grpcCfg, cleanup := launchServer(t, svc) + t.Cleanup(cleanup) + + opts := testPostSetupOpts(t) + verifyingOpts := activation.DefaultTestPostVerifyingOpts() + verifier, err := activation.NewPostVerifier(cfg, logger, activation.WithVerifyingOpts(verifyingOpts)) + require.NoError(t, err) + t.Cleanup(func() { assert.NoError(t, verifier.Close()) }) + poetDb := activation.NewPoetDb(db, logger.Named("poetDb")) + validator := activation.NewValidator(db, poetDb, cfg, opts.Scrypt, verifier) + + var eg errgroup.Group + for i, sig := range signers { + opts := opts + opts.DataDir = t.TempDir() + opts.NumUnits = units[i] + totalNumUnits += units[i] + + eg.Go(func() error { + mgr, err := activation.NewPostSetupManager(cfg, logger, db, atxsdata.New(), goldenATX, syncer, validator) + require.NoError(t, err) + + t.Cleanup(launchPostSupervisor(t, zap.NewNop(), mgr, sig, grpcCfg, opts)) + + require.Eventually(t, func() bool { + _, err := svc.Client(sig.NodeID()) + return err == nil + }, 10*time.Second, 100*time.Millisecond, "timed out waiting for connection") + return nil + }) + } + require.NoError(t, eg.Wait()) + + // ensure that genesis aligns with layer timings + genesis := time.Now().Add(layerDuration).Round(layerDuration) + layerDuration := 2 * time.Second + epoch := layersPerEpoch * layerDuration + poetCfg := activation.PoetConfig{ + PhaseShift: epoch, + CycleGap: epoch / 2, + GracePeriod: epoch / 5, + RequestTimeout: epoch / 5, + RequestRetryDelay: epoch / 50, + MaxRequestRetries: 10, + } + + pubkey, address := spawnTestCertifier(t, cfg, nil, verifying.WithLabelScryptParams(opts.Scrypt)) + certClient := activation.NewCertifierClient(db, localDB, logger.Named("certifier")) + certifier := activation.NewCertifier(localDB, logger, certClient) + poet := spawnPoet( + t, + WithGenesis(genesis), + WithEpochDuration(epoch), + WithPhaseShift(poetCfg.PhaseShift), + WithCycleGap(poetCfg.CycleGap), + WithCertifier(®istration.CertifierConfig{ + URL: (&url.URL{Scheme: "http", Host: address.String()}).String(), + PubKey: registration.Base64Enc(pubkey), + }), + ) + poetClient, err := poet.Client(poetDb, poetCfg, logger, activation.WithCertifier(certifier)) + require.NoError(t, err) + + clock, err := timesync.NewClock( + timesync.WithGenesisTime(genesis), + timesync.WithLayerDuration(layerDuration), + timesync.WithTickInterval(100*time.Millisecond), + timesync.WithLogger(zap.NewNop()), + ) + require.NoError(t, err) + t.Cleanup(clock.Close) + + nb, err := activation.NewNIPostBuilder( + localDB, + svc, + logger.Named("nipostBuilder"), + poetCfg, + clock, + activation.WithPoetClients(poetClient), + ) + require.NoError(t, err) + + mpub := mocks.NewMockPublisher(ctrl) + mFetch := smocks.NewMockFetcher(ctrl) + mBeacon := activation.NewMockAtxReceiver(ctrl) + mTortoise := smocks.NewMockTortoise(ctrl) + + tickSize := uint64(3) + atxHdlr := activation.NewHandler( + "local", + cdb, + atxsdata.New(), + signing.NewEdVerifier(), + clock, + mpub, + mFetch, + goldenATX, + validator, + mBeacon, + mTortoise, + logger, + activation.WithAtxVersions(activation.AtxVersions{0: types.AtxV2}), + activation.WithTickSize(tickSize), + ) + + // Step 1. Marry + publish := types.EpochID(1) + var niposts [2]nipostData + var initialPosts [2]*types.Post + eg = errgroup.Group{} + for i, signer := range signers { + eg.Go(func() error { + post, postInfo, err := nb.Proof(context.Background(), signer.NodeID(), types.EmptyHash32[:]) + if err != nil { + return err + } + + challenge := wire.NIPostChallengeV2{ + PublishEpoch: publish, + PositioningATXID: goldenATX, + InitialPost: wire.PostToWireV1(post), + } + nipost, err := nb.BuildNIPost(context.Background(), signer, challenge.PublishEpoch, challenge.Hash()) + if err != nil { + return err + } + nb.ResetState(signer.NodeID()) + + initialPosts[i] = post + nonces[i] = uint64(*postInfo.Nonce) + niposts[i] = nipostData{types.EmptyATXID, nipost} + return nil + }) + } + require.NoError(t, eg.Wait()) + + // mainID will create marriage ATX + mainID, mergedID := signers[0], signers[1] + + mergedIdAtx := createInitialAtx(publish, goldenATX, goldenATX, niposts[1].NIPostState, initialPosts[1]) + mergedIdAtx.Sign(mergedID) + + marriageATX := createInitialAtx(publish, goldenATX, goldenATX, niposts[0].NIPostState, initialPosts[0]) + marriageATX.Marriages = []wire.MarriageCertificate{ + { + Signature: mainID.Sign(signing.MARRIAGE, mainID.NodeID().Bytes()), + }, + { + ReferenceAtx: mergedIdAtx.ID(), + Signature: mergedID.Sign(signing.MARRIAGE, mainID.NodeID().Bytes()), + }, + } + marriageATX.Sign(mainID) + logger.Info("publishing marriage ATX", zap.Inline(marriageATX)) + + mFetch.EXPECT().RegisterPeerHashes(peer.ID(""), gomock.Any()) + mFetch.EXPECT().GetPoetProof(gomock.Any(), gomock.Any()) + mFetch.EXPECT().GetAtxs(gomock.Any(), []types.ATXID{mergedIdAtx.ID()}, gomock.Any()). + DoAndReturn(func(_ context.Context, _ []types.ATXID, _ ...system.GetAtxOpt) error { + // Provide the referenced ATX for the married ID + mFetch.EXPECT().RegisterPeerHashes(peer.ID(""), gomock.Any()) + mFetch.EXPECT().GetPoetProof(gomock.Any(), gomock.Any()) + mBeacon.EXPECT().OnAtx(gomock.Any()) + mTortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + return atxHdlr.HandleGossipAtx(context.Background(), "", codec.MustEncode(mergedIdAtx)) + }) + mBeacon.EXPECT().OnAtx(gomock.Any()) + mTortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + err = atxHdlr.HandleGossipAtx(context.Background(), "", codec.MustEncode(marriageATX)) + require.NoError(t, err) + + // Verify marriage + for i, signer := range signers { + marriage, idx, err := identities.MarriageInfo(db, signer.NodeID()) + require.NoError(t, err) + require.NotNil(t, marriage) + require.Equal(t, marriageATX.ID(), *marriage) + require.Equal(t, i, idx) + } + + // Step 2. Publish merged ATX together + publish = marriageATX.PublishEpoch + 2 + eg = errgroup.Group{} + + // 2.1. NiPOST for main ID (the publisher) + eg.Go(func() error { + n, err := buildNipost(nb, mainID, publish, marriageATX.ID(), marriageATX.ID()) + logger.Info("built NiPoST", zap.Any("post", niposts[0])) + niposts[0] = n + return err + }) + + // 2.2. NiPOST for merged ID + prevATXID, err := atxs.GetLastIDByNodeID(db, mergedID.NodeID()) + require.NoError(t, err) + eg.Go(func() error { + n, err := buildNipost(nb, mergedID, publish, prevATXID, marriageATX.ID()) + logger.Info("built NiPoST", zap.Any("post", n)) + niposts[1] = n + return err + }) + require.NoError(t, eg.Wait()) + + // 2.3 Construct a multi-ID poet membership merkle proof for both IDs + poetProof, members, err := poetClient.Proof(context.Background(), "1") + require.NoError(t, err) + membershipProof := constructMerkleProof(t, members, map[uint64]bool{0: true, 1: true}) + + mergedATX := createMerged( + niposts[:], + publish, + marriageATX.ID(), + marriageATX.ID(), + []types.ATXID{marriageATX.ID(), prevATXID}, + membershipProof, + ) + mergedATX.Coinbase = coinbase + mergedATX.VRFNonce = nonces[0] + mergedATX.Sign(mainID) + + // 2.4 Publish + logger.Info("publishing merged ATX", zap.Inline(mergedATX)) + + mFetch.EXPECT().RegisterPeerHashes(peer.ID(""), gomock.Any()) + mFetch.EXPECT().GetPoetProof(gomock.Any(), gomock.Any()) + mFetch.EXPECT().GetAtxs(gomock.Any(), gomock.Any(), gomock.Any()) + mBeacon.EXPECT().OnAtx(gomock.Any()) + mTortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + err = atxHdlr.HandleGossipAtx(context.Background(), "", codec.MustEncode(mergedATX)) + require.NoError(t, err) + + // Step 3. verify the merged ATX + atx, err := atxs.Get(db, mergedATX.ID()) + require.NoError(t, err) + require.Equal(t, totalNumUnits, atx.NumUnits) + require.Equal(t, mainID.NodeID(), atx.SmesherID) + require.Equal(t, poetProof.LeafCount/tickSize, atx.TickCount) + require.Equal(t, uint64(totalNumUnits)*atx.TickCount, atx.Weight) + + posATX, err := atxs.Get(db, marriageATX.ID()) + require.NoError(t, err) + require.Equal(t, posATX.TickHeight(), atx.BaseTickHeight) + + // Step 4. Publish merged using the same previous now + // Publish by the other signer this time. + publish = mergedATX.PublishEpoch + 1 + eg = errgroup.Group{} + for i, sig := range signers { + eg.Go(func() error { + n, err := buildNipost(nb, sig, publish, mergedATX.ID(), mergedATX.ID()) + logger.Info("built NiPoST", zap.Any("post", n)) + niposts[i] = n + return err + }) + } + require.NoError(t, eg.Wait()) + poetProof, members, err = poetClient.Proof(context.Background(), "2") + require.NoError(t, err) + membershipProof = constructMerkleProof(t, members, map[uint64]bool{0: true}) + + mergedATX2 := createMerged( + niposts[:], + publish, + marriageATX.ID(), + mergedATX.ID(), + []types.ATXID{mergedATX.ID()}, + membershipProof, + ) + mergedATX2.Coinbase = coinbase + mergedATX2.VRFNonce = nonces[1] + mergedATX2.Sign(signers[1]) + + logger.Info("publishing second merged ATX", zap.Inline(mergedATX2)) + mFetch.EXPECT().RegisterPeerHashes(peer.ID(""), gomock.Any()) + mFetch.EXPECT().GetPoetProof(gomock.Any(), gomock.Any()) + mFetch.EXPECT().GetAtxs(gomock.Any(), gomock.Any(), gomock.Any()) + mBeacon.EXPECT().OnAtx(gomock.Any()) + mTortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + err = atxHdlr.HandleGossipAtx(context.Background(), "", codec.MustEncode(mergedATX2)) + require.NoError(t, err) + + atx, err = atxs.Get(db, mergedATX2.ID()) + require.NoError(t, err) + require.Equal(t, totalNumUnits, atx.NumUnits) + require.Equal(t, signers[1].NodeID(), atx.SmesherID) + require.Equal(t, poetProof.LeafCount/tickSize, atx.TickCount) + require.Equal(t, uint64(totalNumUnits)*atx.TickCount, atx.Weight) + + posATX, err = atxs.Get(db, mergedATX.ID()) + require.NoError(t, err) + require.Equal(t, posATX.TickHeight(), atx.BaseTickHeight) + + // Step 5. Make an emergency split and publish separately + publish = mergedATX2.PublishEpoch + 1 + eg = errgroup.Group{} + for i, sig := range signers { + eg.Go(func() error { + n, err := buildNipost(nb, sig, publish, mergedATX2.ID(), mergedATX2.ID()) + logger.Info("built NiPoST", zap.Any("post", n)) + niposts[i] = n + return err + }) + } + require.NoError(t, eg.Wait()) + + for i, signer := range signers { + atx := createSoloAtx(publish, mergedATX2.ID(), mergedATX2.ID(), niposts[i].NIPostState) + atx.Sign(signer) + logger.Info("publishing split ATX", zap.Inline(atx)) + + mFetch.EXPECT().RegisterPeerHashes(peer.ID(""), gomock.Any()) + mFetch.EXPECT().GetPoetProof(gomock.Any(), gomock.Any()) + mFetch.EXPECT().GetAtxs(gomock.Any(), gomock.Any(), gomock.Any()) + mBeacon.EXPECT().OnAtx(gomock.Any()) + mTortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) + err = atxHdlr.HandleGossipAtx(context.Background(), "", codec.MustEncode(atx)) + require.NoError(t, err) + + atxFromDb, err := atxs.Get(db, atx.ID()) + require.NoError(t, err) + require.Equal(t, units[i], atxFromDb.NumUnits) + require.Equal(t, signer.NodeID(), atxFromDb.SmesherID) + require.Equal(t, publish, atxFromDb.PublishEpoch) + require.Equal(t, mergedATX2.ID(), atxFromDb.PrevATXID) + } +} diff --git a/activation/e2e/builds_atx_v2_test.go b/activation/e2e/builds_atx_v2_test.go index 004656d301a..fa104d7ea01 100644 --- a/activation/e2e/builds_atx_v2_test.go +++ b/activation/e2e/builds_atx_v2_test.go @@ -215,7 +215,7 @@ func TestBuilder_SwitchesToBuildV2(t *testing.T) { require.NotZero(t, atx.BaseTickHeight) require.NotZero(t, atx.TickCount) - require.NotZero(t, atx.GetWeight()) + require.NotZero(t, atx.Weight) require.NotZero(t, atx.TickHeight()) require.Equal(t, opts.NumUnits, atx.NumUnits) previous = atx diff --git a/activation/e2e/certifier_client_test.go b/activation/e2e/certifier_client_test.go index 26da2c7df4a..ea313e0cd75 100644 --- a/activation/e2e/certifier_client_test.go +++ b/activation/e2e/certifier_client_test.go @@ -197,6 +197,7 @@ func spawnTestCertifier( postVerifier, err := activation.NewPostVerifier( cfg, zaptest.NewLogger(t), + activation.WithVerifyingOpts(activation.DefaultTestPostVerifyingOpts()), ) require.NoError(t, err) var eg errgroup.Group diff --git a/activation/handler_test.go b/activation/handler_test.go index d5b03b9be09..30c0681dc23 100644 --- a/activation/handler_test.go +++ b/activation/handler_test.go @@ -642,7 +642,7 @@ func TestHandler_AtxWeight(t *testing.T) { require.Equal(t, uint64(0), stored1.BaseTickHeight) require.Equal(t, leaves/tickSize, stored1.TickCount) require.Equal(t, leaves/tickSize, stored1.TickHeight()) - require.Equal(t, (leaves/tickSize)*units, stored1.GetWeight()) + require.Equal(t, (leaves/tickSize)*units, stored1.Weight) atx2 := newChainedActivationTxV1(t, atx1, atx1.ID()) atx2.Sign(sig) @@ -657,7 +657,7 @@ func TestHandler_AtxWeight(t *testing.T) { require.Equal(t, stored1.TickHeight(), stored2.BaseTickHeight) require.Equal(t, leaves/tickSize, stored2.TickCount) require.Equal(t, stored1.TickHeight()+leaves/tickSize, stored2.TickHeight()) - require.Equal(t, int(leaves/tickSize)*units, int(stored2.GetWeight())) + require.Equal(t, int(leaves/tickSize)*units, int(stored2.Weight)) } func TestHandler_WrongHash(t *testing.T) { diff --git a/activation/handler_v1.go b/activation/handler_v1.go index cefff79e1bb..0126ee98326 100644 --- a/activation/handler_v1.go +++ b/activation/handler_v1.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math/bits" "sync" "time" @@ -683,6 +684,11 @@ func (h *HandlerV1) processATX( atx.NumUnits = effectiveNumUnits atx.BaseTickHeight = baseTickHeight atx.TickCount = leaves / h.tickSize + hi, weight := bits.Mul64(uint64(atx.NumUnits), atx.TickCount) + if hi != 0 { + return nil, errors.New("atx weight would overflow uint64") + } + atx.Weight = weight proof, err = h.storeAtx(ctx, atx, watx) if err != nil { diff --git a/activation/handler_v2.go b/activation/handler_v2.go index a369c50aaf7..20851b3f5a3 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -1,10 +1,11 @@ package activation import ( + "cmp" "context" "errors" "fmt" - "math" + "math/bits" "slices" "time" @@ -121,9 +122,10 @@ func (h *HandlerV2) processATX( atx := &types.ActivationTx{ PublishEpoch: watx.PublishEpoch, Coinbase: watx.Coinbase, - NumUnits: parts.effectiveUnits, BaseTickHeight: baseTickHeight, - TickCount: parts.leaves / h.tickSize, + NumUnits: parts.effectiveUnits, + TickCount: parts.ticks, + Weight: parts.weight, VRFNonce: types.VRFPostIndex(watx.VRFNonce), SmesherID: watx.SmesherID, AtxBlob: types.AtxBlob{Blob: blob, Version: types.AtxV2}, @@ -152,8 +154,6 @@ func (h *HandlerV2) processATX( } // Syntactically validate an ATX. -// TODOs: -// 2. support merged ATXs. func (h *HandlerV2) syntacticallyValidate(ctx context.Context, atx *wire.ActivationTxV2) error { if !h.edVerifier.Verify(signing.ATX, atx.SmesherID, atx.SignedBytes(), atx.Signature) { return fmt.Errorf("invalid atx signature: %w", errMalformedData) @@ -230,8 +230,9 @@ func (h *HandlerV2) syntacticallyValidate(ctx context.Context, atx *wire.Activat if len(atx.Marriages) != 0 { return errors.New("merged atx cannot have marriages") } - // TODO: support merged ATXs - return errors.New("atx merge is not supported") + if err := h.verifyIncludedIDsUniqueness(atx); err != nil { + return err + } default: // Solo chained (non-initial) ATX if len(atx.PreviousATXs) != 1 { @@ -347,7 +348,7 @@ func (h *HandlerV2) previous(ctx context.Context, id types.ATXID) (opaqueAtx, er // Validate the previous ATX for the given PoST and return the effective numunits. func (h *HandlerV2) validatePreviousAtx(id types.NodeID, post *wire.SubPostV2, prevAtxs []opaqueAtx) (uint32, error) { - if post.PrevATXIndex > uint32(len(prevAtxs)) { + if post.PrevATXIndex >= uint32(len(prevAtxs)) { return 0, fmt.Errorf("prevATXIndex out of bounds: %d > %d", post.PrevATXIndex, len(prevAtxs)) } prev := prevAtxs[post.PrevATXIndex] @@ -367,13 +368,28 @@ func (h *HandlerV2) validatePreviousAtx(id types.NodeID, post *wire.SubPostV2, p } return min(prev.NumUnits, post.NumUnits), nil case *wire.ActivationTxV2: - // TODO: support previous merged-ATX - - // previous is solo ATX - if prev.SmesherID == id { - return min(prev.NiPosts[0].Posts[0].NumUnits, post.NumUnits), nil + if prev.MarriageATX != nil { + // Previous is a merged ATX + // need to find out if the given ID was present in the previous ATX + _, idx, err := identities.MarriageInfo(h.cdb, id) + if err != nil { + return 0, fmt.Errorf("fetching marriage info for ID %s: %w", id, err) + } + for _, nipost := range prev.NiPosts { + for _, post := range nipost.Posts { + if post.MarriageIndex == uint32(idx) { + return min(post.NumUnits, post.NumUnits), nil + } + } + } + } else { + // Previous is a solo ATX + if prev.SmesherID == id { + return min(prev.NiPosts[0].Posts[0].NumUnits, post.NumUnits), nil + } } - return 0, fmt.Errorf("previous solo ATX V2 has different owner: %s (expected %s)", prev.SmesherID, id) + + return 0, fmt.Errorf("previous ATX V2 doesn't contain %s", id) } return 0, fmt.Errorf("unexpected previous ATX type: %T", prev) } @@ -440,11 +456,96 @@ func (h *HandlerV2) validateMarriages(atx *wire.ActivationTxV2) ([]types.NodeID, return marryingIDs, nil } +// Validate marriage ATX and return the full equivocation set. +func (h *HandlerV2) equivocationSet(atx *wire.ActivationTxV2) ([]types.NodeID, error) { + if atx.MarriageATX == nil { + return []types.NodeID{atx.SmesherID}, nil + } + marriageAtxID, _, err := identities.MarriageInfo(h.cdb, atx.SmesherID) + switch { + case errors.Is(err, sql.ErrNotFound) || marriageAtxID == nil: + return nil, errors.New("smesher is not married") + case err != nil: + return nil, fmt.Errorf("fetching smesher's marriage atx ID: %w", err) + } + + if *atx.MarriageATX != *marriageAtxID { + return nil, fmt.Errorf("smesher's marriage ATX ID mismatch: %s != %s", *atx.MarriageATX, *marriageAtxID) + } + + marriageAtx, err := atxs.Get(h.cdb, *atx.MarriageATX) + if err != nil { + return nil, fmt.Errorf("fetching marriage atx: %w", err) + } + if !(marriageAtx.PublishEpoch <= atx.PublishEpoch-2) { + return nil, fmt.Errorf( + "marriage atx must be published at least 2 epochs before %v (is %v)", + atx.PublishEpoch, + marriageAtx.PublishEpoch, + ) + } + + return identities.EquivocationSetByMarriageATX(h.cdb, *atx.MarriageATX) +} + type atxParts struct { - leaves uint64 + ticks uint64 + weight uint64 effectiveUnits uint32 } +type nipostSize struct { + units uint32 + ticks uint64 +} + +func (n *nipostSize) addUnits(units uint32) error { + sum, carry := bits.Add32(n.units, units, 0) + if carry != 0 { + return errors.New("units overflow") + } + n.units = sum + return nil +} + +type nipostSizes []*nipostSize + +func (n nipostSizes) minTicks() uint64 { + return slices.MinFunc(n, func(a, b *nipostSize) int { return cmp.Compare(a.ticks, b.ticks) }).ticks +} + +func (n nipostSizes) sumUp() (units uint32, weight uint64, err error) { + var totalEffectiveNumUnits uint32 + var totalWeight uint64 + for _, ns := range n { + sum, carry := bits.Add32(totalEffectiveNumUnits, ns.units, 0) + if carry != 0 { + return 0, 0, fmt.Errorf("total units overflow (%d + %d)", totalEffectiveNumUnits, ns.units) + } + totalEffectiveNumUnits = sum + + hi, weight := bits.Mul64(uint64(ns.units), ns.ticks) + if hi != 0 { + return 0, 0, fmt.Errorf("weight overflow (%d * %d)", ns.units, ns.ticks) + } + totalWeight += weight + } + return totalEffectiveNumUnits, totalWeight, nil +} + +func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error { + seen := make(map[uint32]struct{}) + for _, niposts := range atx.NiPosts { + for _, post := range niposts.Posts { + if _, ok := seen[post.MarriageIndex]; ok { + return fmt.Errorf("ID present twice (duplicated marriage index): %d", post.MarriageIndex) + } + seen[post.MarriageIndex] = struct{}{} + } + } + return nil +} + // Syntactically validate the ATX with its dependencies. func (h *HandlerV2) syntacticallyValidateDeps( ctx context.Context, @@ -469,33 +570,86 @@ func (h *HandlerV2) syntacticallyValidateDeps( previousAtxs[i] = prevAtx } - // validate all niposts - // TODO: support merged ATXs - // For a merged ATX we need to fetch the equivocation this smesher is part of. - equivocationSet := []types.NodeID{atx.SmesherID} - var totalEffectiveNumUnits uint32 - var minLeaves uint64 = math.MaxUint64 - var smesherCommitment *types.ATXID - for _, niposts := range atx.NiPosts { - // verify PoET memberships in a single go - var poetChallenges [][]byte + equivocationSet, err := h.equivocationSet(atx) + if err != nil { + return nil, nil, fmt.Errorf("validating marriages: %w", err) + } + // validate previous ATXs + nipostSizes := make(nipostSizes, len(atx.NiPosts)) + for i, niposts := range atx.NiPosts { + nipostSizes[i] = new(nipostSize) for _, post := range niposts.Posts { if post.MarriageIndex >= uint32(len(equivocationSet)) { err := fmt.Errorf("marriage index out of bounds: %d > %d", post.MarriageIndex, len(equivocationSet)-1) return nil, nil, err } + id := equivocationSet[post.MarriageIndex] effectiveNumUnits := post.NumUnits if atx.Initial == nil { var err error effectiveNumUnits, err = h.validatePreviousAtx(id, &post, previousAtxs) if err != nil { - return nil, nil, fmt.Errorf("validating previous atx for ID %s: %w", id, err) + return nil, nil, fmt.Errorf("validating previous atx: %w", err) } } - totalEffectiveNumUnits += effectiveNumUnits + nipostSizes[i].addUnits(effectiveNumUnits) + } + } + + // validate poet membership proofs + for i, niposts := range atx.NiPosts { + // verify PoET memberships in a single go + indexedChallenges := make(map[uint64][]byte) + + for _, post := range niposts.Posts { + nipostChallenge := wire.NIPostChallengeV2{ + PublishEpoch: atx.PublishEpoch, + PositioningATXID: atx.PositioningATX, + } + if atx.Initial != nil { + nipostChallenge.InitialPost = &atx.Initial.Post + } else { + nipostChallenge.PrevATXID = atx.PreviousATXs[post.PrevATXIndex] + } + if _, ok := indexedChallenges[post.MembershipLeafIndex]; !ok { + indexedChallenges[post.MembershipLeafIndex] = nipostChallenge.Hash().Bytes() + } + } + + leafIndicies := make([]uint64, 0, len(indexedChallenges)) + for i := range indexedChallenges { + leafIndicies = append(leafIndicies, i) + } + slices.Sort(leafIndicies) + poetChallenges := make([][]byte, 0, len(indexedChallenges)) + for _, i := range leafIndicies { + poetChallenges = append(poetChallenges, indexedChallenges[i]) + } + + membership := types.MultiMerkleProof{ + Nodes: niposts.Membership.Nodes, + LeafIndices: leafIndicies, + } + leaves, err := h.nipostValidator.PoetMembership(ctx, &membership, niposts.Challenge, poetChallenges) + if err != nil { + return nil, nil, fmt.Errorf("invalid poet membership: %w", err) + } + nipostSizes[i].ticks = leaves / h.tickSize + } + + totalEffectiveNumUnits, totalWeight, err := nipostSizes.sumUp() + if err != nil { + return nil, nil, err + } + + // validate all niposts + var smesherCommitment *types.ATXID + for _, niposts := range atx.NiPosts { + for _, post := range niposts.Posts { + id := equivocationSet[post.MarriageIndex] var commitment types.ATXID if atx.Initial != nil { commitment = atx.Initial.CommitmentATX @@ -505,7 +659,7 @@ func (h *HandlerV2) syntacticallyValidateDeps( if err != nil { return nil, nil, fmt.Errorf("commitment atx not found for ID %s: %w", id, err) } - if smesherCommitment == nil { + if id == atx.SmesherID { smesherCommitment = &commitment } } @@ -531,36 +685,19 @@ func (h *HandlerV2) syntacticallyValidateDeps( if err != nil { return nil, nil, fmt.Errorf("invalid post for ID %s: %w", id, err) } - - nipostChallenge := wire.NIPostChallengeV2{ - PublishEpoch: atx.PublishEpoch, - PositioningATXID: atx.PositioningATX, - } - if atx.Initial != nil { - nipostChallenge.InitialPost = &atx.Initial.Post - } else { - nipostChallenge.PrevATXID = atx.PreviousATXs[post.PrevATXIndex] - } - - poetChallenges = append(poetChallenges, nipostChallenge.Hash().Bytes()) - } - membership := types.MultiMerkleProof{ - Nodes: niposts.Membership.Nodes, - LeafIndices: niposts.Membership.LeafIndices, } - leaves, err := h.nipostValidator.PoetMembership(ctx, &membership, niposts.Challenge, poetChallenges) - if err != nil { - return nil, nil, fmt.Errorf("invalid poet membership: %w", err) - } - minLeaves = min(leaves, minLeaves) } parts := &atxParts{ - leaves: minLeaves, + ticks: nipostSizes.minTicks(), effectiveUnits: totalEffectiveNumUnits, + weight: totalWeight, } if atx.Initial == nil { + if smesherCommitment == nil { + return nil, nil, errors.New("ATX signer not present in merged ATX") + } err := h.nipostValidator.VRFNonceV2(atx.SmesherID, *smesherCommitment, atx.VRFNonce, atx.TotalNumUnits()) if err != nil { return nil, nil, fmt.Errorf("validating VRF nonce: %w", err) @@ -645,8 +782,8 @@ func (h *HandlerV2) storeAtx( } if len(marrying) != 0 { - for _, id := range marrying { - if err := identities.SetMarriage(tx, id, atx.ID()); err != nil { + for i, id := range marrying { + if err := identities.SetMarriage(tx, id, atx.ID(), i); err != nil { return err } } diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 3de4a113eb0..ffa53f3fbdc 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -3,6 +3,9 @@ package activation import ( "context" "errors" + "fmt" + "math" + "slices" "testing" "time" @@ -31,7 +34,15 @@ type v2TestHandler struct { handlerMocks } -const poetLeaves = 200 +type marriedId struct { + signer *signing.EdSigner + refAtx *wire.ActivationTxV2 +} + +const ( + tickSize = 20 + poetLeaves = 200 +) func newV2TestHandler(tb testing.TB, golden types.ATXID) *v2TestHandler { lg := zaptest.NewLogger(tb) @@ -44,7 +55,7 @@ func newV2TestHandler(tb testing.TB, golden types.ATXID) *v2TestHandler { atxsdata: atxsdata.New(), edVerifier: signing.NewEdVerifier(), clock: mocks.mclock, - tickSize: 1, + tickSize: tickSize, goldenATXID: golden, nipostValidator: mocks.mValidator, logger: lg, @@ -83,6 +94,32 @@ func (h *handlerMocks) expectVerifyNIPoST(atx *wire.ActivationTxV2) { ).Return(poetLeaves, nil) } +func (h *handlerMocks) expectVerifyNIPoSTs( + atx *wire.ActivationTxV2, + equivocationSet []types.NodeID, + poetLeaves []uint64, +) { + for i, nipost := range atx.NiPosts { + for _, post := range nipost.Posts { + h.mValidator.EXPECT().PostV2( + gomock.Any(), + equivocationSet[post.MarriageIndex], + gomock.Any(), + wire.PostFromWireV1(&post.Post), + nipost.Challenge.Bytes(), + post.NumUnits, + gomock.Any(), + ) + } + h.mValidator.EXPECT().PoetMembership( + gomock.Any(), + gomock.Any(), + nipost.Challenge, + gomock.Any(), + ).Return(poetLeaves[i], nil) + } +} + func (h *handlerMocks) expectStoreAtxV2(atx *wire.ActivationTxV2) { h.mbeacon.EXPECT().OnAtx(gomock.Any()) h.mtortoise.EXPECT().OnAtx(gomock.Any(), gomock.Any(), gomock.Any()) @@ -125,6 +162,23 @@ func (h *handlerMocks) expectAtxV2(atx *wire.ActivationTxV2) { h.expectStoreAtxV2(atx) } +func (h *handlerMocks) expectMergedAtxV2( + atx *wire.ActivationTxV2, + equivocationSet []types.NodeID, + poetLeaves []uint64, +) { + h.mclock.EXPECT().CurrentLayer().Return(postGenesisEpoch.FirstLayer()) + h.expectFetchDeps(atx) + h.mValidator.EXPECT().VRFNonceV2( + atx.SmesherID, + gomock.Any(), + atx.VRFNonce, + atx.TotalNumUnits(), + ) + h.expectVerifyNIPoSTs(atx, equivocationSet, poetLeaves) + h.expectStoreAtxV2(atx) +} + func (h *v2TestHandler) createAndProcessInitial(t *testing.T, sig *signing.EdSigner) *wire.ActivationTxV2 { t.Helper() atx := newInitialATXv2(t, h.handlerMocks.goldenATXID) @@ -374,15 +428,19 @@ func TestHandlerV2_SyntacticallyValidate_MergedAtx(t *testing.T) { sig, err := signing.NewEdSigner() require.NoError(t, err) - t.Run("merged ATXs are not supported yet", func(t *testing.T) { + t.Run("cannot have marriage", func(t *testing.T) { t.Parallel() + atx := newSoloATXv2(t, 0, types.RandomATXID(), types.RandomATXID()) atx.MarriageATX = &golden + atx.Marriages = []wire.MarriageCertificate{{ + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }} atx.Sign(sig) atxHandler.mclock.EXPECT().CurrentLayer() - err := atxHandler.syntacticallyValidate(context.Background(), atx) - require.ErrorContains(t, err, "atx merge is not supported") + err = atxHandler.syntacticallyValidate(context.Background(), atx) + require.ErrorContains(t, err, "merged atx cannot have marriages") }) } @@ -400,6 +458,7 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { blob := codec.MustEncode(atx) atxHandler := newV2TestHandler(t, golden) + atxHandler.tickSize = tickSize atxHandler.expectInitialAtxV2(atx) proof, err := atxHandler.processATX(context.Background(), peer, atx, blob, time.Now()) @@ -411,9 +470,10 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { require.NotNil(t, atx) require.Equal(t, atx.ID(), atxFromDb.ID()) require.Equal(t, atx.Coinbase, atxFromDb.Coinbase) - require.EqualValues(t, poetLeaves, atxFromDb.TickCount) - require.EqualValues(t, poetLeaves, atxFromDb.TickHeight()) + require.EqualValues(t, poetLeaves/tickSize, atxFromDb.TickCount) + require.EqualValues(t, 0+atxFromDb.TickCount, atxFromDb.TickHeight()) // positioning is golden require.Equal(t, atx.NiPosts[0].Posts[0].NumUnits, atxFromDb.NumUnits) + require.EqualValues(t, atx.NiPosts[0].Posts[0].NumUnits*poetLeaves/tickSize, atxFromDb.Weight) // processing ATX for the second time should skip checks proof, err = atxHandler.processATX(context.Background(), peer, atx, blob, time.Now()) @@ -426,9 +486,10 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { prev := newInitialATXv1(t, golden) prev.Sign(sig) - atxs.Add(atxHandler.cdb, toAtx(t, prev)) + prevAtx := toAtx(t, prev) + atxs.Add(atxHandler.cdb, prevAtx) - atx := newSoloATXv2(t, prev.PublishEpoch+1, prev.ID(), golden) + atx := newSoloATXv2(t, prev.PublishEpoch+1, prev.ID(), prevAtx.ID()) atx.Sign(sig) blob := codec.MustEncode(atx) atxHandler.expectAtxV2(atx) @@ -441,34 +502,39 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { require.NoError(t, err) require.Nil(t, atxFromDb.CommitmentATX) - // copies coinbase and VRF nonce from the previous ATX - require.Equal(t, prev.Coinbase, atxFromDb.Coinbase) - require.EqualValues(t, *prev.VRFNonce, atxFromDb.VRFNonce) + + require.Equal(t, atx.Coinbase, atxFromDb.Coinbase) + require.EqualValues(t, atx.VRFNonce, atxFromDb.VRFNonce) + require.EqualValues(t, poetLeaves/tickSize, atxFromDb.TickCount) + require.EqualValues(t, prevAtx.TickHeight(), atxFromDb.BaseTickHeight) + require.EqualValues(t, prevAtx.TickHeight()+atxFromDb.TickCount, atxFromDb.TickHeight()) + require.Equal(t, atx.NiPosts[0].Posts[0].NumUnits, atxFromDb.NumUnits) + require.EqualValues(t, atx.NiPosts[0].Posts[0].NumUnits*poetLeaves/tickSize, atxFromDb.Weight) }) t.Run("second ATX, previous V2", func(t *testing.T) { t.Parallel() atxHandler := newV2TestHandler(t, golden) - prev := newInitialATXv2(t, golden) - prev.Sign(sig) - blob := codec.MustEncode(prev) + prev := atxHandler.createAndProcessInitial(t, sig) - atxHandler.expectInitialAtxV2(prev) - proof, err := atxHandler.processATX(context.Background(), peer, prev, blob, time.Now()) - require.NoError(t, err) - require.Nil(t, proof) - - atx := newSoloATXv2(t, prev.PublishEpoch+1, prev.ID(), golden) + atx := newSoloATXv2(t, prev.PublishEpoch+1, prev.ID(), prev.ID()) atx.Sign(sig) - blob = codec.MustEncode(atx) - atxHandler.expectAtxV2(atx) + blob := codec.MustEncode(atx) - proof, err = atxHandler.processATX(context.Background(), peer, atx, blob, time.Now()) + atxHandler.expectAtxV2(atx) + proof, err := atxHandler.processATX(context.Background(), peer, atx, blob, time.Now()) require.NoError(t, err) require.Nil(t, proof) - _, err = atxs.Get(atxHandler.cdb, atx.ID()) + prevAtx, err := atxs.Get(atxHandler.cdb, prev.ID()) + require.NoError(t, err) + atxFromDb, err := atxs.Get(atxHandler.cdb, atx.ID()) require.NoError(t, err) + require.EqualValues(t, poetLeaves/tickSize, atxFromDb.TickCount) + require.EqualValues(t, prevAtx.TickHeight(), atxFromDb.BaseTickHeight) + require.EqualValues(t, prevAtx.TickHeight()+atxFromDb.TickCount, atxFromDb.TickHeight()) + require.Equal(t, atx.NiPosts[0].Posts[0].NumUnits, atxFromDb.NumUnits) + require.EqualValues(t, atx.NiPosts[0].Posts[0].NumUnits*poetLeaves/tickSize, atxFromDb.Weight) }) t.Run("second ATX, previous checkpointed", func(t *testing.T) { t.Parallel() @@ -587,6 +653,265 @@ func TestHandlerV2_ProcessSoloATX(t *testing.T) { }) } +func marryIDs( + t *testing.T, + atxHandler *v2TestHandler, + sig *signing.EdSigner, + golden types.ATXID, + num int, +) (marriage *wire.ActivationTxV2, other []*wire.ActivationTxV2) { + mATX := newInitialATXv2(t, golden) + mATX.Marriages = []wire.MarriageCertificate{{ + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }} + + for range num { + signer, err := signing.NewEdSigner() + require.NoError(t, err) + atx := atxHandler.createAndProcessInitial(t, signer) + other = append(other, atx) + mATX.Marriages = append(mATX.Marriages, wire.MarriageCertificate{ + ReferenceAtx: atx.ID(), + Signature: signer.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }) + } + + mATX.Sign(sig) + atxHandler.expectInitialAtxV2(mATX) + p, err := atxHandler.processATX(context.Background(), "", mATX, codec.MustEncode(mATX), time.Now()) + require.NoError(t, err) + require.Nil(t, p) + + return mATX, other +} + +func TestHandlerV2_ProcessMergedATX(t *testing.T) { + t.Parallel() + golden := types.RandomATXID() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + t.Run("happy case", func(t *testing.T) { + atxHandler := newV2TestHandler(t, golden) + + // Marry IDs + mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 2) + previousATXs := []types.ATXID{mATX.ID()} + equivocationSet := []types.NodeID{sig.NodeID()} + for _, atx := range otherATXs { + previousATXs = append(previousATXs, atx.ID()) + equivocationSet = append(equivocationSet, atx.SmesherID) + } + + // Process a merged ATX + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + totalNumUnits := merged.NiPosts[0].Posts[0].NumUnits + for i, atx := range otherATXs { + post := wire.SubPostV2{ + MarriageIndex: uint32(i + 1), + NumUnits: atx.TotalNumUnits(), + PrevATXIndex: uint32(i + 1), + } + totalNumUnits += post.NumUnits + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + } + mATXID := mATX.ID() + merged.MarriageATX = &mATXID + + merged.PreviousATXs = previousATXs + merged.Sign(sig) + + atxHandler.expectMergedAtxV2(merged, equivocationSet, []uint64{poetLeaves}) + p, err := atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) + require.NoError(t, err) + require.Nil(t, p) + + atx, err := atxs.Get(atxHandler.cdb, merged.ID()) + require.NoError(t, err) + require.Equal(t, totalNumUnits, atx.NumUnits) + require.Equal(t, sig.NodeID(), atx.SmesherID) + require.EqualValues(t, totalNumUnits*poetLeaves/tickSize, atx.Weight) + }) + t.Run("merged IDs on 2 poets", func(t *testing.T) { + const tickSize = 33 + atxHandler := newV2TestHandler(t, golden) + atxHandler.tickSize = tickSize + + // Marry IDs + mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 4) + previousATXs := []types.ATXID{mATX.ID()} + equivocationSet := []types.NodeID{sig.NodeID()} + for _, atx := range otherATXs { + previousATXs = append(previousATXs, atx.ID()) + equivocationSet = append(equivocationSet, atx.SmesherID) + } + + // Process a merged ATX + merged := &wire.ActivationTxV2{ + PublishEpoch: mATX.PublishEpoch + 2, + PreviousATXs: previousATXs, + PositioningATX: mATX.ID(), + Coinbase: types.GenerateAddress([]byte("aaaa")), + VRFNonce: uint64(999), + NiPosts: make([]wire.NiPostsV2, 2), + } + atxsPerPoet := [][]*wire.ActivationTxV2{ + append([]*wire.ActivationTxV2{mATX}, otherATXs[:2]...), + otherATXs[2:], + } + var totalNumUnits uint32 + unitsPerPoet := make([]uint32, 2) + var idx uint32 + for nipostId := range 2 { + for _, atx := range atxsPerPoet[nipostId] { + post := wire.SubPostV2{ + MarriageIndex: idx, + NumUnits: atx.TotalNumUnits(), + PrevATXIndex: idx, + } + unitsPerPoet[nipostId] += post.NumUnits + totalNumUnits += post.NumUnits + merged.NiPosts[nipostId].Posts = append(merged.NiPosts[nipostId].Posts, post) + idx++ + } + } + + mATXID := mATX.ID() + merged.MarriageATX = &mATXID + + merged.PreviousATXs = previousATXs + merged.Sign(sig) + + poetLeaves := []uint64{100, 500} + minPoetLeaves := slices.Min(poetLeaves) + + atxHandler.expectMergedAtxV2(merged, equivocationSet, poetLeaves) + p, err := atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) + require.NoError(t, err) + require.Nil(t, p) + + marriageATX, err := atxs.Get(atxHandler.cdb, mATX.ID()) + require.NoError(t, err) + atx, err := atxs.Get(atxHandler.cdb, merged.ID()) + require.NoError(t, err) + require.Equal(t, totalNumUnits, atx.NumUnits) + require.Equal(t, sig.NodeID(), atx.SmesherID) + require.Equal(t, minPoetLeaves/tickSize, atx.TickCount) + require.Equal(t, marriageATX.TickHeight()+atx.TickCount, atx.TickHeight()) + // the total weight is summed weight on each poet + var weight uint64 + for i := range unitsPerPoet { + ticks := poetLeaves[i] / tickSize + weight += uint64(unitsPerPoet[i]) * ticks + } + require.EqualValues(t, weight, atx.Weight) + }) + t.Run("signer must be included merged ATX", func(t *testing.T) { + atxHandler := newV2TestHandler(t, golden) + + // Marry IDs + mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 2) + previousATXs := []types.ATXID{} + equivocationSet := []types.NodeID{sig.NodeID()} + for _, atx := range otherATXs { + previousATXs = append(previousATXs, atx.ID()) + equivocationSet = append(equivocationSet, atx.SmesherID) + } + + // Process a merged ATX + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + merged.NiPosts[0].Posts = []wire.SubPostV2{} // remove signer's PoST + for i, atx := range otherATXs { + post := wire.SubPostV2{ + MarriageIndex: uint32(i + 1), + NumUnits: atx.TotalNumUnits(), + PrevATXIndex: uint32(i), + } + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + } + mATXID := mATX.ID() + merged.MarriageATX = &mATXID + + merged.PreviousATXs = previousATXs + merged.Sign(sig) + + atxHandler.mclock.EXPECT().CurrentLayer().Return(postGenesisEpoch.FirstLayer()) + atxHandler.expectFetchDeps(merged) + atxHandler.expectVerifyNIPoSTs(merged, equivocationSet, []uint64{200}) + + p, err := atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) + require.ErrorContains(t, err, "ATX signer not present in merged ATX") + require.Nil(t, p) + }) + t.Run("ID must be present max 1 times", func(t *testing.T) { + atxHandler := newV2TestHandler(t, golden) + + // Marry IDs + mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 1) + previousATXs := []types.ATXID{mATX.ID()} + equivocationSet := []types.NodeID{sig.NodeID()} + for _, atx := range otherATXs { + previousATXs = append(previousATXs, atx.ID()) + equivocationSet = append(equivocationSet, atx.SmesherID) + } + + // Process a merged ATX + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + // Insert the same ID twice + for range 2 { + post := wire.SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 1, + NumUnits: otherATXs[0].TotalNumUnits(), + } + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + } + mATXID := mATX.ID() + merged.MarriageATX = &mATXID + + merged.PreviousATXs = previousATXs + merged.Sign(sig) + + atxHandler.mclock.EXPECT().CurrentLayer().Return(postGenesisEpoch.FirstLayer()) + p, err := atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) + require.ErrorContains(t, err, "ID present twice (duplicated marriage index)") + require.Nil(t, p) + }) + t.Run("ID must use previous ATX containing itself", func(t *testing.T) { + atxHandler := newV2TestHandler(t, golden) + + // Marry IDs + mATX, otherATXs := marryIDs(t, atxHandler, sig, golden, 1) + previousATXs := []types.ATXID{mATX.ID()} + equivocationSet := []types.NodeID{sig.NodeID()} + for _, atx := range otherATXs { + previousATXs = append(previousATXs, atx.ID()) + equivocationSet = append(equivocationSet, atx.SmesherID) + } + + // Process a merged ATX + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + post := wire.SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 0, // use wrong previous ATX + NumUnits: otherATXs[0].TotalNumUnits(), + } + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + + mATXID := mATX.ID() + merged.MarriageATX = &mATXID + + merged.PreviousATXs = previousATXs + merged.Sign(sig) + + atxHandler.mclock.EXPECT().CurrentLayer().Return(postGenesisEpoch.FirstLayer()) + atxHandler.expectFetchDeps(merged) + p, err := atxHandler.processATX(context.Background(), "", merged, codec.MustEncode(merged), time.Now()) + require.ErrorContains(t, err, fmt.Sprintf("previous ATX V2 doesn't contain %s", otherATXs[0].SmesherID)) + require.Nil(t, p) + }) +} + func TestCollectDeps_AtxV2(t *testing.T) { goldenATX := types.RandomATXID() prev0 := types.RandomATXID() @@ -788,6 +1113,143 @@ func Test_ValidatePositioningAtx(t *testing.T) { }) } +func Test_ValidateMarriages(t *testing.T) { + t.Parallel() + golden := types.RandomATXID() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + t.Run("marriage ATX not set (solo ATX)", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) + atx := newInitialATXv2(t, golden) + atx.Sign(sig) + + set, err := atxHandler.equivocationSet(atx) + require.NoError(t, err) + require.Equal(t, []types.NodeID{atx.SmesherID}, set) + }) + t.Run("smesher is not married", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) + atx := newSoloATXv2(t, 0, types.RandomATXID(), golden) + atx.MarriageATX = &golden + atx.Sign(sig) + + _, err := atxHandler.equivocationSet(atx) + require.ErrorContains(t, err, "smesher is not married") + }) + t.Run("marriage ATX must be published 2 epochs prior merging IDs", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) + otherSigner, err := signing.NewEdSigner() + require.NoError(t, err) + otherAtx := atxHandler.createAndProcessInitial(t, otherSigner) + + marriage := newInitialATXv2(t, golden) + marriage.PublishEpoch = 1 + marriage.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: otherAtx.ID(), + Signature: otherSigner.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + marriage.Sign(sig) + + atxHandler.expectInitialAtxV2(marriage) + p, err := atxHandler.processATX(context.Background(), "", marriage, codec.MustEncode(marriage), time.Now()) + require.NoError(t, err) + require.Nil(t, p) + + atx := newSoloATXv2(t, marriage.PublishEpoch+1, types.RandomATXID(), golden) + marriageATXID := marriage.ID() + atx.MarriageATX = &marriageATXID + atx.Sign(sig) + + _, err = atxHandler.equivocationSet(atx) + require.ErrorContains(t, err, "marriage atx must be published at least 2 epochs before") + }) + t.Run("can't use somebody else's marriage ATX", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) + + otherSigner, err := signing.NewEdSigner() + require.NoError(t, err) + otherAtx := atxHandler.createAndProcessInitial(t, otherSigner) + + marriage := newInitialATXv2(t, golden) + marriage.PublishEpoch = 1 + marriage.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: otherAtx.ID(), + Signature: otherSigner.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + marriage.Sign(sig) + + atxHandler.expectInitialAtxV2(marriage) + p, err := atxHandler.processATX(context.Background(), "", marriage, codec.MustEncode(marriage), time.Now()) + require.NoError(t, err) + require.Nil(t, p) + + atx := newSoloATXv2(t, marriage.PublishEpoch+1, types.RandomATXID(), golden) + marriageATXID := types.RandomATXID() + atx.MarriageATX = &marriageATXID + atx.Sign(sig) + + _, err = atxHandler.equivocationSet(atx) + require.ErrorContains(t, err, "smesher's marriage ATX ID mismatch") + }) + t.Run("smesher is married", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) + marriage := newInitialATXv2(t, golden) + marriage.Marriages = []wire.MarriageCertificate{{ + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }} + + var otherIds []marriedId + for range 5 { + signer, err := signing.NewEdSigner() + require.NoError(t, err) + atx := atxHandler.createAndProcessInitial(t, signer) + otherIds = append(otherIds, marriedId{signer, atx}) + } + + expectedSet := []types.NodeID{sig.NodeID()} + + for _, id := range otherIds { + cert := wire.MarriageCertificate{ + ReferenceAtx: id.refAtx.ID(), + Signature: id.signer.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + } + marriage.Marriages = append(marriage.Marriages, cert) + expectedSet = append(expectedSet, id.signer.NodeID()) + } + marriage.Sign(sig) + + p, err := atxHandler.processInitial(marriage) + require.NoError(t, err) + require.Nil(t, p) + + atx := newSoloATXv2(t, 0, marriage.ID(), golden) + atx.PublishEpoch = marriage.PublishEpoch + 2 + marriageATXID := marriage.ID() + atx.MarriageATX = &marriageATXID + atx.Sign(sig) + + set, err := atxHandler.equivocationSet(atx) + require.NoError(t, err) + require.Equal(t, expectedSet, set) + }) +} + func Test_LoadPreviousATX(t *testing.T) { t.Parallel() t.Run("not found", func(t *testing.T) { @@ -909,7 +1371,7 @@ func Test_ValidatePreviousATX(t *testing.T) { prev := newInitialATXv2(t, golden) prev.SmesherID = types.RandomNodeID() _, err := atxHandler.validatePreviousAtx(types.RandomNodeID(), &wire.SubPostV2{}, []opaqueAtx{prev}) - require.ErrorContains(t, err, "previous solo ATX V2 has different owner") + require.ErrorContains(t, err, "previous ATX V2 doesn't contain") }) t.Run("previous golden, valid", func(t *testing.T) { t.Parallel() @@ -1019,6 +1481,7 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { atx := newInitialATXv2(t, golden) atx.Sign(sig) + atxHandler.mValidator.EXPECT().PoetMembership(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) atxHandler.mValidator.EXPECT(). PostV2( gomock.Any(), @@ -1062,15 +1525,6 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { atx := newInitialATXv2(t, golden) atx.Sign(sig) - atxHandler.mValidator.EXPECT().PostV2( - gomock.Any(), - sig.NodeID(), - golden, - wire.PostFromWireV1(&atx.NiPosts[0].Posts[0].Post), - atx.NiPosts[0].Challenge.Bytes(), - atx.TotalNumUnits(), - gomock.Any(), - ) atxHandler.mValidator.EXPECT(). PoetMembership(gomock.Any(), gomock.Any(), atx.NiPosts[0].Challenge, gomock.Any()). Return(0, errors.New("poet failure")) @@ -1262,7 +1716,7 @@ func Test_MarryingMalicious(t *testing.T) { atxHandler.mtortoise.EXPECT().OnMalfeasance(sig.NodeID()) atxHandler.mtortoise.EXPECT().OnMalfeasance(otherSig.NodeID()) - _, err = atxHandler.processATX(context.Background(), "", atx, codec.MustEncode(atx), time.Now()) + _, err := atxHandler.processATX(context.Background(), "", atx, codec.MustEncode(atx), time.Now()) require.NoError(t, err) equiv, err := identities.EquivocationSet(atxHandler.cdb, sig.NodeID()) @@ -1278,6 +1732,57 @@ func Test_MarryingMalicious(t *testing.T) { } } +func Test_CalculatingUnits(t *testing.T) { + t.Parallel() + t.Run("units on 1 nipost must not overflow", func(t *testing.T) { + t.Parallel() + ns := nipostSize{} + require.NoError(t, ns.addUnits(1)) + require.EqualValues(t, 1, ns.units) + require.Error(t, ns.addUnits(math.MaxUint32)) + }) + t.Run("total units on all niposts must not overflow", func(t *testing.T) { + t.Parallel() + ns := make(nipostSizes, 0) + ns = append(ns, &nipostSize{units: 11}, &nipostSize{units: math.MaxUint32 - 10}) + _, _, err := ns.sumUp() + require.Error(t, err) + }) + t.Run("units = sum of units on every nipost", func(t *testing.T) { + t.Parallel() + ns := make(nipostSizes, 0) + ns = append(ns, &nipostSize{units: 1}, &nipostSize{units: 10}) + u, _, err := ns.sumUp() + require.NoError(t, err) + require.EqualValues(t, 1+10, u) + }) +} + +func Test_CalculatingWeight(t *testing.T) { + t.Parallel() + t.Run("total weight must not overflow uint64", func(t *testing.T) { + t.Parallel() + ns := make(nipostSizes, 0) + ns = append(ns, &nipostSize{units: 1, ticks: 100}, &nipostSize{units: 10, ticks: math.MaxUint64}) + _, _, err := ns.sumUp() + require.Error(t, err) + }) + t.Run("weight = sum of weight on every nipost", func(t *testing.T) { + t.Parallel() + ns := make(nipostSizes, 0) + ns = append(ns, &nipostSize{units: 1, ticks: 100}, &nipostSize{units: 10, ticks: 1000}) + _, w, err := ns.sumUp() + require.NoError(t, err) + require.EqualValues(t, 1*100+10*1000, w) + }) +} + +func Test_CalculatingTicks(t *testing.T) { + ns := make(nipostSizes, 0) + ns = append(ns, &nipostSize{units: 1, ticks: 100}, &nipostSize{units: 10, ticks: 1000}) + require.EqualValues(t, 100, ns.minTicks()) +} + func newInitialATXv2(t testing.TB, golden types.ATXID) *wire.ActivationTxV2 { t.Helper() atx := &wire.ActivationTxV2{ diff --git a/activation/post.go b/activation/post.go index 08f083592a4..7e1664df898 100644 --- a/activation/post.go +++ b/activation/post.go @@ -110,6 +110,14 @@ func DefaultPostVerifyingOpts() PostProofVerifyingOpts { } } +func DefaultTestPostVerifyingOpts() PostProofVerifyingOpts { + return PostProofVerifyingOpts{ + MinWorkers: 1, + Workers: 1, + Flags: PostPowFlags(config.DefaultVerifyingPowFlags()), + } +} + // PostSetupStatus represents a status snapshot of the Post setup. type PostSetupStatus struct { State PostSetupState diff --git a/activation/post_test.go b/activation/post_test.go index c0273a63697..de51d599dfc 100644 --- a/activation/post_test.go +++ b/activation/post_test.go @@ -273,15 +273,15 @@ func TestPostSetupManager_findCommitmentAtx_UsesLatestAtx(t *testing.T) { signer, err := signing.NewEdSigner() require.NoError(t, err) - challenge := types.NIPostChallenge{ + atx := &types.ActivationTx{ PublishEpoch: 1, + NumUnits: 2, + Weight: 2, + SmesherID: signer.NodeID(), + TickCount: 1, } - atx := types.NewActivationTx(challenge, types.Address{}, 2) - atx.SmesherID = signer.NodeID() atx.SetID(types.RandomATXID()) atx.SetReceived(time.Now()) - atx.TickCount = 1 - require.NoError(t, err) require.NoError(t, atxs.Add(mgr.db, atx)) mgr.atxsdata.AddFromAtx(atx, false) @@ -323,12 +323,16 @@ func TestPostSetupManager_getCommitmentAtx_getsCommitmentAtxFromInitialAtx(t *te // add an atx by the same node commitmentAtx := types.RandomATXID() - atx := types.NewActivationTx(types.NIPostChallenge{}, types.Address{}, 1) - atx.CommitmentATX = &commitmentAtx - atx.SmesherID = signer.NodeID() + atx := &types.ActivationTx{ + NumUnits: 1, + Weight: 1, + SmesherID: signer.NodeID(), + TickCount: 1, + CommitmentATX: &commitmentAtx, + } + atx.SetID(types.RandomATXID()) atx.SetReceived(time.Now()) - atx.TickCount = 1 require.NoError(t, atxs.Add(mgr.cdb, atx)) atxid, err := mgr.commitmentAtx(context.Background(), mgr.opts.DataDir, signer.NodeID()) diff --git a/activation/wire/challenge_v2.go b/activation/wire/challenge_v2.go index 257f7093e12..198edbd556f 100644 --- a/activation/wire/challenge_v2.go +++ b/activation/wire/challenge_v2.go @@ -31,6 +31,7 @@ func (c *NIPostChallengeV2) MarshalLogObject(encoder zapcore.ObjectEncoder) erro if c == nil { return nil } + encoder.AddString("Hash", c.Hash().String()) encoder.AddUint32("PublishEpoch", c.PublishEpoch.Uint32()) encoder.AddString("PrevATXID", c.PrevATXID.String()) encoder.AddString("PositioningATX", c.PositioningATXID.String()) diff --git a/activation/wire/wire_v2.go b/activation/wire/wire_v2.go index db3bf9d022d..c844ae78093 100644 --- a/activation/wire/wire_v2.go +++ b/activation/wire/wire_v2.go @@ -196,8 +196,7 @@ func (mc *MarriageCertificate) Root() []byte { // MerkleProofV2 proves membership of multiple challenges in a PoET membership merkle tree. type MerkleProofV2 struct { // Nodes on path from leaf to root (not including leaf) - Nodes []types.Hash32 `scale:"max=32"` - LeafIndices []uint64 `scale:"max=256"` // support merging up to 256 IDs + Nodes []types.Hash32 `scale:"max=32"` } type SubPostV2 struct { @@ -206,8 +205,12 @@ type SubPostV2 struct { // Must be 0 for non-merged ATXs. MarriageIndex uint32 PrevATXIndex uint32 // Index of the previous ATX in the `InnerActivationTxV2.PreviousATXs` slice - Post PostV1 - NumUnits uint32 + // Index of the leaf for this ID's challenge in the poet membership tree. + // IDs might shared the same index if their nipost challenges are equal. + // This happens when the IDs are continuously merged (they share the previous ATX). + MembershipLeafIndex uint64 + Post PostV1 + NumUnits uint32 } func (sp *SubPostV2) Root(prevATXs []types.ATXID) []byte { @@ -225,6 +228,11 @@ func (sp *SubPostV2) Root(prevATXs []types.ATXID) []byte { return nil // invalid index, root cannot be generated } tree.AddLeaf(prevATXs[sp.PrevATXIndex].Bytes()) + + var leafIndex [8]byte + binary.LittleEndian.PutUint64(leafIndex[:], sp.MembershipLeafIndex) + tree.AddLeaf(leafIndex[:]) + tree.AddLeaf(sp.Post.Root()) numUnits := make([]byte, 4) @@ -235,7 +243,6 @@ func (sp *SubPostV2) Root(prevATXs []types.ATXID) []byte { type NiPostsV2 struct { // Single membership proof for all IDs in `Posts`. - // The index of ID in `Posts` is the index of the challenge in the proof (`LeafIndices`). Membership MerkleProofV2 // The root of the PoET proof, that serves as the challenge for PoSTs. Challenge types.Hash32 @@ -336,6 +343,7 @@ func (post *SubPostV2) MarshalLogObject(encoder zapcore.ObjectEncoder) error { } encoder.AddUint32("MarriageIndex", post.MarriageIndex) encoder.AddUint32("PrevATXIndex", post.PrevATXIndex) + encoder.AddUint64("MembershipLeafIndex", post.MembershipLeafIndex) encoder.AddObject("Post", &post.Post) encoder.AddUint32("NumUnits", post.NumUnits) return nil diff --git a/activation/wire/wire_v2_scale.go b/activation/wire/wire_v2_scale.go index 286a2004281..4c5404a34c7 100644 --- a/activation/wire/wire_v2_scale.go +++ b/activation/wire/wire_v2_scale.go @@ -257,13 +257,6 @@ func (t *MerkleProofV2) EncodeScale(enc *scale.Encoder) (total int, err error) { } total += n } - { - n, err := scale.EncodeUint64SliceWithLimit(enc, t.LeafIndices, 256) - if err != nil { - return total, err - } - total += n - } return total, nil } @@ -276,14 +269,6 @@ func (t *MerkleProofV2) DecodeScale(dec *scale.Decoder) (total int, err error) { total += n t.Nodes = field } - { - field, n, err := scale.DecodeUint64SliceWithLimit(dec, 256) - if err != nil { - return total, err - } - total += n - t.LeafIndices = field - } return total, nil } @@ -302,6 +287,13 @@ func (t *SubPostV2) EncodeScale(enc *scale.Encoder) (total int, err error) { } total += n } + { + n, err := scale.EncodeCompact64(enc, uint64(t.MembershipLeafIndex)) + if err != nil { + return total, err + } + total += n + } { n, err := t.Post.EncodeScale(enc) if err != nil { @@ -336,6 +328,14 @@ func (t *SubPostV2) DecodeScale(dec *scale.Decoder) (total int, err error) { total += n t.PrevATXIndex = uint32(field) } + { + field, n, err := scale.DecodeCompact64(dec) + if err != nil { + return total, err + } + total += n + t.MembershipLeafIndex = uint64(field) + } { n, err := t.Post.DecodeScale(dec) if err != nil { diff --git a/activation/wire/wire_v2_test.go b/activation/wire/wire_v2_test.go index f56ae7423e1..596be060913 100644 --- a/activation/wire/wire_v2_test.go +++ b/activation/wire/wire_v2_test.go @@ -37,16 +37,14 @@ func Benchmark_ATXv2ID_WorstScenario(b *testing.B) { NiPosts: []NiPostsV2{ { Membership: MerkleProofV2{ - Nodes: make([]types.Hash32, 32), - LeafIndices: make([]uint64, 256), + Nodes: make([]types.Hash32, 32), }, Challenge: types.RandomHash(), Posts: make([]SubPostV2, 256), }, { Membership: MerkleProofV2{ - Nodes: make([]types.Hash32, 32), - LeafIndices: make([]uint64, 256), + Nodes: make([]types.Hash32, 32), }, Challenge: types.RandomHash(), Posts: make([]SubPostV2, 256), // actually the sum of all posts in `NiPosts` should be 256 @@ -96,8 +94,7 @@ func Test_GenerateDoublePublishProof(t *testing.T) { NiPosts: []NiPostsV2{ { Membership: MerkleProofV2{ - Nodes: make([]types.Hash32, 32), - LeafIndices: make([]uint64, 256), + Nodes: make([]types.Hash32, 32), }, Challenge: types.RandomHash(), Posts: []SubPostV2{ diff --git a/api/grpcserver/grpcserver_test.go b/api/grpcserver/grpcserver_test.go index dfdd6c49775..3d42cf23ad5 100644 --- a/api/grpcserver/grpcserver_test.go +++ b/api/grpcserver/grpcserver_test.go @@ -85,8 +85,6 @@ var ( addr1 types.Address addr2 types.Address rewardSmesherID = types.RandomNodeID() - prevAtxID = types.ATXID(types.HexToHash32("44444")) - challenge = newChallenge(1, prevAtxID, prevAtxID, postGenesisEpoch) globalAtx *types.ActivationTx globalAtx2 *types.ActivationTx globalTx *types.Transaction @@ -162,12 +160,28 @@ func TestMain(m *testing.M) { addr1 = wallet.Address(signer1.PublicKey().Bytes()) addr2 = wallet.Address(signer2.PublicKey().Bytes()) - globalAtx = types.NewActivationTx(challenge, addr1, numUnits) + globalAtx = &types.ActivationTx{ + PublishEpoch: postGenesisEpoch, + Sequence: 1, + PrevATXID: types.ATXID{4, 4, 4, 4}, + Coinbase: addr1, + NumUnits: numUnits, + Weight: numUnits, + TickCount: 1, + SmesherID: signer.NodeID(), + } globalAtx.SetReceived(time.Now()) - globalAtx.SmesherID = signer.NodeID() - globalAtx.TickCount = 1 - globalAtx2 = types.NewActivationTx(challenge, addr2, numUnits) + globalAtx2 = &types.ActivationTx{ + PublishEpoch: postGenesisEpoch, + Sequence: 1, + PrevATXID: types.ATXID{5, 5, 5, 5}, + Coinbase: addr2, + NumUnits: numUnits, + Weight: numUnits, + TickCount: 1, + SmesherID: signer.NodeID(), + } globalAtx2.SetReceived(time.Now()) globalAtx2.SmesherID = signer.NodeID() globalAtx2.TickCount = 1 @@ -388,15 +402,6 @@ func NewTx(nonce uint64, recipient types.Address, signer *signing.EdSigner) *typ return &tx } -func newChallenge(sequence uint64, prevAtxID, posAtxID types.ATXID, epoch types.EpochID) types.NIPostChallenge { - return types.NIPostChallenge{ - Sequence: sequence, - PrevATXID: prevAtxID, - PublishEpoch: epoch, - PositioningATX: posAtxID, - } -} - func launchServer(tb testing.TB, services ...ServiceAPI) (Config, func()) { cfg := DefaultTestConfig() grpcService, err := NewWithServices(cfg.PublicListener, zaptest.NewLogger(tb).Named("grpc"), cfg, services) diff --git a/api/grpcserver/v2alpha1/activation.go b/api/grpcserver/v2alpha1/activation.go index 2025596ae29..100d86450b2 100644 --- a/api/grpcserver/v2alpha1/activation.go +++ b/api/grpcserver/v2alpha1/activation.go @@ -149,7 +149,7 @@ func toAtx(atx *types.ActivationTx) *spacemeshv2alpha1.Activation { PublishEpoch: atx.PublishEpoch.Uint32(), PreviousAtx: atx.PrevATXID[:], Coinbase: atx.Coinbase.String(), - Weight: atx.GetWeight(), + Weight: atx.Weight, Height: atx.TickHeight(), } } diff --git a/atxsdata/data.go b/atxsdata/data.go index f8eae4794c3..94c4cfc89d2 100644 --- a/atxsdata/data.go +++ b/atxsdata/data.go @@ -76,7 +76,7 @@ func (d *Data) AddFromAtx(atx *types.ActivationTx, malicious bool) *ATX { atx.SmesherID, atx.Coinbase, atx.ID(), - atx.GetWeight(), + atx.Weight, atx.BaseTickHeight, atx.TickHeight(), atx.VRFNonce, diff --git a/beacon/beacon.go b/beacon/beacon.go index 8160597c0f1..3a0effae276 100644 --- a/beacon/beacon.go +++ b/beacon/beacon.go @@ -604,7 +604,7 @@ func (pd *ProtocolDriver) initEpochStateIfNotPresent(logger *zap.Logger, target ) err := atxs.IterateAtxsWithMalfeasance(pd.cdb, target-1, func(atx *types.ActivationTx, malicious bool) bool { if !malicious { - epochWeight += atx.GetWeight() + epochWeight += atx.Weight } else { logger.Debug("malicious miner get 0 weight", zap.Stringer("smesher", atx.SmesherID)) } diff --git a/beacon/beacon_test.go b/beacon/beacon_test.go index c72776f17ad..bdc1c54fa79 100644 --- a/beacon/beacon_test.go +++ b/beacon/beacon_test.go @@ -114,22 +114,25 @@ func createATX( numUnits uint32, received time.Time, ) types.ATXID { - nonce := types.VRFPostIndex(1) - atx := types.NewActivationTx( - types.NIPostChallenge{PublishEpoch: lid.GetEpoch()}, - types.GenerateAddress(types.RandomBytes(types.AddressLength)), - numUnits, - ) - atx.VRFNonce = nonce + tb.Helper() + atx := types.ActivationTx{ + PublishEpoch: lid.GetEpoch(), + Coinbase: types.GenerateAddress(types.RandomBytes(types.AddressLength)), + NumUnits: numUnits, + VRFNonce: 1, + TickCount: 1, + Weight: uint64(numUnits), + SmesherID: sig.NodeID(), + } + atx.SetReceived(received) - atx.SmesherID = sig.NodeID() atx.SetID(types.RandomATXID()) - atx.TickCount = 1 - require.NoError(tb, atxs.Add(db, atx)) + require.NoError(tb, atxs.Add(db, &atx)) return atx.ID() } func createRandomATXs(tb testing.TB, db *datastore.CachedDB, lid types.LayerID, num int) { + tb.Helper() for i := 0; i < num; i++ { sig, err := signing.NewEdSigner() require.NoError(tb, err) @@ -187,12 +190,8 @@ func TestBeacon_MultipleNodes(t *testing.T) { require.NoError(t, err) require.Equal(t, bootstrap, got) } - for i, node := range testNodes { - if i == 0 { - // make the first node non-smeshing node - continue - } - + // make the first node non-smeshing node + for _, node := range testNodes[1:] { for _, db := range dbs { for _, s := range node.signers { createATX(t, db, atxPublishLid, s, 1, time.Now().Add(-1*time.Second)) diff --git a/beacon/handlers.go b/beacon/handlers.go index 7234572101e..89d838e985f 100644 --- a/beacon/handlers.go +++ b/beacon/handlers.go @@ -331,7 +331,7 @@ func (pd *ProtocolDriver) storeFirstVotes(m FirstVotingMessage, nodeID types.Nod } voteWeight := new(big.Int) if !malicious { - voteWeight.SetUint64(atx.GetWeight()) + voteWeight.SetUint64(atx.Weight) } else { pd.logger.Debug("malicious miner get 0 weight", zap.Stringer("smesher", nodeID)) } @@ -457,7 +457,7 @@ func (pd *ProtocolDriver) storeFollowingVotes(m FollowingVotingMessage, nodeID t } voteWeight := new(big.Int) if !malicious { - voteWeight.SetUint64(atx.GetWeight()) + voteWeight.SetUint64(atx.Weight) } else { pd.logger.Debug("malicious miner get 0 weight", zap.Stringer("smesher", nodeID)) } diff --git a/blocks/generator_test.go b/blocks/generator_test.go index 0d4d2064645..4145f3ff29b 100644 --- a/blocks/generator_test.go +++ b/blocks/generator_test.go @@ -154,14 +154,15 @@ func createModifiedATXs( signer, err := signing.NewEdSigner() require.NoError(tb, err) signers = append(signers, signer) - address := types.GenerateAddress(signer.PublicKey().Bytes()) - atx := types.NewActivationTx( - types.NIPostChallenge{PublishEpoch: lid.GetEpoch()}, - address, - numUnit, - ) + atx := &types.ActivationTx{ + PublishEpoch: lid.GetEpoch(), + Coinbase: types.GenerateAddress(signer.PublicKey().Bytes()), + NumUnits: numUnit, + SmesherID: signer.NodeID(), + TickCount: 1, + Weight: uint64(numUnit), + } atx.SetReceived(time.Now()) - atx.SmesherID = signer.NodeID() atx.SetID(types.RandomATXID()) onAtx(atx) data.AddFromAtx(atx, false) diff --git a/cmd/activeset/activeset.go b/cmd/activeset/activeset.go index 6c3acd6d0c3..1046916b024 100644 --- a/cmd/activeset/activeset.go +++ b/cmd/activeset/activeset.go @@ -39,7 +39,7 @@ Example: for _, id := range ids { atx, err := atxs.Get(db, id) must(err, "get id %v: %s\n", id, err) - weight += atx.GetWeight() + weight += atx.Weight } fmt.Printf("count = %d\nweight = %d\n", len(ids), weight) } diff --git a/common/types/activation.go b/common/types/activation.go index 39c3499ce86..d6cf0a63c5f 100644 --- a/common/types/activation.go +++ b/common/types/activation.go @@ -185,6 +185,13 @@ type ActivationTx struct { TickCount uint64 VRFNonce VRFPostIndex SmesherID NodeID + // Weight of the ATX. The total weight of the epoch is expected to fit in a uint64. + // The total ATX weight is sum(NumUnits * TickCount) for identity it holds. + // Space Units sizes are chosen such that NumUnits for all ATXs in an epoch is expected to be < 10^6. + // PoETs should produce ~10k ticks at genesis, but are expected due to technological advances + // to produce more over time. A uint64 should be large enough to hold the total weight of an epoch, + // for at least the first few years. + Weight uint64 AtxBlob @@ -194,25 +201,6 @@ type ActivationTx struct { validity Validity // whether the chain is fully verified and OK } -// NewActivationTx returns a new activation transaction. The ATXID is calculated and cached. -// NOTE: this function is deprecated and used in a few tests only. -// Create a new ActivationTx with ActivationTx{...}, setting the fields manually. -func NewActivationTx( - challenge NIPostChallenge, - coinbase Address, - numUnits uint32, -) *ActivationTx { - atx := &ActivationTx{ - PublishEpoch: challenge.PublishEpoch, - Sequence: challenge.Sequence, - PrevATXID: challenge.PrevATXID, - CommitmentATX: challenge.CommitmentATX, - Coinbase: coinbase, - NumUnits: numUnits, - } - return atx -} - // TargetEpoch returns the target epoch of the ATX. This is the epoch in which the miner is eligible // to participate thanks to the ATX. func (atx *ActivationTx) TargetEpoch() EpochID { @@ -238,16 +226,6 @@ func (atx *ActivationTx) SetGolden() { atx.golden = true } -// Weight of the ATX. The total weight of the epoch is expected to fit in a uint64 and is -// sum(atx.NumUnits * atx.TickCount for each ATX in a given epoch). -// Space Units sizes are chosen such that NumUnits for all ATXs in an epoch is expected to be < 10^6. -// PoETs should produce ~10k ticks at genesis, but are expected due to technological advances -// to produce more over time. A uint64 should be large enough to hold the total weight of an epoch, -// for at least the first few years. -func (atx *ActivationTx) GetWeight() uint64 { - return getWeight(uint64(atx.NumUnits), atx.TickCount) -} - // TickHeight returns a sum of base tick height and tick count. func (atx *ActivationTx) TickHeight() uint64 { return atx.BaseTickHeight + atx.TickCount @@ -270,7 +248,7 @@ func (atx *ActivationTx) MarshalLogObject(encoder log.ObjectEncoder) error { encoder.AddUint64("sequence_number", atx.Sequence) encoder.AddUint64("base_tick_height", atx.BaseTickHeight) encoder.AddUint64("tick_count", atx.TickCount) - encoder.AddUint64("weight", atx.GetWeight()) + encoder.AddUint64("weight", atx.Weight) encoder.AddUint64("height", atx.TickHeight()) return nil } @@ -400,15 +378,3 @@ type EpochActiveSet struct { } var MaxEpochActiveSetSize = scale.MustGetMaxElements[EpochActiveSet]("Set") - -func getWeight(numUnits, tickCount uint64) uint64 { - return safeMul(numUnits, tickCount) -} - -func safeMul(a, b uint64) uint64 { - c := a * b - if a > 1 && b > 1 && c/b != a { - panic("uint64 overflow") - } - return c -} diff --git a/fetch/mesh_data_test.go b/fetch/mesh_data_test.go index da6ef89ccfa..6c8c1821bf6 100644 --- a/fetch/mesh_data_test.go +++ b/fetch/mesh_data_test.go @@ -446,11 +446,11 @@ func genATXs(tb testing.TB, num uint32) []*types.ActivationTx { require.NoError(tb, err) atxs := make([]*types.ActivationTx, 0, num) for i := uint32(0); i < num; i++ { - atx := types.NewActivationTx( - types.NIPostChallenge{}, - types.Address{1, 2, 3}, - i, - ) + atx := &types.ActivationTx{ + Coinbase: types.Address{1, 2, 3}, + NumUnits: i, + Weight: uint64(i), + } atx.SmesherID = sig.NodeID() atx.SetID(types.RandomATXID()) atxs = append(atxs, atx) diff --git a/hare3/eligibility/oracle_test.go b/hare3/eligibility/oracle_test.go index 60652a04a7b..e299ea2f02c 100644 --- a/hare3/eligibility/oracle_test.go +++ b/hare3/eligibility/oracle_test.go @@ -143,8 +143,7 @@ func (t *testOracle) createActiveSet( miners = append(miners, nodeID) atx := &types.ActivationTx{ PublishEpoch: lid.GetEpoch(), - NumUnits: uint32(i + 1), - TickCount: 1, + Weight: uint64(i + 1), SmesherID: nodeID, } atx.SetID(id) @@ -371,8 +370,7 @@ func Test_VrfSignVerify(t *testing.T) { activeSet := types.RandomActiveSet(numMiners) atx1 := &types.ActivationTx{ PublishEpoch: prevEpoch, - NumUnits: 1 * 1024, - TickCount: 1, + Weight: 1 * 1024, SmesherID: signer.NodeID(), } atx1.SetID(activeSet[0]) @@ -384,9 +382,8 @@ func Test_VrfSignVerify(t *testing.T) { atx2 := &types.ActivationTx{ PublishEpoch: prevEpoch, - NumUnits: 9 * 1024, + Weight: 9 * 1024, SmesherID: signer2.NodeID(), - TickCount: 1, } atx2.SetID(activeSet[1]) atx2.SetReceived(time.Now()) diff --git a/hare3/hare_test.go b/hare3/hare_test.go index c53f78fbe59..acdaa7f398d 100644 --- a/hare3/hare_test.go +++ b/hare3/hare_test.go @@ -163,6 +163,7 @@ func (n *node) withAtx(min, max int) *node { } else { atx.NumUnits = uint32(min) } + atx.Weight = uint64(atx.NumUnits) * atx.TickCount id := types.ATXID{} n.t.rng.Read(id[:]) atx.SetID(id) diff --git a/malfeasance/wire/malfeasance_test.go b/malfeasance/wire/malfeasance_test.go index b367d24ee0b..df927e2145f 100644 --- a/malfeasance/wire/malfeasance_test.go +++ b/malfeasance/wire/malfeasance_test.go @@ -25,14 +25,11 @@ func TestMain(m *testing.M) { func TestCodec_MultipleATXs(t *testing.T) { epoch := types.EpochID(11) - a1 := types.NewActivationTx(types.NIPostChallenge{PublishEpoch: epoch}, types.Address{1, 2, 3}, 10) - a2 := types.NewActivationTx(types.NIPostChallenge{PublishEpoch: epoch}, types.Address{3, 2, 1}, 11) - var atxProof wire.AtxProof - for i, a := range []*types.ActivationTx{a1, a2} { + for i := range atxProof.Messages { atxProof.Messages[i] = wire.AtxProofMsg{ InnerMsg: types.ATXMetadata{ - PublishEpoch: a.PublishEpoch, + PublishEpoch: epoch, MsgHash: types.RandomHash(), }, SmesherID: types.RandomNodeID(), diff --git a/mesh/executor_test.go b/mesh/executor_test.go index 01330cfb6e3..01645d640fa 100644 --- a/mesh/executor_test.go +++ b/mesh/executor_test.go @@ -69,16 +69,17 @@ func makeResults(lid types.LayerID, txs ...types.Transaction) []types.Transactio func (t *testExecutor) createATX(epoch types.EpochID, cb types.Address) (types.ATXID, types.NodeID) { sig, err := signing.NewEdSigner() require.NoError(t.tb, err) - atx := types.NewActivationTx( - types.NIPostChallenge{PublishEpoch: epoch}, - cb, - 11, - ) - atx.VRFNonce = 1 + atx := &types.ActivationTx{ + PublishEpoch: epoch, + Coinbase: cb, + NumUnits: 11, + Weight: 11, + VRFNonce: 1, + TickCount: 1, + SmesherID: sig.NodeID(), + } atx.SetReceived(time.Now()) - atx.SmesherID = sig.NodeID() atx.SetID(types.RandomATXID()) - atx.TickCount = 1 require.NoError(t.tb, atxs.Add(t.db, atx)) t.atxsdata.AddFromAtx(atx, false) return atx.ID(), sig.NodeID() diff --git a/miner/proposal_builder_test.go b/miner/proposal_builder_test.go index 542927cc2c3..3fd3fa24578 100644 --- a/miner/proposal_builder_test.go +++ b/miner/proposal_builder_test.go @@ -75,6 +75,7 @@ func gatx( PublishEpoch: epoch, TickCount: ticks, SmesherID: smesher, + Weight: uint64(units) * ticks, } atx.SetID(id) atx.SetReceived(time.Time{}.Add(1)) diff --git a/proposals/eligibility_validator_test.go b/proposals/eligibility_validator_test.go index 6030327d4fd..acdcc9203c4 100644 --- a/proposals/eligibility_validator_test.go +++ b/proposals/eligibility_validator_test.go @@ -27,6 +27,7 @@ func gatx( VRFNonce: nonce, TickCount: 100, SmesherID: smesher, + Weight: uint64(units) * 100, } atx.SetID(id) atx.SetReceived(time.Time{}.Add(1)) diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index 4f41ab4a689..423d28df65d 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -22,7 +22,7 @@ const ( // filters that refer to the id column. const fieldsQuery = `select atxs.id, atxs.nonce, atxs.base_tick_height, atxs.tick_count, atxs.pubkey, atxs.effective_num_units, -atxs.received, atxs.epoch, atxs.sequence, atxs.coinbase, atxs.validity, atxs.prev_id, atxs.commitment_atx` +atxs.received, atxs.epoch, atxs.sequence, atxs.coinbase, atxs.validity, atxs.prev_id, atxs.commitment_atx, atxs.weight` const fullQuery = fieldsQuery + ` from atxs` @@ -61,6 +61,7 @@ func decoder(fn decoderCallback) sql.Decoder { a.CommitmentATX = new(types.ATXID) stmt.ColumnBytes(12, a.CommitmentATX[:]) } + a.Weight = uint64(stmt.ColumnInt64(13)) return fn(&a) } @@ -440,13 +441,14 @@ func Add(db sql.Executor, atx *types.ActivationTx) error { } else { stmt.BindNull(13) } + stmt.BindInt64(14, int64(atx.Weight)) } _, err := db.Exec(` insert into atxs (id, epoch, effective_num_units, commitment_atx, nonce, pubkey, received, base_tick_height, tick_count, sequence, coinbase, - validity, prev_id) - values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)`, enc, nil) + validity, prev_id, weight) + values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)`, enc, nil) if err != nil { return fmt.Errorf("insert ATX ID %v: %w", atx.ID(), err) } @@ -776,7 +778,7 @@ func IterateAtxsWithMalfeasance( func(s *sql.Statement) { s.BindInt64(1, int64(publish)) }, func(s *sql.Statement) bool { return decoder(func(atx *types.ActivationTx) bool { - return fn(atx, s.ColumnInt(13) != 0) + return fn(atx, s.ColumnInt(14) != 0) })(s) }, ) diff --git a/sql/identities/identities.go b/sql/identities/identities.go index 613e19326f5..90327eb2038 100644 --- a/sql/identities/identities.go +++ b/sql/identities/identities.go @@ -5,6 +5,8 @@ import ( "fmt" "time" + sqlite "github.com/go-llsqlite/crawshaw" + "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/malfeasance/wire" @@ -143,17 +145,45 @@ func Married(db sql.Executor, id types.NodeID) (bool, error) { return rows > 0, nil } +// MarriageInfo obtains the marriage ATX and index for given ID. +func MarriageInfo(db sql.Executor, id types.NodeID) (*types.ATXID, int, error) { + var ( + atx *types.ATXID + index int + ) + rows, err := db.Exec("select marriage_atx, marriage_idx from identities where pubkey = ?1;", + func(stmt *sql.Statement) { + stmt.BindBytes(1, id.Bytes()) + }, func(stmt *sql.Statement) bool { + if stmt.ColumnType(0) != sqlite.SQLITE_NULL { + atx = new(types.ATXID) + stmt.ColumnBytes(0, atx[:]) + + index = int(stmt.ColumnInt64(1)) + } + return false + }) + if err != nil { + return nil, 0, fmt.Errorf("getting marriage ATX for %v: %w", id, err) + } + if rows == 0 { + return nil, 0, sql.ErrNotFound + } + return atx, index, nil +} + // Set marriage inserts marriage ATX for given identity. // If identitty doesn't exist - create it. -func SetMarriage(db sql.Executor, id types.NodeID, atx types.ATXID) error { +func SetMarriage(db sql.Executor, id types.NodeID, atx types.ATXID, marriageIndex int) error { _, err := db.Exec(` - INSERT INTO identities (pubkey, marriage_atx) - values (?1, ?2) - ON CONFLICT(pubkey) DO UPDATE SET marriage_atx = excluded.marriage_atx + INSERT INTO identities (pubkey, marriage_atx, marriage_idx) + values (?1, ?2, ?3) + ON CONFLICT(pubkey) DO UPDATE SET marriage_atx = excluded.marriage_atx, marriage_idx = excluded.marriage_idx WHERE marriage_atx IS NULL;`, func(stmt *sql.Statement) { stmt.BindBytes(1, id.Bytes()) stmt.BindBytes(2, atx.Bytes()) + stmt.BindInt64(3, int64(marriageIndex)) }, nil, ) if err != nil { @@ -188,3 +218,24 @@ func EquivocationSet(db sql.Executor, id types.NodeID) ([]types.NodeID, error) { return ids, nil } + +func EquivocationSetByMarriageATX(db sql.Executor, atx types.ATXID) ([]types.NodeID, error) { + var ids []types.NodeID + + _, err := db.Exec(` + SELECT pubkey FROM identities WHERE marriage_atx = ?1 ORDER BY marriage_idx ASC;`, + func(stmt *sql.Statement) { + stmt.BindBytes(1, atx.Bytes()) + }, + func(stmt *sql.Statement) bool { + var nid types.NodeID + stmt.ColumnBytes(0, nid[:]) + ids = append(ids, nid) + return true + }) + if err != nil { + return nil, fmt.Errorf("getting equivocation set by ID %s: %w", atx, err) + } + + return ids, nil +} diff --git a/sql/identities/identities_test.go b/sql/identities/identities_test.go index 0feaeb018f9..a8b77cd71f6 100644 --- a/sql/identities/identities_test.go +++ b/sql/identities/identities_test.go @@ -131,7 +131,7 @@ func TestMarried(t *testing.T) { require.False(t, married) atx := types.RandomATXID() - require.NoError(t, SetMarriage(db, id, atx)) + require.NoError(t, SetMarriage(db, id, atx, 0)) married, err = Married(db, id) require.NoError(t, err) @@ -149,7 +149,7 @@ func TestMarried(t *testing.T) { require.NoError(t, err) require.False(t, married) - require.NoError(t, SetMarriage(db, id, types.RandomATXID())) + require.NoError(t, SetMarriage(db, id, types.RandomATXID(), 0)) married, err = Married(db, id) require.NoError(t, err) @@ -157,6 +157,30 @@ func TestMarried(t *testing.T) { }) } +func TestMarriageATX(t *testing.T) { + t.Parallel() + t.Run("not married", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + + id := types.RandomNodeID() + _, _, err := MarriageInfo(db, id) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + t.Run("married", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + + id := types.RandomNodeID() + atx := types.RandomATXID() + require.NoError(t, SetMarriage(db, id, atx, 5)) + got, idx, err := MarriageInfo(db, id) + require.NoError(t, err) + require.Equal(t, atx, *got) + require.Equal(t, 5, idx) + }) +} + func TestEquivocationSet(t *testing.T) { t.Parallel() t.Run("equivocation set of married IDs", func(t *testing.T) { @@ -169,8 +193,8 @@ func TestEquivocationSet(t *testing.T) { types.RandomNodeID(), types.RandomNodeID(), } - for _, id := range ids { - require.NoError(t, SetMarriage(db, id, atx)) + for i, id := range ids { + require.NoError(t, SetMarriage(db, id, atx, i)) } for _, id := range ids { @@ -198,8 +222,8 @@ func TestEquivocationSet(t *testing.T) { types.RandomNodeID(), types.RandomNodeID(), } - for _, id := range ids { - require.NoError(t, SetMarriage(db, id, atx)) + for i, id := range ids { + require.NoError(t, SetMarriage(db, id, atx, i)) } for _, id := range ids { @@ -210,7 +234,7 @@ func TestEquivocationSet(t *testing.T) { // try to marry via another random ATX // the set should remain intact - require.NoError(t, SetMarriage(db, ids[0], types.RandomATXID())) + require.NoError(t, SetMarriage(db, ids[0], types.RandomATXID(), 0)) for _, id := range ids { set, err := EquivocationSet(db, id) require.NoError(t, err) @@ -221,7 +245,7 @@ func TestEquivocationSet(t *testing.T) { db := sql.InMemory() atx := types.RandomATXID() id := types.RandomNodeID() - require.NoError(t, SetMarriage(db, id, atx)) + require.NoError(t, SetMarriage(db, id, atx, 0)) malicious, err := IsMalicious(db, id) require.NoError(t, err) @@ -243,8 +267,8 @@ func TestEquivocationSet(t *testing.T) { types.RandomNodeID(), types.RandomNodeID(), } - for _, id := range ids { - require.NoError(t, SetMarriage(db, id, atx)) + for i, id := range ids { + require.NoError(t, SetMarriage(db, id, atx, i)) } require.NoError(t, SetMalicious(db, ids[0], []byte("proof"), time.Now())) @@ -256,3 +280,30 @@ func TestEquivocationSet(t *testing.T) { } }) } + +func TestEquivocationSetByMarriageATX(t *testing.T) { + t.Parallel() + + t.Run("married IDs", func(t *testing.T) { + db := sql.InMemory() + ids := []types.NodeID{ + types.RandomNodeID(), + types.RandomNodeID(), + types.RandomNodeID(), + types.RandomNodeID(), + } + atx := types.RandomATXID() + for i, id := range ids { + require.NoError(t, SetMarriage(db, id, atx, i)) + } + set, err := EquivocationSetByMarriageATX(db, atx) + require.NoError(t, err) + require.Equal(t, ids, set) + }) + t.Run("empty set", func(t *testing.T) { + db := sql.InMemory() + set, err := EquivocationSetByMarriageATX(db, types.RandomATXID()) + require.NoError(t, err) + require.Empty(t, set) + }) +} diff --git a/sql/migrations/state/0019_marriages.sql b/sql/migrations/state/0019_marriages.sql index 66bc7d11289..799f36c7d07 100644 --- a/sql/migrations/state/0019_marriages.sql +++ b/sql/migrations/state/0019_marriages.sql @@ -1 +1,2 @@ ALTER TABLE identities ADD COLUMN marriage_atx CHAR(32); +ALTER TABLE identities ADD COLUMN marriage_idx INTEGER; diff --git a/sql/migrations/state/0020_atx_weight.sql b/sql/migrations/state/0020_atx_weight.sql new file mode 100644 index 00000000000..4504bd43206 --- /dev/null +++ b/sql/migrations/state/0020_atx_weight.sql @@ -0,0 +1,2 @@ +ALTER TABLE atxs ADD COLUMN weight INTEGER; +INSERT INTO atxs (weight) SELECT effective_num_units * tick_count FROM atxs; diff --git a/tortoise/model/core.go b/tortoise/model/core.go index ce7022fa33f..04381a2aa1d 100644 --- a/tortoise/model/core.go +++ b/tortoise/model/core.go @@ -147,19 +147,20 @@ func (c *core) OnMessage(m Messenger, event Message) { return } - nipost := types.NIPostChallenge{ - PublishEpoch: ev.LayerID.GetEpoch(), + atx := &types.ActivationTx{ + PublishEpoch: ev.LayerID.GetEpoch(), + NumUnits: c.units, + Coinbase: types.GenerateAddress(c.signer.PublicKey().Bytes()), + SmesherID: c.signer.NodeID(), + BaseTickHeight: 1, + TickCount: 2, + Weight: uint64(c.units) * 2, } - addr := types.GenerateAddress(c.signer.PublicKey().Bytes()) - atx := types.NewActivationTx(nipost, addr, c.units) - atx.SmesherID = c.signer.NodeID() atx.SetID(types.RandomATXID()) atx.SetReceived(time.Now()) - atx.BaseTickHeight = 1 - atx.TickCount = 2 c.refBallot = nil c.atx = atx.ID() - c.weight = atx.GetWeight() + c.weight = atx.Weight m.Send(MessageAtx{Atx: atx}) case MessageBlock: diff --git a/tortoise/sim/generator.go b/tortoise/sim/generator.go index 3ebc5a82c49..d89a3be918d 100644 --- a/tortoise/sim/generator.go +++ b/tortoise/sim/generator.go @@ -229,23 +229,24 @@ func (g *Generator) generateAtxs() { if err != nil { panic(err) } - address := types.GenerateAddress(sig.PublicKey().Bytes()) - nipost := types.NIPostChallenge{ - PublishEpoch: g.nextLayer.Sub(1).GetEpoch(), - } - atx := types.NewActivationTx(nipost, address, units) var ticks uint64 if g.ticks != nil { ticks = g.ticks[i] } else { ticks = uint64(intInRange(g.rng, g.ticksRange)) } - atx.SmesherID = sig.NodeID() + atx := &types.ActivationTx{ + PublishEpoch: g.nextLayer.Sub(1).GetEpoch(), + Coinbase: types.GenerateAddress(sig.PublicKey().Bytes()), + NumUnits: units, + SmesherID: sig.NodeID(), + BaseTickHeight: g.prevHeight[i], + TickCount: ticks, + Weight: uint64(units) * ticks, + } atx.SetID(types.RandomATXID()) atx.SetReceived(time.Now()) - atx.BaseTickHeight = g.prevHeight[i] - atx.TickCount = ticks g.prevHeight[i] += ticks g.activations[i] = atx for _, state := range g.states { diff --git a/tortoise/sim/layer.go b/tortoise/sim/layer.go index 5a1ba8a7d6c..5bc6b74d2d8 100644 --- a/tortoise/sim/layer.go +++ b/tortoise/sim/layer.go @@ -159,7 +159,7 @@ func (g *Generator) genLayer(cfg nextConf) types.LayerID { } var total uint64 for _, atx := range g.activations { - total += atx.GetWeight() + total += atx.Weight } miners := make([]uint32, len(g.activations)) @@ -182,7 +182,7 @@ func (g *Generator) genLayer(cfg nextConf) types.LayerID { if err != nil { g.logger.Panic("failed to get a beacon", zap.Error(err)) } - n, err := util.GetNumEligibleSlots(atx.GetWeight(), 0, total, g.conf.LayerSize, g.conf.LayersPerEpoch) + n, err := util.GetNumEligibleSlots(atx.Weight, 0, total, g.conf.LayerSize, g.conf.LayersPerEpoch) if err != nil { g.logger.Panic("eligible slots", zap.Error(err)) } diff --git a/tortoise/tortoise_test.go b/tortoise/tortoise_test.go index d8b2de1b9f6..9d1ec119c03 100644 --- a/tortoise/tortoise_test.go +++ b/tortoise/tortoise_test.go @@ -475,8 +475,7 @@ func TestComputeExpectedWeight(t *testing.T) { eid := first + types.EpochID(i) atx := &types.ActivationTx{ PublishEpoch: eid - 1, - NumUnits: uint32(weight), - TickCount: 1, + Weight: weight, } atx.SetID(types.RandomATXID()) atx.SetReceived(time.Now()) @@ -500,7 +499,7 @@ func extractAtxsData(db sql.Executor, target types.EpochID) (uint64, uint64, err heights []uint64 ) if err := atxs.IterateAtxsOps(db, builder.FilterEpochOnly(target-1), func(atx *types.ActivationTx) bool { - weight += atx.GetWeight() + weight += atx.Weight heights = append(heights, atx.TickHeight()) return true }); err != nil {