Skip to content

Commit

Permalink
feat(lib/trie): Parallel hash trie. (ChainSafe#1657)
Browse files Browse the repository at this point in the history
* feat(lib/trie): Parallel hash trie.

* Fix race.

* Use bytes.Buffer in pool.
  • Loading branch information
arijitAD committed Jul 2, 2021
1 parent f5a4d3b commit 22827e7
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 140 deletions.
1 change: 0 additions & 1 deletion dot/state/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ func TestService_PruneStorage(t *testing.T) {
}

var toFinalize common.Hash

for i := 0; i < 3; i++ {
block, trieState := generateBlockWithRandomTrie(t, serv, nil, int64(i+1))
block.Header.Digest = types.Digest{
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ require (
github.com/urfave/cli v1.20.0
github.com/wasmerio/go-ext-wasm v0.3.2-0.20200326095750-0a32be6068ec
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40 // indirect
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1
google.golang.org/appengine v1.6.5 // indirect
Expand Down
1 change: 0 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200j
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc=
github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0=
Expand Down
4 changes: 3 additions & 1 deletion lib/trie/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ func encodeRecursive(n node, enc []byte) ([]byte, error) {
return []byte{}, nil
}

nenc, err := n.encode()
hasher := NewHasher(false)
defer hasher.returnToPool()
nenc, err := hasher.encode(n)
if err != nil {
return enc, err
}
Expand Down
176 changes: 166 additions & 10 deletions lib/trie/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,55 @@
package trie

import (
"bytes"
"context"
"hash"
"sync"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/scale"
"golang.org/x/crypto/blake2b"
"golang.org/x/sync/errgroup"
)

// Hasher is a wrapper around a hash function
type Hasher struct {
hash hash.Hash
hash hash.Hash
tmp bytes.Buffer
parallel bool // Whether to use parallel threads when hashing
}

// hasherPool creates a pool of Hasher.
var hasherPool = sync.Pool{
New: func() interface{} {
h, _ := blake2b.New256(nil)
var buf bytes.Buffer
// This allocation will be helpful for encoding keys. This is the min buffer size.
buf.Grow(700)

return &Hasher{
tmp: buf,
hash: h,
}
},
}

// NewHasher create new Hasher instance
func NewHasher() (*Hasher, error) {
h, err := blake2b.New256(nil)
if err != nil {
return nil, err
}
func NewHasher(parallel bool) *Hasher {
h := hasherPool.Get().(*Hasher)
h.parallel = parallel
return h
}

return &Hasher{
hash: h,
}, nil
func (h *Hasher) returnToPool() {
h.tmp.Reset()
h.hash.Reset()
hasherPool.Put(h)
}

// Hash encodes the node and then hashes it if its encoded length is > 32 bytes
func (h *Hasher) Hash(n node) (res []byte, err error) {
encNode, err := n.encode()
encNode, err := h.encode(n)
if err != nil {
return nil, err
}
Expand All @@ -51,6 +75,7 @@ func (h *Hasher) Hash(n node) (res []byte, err error) {
return encNode, nil
}

h.hash.Reset()
// otherwise, hash encoded node
_, err = h.hash.Write(encNode)
if err == nil {
Expand All @@ -59,3 +84,134 @@ func (h *Hasher) Hash(n node) (res []byte, err error) {

return res, err
}

// encode is the high-level function wrapping the encoding for different node types
// encoding has the following format:
// NodeHeader | Extra partial key length | Partial Key | Value
func (h *Hasher) encode(n node) ([]byte, error) {
switch n := n.(type) {
case *branch:
return h.encodeBranch(n)
case *leaf:
return h.encodeLeaf(n)
case nil:
return []byte{0}, nil
}

return nil, nil
}

func encodeAndHash(n node) ([]byte, error) {
h := NewHasher(false)
defer h.returnToPool()

encChild, err := h.Hash(n)
if err != nil {
return nil, err
}

scEncChild, err := scale.Encode(encChild)
if err != nil {
return nil, err
}
return scEncChild, nil
}

// encodeBranch encodes a branch with the encoding specified at the top of this package
func (h *Hasher) encodeBranch(b *branch) ([]byte, error) {
if !b.dirty && b.encoding != nil {
return b.encoding, nil
}
h.tmp.Reset()

encoding, err := b.header()
h.tmp.Write(encoding)
if err != nil {
return nil, err
}

h.tmp.Write(nibblesToKeyLE(b.key))
h.tmp.Write(common.Uint16ToBytes(b.childrenBitmap()))

if b.value != nil {
buffer := bytes.Buffer{}
se := scale.Encoder{Writer: &buffer}
_, err = se.Encode(b.value)
if err != nil {
return nil, err
}
h.tmp.Write(buffer.Bytes())
}

if h.parallel {
wg, _ := errgroup.WithContext(context.Background())
resBuff := make([][]byte, 16)
for i := 0; i < 16; i++ {
func(i int) {
wg.Go(func() error {
child := b.children[i]
if child == nil {
return nil
}

var err error
resBuff[i], err = encodeAndHash(child)
if err != nil {
return err
}
return nil
})
}(i)
}
if err := wg.Wait(); err != nil {
return nil, err
}

for _, v := range resBuff {
if v != nil {
h.tmp.Write(v)
}
}
} else {
for i := 0; i < 16; i++ {
if child := b.children[i]; child != nil {
scEncChild, err := encodeAndHash(child)
if err != nil {
return nil, err
}
h.tmp.Write(scEncChild)
}
}
}

return h.tmp.Bytes(), nil
}

// encodeLeaf encodes a leaf with the encoding specified at the top of this package
func (h *Hasher) encodeLeaf(l *leaf) ([]byte, error) {
if !l.dirty && l.encoding != nil {
return l.encoding, nil
}

h.tmp.Reset()

encoding, err := l.header()
h.tmp.Write(encoding)
if err != nil {
return nil, err
}

h.tmp.Write(nibblesToKeyLE(l.key))

buffer := bytes.Buffer{}
se := scale.Encoder{Writer: &buffer}

_, err = se.Encode(l.value)
if err != nil {
return nil, err
}

h.tmp.Write(buffer.Bytes())
l.encoding = h.tmp.Bytes()
return h.tmp.Bytes(), nil
}
30 changes: 10 additions & 20 deletions lib/trie/hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,10 @@ func generateRand(size int) [][]byte {
}

func TestNewHasher(t *testing.T) {
hasher, err := NewHasher()
if err != nil {
t.Fatalf("error creating new hasher: %s", err)
} else if hasher == nil {
t.Fatal("did not create new hasher")
}
hasher := NewHasher(false)
defer hasher.returnToPool()

_, err = hasher.hash.Write([]byte("noot"))
_, err := hasher.hash.Write([]byte("noot"))
if err != nil {
t.Error(err)
}
Expand All @@ -62,10 +58,8 @@ func TestNewHasher(t *testing.T) {
}

func TestHashLeaf(t *testing.T) {
hasher, err := NewHasher()
if err != nil {
t.Fatal(err)
}
hasher := NewHasher(false)
defer hasher.returnToPool()

n := &leaf{key: generateRandBytes(380), value: generateRandBytes(64)}
h, err := hasher.Hash(n)
Expand All @@ -77,10 +71,8 @@ func TestHashLeaf(t *testing.T) {
}

func TestHashBranch(t *testing.T) {
hasher, err := NewHasher()
if err != nil {
t.Fatal(err)
}
hasher := NewHasher(false)
defer hasher.returnToPool()

n := &branch{key: generateRandBytes(380), value: generateRandBytes(380)}
n.children[3] = &leaf{key: generateRandBytes(380), value: generateRandBytes(380)}
Expand All @@ -93,13 +85,11 @@ func TestHashBranch(t *testing.T) {
}

func TestHashShort(t *testing.T) {
hasher, err := NewHasher()
if err != nil {
t.Fatal(err)
}
hasher := NewHasher(false)
defer hasher.returnToPool()

n := &leaf{key: generateRandBytes(2), value: generateRandBytes(3)}
expected, err := n.encode()
expected, err := hasher.encode(n)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading

0 comments on commit 22827e7

Please sign in to comment.