Skip to content

Commit

Permalink
Return ATX version along with blob (#5922)
Browse files Browse the repository at this point in the history
## Motivation

Need to get the ATX version from the `atx_blobs` table to decode ATX w/o knowing epoch->version mapping everywhere in the code.
  • Loading branch information
poszu committed May 13, 2024
1 parent cbff8b0 commit 2452f59
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 82 deletions.
2 changes: 1 addition & 1 deletion activation/activation.go
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ func (b *Builder) Regossip(ctx context.Context, nodeID types.NodeID) error {
return err
}
var blob sql.Blob
if err := atxs.LoadBlob(ctx, b.db, atx.Bytes(), &blob); err != nil {
if _, err := atxs.LoadBlob(ctx, b.db, atx.Bytes(), &blob); err != nil {
return fmt.Errorf("get blob %s: %w", atx.ShortString(), err)
}
if len(blob.Bytes) == 0 {
Expand Down
4 changes: 3 additions & 1 deletion activation/activation_multi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ func TestRegossip(t *testing.T) {
}

var blob sql.Blob
require.NoError(t, atxs.LoadBlob(context.Background(), tab.db, refAtx.ID().Bytes(), &blob))
ver, err := atxs.LoadBlob(context.Background(), tab.db, refAtx.ID().Bytes(), &blob)
require.NoError(t, err)
require.Equal(t, types.AtxV1, ver)

// atx will be regossiped once (by the smesher)
tab.mclock.EXPECT().CurrentLayer().Return(layer)
Expand Down
17 changes: 11 additions & 6 deletions activation/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ func (h *Handler) handleAtx(
// Obtain the atxSignature of the given ATX.
func atxSignature(ctx context.Context, db sql.Executor, id types.ATXID) (types.EdSignature, error) {
var blob sql.Blob
if err := atxs.LoadBlob(ctx, db, id.Bytes(), &blob); err != nil {
v, err := atxs.LoadBlob(ctx, db, id.Bytes(), &blob)
if err != nil {
return types.EmptyEdSignature, err
}

Expand All @@ -306,10 +307,14 @@ func atxSignature(ctx context.Context, db sql.Executor, id types.ATXID) (types.E
return types.EmptyEdSignature, fmt.Errorf("can't get signature for a golden (checkpointed) ATX: %s", id)
}

// TODO: decide how to decode based on the `version` column.
var prev wire.ActivationTxV1
if err := codec.Decode(blob.Bytes, &prev); err != nil {
return types.EmptyEdSignature, fmt.Errorf("decoding previous atx: %w", err)
// TODO: implement for ATX V2
switch v {
case types.AtxV1:
var atx wire.ActivationTxV1
if err := codec.Decode(blob.Bytes, &atx); err != nil {
return types.EmptyEdSignature, fmt.Errorf("decoding atx v1: %w", err)
}
return atx.Signature, nil
}
return prev.Signature, nil
return types.EmptyEdSignature, fmt.Errorf("unsupported ATX version: %v", v)
}
6 changes: 5 additions & 1 deletion activation/handler_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ func (h *HandlerV1) commitment(ctx context.Context, atx *wire.ActivationTxV1) (t
// to use the effective num units.
func (h *HandlerV1) previous(ctx context.Context, atx *wire.ActivationTxV1) (*types.ActivationTx, error) {
var blob sql.Blob
if err := atxs.LoadBlob(ctx, h.cdb, atx.PrevATXID[:], &blob); err != nil {
v, err := atxs.LoadBlob(ctx, h.cdb, atx.PrevATXID[:], &blob)
if err != nil {
return nil, err
}

Expand All @@ -150,6 +151,9 @@ func (h *HandlerV1) previous(ctx context.Context, atx *wire.ActivationTxV1) (*ty
}
return atx, nil
}
if v != types.AtxV1 {
return nil, fmt.Errorf("previous atx %s is not of version 1", atx.PrevATXID)
}

var prev wire.ActivationTxV1
if err := codec.Decode(blob.Bytes, &prev); err != nil {
Expand Down
47 changes: 26 additions & 21 deletions activation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,34 +373,39 @@ type atxDeps struct {

func (v *Validator) getAtxDeps(ctx context.Context, db sql.Executor, id types.ATXID) (*atxDeps, error) {
var blob sql.Blob
if err := atxs.LoadBlob(ctx, v.db, id.Bytes(), &blob); err != nil {
version, err := atxs.LoadBlob(ctx, v.db, id.Bytes(), &blob)
if err != nil {
return nil, fmt.Errorf("getting blob for %s: %w", id, err)
}

// TODO: decide about version based on `version` column
var atx wire.ActivationTxV1
if err := codec.Decode(blob.Bytes, &atx); err != nil {
return nil, fmt.Errorf("decoding ATX blob: %w", err)
}
var commitment types.ATXID
if atx.CommitmentATXID != nil {
commitment = *atx.CommitmentATXID
} else {
catx, err := atxs.CommitmentATX(v.db, atx.SmesherID)
if err != nil {
return nil, fmt.Errorf("getting commitment ATX: %w", err)
// TODO: implement ATX V2
switch version {
case types.AtxV1:
var commitment types.ATXID
var atx wire.ActivationTxV1
if err := codec.Decode(blob.Bytes, &atx); err != nil {
return nil, fmt.Errorf("decoding ATX blob: %w", err)
}
if atx.CommitmentATXID != nil {
commitment = *atx.CommitmentATXID
} else {
catx, err := atxs.CommitmentATX(v.db, atx.SmesherID)
if err != nil {
return nil, fmt.Errorf("getting commitment ATX: %w", err)
}
commitment = catx
}
commitment = catx
}

deps := &atxDeps{
nipost: *wire.NiPostFromWireV1(atx.NIPost),
positioning: atx.PositioningATXID,
previous: atx.PrevATXID,
commitment: commitment,
deps := &atxDeps{
nipost: *wire.NiPostFromWireV1(atx.NIPost),
positioning: atx.PositioningATXID,
previous: atx.PrevATXID,
commitment: commitment,
}
return deps, nil
}

return deps, nil
return nil, fmt.Errorf("unsupported ATX version: %v", version)
}

func (v *Validator) verifyChainWithOpts(
Expand Down
2 changes: 1 addition & 1 deletion checkpoint/recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ func collect(
return err
}
var blob sql.Blob
err = atxs.LoadBlob(context.Background(), db, ref.Bytes(), &blob)
_, err = atxs.LoadBlob(context.Background(), db, ref.Bytes(), &blob)
if err != nil {
return fmt.Errorf("load atx blob %v: %w", ref, err)
}
Expand Down
37 changes: 23 additions & 14 deletions checkpoint/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,29 +168,38 @@ func backupOldDb(fs afero.Fs, srcDir, dbFile string) (string, error) {

func positioningATX(ctx context.Context, db sql.Executor, id types.ATXID) (types.ATXID, error) {
var blob sql.Blob
if err := atxs.LoadBlob(ctx, db, id.Bytes(), &blob); err != nil {
version, err := atxs.LoadBlob(ctx, db, id.Bytes(), &blob)
if err != nil {
return types.EmptyATXID, fmt.Errorf("get blob %s: %w", id, err)
}
// TODO: decide how to decode based on the `version` column
var atx wire.ActivationTxV1
if err := codec.Decode(blob.Bytes, &atx); err != nil {
return types.EmptyATXID, fmt.Errorf("decode %s: %w", id, err)
// TODO: implement for ATX V2
switch version {
case types.AtxV1:
var atx wire.ActivationTxV1
if err := codec.Decode(blob.Bytes, &atx); err != nil {
return types.EmptyATXID, fmt.Errorf("decode %s: %w", id, err)
}
return atx.PositioningATXID, nil
}

return atx.PositioningATXID, nil
return types.EmptyATXID, fmt.Errorf("unsupported ATX version: %v", version)
}

func poetProofRef(ctx context.Context, db sql.Executor, id types.ATXID) (types.PoetProofRef, error) {
var blob sql.Blob
if err := atxs.LoadBlob(ctx, db, id.Bytes(), &blob); err != nil {
version, err := atxs.LoadBlob(ctx, db, id.Bytes(), &blob)
if err != nil {
return types.PoetProofRef{}, fmt.Errorf("getting blob for %s: %w", id, err)
}

// TODO: decide about version based the `version` column in `atx_blobs`
var atx wire.ActivationTxV1
if err := codec.Decode(blob.Bytes, &atx); err != nil {
return types.PoetProofRef{}, fmt.Errorf("decoding ATX blob: %w", err)
}
// TODO: implement for ATX V2
switch version {
case types.AtxV1:
var atx wire.ActivationTxV1
if err := codec.Decode(blob.Bytes, &atx); err != nil {
return types.PoetProofRef{}, fmt.Errorf("decoding ATX blob: %w", err)
}

return types.PoetProofRef(atx.NIPost.PostMetadata.Challenge), nil
return types.PoetProofRef(atx.NIPost.PostMetadata.Challenge), nil
}
return types.PoetProofRef{}, fmt.Errorf("unsupported ATX version: %v", version)
}
5 changes: 4 additions & 1 deletion datastore/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,10 @@ type (
)

var loadBlobDispatch = map[Hint]loadBlobFunc{
ATXDB: atxs.LoadBlob,
ATXDB: func(ctx context.Context, db sql.Executor, key []byte, blob *sql.Blob) error {
_, err := atxs.LoadBlob(ctx, db, key, blob)
return err
},
BallotDB: ballots.LoadBlob,
BlockDB: blocks.LoadBlob,
TXDB: transactions.LoadBlob,
Expand Down
70 changes: 45 additions & 25 deletions sql/atxs/atxs.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,37 +324,57 @@ func GetBlobSizes(db sql.Executor, ids [][]byte) (sizes []int, err error) {
}

// LoadBlob loads ATX as an encoded blob, ready to be sent over the wire.
func LoadBlob(ctx context.Context, db sql.Executor, id []byte, blob *sql.Blob) error {
//
// SAFETY: The contents of the returned blob MUST NOT be modified.
// They might point to the inner sql cache and modifying them would
// corrupt the cache.
func LoadBlob(ctx context.Context, db sql.Executor, id []byte, blob *sql.Blob) (types.AtxVersion, error) {
if sql.IsCached(db) {
b, err := getBlob(ctx, db, id)
type cachedBlob struct {
version types.AtxVersion
buf []byte
}
cacheKey := sql.QueryCacheKey(CacheKindATXBlob, string(id))
cached, err := sql.WithCachedValue(ctx, db, cacheKey, func(context.Context) (*cachedBlob, error) {
// We don't use the provided blob in this case to avoid
// caching references to the underlying slice (subsequent calls would modify it).
var blob sql.Blob
v, err := getBlob(ctx, db, id, &blob)
if err != nil {
return nil, err
}
return &cachedBlob{version: v, buf: blob.Bytes}, nil
})
if err != nil {
return err
return 0, err
}
blob.Bytes = b
return nil
// Here we return the cached slice, hence the safety warning.
blob.Bytes = cached.buf
return cached.version, nil
}
return sql.LoadBlob(db, "select atx from atx_blobs where id = ?1", id, blob)

return getBlob(ctx, db, id, blob)
}

func getBlob(ctx context.Context, db sql.Executor, id []byte) (buf []byte, err error) {
cacheKey := sql.QueryCacheKey(CacheKindATXBlob, string(id))
return sql.WithCachedValue(ctx, db, cacheKey, func(context.Context) ([]byte, error) {
if rows, err := db.Exec("select atx from atx_blobs where id = ?1",
func(stmt *sql.Statement) {
stmt.BindBytes(1, id)
}, func(stmt *sql.Statement) bool {
if stmt.ColumnLen(0) > 0 {
buf = make([]byte, stmt.ColumnLen(0))
stmt.ColumnBytes(0, buf)
}
return true
}); err != nil {
return nil, fmt.Errorf("get %s: %w", types.BytesToHash(id), err)
} else if rows == 0 {
return nil, fmt.Errorf("%w: atx %s", sql.ErrNotFound, types.BytesToHash(id))
}
return buf, nil
})
func getBlob(ctx context.Context, db sql.Executor, id []byte, blob *sql.Blob) (types.AtxVersion, error) {
var version types.AtxVersion
rows, err := db.Exec("select atx, version from atx_blobs where id = ?1",
func(stmt *sql.Statement) {
stmt.BindBytes(1, id)
}, func(stmt *sql.Statement) bool {
blob.FromColumn(stmt, 0)
version = types.AtxVersion(stmt.ColumnInt(1))
return true
},
)
if err != nil {
return 0, fmt.Errorf("get %v: %w", types.BytesToHash(id), err)
}
if rows == 0 {
return 0, fmt.Errorf("%w: atx %s", sql.ErrNotFound, types.BytesToHash(id))
}

return version, nil
}

// NonceByID retrieves VRFNonce corresponding to the specified ATX ID.
Expand Down
Loading

0 comments on commit 2452f59

Please sign in to comment.