diff --git a/sharding/collation.go b/sharding/collation.go index 19f1b8d05bd0..d93d795fe4a0 100644 --- a/sharding/collation.go +++ b/sharding/collation.go @@ -37,6 +37,12 @@ type collationHeaderData struct { ProposerSignature []byte // the proposer's signature for calculating collation hash. } +// NewCollation initializes a collation and leaves it up to clients to serialize, deserialize +// and provide the body and transactions upon creation. +func NewCollation(header *CollationHeader, body []byte, transactions []*types.Transaction) *Collation { + return &Collation{header, body, transactions} +} + // NewCollationHeader initializes a collation header struct. func NewCollationHeader(shardID *big.Int, chunkRoot *common.Hash, period *big.Int, proposerAddress *common.Address, proposerSignature []byte) *CollationHeader { data := collationHeaderData{ @@ -94,12 +100,6 @@ func (c *Collation) ProposerAddress() *common.Address { return c.header.data.ProposerAddress } -// AddTransaction adds to the collation's body of tx blobs. -func (c *Collation) AddTransaction(tx *types.Transaction) { - // TODO: Include blob serialization instead. - c.transactions = append(c.transactions, tx) -} - // CalculateChunkRoot updates the collation header's chunk root based on the body. func (c *Collation) CalculateChunkRoot() { // TODO: this needs to be based on blob serialization. diff --git a/sharding/collation_test.go b/sharding/collation_test.go index d1d8daa67da8..46bc5521a3f0 100644 --- a/sharding/collation_test.go +++ b/sharding/collation_test.go @@ -1,48 +1,80 @@ package sharding import ( + "math/big" "testing" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" ) -// TODO: this test needs to change as we will be serializing tx's into blobs -// within the collation body instead. - -func TestCollation_AddTransactions(t *testing.T) { - tests := []struct { - transactions []*types.Transaction - }{ - { - transactions: []*types.Transaction{ - makeTxWithGasLimit(0), - makeTxWithGasLimit(1), - makeTxWithGasLimit(2), - makeTxWithGasLimit(3), - }, - }, { - transactions: []*types.Transaction{}, - }, +func TestCollation_Transactions(t *testing.T) { + header := NewCollationHeader(big.NewInt(1), nil, big.NewInt(1), nil, []byte{}) + body := []byte{} + transactions := []*types.Transaction{ + makeTxWithGasLimit(0), + makeTxWithGasLimit(1), + makeTxWithGasLimit(2), + makeTxWithGasLimit(3), } - for _, tt := range tests { - c := &Collation{} - for _, tx := range tt.transactions { - c.AddTransaction(tx) - } - results := c.Transactions() - if len(results) != len(tt.transactions) { - t.Fatalf("Wrong number of transactions. want=%d. got=%d", len(tt.transactions), len(results)) - } - for i, tx := range tt.transactions { - if results[i] != tx { - t.Fatalf("Mismatched transactions. wanted=%+v. got=%+v", tt.transactions, results) - } + collation := NewCollation(header, body, transactions) + + for i, tx := range collation.Transactions() { + if tx.Hash().String() != transactions[i].Hash().String() { + t.Errorf("initialized collation struct does not contain correct transactions") } } } +func TestCollation_ProposerAddress(t *testing.T) { + proposerAddr := common.StringToAddress("proposer") + header := NewCollationHeader(big.NewInt(1), nil, big.NewInt(1), &proposerAddr, []byte{}) + body := []byte{} + + collation := NewCollation(header, body, nil) + + if collation.ProposerAddress().String() != proposerAddr.String() { + t.Errorf("initialized collation does not contain correct proposer address") + } +} + +// TODO: this test needs to change as we will be serializing tx's into blobs +// within the collation body instead. + +// func TestCollation_AddTransactions(t *testing.T) { +// tests := []struct { +// transactions []*types.Transaction +// }{ +// { +// transactions: []*types.Transaction{ +// makeTxWithGasLimit(0), +// makeTxWithGasLimit(1), +// makeTxWithGasLimit(2), +// makeTxWithGasLimit(3), +// }, +// }, { +// transactions: []*types.Transaction{}, +// }, +// } + +// for _, tt := range tests { +// c := &Collation{} +// for _, tx := range tt.transactions { +// c.AddTransaction(tx) +// } +// results := c.Transactions() +// if len(results) != len(tt.transactions) { +// t.Fatalf("Wrong number of transactions. want=%d. got=%d", len(tt.transactions), len(results)) +// } +// for i, tx := range tt.transactions { +// if results[i] != tx { +// t.Fatalf("Mismatched transactions. wanted=%+v. got=%+v", tt.transactions, results) +// } +// } +// } +// } + func makeTxWithGasLimit(gl uint64) *types.Transaction { return types.NewTransaction(0 /*nonce*/, common.HexToAddress("0x0") /*to*/, nil /*amount*/, gl, nil /*gasPrice*/, nil /*data*/) } diff --git a/sharding/database/inmemory.go b/sharding/database/inmemory.go index 0e2a5c44b522..02150e91c276 100644 --- a/sharding/database/inmemory.go +++ b/sharding/database/inmemory.go @@ -14,8 +14,8 @@ type ShardKV struct { kv map[common.Hash][]byte } -// MakeShardKV initializes a keyval store in memory. -func MakeShardKV() *ShardKV { +// NewShardKV initializes a keyval store in memory. +func NewShardKV() *ShardKV { return &ShardKV{kv: make(map[common.Hash][]byte)} } diff --git a/sharding/database/inmemory_test.go b/sharding/database/inmemory_test.go index d98a937113ad..ca4659b8ba4c 100644 --- a/sharding/database/inmemory_test.go +++ b/sharding/database/inmemory_test.go @@ -7,7 +7,7 @@ import ( ) func Test_ShardKVPut(t *testing.T) { - kv := MakeShardKV() + kv := NewShardKV() hash := common.StringToHash("ralph merkle") if err := kv.Put(hash, []byte{1, 2, 3}); err != nil { @@ -16,7 +16,7 @@ func Test_ShardKVPut(t *testing.T) { } func Test_ShardKVHas(t *testing.T) { - kv := MakeShardKV() + kv := NewShardKV() hash := common.StringToHash("ralph merkle") if err := kv.Put(hash, []byte{1, 2, 3}); err != nil { @@ -34,7 +34,7 @@ func Test_ShardKVHas(t *testing.T) { } func Test_ShardKVGet(t *testing.T) { - kv := MakeShardKV() + kv := NewShardKV() hash := common.StringToHash("ralph merkle") if err := kv.Put(hash, []byte{1, 2, 3}); err != nil { @@ -57,7 +57,7 @@ func Test_ShardKVGet(t *testing.T) { } func Test_ShardKVDelete(t *testing.T) { - kv := MakeShardKV() + kv := NewShardKV() hash := common.StringToHash("ralph merkle") if err := kv.Put(hash, []byte{1, 2, 3}); err != nil { diff --git a/sharding/shard.go b/sharding/shard.go index 0b00c4c4563e..809f6e493285 100644 --- a/sharding/shard.go +++ b/sharding/shard.go @@ -22,8 +22,8 @@ type Shard struct { shardID *big.Int } -// MakeShard creates an instance of a Shard struct given a shardID. -func MakeShard(shardID *big.Int, shardDB shardBackend) *Shard { +// NewShard creates an instance of a Shard struct given a shardID. +func NewShard(shardID *big.Int, shardDB shardBackend) *Shard { return &Shard{ shardID: shardID, shardDB: shardDB, @@ -66,15 +66,15 @@ func (s *Shard) CollationByHash(headerHash *common.Hash) (*Collation, error) { if err != nil { return nil, err } - if header == nil { - return nil, fmt.Errorf("header not found") - } body, err := s.BodyByChunkRoot(header.ChunkRoot()) if err != nil { return nil, err } - return &Collation{header: header, body: body}, nil + // TODO: deserializes the body into a tx's object instead of using + // nil as the third arg to MakeCollation. + col := NewCollation(header, body, nil) + return col, nil } // CanonicalHeaderHash gets a collation header hash that has been set as @@ -104,8 +104,12 @@ func (s *Shard) CanonicalHeaderHash(shardID *big.Int, period *big.Int) (*common. func (s *Shard) CanonicalCollation(shardID *big.Int, period *big.Int) (*Collation, error) { h, err := s.CanonicalHeaderHash(shardID, period) if err != nil { - return nil, fmt.Errorf("hash not found: %v", err) + return nil, fmt.Errorf("error while getting canoncial header hash: %v", err) } + if h == nil { + return nil, fmt.Errorf("header not found") + } + collation, err := s.CollationByHash(h) if err != nil { return nil, fmt.Errorf("no canonical collation found for hash: %v", err) diff --git a/sharding/shard_test.go b/sharding/shard_test.go index 55bb99eed59e..78c40fa3f692 100644 --- a/sharding/shard_test.go +++ b/sharding/shard_test.go @@ -22,15 +22,15 @@ func TestShard_ValidateShardID(t *testing.T) { emptyHash := common.StringToHash("") emptyAddr := common.StringToAddress("") header := NewCollationHeader(big.NewInt(1), &emptyHash, big.NewInt(1), &emptyAddr, []byte{}) - shardDB := database.MakeShardKV() - shard := MakeShard(big.NewInt(3), shardDB) + shardDB := database.NewShardKV() + shard := NewShard(big.NewInt(3), shardDB) if err := shard.ValidateShardID(header); err == nil { t.Errorf("ShardID validation incorrect. Function should throw error when ShardID's do not match. want=%d. got=%d", header.ShardID().Int64(), shard.ShardID().Int64()) } header2 := NewCollationHeader(big.NewInt(100), &emptyHash, big.NewInt(1), &emptyAddr, []byte{}) - shard2 := MakeShard(big.NewInt(100), shardDB) + shard2 := NewShard(big.NewInt(100), shardDB) if err := shard2.ValidateShardID(header2); err != nil { t.Errorf("ShardID validation incorrect. Function should not throw error when ShardID's match. want=%d. got=%d", header2.ShardID().Int64(), shard2.ShardID().Int64()) @@ -41,8 +41,8 @@ func TestShard_HeaderByHash(t *testing.T) { emptyHash := common.StringToHash("") emptyAddr := common.StringToAddress("") header := NewCollationHeader(big.NewInt(1), &emptyHash, big.NewInt(1), &emptyAddr, []byte{}) - shardDB := database.MakeShardKV() - shard := MakeShard(big.NewInt(1), shardDB) + shardDB := database.NewShardKV() + shard := NewShard(big.NewInt(1), shardDB) if err := shard.SaveHeader(header); err != nil { t.Fatalf("cannot save collation header: %v", err) @@ -70,16 +70,25 @@ func TestShard_CollationByHash(t *testing.T) { body: []byte{1, 2, 3}, } + // TODO: check if body by chunk root fails! + // We set the chunk root. collation.CalculateChunkRoot() - shardDB := database.MakeShardKV() - shard := MakeShard(big.NewInt(1), shardDB) + shardDB := database.NewShardKV() + shard := NewShard(big.NewInt(1), shardDB) + + hash := collation.Header().Hash() + + // should not be able to fetch collation without saving first. + _, err := shard.CollationByHash(&hash) + if err == nil { + t.Errorf("should not be able to fetch collation before saving first") + } if err := shard.SaveCollation(collation); err != nil { t.Fatalf("cannot save collation: %v", err) } - hash := collation.Header().Hash() dbCollation, err := shard.CollationByHash(&hash) if err != nil { @@ -100,8 +109,13 @@ func TestShard_CanonicalHeaderHash(t *testing.T) { emptyHash := common.StringToHash("") header := NewCollationHeader(shardID, &emptyHash, period, &proposerAddress, proposerSignature) - shardDB := database.MakeShardKV() - shard := MakeShard(shardID, shardDB) + shardDB := database.NewShardKV() + shard := NewShard(shardID, shardDB) + + // should not be able to set as canonical before saving the header. + if err := shard.SetCanonical(header); err == nil { + t.Errorf("cannot set as canonical before saving header first") + } if err := shard.SaveHeader(header); err != nil { t.Fatalf("failed to save header to shardDB: %v", err) @@ -131,8 +145,9 @@ func TestShard_CanonicalCollation(t *testing.T) { emptyHash := common.StringToHash("") header := NewCollationHeader(shardID, &emptyHash, period, &proposerAddress, proposerSignature) - shardDB := database.MakeShardKV() - shard := MakeShard(shardID, shardDB) + shardDB := database.NewShardKV() + shard := NewShard(shardID, shardDB) + otherShard := NewShard(big.NewInt(2), shardDB) collation := &Collation{ header: header, @@ -150,6 +165,11 @@ func TestShard_CanonicalCollation(t *testing.T) { t.Fatalf("failed to set header as canonical: %v", err) } + // should not be allowed to set as canonical in a different shard. + if err := otherShard.SetCanonical(header); err == nil { + t.Errorf("should not be able to set header with ShardID=%v as canonical in other shard=%v", header.ShardID(), big.NewInt(2)) + } + canonicalCollation, err := shard.CanonicalCollation(shardID, period) if err != nil { t.Fatalf("failed to get canonical collation from shardDB: %v", err) @@ -162,8 +182,8 @@ func TestShard_CanonicalCollation(t *testing.T) { func TestShard_BodyByChunkRoot(t *testing.T) { body := []byte{1, 2, 3, 4, 5} shardID := big.NewInt(1) - shardDB := database.MakeShardKV() - shard := MakeShard(shardID, shardDB) + shardDB := database.NewShardKV() + shard := NewShard(shardID, shardDB) if err := shard.SaveBody(body); err != nil { t.Fatalf("cannot save body: %v", err) @@ -192,8 +212,8 @@ func TestShard_CheckAvailability(t *testing.T) { emptyHash := common.StringToHash("") header := NewCollationHeader(shardID, &emptyHash, period, &proposerAddress, proposerSignature) - shardDB := database.MakeShardKV() - shard := MakeShard(shardID, shardDB) + shardDB := database.NewShardKV() + shard := NewShard(shardID, shardDB) collation := &Collation{ header: header, @@ -215,3 +235,27 @@ func TestShard_CheckAvailability(t *testing.T) { t.Errorf("collation body is not available: chunkRoot=%v, body=%v", header.ChunkRoot(), collation.body) } } + +func TestShard_SaveCollation(t *testing.T) { + headerShardID := big.NewInt(1) + period := big.NewInt(1) + proposerAddress := common.StringToAddress("") + proposerSignature := []byte{} + emptyHash := common.StringToHash("") + header := NewCollationHeader(headerShardID, &emptyHash, period, &proposerAddress, proposerSignature) + + shardDB := database.NewShardKV() + shard := NewShard(big.NewInt(2), shardDB) + + collation := &Collation{ + header: header, + body: []byte{1, 2, 3}, + } + + // We set the chunk root. + collation.CalculateChunkRoot() + + if err := shard.SaveCollation(collation); err == nil { + t.Errorf("cannot save collation in shard with wrong shardID") + } +}