From 65a23b00590446943af2f5cd185cff645bec0fc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Thu, 20 Jun 2024 08:47:43 +0000 Subject: [PATCH] ATX handler V2 supports marriages (#6010) ## Motivation Closes #6007 --- activation/handler_v2.go | 137 +++++++++++++- activation/handler_v2_test.go | 234 ++++++++++++++++++++++-- activation/wire/wire_v2.go | 10 +- activation/wire/wire_v2_scale.go | 4 +- malfeasance/handler.go | 1 + malfeasance/wire/malfeasance.go | 9 +- malfeasance/wire/malfeasance_scale.go | 8 + signing/signer.go | 1 + sql/identities/identities.go | 71 ++++++- sql/identities/identities_test.go | 142 ++++++++++++++ sql/migrations/state/0019_marriages.sql | 1 + 11 files changed, 591 insertions(+), 27 deletions(-) create mode 100644 sql/migrations/state/0019_marriages.sql diff --git a/activation/handler_v2.go b/activation/handler_v2.go index c9ec247618..a369c50aaf 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math" + "slices" "time" "github.com/spacemeshos/post/shared" @@ -103,6 +104,11 @@ func (h *HandlerV2) processATX( return nil, fmt.Errorf("validating positioning atx: %w", err) } + marrying, err := h.validateMarriages(watx) + if err != nil { + return nil, fmt.Errorf("validating marriages: %w", err) + } + parts, proof, err := h.syntacticallyValidateDeps(ctx, watx) if err != nil { return nil, fmt.Errorf("atx %s syntactically invalid based on deps: %w", watx.ID(), err) @@ -135,7 +141,7 @@ func (h *HandlerV2) processATX( atx.SetID(watx.ID()) atx.SetReceived(received) - proof, err = h.storeAtx(ctx, atx, watx) + proof, err = h.storeAtx(ctx, atx, watx, marrying) if err != nil { return nil, fmt.Errorf("cannot store atx %s: %w", atx.ShortString(), err) } @@ -147,7 +153,6 @@ func (h *HandlerV2) processATX( // Syntactically validate an ATX. // TODOs: -// 1. support marriages // 2. support merged ATXs. func (h *HandlerV2) syntacticallyValidate(ctx context.Context, atx *wire.ActivationTxV2) error { if !h.edVerifier.Verify(signing.ATX, atx.SmesherID, atx.SignedBytes(), atx.Signature) { @@ -156,9 +161,15 @@ func (h *HandlerV2) syntacticallyValidate(ctx context.Context, atx *wire.Activat if atx.PositioningATX == types.EmptyATXID { return errors.New("empty positioning atx") } - // TODO: support marriages if len(atx.Marriages) != 0 { - return errors.New("marriages are not supported") + // Marriage ATX must contain a self-signed certificate. + // It's identified by having ReferenceAtx == EmptyATXID. + idx := slices.IndexFunc(atx.Marriages, func(cert wire.MarriageCertificate) bool { + return cert.ReferenceAtx == types.EmptyATXID + }) + if idx == -1 { + return errors.New("signer must marry itself") + } } current := h.clock.CurrentLayer().GetEpoch() @@ -216,6 +227,9 @@ func (h *HandlerV2) syntacticallyValidate(ctx context.Context, atx *wire.Activat switch { case atx.MarriageATX != nil: // Merged ATX + if len(atx.Marriages) != 0 { + return errors.New("merged atx cannot have marriages") + } // TODO: support merged ATXs return errors.New("atx merge is not supported") default: @@ -278,6 +292,9 @@ func (h *HandlerV2) collectAtxDeps(atx *wire.ActivationTxV2) ([]types.Hash32, [] if atx.MarriageATX != nil { ids = append(ids, *atx.MarriageATX) } + for _, cert := range atx.Marriages { + ids = append(ids, cert.ReferenceAtx) + } filtered := make(map[types.ATXID]struct{}) for _, id := range ids { @@ -395,6 +412,34 @@ func (h *HandlerV2) validatePositioningAtx(publish types.EpochID, golden, positi return posAtx.TickHeight(), nil } +// Validate marriages and return married IDs. +// Note: The order of returned IDs is important and must match the order of the marriage certificates. +// The MarriageIndex in PoST proof matches the index in this marriage slice. +func (h *HandlerV2) validateMarriages(atx *wire.ActivationTxV2) ([]types.NodeID, error) { + if len(atx.Marriages) == 0 { + return nil, nil + } + var marryingIDs []types.NodeID + for i, m := range atx.Marriages { + var id types.NodeID + if m.ReferenceAtx == types.EmptyATXID { + id = atx.SmesherID + } else { + atx, err := atxs.Get(h.cdb, m.ReferenceAtx) + if err != nil { + return nil, fmt.Errorf("getting marriage reference atx: %w", err) + } + id = atx.SmesherID + } + + if !h.edVerifier.Verify(signing.MARRIAGE, id, atx.SmesherID.Bytes(), m.Signature) { + return nil, fmt.Errorf("invalid marriage[%d] signature", i) + } + marryingIDs = append(marryingIDs, id) + } + return marryingIDs, nil +} + type atxParts struct { leaves uint64 effectiveUnits uint32 @@ -529,6 +574,7 @@ func (h *HandlerV2) checkMalicious( ctx context.Context, tx *sql.Tx, watx *wire.ActivationTxV2, + marrying []types.NodeID, ) (bool, *mwire.MalfeasanceProof, error) { malicious, err := identities.IsMalicious(tx, watx.SmesherID) if err != nil { @@ -538,6 +584,14 @@ func (h *HandlerV2) checkMalicious( return true, nil, nil } + proof, err := h.checkDoubleMarry(tx, watx, marrying) + if err != nil { + return false, nil, fmt.Errorf("checking double marry: %w", err) + } + if proof != nil { + return true, proof, nil + } + // TODO: contextual validation: // 1. check double-publish // 2. check previous ATX @@ -548,12 +602,36 @@ func (h *HandlerV2) checkMalicious( return false, nil, nil } +func (h *HandlerV2) checkDoubleMarry( + tx *sql.Tx, + watx *wire.ActivationTxV2, + marrying []types.NodeID, +) (*mwire.MalfeasanceProof, error) { + for _, id := range marrying { + married, err := identities.Married(tx, id) + if err != nil { + return nil, fmt.Errorf("checking if ID is married: %w", err) + } + if married { + proof := &mwire.MalfeasanceProof{ + Proof: mwire.Proof{ + Type: mwire.DoubleMarry, + Data: &mwire.DoubleMarryProof{}, + }, + } + return proof, nil + } + } + return nil, nil +} + // Store an ATX in the DB. // TODO: detect malfeasance and create proofs. func (h *HandlerV2) storeAtx( ctx context.Context, atx *types.ActivationTx, watx *wire.ActivationTxV2, + marrying []types.NodeID, ) (*mwire.MalfeasanceProof, error) { var ( malicious bool @@ -561,11 +639,26 @@ func (h *HandlerV2) storeAtx( ) if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { var err error - malicious, proof, err = h.checkMalicious(ctx, tx, watx) + malicious, proof, err = h.checkMalicious(ctx, tx, watx, marrying) if err != nil { return fmt.Errorf("check malicious: %w", err) } + if len(marrying) != 0 { + for _, id := range marrying { + if err := identities.SetMarriage(tx, id, atx.ID()); err != nil { + return err + } + } + if !malicious && proof == nil { + // We check for malfeasance again becase the marriage increased the equivocation set. + malicious, err = identities.IsMalicious(tx, atx.SmesherID) + if err != nil { + return fmt.Errorf("re-checking if smesherID is malicious: %w", err) + } + } + } + err = atxs.Add(tx, atx) if err != nil && !errors.Is(err, sql.ErrObjectExists) { return fmt.Errorf("add atx to db: %w", err) @@ -576,9 +669,39 @@ func (h *HandlerV2) storeAtx( } atxs.AtxAdded(h.cdb, atx) + + var allMalicious map[types.NodeID]struct{} + if malicious || proof != nil { + // Combine IDs from the present equivocation set for atx.SmesherID and IDs in atx.Marriages. + allMalicious = make(map[types.NodeID]struct{}) + + set, err := identities.EquivocationSet(h.cdb, atx.SmesherID) + if err != nil { + return nil, fmt.Errorf("getting equivocation set: %w", err) + } + for _, id := range set { + allMalicious[id] = struct{}{} + } + for _, id := range marrying { + allMalicious[id] = struct{}{} + } + } if proof != nil { - h.cdb.CacheMalfeasanceProof(atx.SmesherID, proof) - h.tortoise.OnMalfeasance(atx.SmesherID) + encoded, err := codec.Encode(proof) + if err != nil { + return nil, fmt.Errorf("encoding malfeasance proof: %w", err) + } + + for id := range allMalicious { + if err := identities.SetMalicious(h.cdb, id, encoded, atx.Received()); err != nil { + return nil, fmt.Errorf("setting malfeasance proof: %w", err) + } + h.cdb.CacheMalfeasanceProof(id, proof) + } + } + + for id := range allMalicious { + h.tortoise.OnMalfeasance(id) } h.beacon.OnAtx(atx) diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 83caf5220c..3de4a113eb 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -18,9 +18,11 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" + mwire "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/identities" ) type v2TestHandler struct { @@ -123,6 +125,21 @@ func (h *handlerMocks) expectAtxV2(atx *wire.ActivationTxV2) { h.expectStoreAtxV2(atx) } +func (h *v2TestHandler) createAndProcessInitial(t *testing.T, sig *signing.EdSigner) *wire.ActivationTxV2 { + t.Helper() + atx := newInitialATXv2(t, h.handlerMocks.goldenATXID) + atx.Sign(sig) + p, err := h.processInitial(atx) + require.NoError(t, err) + require.Nil(t, p) + return atx +} + +func (h *v2TestHandler) processInitial(atx *wire.ActivationTxV2) (*mwire.MalfeasanceProof, error) { + h.expectInitialAtxV2(atx) + return h.processATX(context.Background(), peer.ID("peer"), atx, codec.MustEncode(atx), time.Now()) +} + func TestHandlerV2_SyntacticallyValidate(t *testing.T) { t.Parallel() golden := types.RandomATXID() @@ -157,16 +174,6 @@ func TestHandlerV2_SyntacticallyValidate(t *testing.T) { err := atxHandler.syntacticallyValidate(context.Background(), atx) require.ErrorContains(t, err, "empty positioning atx") }) - t.Run("marriages are not supported (yet)", func(t *testing.T) { - t.Parallel() - atx := newInitialATXv2(t, golden) - atx.Marriages = []wire.MarriageCertificate{{}} - atx.Sign(sig) - - atxHandler := newV2TestHandler(t, golden) - err := atxHandler.syntacticallyValidate(context.Background(), atx) - require.ErrorContains(t, err, "marriages are not supported") - }) t.Run("reject golden previous ATX", func(t *testing.T) { t.Parallel() atx := newSoloATXv2(t, 0, golden, golden) @@ -587,6 +594,8 @@ func TestCollectDeps_AtxV2(t *testing.T) { positioning := types.RandomATXID() commitment := types.RandomATXID() marriage := types.RandomATXID() + ref0 := types.RandomATXID() + ref1 := types.RandomATXID() poetA := types.RandomHash() poetB := types.RandomHash() @@ -603,10 +612,15 @@ func TestCollectDeps_AtxV2(t *testing.T) { {Challenge: poetA}, {Challenge: poetB}, }, + Marriages: []wire.MarriageCertificate{ + {ReferenceAtx: types.EmptyATXID}, + {ReferenceAtx: ref0}, + {ReferenceAtx: ref1}, + }, } poetDeps, atxIDs := atxHandler.collectAtxDeps(&atx) require.ElementsMatch(t, []types.Hash32{poetA, poetB}, poetDeps) - require.ElementsMatch(t, []types.ATXID{prev0, prev1, positioning, commitment, marriage}, atxIDs) + require.ElementsMatch(t, []types.ATXID{prev0, prev1, positioning, commitment, marriage, ref0, ref1}, atxIDs) }) t.Run("eliminates duplicates", func(t *testing.T) { t.Parallel() @@ -1066,6 +1080,204 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { }) } +func Test_Marriages(t *testing.T) { + t.Parallel() + golden := types.RandomATXID() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + t.Run("invalid marriage signature", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) + + atx := newInitialATXv2(t, golden) + atx.Marriages = []wire.MarriageCertificate{{ + Signature: types.RandomEdSignature(), + }} + + _, err = atxHandler.validateMarriages(atx) + require.ErrorContains(t, err, "invalid marriage[0] signature") + }) + t.Run("valid marriage", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) + + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + othersAtx := atxHandler.createAndProcessInitial(t, otherSig) + + atx := newInitialATXv2(t, golden) + atx.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: othersAtx.ID(), + Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + + atx.Sign(sig) + + p, err := atxHandler.processInitial(atx) + require.NoError(t, err) + require.Nil(t, p) + + married, err := identities.Married(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + require.True(t, married) + + married, err = identities.Married(atxHandler.cdb, otherSig.NodeID()) + require.NoError(t, err) + require.True(t, married) + + set, err := identities.EquivocationSet(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + require.ElementsMatch(t, []types.NodeID{sig.NodeID(), otherSig.NodeID()}, set) + }) + t.Run("can't marry twice", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) + + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + othersAtx := atxHandler.createAndProcessInitial(t, otherSig) + + atx := newInitialATXv2(t, golden) + atx.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: othersAtx.ID(), + Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + atx.Sign(sig) + + atxHandler.expectInitialAtxV2(atx) + _, err = atxHandler.processATX(context.Background(), "", atx, codec.MustEncode(atx), time.Now()) + require.NoError(t, err) + + // otherSig2 cannot marry sig, trying to extend its set. + otherSig2, err := signing.NewEdSigner() + require.NoError(t, err) + others2Atx := atxHandler.createAndProcessInitial(t, otherSig2) + atx2 := newSoloATXv2(t, atx.PublishEpoch+1, atx.ID(), atx.ID()) + atx2.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: others2Atx.ID(), + Signature: otherSig2.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + atx2.Sign(sig) + atxHandler.expectAtxV2(atx2) + ids := []types.NodeID{sig.NodeID(), otherSig.NodeID(), otherSig2.NodeID()} + for _, id := range ids { + atxHandler.mtortoise.EXPECT().OnMalfeasance(id) + } + proof, err := atxHandler.processATX(context.Background(), "", atx2, codec.MustEncode(atx2), time.Now()) + require.NoError(t, err) + // TODO: check the proof contents once its implemented + require.NotNil(t, proof) + + // All 3 IDs are marked as malicious + for _, id := range ids { + malicious, err := identities.IsMalicious(atxHandler.cdb, id) + require.NoError(t, err) + require.True(t, malicious) + } + + // The equivocation set of sig and otherSig didn't grow + equiv, err := identities.EquivocationSet(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + require.ElementsMatch(t, []types.NodeID{sig.NodeID(), otherSig.NodeID()}, equiv) + }) + t.Run("signer must marry self", func(t *testing.T) { + t.Parallel() + atxHandler := newV2TestHandler(t, golden) + + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + othersAtx := atxHandler.createAndProcessInitial(t, otherSig) + + atx := newInitialATXv2(t, golden) + atx.Marriages = []wire.MarriageCertificate{ + { + ReferenceAtx: othersAtx.ID(), + Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + atx.Sign(sig) + + atxHandler.mclock.EXPECT().CurrentLayer().AnyTimes() + _, err = atxHandler.processATX(context.Background(), "", atx, codec.MustEncode(atx), time.Now()) + require.ErrorContains(t, err, "signer must marry itself") + }) +} + +func Test_MarryingMalicious(t *testing.T) { + t.Parallel() + golden := types.RandomATXID() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + + tt := []struct { + name string + malicious types.NodeID + }{ + { + name: "owner is malicious", + malicious: sig.NodeID(), + }, { + name: "other is malicious", + malicious: otherSig.NodeID(), + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + atxHandler := newV2TestHandler(t, golden) + + othersAtx := atxHandler.createAndProcessInitial(t, otherSig) + + atx := newInitialATXv2(t, golden) + atx.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, { + ReferenceAtx: othersAtx.ID(), + Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + atx.Sign(sig) + + require.NoError(t, identities.SetMalicious(atxHandler.cdb, tc.malicious, []byte("proof"), time.Now())) + + atxHandler.expectInitialAtxV2(atx) + atxHandler.mtortoise.EXPECT().OnMalfeasance(sig.NodeID()) + atxHandler.mtortoise.EXPECT().OnMalfeasance(otherSig.NodeID()) + + _, err = atxHandler.processATX(context.Background(), "", atx, codec.MustEncode(atx), time.Now()) + require.NoError(t, err) + + equiv, err := identities.EquivocationSet(atxHandler.cdb, sig.NodeID()) + require.NoError(t, err) + require.ElementsMatch(t, []types.NodeID{sig.NodeID(), otherSig.NodeID()}, equiv) + + for _, id := range []types.NodeID{sig.NodeID(), otherSig.NodeID()} { + m, err := identities.IsMalicious(atxHandler.cdb, id) + require.NoError(t, err) + require.True(t, m) + } + }) + } +} + func newInitialATXv2(t testing.TB, golden types.ATXID) *wire.ActivationTxV2 { t.Helper() atx := &wire.ActivationTxV2{ diff --git a/activation/wire/wire_v2.go b/activation/wire/wire_v2.go index 354ee7c75c..db3bf9d022 100644 --- a/activation/wire/wire_v2.go +++ b/activation/wire/wire_v2.go @@ -31,6 +31,7 @@ type ActivationTxV2 struct { // A marriage is permanent and cannot be revoked or repeated. // All new IDs that are married to this ID are added to the equivocation set // that this ID belongs to. + // It must contain a self-marriage certificate (needed for malfeasance proofs). Marriages []MarriageCertificate `scale:"max=256"` // The ID of the ATX containing marriage for the included IDs. @@ -170,7 +171,10 @@ func (i *InitialAtxPartsV2) Root() []byte { // A marriage allows for publishing a merged ATX, which can contain PoST for all married IDs. // Any ID from the marriage can publish a merged ATX on behalf of all married IDs. type MarriageCertificate struct { - ID types.NodeID + // An ATX of the ID that marries. It proves that the ID exists. + // Note: the reference ATX does not need to be from the previous epoch. + // It only needs to prove the existence of the ID. + ReferenceAtx types.ATXID // Signature over the other ID that this ID marries with // If Alice marries Bob, then Alice signs Bob's ID // and Bob includes this certificate in his ATX. @@ -184,7 +188,7 @@ func (mc *MarriageCertificate) Root() []byte { if err != nil { panic(err) } - tree.AddLeaf(mc.ID.Bytes()) + tree.AddLeaf(mc.ReferenceAtx.Bytes()) tree.AddLeaf(mc.Signature.Bytes()) return tree.Root() } @@ -312,7 +316,7 @@ func (marriage *MarriageCertificate) MarshalLogObject(encoder zapcore.ObjectEnco if marriage == nil { return nil } - encoder.AddString("ID", marriage.ID.String()) + encoder.AddString("ReferenceATX", marriage.ReferenceAtx.String()) encoder.AddString("Signature", marriage.Signature.String()) return nil } diff --git a/activation/wire/wire_v2_scale.go b/activation/wire/wire_v2_scale.go index 5024ba3202..286a200428 100644 --- a/activation/wire/wire_v2_scale.go +++ b/activation/wire/wire_v2_scale.go @@ -215,7 +215,7 @@ func (t *InitialAtxPartsV2) DecodeScale(dec *scale.Decoder) (total int, err erro func (t *MarriageCertificate) EncodeScale(enc *scale.Encoder) (total int, err error) { { - n, err := scale.EncodeByteArray(enc, t.ID[:]) + n, err := scale.EncodeByteArray(enc, t.ReferenceAtx[:]) if err != nil { return total, err } @@ -233,7 +233,7 @@ func (t *MarriageCertificate) EncodeScale(enc *scale.Encoder) (total int, err er func (t *MarriageCertificate) DecodeScale(dec *scale.Decoder) (total int, err error) { { - n, err := scale.DecodeByteArray(dec, t.ID[:]) + n, err := scale.DecodeByteArray(dec, t.ReferenceAtx[:]) if err != nil { return total, err } diff --git a/malfeasance/handler.go b/malfeasance/handler.go index ff6590ecbb..0476dc1464 100644 --- a/malfeasance/handler.go +++ b/malfeasance/handler.go @@ -44,6 +44,7 @@ const ( InvalidActivation MalfeasanceType = iota + 10 InvalidBallot InvalidHareMsg + DoubleMarry = MalfeasanceType(wire.DoubleMarry) ) // Handler processes MalfeasanceProof from gossip and, if deems it valid, propagates it to peers. diff --git a/malfeasance/wire/malfeasance.go b/malfeasance/wire/malfeasance.go index 0ffd8e4228..b4132ea568 100644 --- a/malfeasance/wire/malfeasance.go +++ b/malfeasance/wire/malfeasance.go @@ -15,7 +15,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" ) -//go:generate scalegen -types MalfeasanceProof,MalfeasanceGossip,AtxProof,BallotProof,HareProof,AtxProofMsg,BallotProofMsg,HareProofMsg,HareMetadata,InvalidPostIndexProof,InvalidPrevATXProof +//go:generate scalegen -types MalfeasanceProof,MalfeasanceGossip,AtxProof,BallotProof,HareProof,AtxProofMsg,BallotProofMsg,HareProofMsg,HareMetadata,InvalidPostIndexProof,InvalidPrevATXProof,DoubleMarryProof const ( MultipleATXs byte = iota + 1 @@ -23,6 +23,7 @@ const ( HareEquivocation InvalidPostIndex InvalidPrevATX + DoubleMarry ) type MalfeasanceProof struct { @@ -324,6 +325,12 @@ type InvalidPrevATXProof struct { func (p *InvalidPrevATXProof) isProof() {} +type DoubleMarryProof struct { + // TODO: implement +} + +func (p *DoubleMarryProof) isProof() {} + func MalfeasanceInfo(smesher types.NodeID, mp *MalfeasanceProof) string { var b strings.Builder b.WriteString(fmt.Sprintf("generate layer: %v\n", mp.Layer)) diff --git a/malfeasance/wire/malfeasance_scale.go b/malfeasance/wire/malfeasance_scale.go index 3ec88a1acc..6e23fd2175 100644 --- a/malfeasance/wire/malfeasance_scale.go +++ b/malfeasance/wire/malfeasance_scale.go @@ -422,3 +422,11 @@ func (t *InvalidPrevATXProof) DecodeScale(dec *scale.Decoder) (total int, err er } return total, nil } + +func (t *DoubleMarryProof) EncodeScale(enc *scale.Encoder) (total int, err error) { + return total, nil +} + +func (t *DoubleMarryProof) DecodeScale(dec *scale.Decoder) (total int, err error) { + return total, nil +} diff --git a/signing/signer.go b/signing/signer.go index f3b258c1be..c07f8589a1 100644 --- a/signing/signer.go +++ b/signing/signer.go @@ -24,6 +24,7 @@ const ( BALLOT = 2 HARE = 3 POET = 4 + MARRIAGE = 5 BEACON_FIRST_MSG = 10 BEACON_FOLLOWUP_MSG = 11 diff --git a/sql/identities/identities.go b/sql/identities/identities.go index 48ef4b36d8..613e19326f 100644 --- a/sql/identities/identities.go +++ b/sql/identities/identities.go @@ -15,7 +15,8 @@ import ( func SetMalicious(db sql.Executor, nodeID types.NodeID, proof []byte, received time.Time) error { _, err := db.Exec(`insert into identities (pubkey, proof, received) values (?1, ?2, ?3) - on conflict do nothing;`, + ON CONFLICT(pubkey) DO UPDATE SET proof = excluded.proof + WHERE proof IS NULL;`, func(stmt *sql.Statement) { stmt.BindBytes(1, nodeID.Bytes()) stmt.BindBytes(2, proof) @@ -30,7 +31,12 @@ func SetMalicious(db sql.Executor, nodeID types.NodeID, proof []byte, received t // IsMalicious returns true if identity is known to be malicious. func IsMalicious(db sql.Executor, nodeID types.NodeID) (bool, error) { - rows, err := db.Exec("select 1 from identities where pubkey = ?1;", + rows, err := db.Exec(` + SELECT 1 FROM identities + WHERE (marriage_atx = ( + SELECT marriage_atx FROM identities WHERE pubkey = ?1 AND marriage_atx IS NOT NULL) AND proof IS NOT NULL + ) + OR (pubkey = ?1 AND marriage_atx IS NULL AND proof IS NOT NULL);`, func(stmt *sql.Statement) { stmt.BindBytes(1, nodeID.Bytes()) }, nil) @@ -46,7 +52,7 @@ func GetMalfeasanceProof(db sql.Executor, nodeID types.NodeID) (*wire.Malfeasanc data []byte received time.Time ) - rows, err := db.Exec("select proof, received from identities where pubkey = ?1;", + rows, err := db.Exec("select proof, received from identities where pubkey = ?1 AND proof IS NOT NULL;", func(stmt *sql.Statement) { stmt.BindBytes(1, nodeID.Bytes()) }, func(stmt *sql.Statement) bool { @@ -123,3 +129,62 @@ func GetMalicious(db sql.Executor) (nids []types.NodeID, err error) { } return nids, nil } + +// Married checks if id is married. +// ID is married if it has non-null marriage_atx column. +func Married(db sql.Executor, id types.NodeID) (bool, error) { + rows, err := db.Exec("select 1 from identities where pubkey = ?1 and marriage_atx is not null;", + func(stmt *sql.Statement) { + stmt.BindBytes(1, id.Bytes()) + }, nil) + if err != nil { + return false, fmt.Errorf("married %v: %w", id, err) + } + return rows > 0, nil +} + +// Set marriage inserts marriage ATX for given identity. +// If identitty doesn't exist - create it. +func SetMarriage(db sql.Executor, id types.NodeID, atx types.ATXID) error { + _, err := db.Exec(` + INSERT INTO identities (pubkey, marriage_atx) + values (?1, ?2) + ON CONFLICT(pubkey) DO UPDATE SET marriage_atx = excluded.marriage_atx + WHERE marriage_atx IS NULL;`, + func(stmt *sql.Statement) { + stmt.BindBytes(1, id.Bytes()) + stmt.BindBytes(2, atx.Bytes()) + }, nil, + ) + if err != nil { + return fmt.Errorf("setting marriage %v: %w", id, err) + } + return nil +} + +// EquivocationSet returns all node IDs that are married to the given node ID +// including itself. +func EquivocationSet(db sql.Executor, id types.NodeID) ([]types.NodeID, error) { + var ids []types.NodeID + + rows, err := db.Exec(` + SELECT pubkey FROM identities + WHERE marriage_atx = (SELECT marriage_atx FROM identities WHERE pubkey = ?1) AND marriage_atx IS NOT NULL;`, + func(stmt *sql.Statement) { + stmt.BindBytes(1, id.Bytes()) + }, + func(stmt *sql.Statement) bool { + var nid types.NodeID + stmt.ColumnBytes(0, nid[:]) + ids = append(ids, nid) + return true + }) + if err != nil { + return nil, fmt.Errorf("getting marriage for %v: %w", id, err) + } + if rows == 0 { + return []types.NodeID{id}, nil + } + + return ids, nil +} diff --git a/sql/identities/identities_test.go b/sql/identities/identities_test.go index 58f5b4df4d..0feaeb018f 100644 --- a/sql/identities/identities_test.go +++ b/sql/identities/identities_test.go @@ -48,6 +48,10 @@ func TestMalicious(t *testing.T) { require.NoError(t, err) require.True(t, mal) + mal, err = IsMalicious(db, types.RandomNodeID()) + require.NoError(t, err) + require.False(t, mal) + got, err := GetMalfeasanceProof(db, nodeID) require.NoError(t, err) require.Equal(t, now.UTC(), got.Received().UTC()) @@ -114,3 +118,141 @@ func TestLoadMalfeasanceBlob(t *testing.T) { require.NoError(t, err) require.Equal(t, []int{len(blob1.Bytes), -1, len(blob2.Bytes)}, blobSizes) } + +func TestMarried(t *testing.T) { + t.Parallel() + t.Run("identity not in DB", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + + id := types.RandomNodeID() + married, err := Married(db, id) + require.NoError(t, err) + require.False(t, married) + + atx := types.RandomATXID() + require.NoError(t, SetMarriage(db, id, atx)) + + married, err = Married(db, id) + require.NoError(t, err) + require.True(t, married) + }) + t.Run("identity in DB", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + + id := types.RandomNodeID() + // add ID in the DB + SetMalicious(db, id, types.RandomBytes(11), time.Now()) + + married, err := Married(db, id) + require.NoError(t, err) + require.False(t, married) + + require.NoError(t, SetMarriage(db, id, types.RandomATXID())) + + married, err = Married(db, id) + require.NoError(t, err) + require.True(t, married) + }) +} + +func TestEquivocationSet(t *testing.T) { + t.Parallel() + t.Run("equivocation set of married IDs", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + + atx := types.RandomATXID() + ids := []types.NodeID{ + types.RandomNodeID(), + types.RandomNodeID(), + types.RandomNodeID(), + } + for _, id := range ids { + require.NoError(t, SetMarriage(db, id, atx)) + } + + for _, id := range ids { + married, err := Married(db, id) + require.NoError(t, err) + require.True(t, married) + set, err := EquivocationSet(db, id) + require.NoError(t, err) + require.ElementsMatch(t, ids, set) + } + }) + t.Run("equivocation set for unmarried ID contains itself only", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + id := types.RandomNodeID() + set, err := EquivocationSet(db, id) + require.NoError(t, err) + require.Equal(t, []types.NodeID{id}, set) + }) + t.Run("can't escape the marriage", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + atx := types.RandomATXID() + ids := []types.NodeID{ + types.RandomNodeID(), + types.RandomNodeID(), + } + for _, id := range ids { + require.NoError(t, SetMarriage(db, id, atx)) + } + + for _, id := range ids { + set, err := EquivocationSet(db, id) + require.NoError(t, err) + require.ElementsMatch(t, ids, set) + } + + // try to marry via another random ATX + // the set should remain intact + require.NoError(t, SetMarriage(db, ids[0], types.RandomATXID())) + for _, id := range ids { + set, err := EquivocationSet(db, id) + require.NoError(t, err) + require.ElementsMatch(t, ids, set) + } + }) + t.Run("married doesn't become malicious immediately", func(t *testing.T) { + db := sql.InMemory() + atx := types.RandomATXID() + id := types.RandomNodeID() + require.NoError(t, SetMarriage(db, id, atx)) + + malicious, err := IsMalicious(db, id) + require.NoError(t, err) + require.False(t, malicious) + + proof, err := GetMalfeasanceProof(db, id) + require.ErrorIs(t, err, sql.ErrNotFound) + require.Nil(t, proof) + + ids, err := GetMalicious(db) + require.NoError(t, err) + require.Empty(t, ids) + }) + t.Run("all IDs in equivocation set are malicious if one is", func(t *testing.T) { + t.Parallel() + db := sql.InMemory() + atx := types.RandomATXID() + ids := []types.NodeID{ + types.RandomNodeID(), + types.RandomNodeID(), + } + for _, id := range ids { + require.NoError(t, SetMarriage(db, id, atx)) + } + + require.NoError(t, SetMalicious(db, ids[0], []byte("proof"), time.Now())) + + for _, id := range ids { + malicious, err := IsMalicious(db, id) + require.NoError(t, err) + require.True(t, malicious) + } + }) +} diff --git a/sql/migrations/state/0019_marriages.sql b/sql/migrations/state/0019_marriages.sql new file mode 100644 index 0000000000..66bc7d1128 --- /dev/null +++ b/sql/migrations/state/0019_marriages.sql @@ -0,0 +1 @@ +ALTER TABLE identities ADD COLUMN marriage_atx CHAR(32);