Skip to content

Commit

Permalink
sharding: address review comments om #100, fix linter issues
Browse files Browse the repository at this point in the history
Former-commit-id: 5febe72a5a1936ce643488067e0990da810f1f5e [formerly 74c85fc]
Former-commit-id: 0cc6d45
  • Loading branch information
rauljordan committed May 9, 2018
1 parent a617eba commit d1aa843
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 89 deletions.
8 changes: 0 additions & 8 deletions sharding/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,6 @@ func (c *Collation) Header() *CollationHeader { return c.header }
// Body returns the collation's byte body.
func (c *Collation) Body() []byte { return c.body }

// Hash returns the hash of a collation's entire contents. Useful for tests.
func (c *Collation) Hash() (hash common.Hash) {
hw := sha3.NewKeccak256()
rlp.Encode(hw, c)
hw.Sum(hash[:0])
return hash
}

// Transactions returns an array of tx's in the collation.
func (c *Collation) Transactions() []*types.Transaction { return c.transactions }

Expand Down
15 changes: 3 additions & 12 deletions sharding/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
)

type shardKV struct {
// Shard state storage is a mapping of hashes to RLP encoded values.
kv map[common.Hash][]byte
}

Expand All @@ -16,31 +17,21 @@ func makeShardKV() *shardKV {

func (sb *shardKV) Get(k common.Hash) ([]byte, error) {
v, ok := sb.kv[k]
fmt.Printf("Map: %v\n", sb.kv)
fmt.Printf("Key: %v\n", k)
fmt.Printf("Val: %v\n", sb.kv[k])
fmt.Printf("Ok: %v\n", ok)
if !ok {
return nil, fmt.Errorf("Key Not Found")
return nil, fmt.Errorf("key not found: %v", k)
}
return v, nil
}

func (sb *shardKV) Has(k common.Hash) bool {
v := sb.kv[k]
if v == nil {
return false
}
return true
return v != nil
}

func (sb *shardKV) Put(k common.Hash, v []byte) {
sb.kv[k] = v
fmt.Printf("Put: %v\n", sb.kv[k])
return
}

func (sb *shardKV) Delete(k common.Hash) {
delete(sb.kv, k)
return
}
79 changes: 45 additions & 34 deletions sharding/shard.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,28 @@ package sharding

import (
"fmt"
"log"
"math/big"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
)

type shardBackend interface {
Get(k common.Hash) ([]byte, error)
Has(k common.Hash) bool
Put(k common.Hash, val []byte)
Delete(k common.Hash)
}

// Shard base struct.
type Shard struct {
shardDB *shardKV
shardDB shardBackend
shardID *big.Int
}

// MakeShard creates an instance of a Shard struct given a shardID.
func MakeShard(shardID *big.Int) *Shard {
// Swappable - can be makeShardLevelDB, makeShardSparseTrie, etc.
shardDB := makeShardKV()

func MakeShard(shardID *big.Int, shardDB shardBackend) *Shard {
return &Shard{
shardID: shardID,
shardDB: shardDB,
Expand All @@ -33,68 +38,74 @@ func (s *Shard) ShardID() *big.Int {
// ValidateShardID checks if header belongs to shard.
func (s *Shard) ValidateShardID(h *CollationHeader) error {
if s.ShardID().Cmp(h.ShardID()) != 0 {
return fmt.Errorf("Error: Collation Does Not Belong to Shard %d but Instead Has ShardID %d", h.ShardID(), s.ShardID())
return fmt.Errorf("collation does not belong to shard %d but instead has shardID %d", h.ShardID().Int64(), s.ShardID().Int64())
}
return nil
}

// GetHeaderByHash of collation.
func (s *Shard) GetHeaderByHash(hash *common.Hash) (*CollationHeader, error) {
// HeaderByHash of collation.
func (s *Shard) HeaderByHash(hash *common.Hash) (*CollationHeader, error) {
encoded, err := s.shardDB.Get(*hash)
if err != nil {
return nil, fmt.Errorf("Error: Header Not Found: %v", err)
return nil, fmt.Errorf("header not found: %v", err)
}
log.Printf("encoded header in func: %v", encoded)
var header CollationHeader
if err := rlp.DecodeBytes(encoded, &header); err != nil {
return nil, fmt.Errorf("Error: Problem Decoding Header: %v", err)
return nil, fmt.Errorf("could not decode header: %v", err)
}
log.Printf("decoded header in func: %v", header)
return &header, nil
}

// GetCollationByHash fetches full collation.
func (s *Shard) GetCollationByHash(headerHash *common.Hash) (*Collation, error) {
header, err := s.GetHeaderByHash(headerHash)
// CollationByHash fetches full collation.
func (s *Shard) CollationByHash(headerHash *common.Hash) (*Collation, error) {
header, err := s.HeaderByHash(headerHash)
if err != nil {
return nil, err
}
body, err := s.GetBodyByChunkRoot(header.ChunkRoot())
if header.ChunkRoot() == nil {
return nil, fmt.Errorf("invalid header fetched: %v", header)
}
body, err := s.BodyByChunkRoot(header.ChunkRoot())
if err != nil {
return nil, err
}
return &Collation{header: header, body: body}, nil
}

// GetCanonicalCollationHash gets a collation header hash that has been set as canonical for
// CanonicalCollationHash gets a collation header hash that has been set as canonical for
// shardID/period pair
func (s *Shard) GetCanonicalCollationHash(shardID *big.Int, period *big.Int) (*common.Hash, error) {
func (s *Shard) CanonicalCollationHash(shardID *big.Int, period *big.Int) (*common.Hash, error) {
key := canonicalCollationLookupKey(shardID, period)
hash := common.BytesToHash(key.Bytes())
collationHashBytes, err := s.shardDB.Get(hash)
if err != nil {
return nil, fmt.Errorf("Error: No Canonical Collation Set for Period/ShardID")
return nil, fmt.Errorf("no canonical collation set for period, shardID pair: %v", err)
}
collationHash := common.BytesToHash(collationHashBytes)
return &collationHash, nil
}

// GetCanonicalCollation fetches the collation set as canonical in the shardDB.
func (s *Shard) GetCanonicalCollation(shardID *big.Int, period *big.Int) (*Collation, error) {
h, err := s.GetCanonicalCollationHash(shardID, period)
// CanonicalCollation fetches the collation set as canonical in the shardDB.
func (s *Shard) CanonicalCollation(shardID *big.Int, period *big.Int) (*Collation, error) {
h, err := s.CanonicalCollationHash(shardID, period)
if err != nil {
return nil, fmt.Errorf("Error: No Hash Found")
return nil, fmt.Errorf("hash not found: %v", err)
}
collation, err := s.GetCollationByHash(h)
collation, err := s.CollationByHash(h)
if err != nil {
return nil, fmt.Errorf("Error: No Canonical Collation Found for Hash")
return nil, fmt.Errorf("no canonical collation found for hash: %v", err)
}
return collation, nil
}

// GetBodyByChunkRoot fetches a collation body.
func (s *Shard) GetBodyByChunkRoot(chunkRoot *common.Hash) ([]byte, error) {
// BodyByChunkRoot fetches a collation body.
func (s *Shard) BodyByChunkRoot(chunkRoot *common.Hash) ([]byte, error) {
log.Printf("Chunk Root: %v", chunkRoot)
body, err := s.shardDB.Get(*chunkRoot)
if err != nil {
return nil, fmt.Errorf("Error: No Corresponding Body With Chunk Root Found")
return nil, fmt.Errorf("no corresponding body with chunk root found: %v", err)
}
return body, nil
}
Expand All @@ -104,11 +115,11 @@ func (s *Shard) CheckAvailability(header *CollationHeader) (bool, error) {
key := dataAvailabilityLookupKey(header.ChunkRoot())
availabilityVal, err := s.shardDB.Get(key)
if err != nil {
return false, fmt.Errorf("Error: Key Not Found")
return false, fmt.Errorf("key not found: %v", key)
}
var availability int
if err := rlp.DecodeBytes(availabilityVal, &availability); err != nil {
return false, fmt.Errorf("Error: Cannot RLP Decode Availability: %v", err)
return false, fmt.Errorf("cannot RLP decode availability: %v", err)
}
if availability != 0 {
return true, nil
Expand All @@ -122,13 +133,13 @@ func (s *Shard) SetAvailability(chunkRoot *common.Hash, availability bool) error
if availability {
enc, err := rlp.EncodeToBytes(true)
if err != nil {
return fmt.Errorf("Cannot RLP encode availability: %v", err)
return fmt.Errorf("cannot RLP encode availability: %v", err)
}
s.shardDB.Put(key, enc)
} else {
enc, err := rlp.EncodeToBytes(false)
if err != nil {
return fmt.Errorf("Cannot RLP encode availability: %v", err)
return fmt.Errorf("cannot RLP encode availability: %v", err)
}
s.shardDB.Put(key, enc)
}
Expand All @@ -139,17 +150,17 @@ func (s *Shard) SetAvailability(chunkRoot *common.Hash, availability bool) error
func (s *Shard) SaveHeader(header *CollationHeader) error {
encoded, err := rlp.EncodeToBytes(header)
if err != nil {
return fmt.Errorf("Error: Cannot Encode Header")
return fmt.Errorf("cannot encode header: %v", err)
}
// Uses the hash of the header as the key.
hash := header.Hash()
fmt.Printf("In SaveHeader: %s\n", hash.String())
s.shardDB.Put(hash, encoded)
return nil
}

// SaveBody adds the collation body to the shardDB and sets availability.
func (s *Shard) SaveBody(body []byte) error {
// TODO: check if body is empty and throw error.
// TODO: dependent on blob serialization.
// chunkRoot := getChunkRoot(body) using the blob algorithm utils.
// right now we will just take the raw keccak256 of the body until #92 is merged.
Expand Down Expand Up @@ -178,14 +189,14 @@ func (s *Shard) SetCanonical(header *CollationHeader) error {
// the header needs to have been stored in the DB previously, so we
// fetch it from the shardDB.
hash := header.Hash()
dbHeader, err := s.GetHeaderByHash(&hash)
dbHeader, err := s.HeaderByHash(&hash)
if err != nil {
return err
}
key := canonicalCollationLookupKey(dbHeader.ShardID(), dbHeader.Period())
encoded, err := rlp.EncodeToBytes(dbHeader)
if err != nil {
return fmt.Errorf("Error: Cannot Encode Header")
return fmt.Errorf("cannot encode header: %v", err)
}
s.shardDB.Put(key, encoded)
return nil
Expand Down
84 changes: 49 additions & 35 deletions sharding/shard_test.go
Original file line number Diff line number Diff line change
@@ -1,68 +1,82 @@
package sharding

import (
"fmt"
"log"
"math/big"
"testing"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto/sha3"
"github.com/ethereum/go-ethereum/rlp"
)

// Hash returns the hash of a collation's entire contents. Useful for comparison tests.
func (c *Collation) Hash() (hash common.Hash) {
hw := sha3.NewKeccak256()
rlp.Encode(hw, c)
hw.Sum(hash[:0])
return hash
}
func TestShard_ValidateShardID(t *testing.T) {
header := &CollationHeader{shardID: big.NewInt(4)}
shard := MakeShard(big.NewInt(3))
shardDB := makeShardKV()
shard := MakeShard(big.NewInt(3), shardDB)

if err := shard.ValidateShardID(header); err == nil {
t.Fatalf("Shard ID validation incorrect. Function should throw error when shardID's do not match. want=%d. got=%d", header.ShardID().Int64(), shard.ShardID().Int64())
t.Errorf("ShardID validation incorrect. Function should throw error when ShardID's do not match. want=%d. got=%d", header.ShardID().Int64(), shard.ShardID().Int64())
}

header2 := &CollationHeader{shardID: big.NewInt(100)}
shard2 := MakeShard(big.NewInt(100))
shard2 := MakeShard(big.NewInt(100), shardDB)

if err := shard2.ValidateShardID(header2); err != nil {
t.Fatalf("Shard ID validation incorrect. Function should not throw error when shardID's match. want=%d. got=%d", header2.ShardID().Int64(), shard2.ShardID().Int64())
t.Errorf("ShardID validation incorrect. Function should not throw error when ShardID's match. want=%d. got=%d", header2.ShardID().Int64(), shard2.ShardID().Int64())
}
}

func TestShard_GetHeaderByHash(t *testing.T) {
header := &CollationHeader{shardID: big.NewInt(1)}
shard := MakeShard(big.NewInt(1))
func TestShard_HeaderByHash(t *testing.T) {
root := common.StringToHash("hi")
header := &CollationHeader{shardID: big.NewInt(1), chunkRoot: &root}
shardDB := makeShardKV()
shard := MakeShard(big.NewInt(1), shardDB)

if err := shard.SaveHeader(header); err != nil {
t.Fatal(err)
t.Fatalf("cannot save collation header: %v", err)
}
hash := header.Hash()
fmt.Printf("In Test: %s\n", hash.String())

// It's being saved, but the .Get func doesn't fetch the value...?
dbHeader, err := shard.GetHeaderByHash(&hash)
dbHeader, err := shard.HeaderByHash(&hash)
if err != nil {
t.Fatal(err)
t.Fatalf("could not fetch collation header by hash: %v", err)
}
log.Printf("header in first test: %v", header.ChunkRoot().String())
log.Printf("db header in first test: %v", dbHeader.ChunkRoot().String())
// Compare the hashes.
if header.Hash() != dbHeader.Hash() {
t.Fatalf("Headers do not match. want=%v. got=%v", header, dbHeader)
t.Errorf("headers do not match. want=%v. got=%v", header, dbHeader)
}
}

// func TestShard_GetCollationByHash(t *testing.T) {
// collation := &Collation{
// header: &CollationHeader{shardID: big.NewInt(1)},
// body: []byte{1, 2, 3},
// }
// shard := MakeShard(big.NewInt(1))
func TestShard_CollationByHash(t *testing.T) {
collation := &Collation{
header: &CollationHeader{shardID: big.NewInt(1)},
body: []byte{1, 2, 3},
}
shardDB := makeShardKV()
shard := MakeShard(big.NewInt(1), shardDB)

if err := shard.SaveCollation(collation); err != nil {
t.Fatalf("cannot save collation: %v", err)
}
hash := collation.Header().Hash()

// if err := shard.SaveCollation(collation); err != nil {
// t.Fatal(err)
// }
// hash := collation.Header().Hash()
// fmt.Printf("In Test: %s\n", hash.String())
dbCollation, err := shard.CollationByHash(&hash)
if err != nil {
t.Fatalf("could not fetch collation by hash: %v", err)
}

// // It's being saved, but the .Get func doesn't fetch the value...?
// dbCollation, err := shard.GetCollationByHash(&hash)
// if err != nil {
// t.Fatal(err)
// }
// // TODO: decode the RLP
// if collation.Hash() != dbCollation.Hash() {
// t.Fatalf("Collations do not match. want=%v. got=%v", collation, dbCollation)
// }
// }
// Compare the hashes.
if collation.Hash() != dbCollation.Hash() {
t.Errorf("collations do not match. want=%v. got=%v", collation, dbCollation)
}
}

0 comments on commit d1aa843

Please sign in to comment.