Skip to content

Commit

Permalink
ATX handler V2 supports marriages (#6010)
Browse files Browse the repository at this point in the history
## Motivation

Closes #6007
  • Loading branch information
poszu committed Jun 20, 2024
1 parent e90a964 commit 65a23b0
Show file tree
Hide file tree
Showing 11 changed files with 591 additions and 27 deletions.
137 changes: 130 additions & 7 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"math"
"slices"
"time"

"github.com/spacemeshos/post/shared"
Expand Down Expand Up @@ -103,6 +104,11 @@ func (h *HandlerV2) processATX(
return nil, fmt.Errorf("validating positioning atx: %w", err)
}

marrying, err := h.validateMarriages(watx)
if err != nil {
return nil, fmt.Errorf("validating marriages: %w", err)
}

parts, proof, err := h.syntacticallyValidateDeps(ctx, watx)
if err != nil {
return nil, fmt.Errorf("atx %s syntactically invalid based on deps: %w", watx.ID(), err)
Expand Down Expand Up @@ -135,7 +141,7 @@ func (h *HandlerV2) processATX(
atx.SetID(watx.ID())
atx.SetReceived(received)

proof, err = h.storeAtx(ctx, atx, watx)
proof, err = h.storeAtx(ctx, atx, watx, marrying)
if err != nil {
return nil, fmt.Errorf("cannot store atx %s: %w", atx.ShortString(), err)
}
Expand All @@ -147,7 +153,6 @@ func (h *HandlerV2) processATX(

// Syntactically validate an ATX.
// TODOs:
// 1. support marriages
// 2. support merged ATXs.
func (h *HandlerV2) syntacticallyValidate(ctx context.Context, atx *wire.ActivationTxV2) error {
if !h.edVerifier.Verify(signing.ATX, atx.SmesherID, atx.SignedBytes(), atx.Signature) {
Expand All @@ -156,9 +161,15 @@ func (h *HandlerV2) syntacticallyValidate(ctx context.Context, atx *wire.Activat
if atx.PositioningATX == types.EmptyATXID {
return errors.New("empty positioning atx")
}
// TODO: support marriages
if len(atx.Marriages) != 0 {
return errors.New("marriages are not supported")
// Marriage ATX must contain a self-signed certificate.
// It's identified by having ReferenceAtx == EmptyATXID.
idx := slices.IndexFunc(atx.Marriages, func(cert wire.MarriageCertificate) bool {
return cert.ReferenceAtx == types.EmptyATXID
})
if idx == -1 {
return errors.New("signer must marry itself")
}
}

current := h.clock.CurrentLayer().GetEpoch()
Expand Down Expand Up @@ -216,6 +227,9 @@ func (h *HandlerV2) syntacticallyValidate(ctx context.Context, atx *wire.Activat
switch {
case atx.MarriageATX != nil:
// Merged ATX
if len(atx.Marriages) != 0 {
return errors.New("merged atx cannot have marriages")
}
// TODO: support merged ATXs
return errors.New("atx merge is not supported")
default:
Expand Down Expand Up @@ -278,6 +292,9 @@ func (h *HandlerV2) collectAtxDeps(atx *wire.ActivationTxV2) ([]types.Hash32, []
if atx.MarriageATX != nil {
ids = append(ids, *atx.MarriageATX)
}
for _, cert := range atx.Marriages {
ids = append(ids, cert.ReferenceAtx)
}

filtered := make(map[types.ATXID]struct{})
for _, id := range ids {
Expand Down Expand Up @@ -395,6 +412,34 @@ func (h *HandlerV2) validatePositioningAtx(publish types.EpochID, golden, positi
return posAtx.TickHeight(), nil
}

// Validate marriages and return married IDs.
// Note: The order of returned IDs is important and must match the order of the marriage certificates.
// The MarriageIndex in PoST proof matches the index in this marriage slice.
func (h *HandlerV2) validateMarriages(atx *wire.ActivationTxV2) ([]types.NodeID, error) {
if len(atx.Marriages) == 0 {
return nil, nil
}
var marryingIDs []types.NodeID
for i, m := range atx.Marriages {
var id types.NodeID
if m.ReferenceAtx == types.EmptyATXID {
id = atx.SmesherID
} else {
atx, err := atxs.Get(h.cdb, m.ReferenceAtx)
if err != nil {
return nil, fmt.Errorf("getting marriage reference atx: %w", err)
}
id = atx.SmesherID
}

if !h.edVerifier.Verify(signing.MARRIAGE, id, atx.SmesherID.Bytes(), m.Signature) {
return nil, fmt.Errorf("invalid marriage[%d] signature", i)
}
marryingIDs = append(marryingIDs, id)
}
return marryingIDs, nil
}

type atxParts struct {
leaves uint64
effectiveUnits uint32
Expand Down Expand Up @@ -529,6 +574,7 @@ func (h *HandlerV2) checkMalicious(
ctx context.Context,
tx *sql.Tx,
watx *wire.ActivationTxV2,
marrying []types.NodeID,
) (bool, *mwire.MalfeasanceProof, error) {
malicious, err := identities.IsMalicious(tx, watx.SmesherID)
if err != nil {
Expand All @@ -538,6 +584,14 @@ func (h *HandlerV2) checkMalicious(
return true, nil, nil
}

proof, err := h.checkDoubleMarry(tx, watx, marrying)
if err != nil {
return false, nil, fmt.Errorf("checking double marry: %w", err)
}
if proof != nil {
return true, proof, nil
}

// TODO: contextual validation:
// 1. check double-publish
// 2. check previous ATX
Expand All @@ -548,24 +602,63 @@ func (h *HandlerV2) checkMalicious(
return false, nil, nil
}

func (h *HandlerV2) checkDoubleMarry(
tx *sql.Tx,
watx *wire.ActivationTxV2,
marrying []types.NodeID,
) (*mwire.MalfeasanceProof, error) {
for _, id := range marrying {
married, err := identities.Married(tx, id)
if err != nil {
return nil, fmt.Errorf("checking if ID is married: %w", err)
}
if married {
proof := &mwire.MalfeasanceProof{
Proof: mwire.Proof{
Type: mwire.DoubleMarry,
Data: &mwire.DoubleMarryProof{},
},
}
return proof, nil
}
}
return nil, nil
}

// Store an ATX in the DB.
// TODO: detect malfeasance and create proofs.
func (h *HandlerV2) storeAtx(
ctx context.Context,
atx *types.ActivationTx,
watx *wire.ActivationTxV2,
marrying []types.NodeID,
) (*mwire.MalfeasanceProof, error) {
var (
malicious bool
proof *mwire.MalfeasanceProof
)
if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error {
var err error
malicious, proof, err = h.checkMalicious(ctx, tx, watx)
malicious, proof, err = h.checkMalicious(ctx, tx, watx, marrying)
if err != nil {
return fmt.Errorf("check malicious: %w", err)
}

if len(marrying) != 0 {
for _, id := range marrying {
if err := identities.SetMarriage(tx, id, atx.ID()); err != nil {
return err
}
}
if !malicious && proof == nil {
// We check for malfeasance again becase the marriage increased the equivocation set.
malicious, err = identities.IsMalicious(tx, atx.SmesherID)
if err != nil {
return fmt.Errorf("re-checking if smesherID is malicious: %w", err)
}
}
}

err = atxs.Add(tx, atx)
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
Expand All @@ -576,9 +669,39 @@ func (h *HandlerV2) storeAtx(
}

atxs.AtxAdded(h.cdb, atx)

var allMalicious map[types.NodeID]struct{}
if malicious || proof != nil {
// Combine IDs from the present equivocation set for atx.SmesherID and IDs in atx.Marriages.
allMalicious = make(map[types.NodeID]struct{})

set, err := identities.EquivocationSet(h.cdb, atx.SmesherID)
if err != nil {
return nil, fmt.Errorf("getting equivocation set: %w", err)
}
for _, id := range set {
allMalicious[id] = struct{}{}
}
for _, id := range marrying {
allMalicious[id] = struct{}{}
}
}
if proof != nil {
h.cdb.CacheMalfeasanceProof(atx.SmesherID, proof)
h.tortoise.OnMalfeasance(atx.SmesherID)
encoded, err := codec.Encode(proof)
if err != nil {
return nil, fmt.Errorf("encoding malfeasance proof: %w", err)
}

for id := range allMalicious {
if err := identities.SetMalicious(h.cdb, id, encoded, atx.Received()); err != nil {
return nil, fmt.Errorf("setting malfeasance proof: %w", err)
}
h.cdb.CacheMalfeasanceProof(id, proof)
}
}

for id := range allMalicious {
h.tortoise.OnMalfeasance(id)
}

h.beacon.OnAtx(atx)
Expand Down

0 comments on commit 65a23b0

Please sign in to comment.