diff --git a/dot/state/service_test.go b/dot/state/service_test.go index 00e7258dcf..bb734d35e5 100644 --- a/dot/state/service_test.go +++ b/dot/state/service_test.go @@ -236,7 +236,6 @@ func TestService_PruneStorage(t *testing.T) { } var toFinalize common.Hash - for i := 0; i < 3; i++ { block, trieState := generateBlockWithRandomTrie(t, serv, nil, int64(i+1)) block.Header.Digest = types.Digest{ diff --git a/go.mod b/go.mod index 01a0602622..dbd3ea5312 100644 --- a/go.mod +++ b/go.mod @@ -52,6 +52,7 @@ require ( github.com/urfave/cli v1.20.0 github.com/wasmerio/go-ext-wasm v0.3.2-0.20200326095750-0a32be6068ec golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40 // indirect golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 google.golang.org/appengine v1.6.5 // indirect diff --git a/go.sum b/go.sum index cbd5a83445..e0ef8a76d9 100644 --- a/go.sum +++ b/go.sum @@ -156,7 +156,6 @@ github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200j github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= -github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0= diff --git a/lib/trie/encode.go b/lib/trie/encode.go index 5338475d99..09f1eec99d 100644 --- a/lib/trie/encode.go +++ b/lib/trie/encode.go @@ -34,7 +34,9 @@ func encodeRecursive(n node, enc []byte) ([]byte, error) { return []byte{}, nil } - nenc, err := n.encode() + hasher := NewHasher(false) + defer hasher.returnToPool() + nenc, err := hasher.encode(n) if err != nil { return enc, err } diff --git a/lib/trie/hash.go b/lib/trie/hash.go index e299e07486..8b9d8374b8 100644 --- a/lib/trie/hash.go +++ b/lib/trie/hash.go @@ -17,31 +17,55 @@ package trie import ( + "bytes" + "context" "hash" + "sync" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/scale" "golang.org/x/crypto/blake2b" + "golang.org/x/sync/errgroup" ) // Hasher is a wrapper around a hash function type Hasher struct { - hash hash.Hash + hash hash.Hash + tmp bytes.Buffer + parallel bool // Whether to use parallel threads when hashing +} + +// hasherPool creates a pool of Hasher. +var hasherPool = sync.Pool{ + New: func() interface{} { + h, _ := blake2b.New256(nil) + var buf bytes.Buffer + // This allocation will be helpful for encoding keys. This is the min buffer size. + buf.Grow(700) + + return &Hasher{ + tmp: buf, + hash: h, + } + }, } // NewHasher create new Hasher instance -func NewHasher() (*Hasher, error) { - h, err := blake2b.New256(nil) - if err != nil { - return nil, err - } +func NewHasher(parallel bool) *Hasher { + h := hasherPool.Get().(*Hasher) + h.parallel = parallel + return h +} - return &Hasher{ - hash: h, - }, nil +func (h *Hasher) returnToPool() { + h.tmp.Reset() + h.hash.Reset() + hasherPool.Put(h) } // Hash encodes the node and then hashes it if its encoded length is > 32 bytes func (h *Hasher) Hash(n node) (res []byte, err error) { - encNode, err := n.encode() + encNode, err := h.encode(n) if err != nil { return nil, err } @@ -51,6 +75,7 @@ func (h *Hasher) Hash(n node) (res []byte, err error) { return encNode, nil } + h.hash.Reset() // otherwise, hash encoded node _, err = h.hash.Write(encNode) if err == nil { @@ -59,3 +84,134 @@ func (h *Hasher) Hash(n node) (res []byte, err error) { return res, err } + +// encode is the high-level function wrapping the encoding for different node types +// encoding has the following format: +// NodeHeader | Extra partial key length | Partial Key | Value +func (h *Hasher) encode(n node) ([]byte, error) { + switch n := n.(type) { + case *branch: + return h.encodeBranch(n) + case *leaf: + return h.encodeLeaf(n) + case nil: + return []byte{0}, nil + } + + return nil, nil +} + +func encodeAndHash(n node) ([]byte, error) { + h := NewHasher(false) + defer h.returnToPool() + + encChild, err := h.Hash(n) + if err != nil { + return nil, err + } + + scEncChild, err := scale.Encode(encChild) + if err != nil { + return nil, err + } + return scEncChild, nil +} + +// encodeBranch encodes a branch with the encoding specified at the top of this package +func (h *Hasher) encodeBranch(b *branch) ([]byte, error) { + if !b.dirty && b.encoding != nil { + return b.encoding, nil + } + h.tmp.Reset() + + encoding, err := b.header() + h.tmp.Write(encoding) + if err != nil { + return nil, err + } + + h.tmp.Write(nibblesToKeyLE(b.key)) + h.tmp.Write(common.Uint16ToBytes(b.childrenBitmap())) + + if b.value != nil { + buffer := bytes.Buffer{} + se := scale.Encoder{Writer: &buffer} + _, err = se.Encode(b.value) + if err != nil { + return nil, err + } + h.tmp.Write(buffer.Bytes()) + } + + if h.parallel { + wg, _ := errgroup.WithContext(context.Background()) + resBuff := make([][]byte, 16) + for i := 0; i < 16; i++ { + func(i int) { + wg.Go(func() error { + child := b.children[i] + if child == nil { + return nil + } + + var err error + resBuff[i], err = encodeAndHash(child) + if err != nil { + return err + } + return nil + }) + }(i) + } + if err := wg.Wait(); err != nil { + return nil, err + } + + for _, v := range resBuff { + if v != nil { + h.tmp.Write(v) + } + } + } else { + for i := 0; i < 16; i++ { + if child := b.children[i]; child != nil { + scEncChild, err := encodeAndHash(child) + if err != nil { + return nil, err + } + h.tmp.Write(scEncChild) + } + } + } + + return h.tmp.Bytes(), nil +} + +// encodeLeaf encodes a leaf with the encoding specified at the top of this package +func (h *Hasher) encodeLeaf(l *leaf) ([]byte, error) { + if !l.dirty && l.encoding != nil { + return l.encoding, nil + } + + h.tmp.Reset() + + encoding, err := l.header() + h.tmp.Write(encoding) + if err != nil { + return nil, err + } + + h.tmp.Write(nibblesToKeyLE(l.key)) + + buffer := bytes.Buffer{} + se := scale.Encoder{Writer: &buffer} + + _, err = se.Encode(l.value) + if err != nil { + return nil, err + } + + h.tmp.Write(buffer.Bytes()) + l.encoding = h.tmp.Bytes() + return h.tmp.Bytes(), nil +} diff --git a/lib/trie/hash_test.go b/lib/trie/hash_test.go index 7803fc7d39..94c43e1adf 100644 --- a/lib/trie/hash_test.go +++ b/lib/trie/hash_test.go @@ -41,14 +41,10 @@ func generateRand(size int) [][]byte { } func TestNewHasher(t *testing.T) { - hasher, err := NewHasher() - if err != nil { - t.Fatalf("error creating new hasher: %s", err) - } else if hasher == nil { - t.Fatal("did not create new hasher") - } + hasher := NewHasher(false) + defer hasher.returnToPool() - _, err = hasher.hash.Write([]byte("noot")) + _, err := hasher.hash.Write([]byte("noot")) if err != nil { t.Error(err) } @@ -62,10 +58,8 @@ func TestNewHasher(t *testing.T) { } func TestHashLeaf(t *testing.T) { - hasher, err := NewHasher() - if err != nil { - t.Fatal(err) - } + hasher := NewHasher(false) + defer hasher.returnToPool() n := &leaf{key: generateRandBytes(380), value: generateRandBytes(64)} h, err := hasher.Hash(n) @@ -77,10 +71,8 @@ func TestHashLeaf(t *testing.T) { } func TestHashBranch(t *testing.T) { - hasher, err := NewHasher() - if err != nil { - t.Fatal(err) - } + hasher := NewHasher(false) + defer hasher.returnToPool() n := &branch{key: generateRandBytes(380), value: generateRandBytes(380)} n.children[3] = &leaf{key: generateRandBytes(380), value: generateRandBytes(380)} @@ -93,13 +85,11 @@ func TestHashBranch(t *testing.T) { } func TestHashShort(t *testing.T) { - hasher, err := NewHasher() - if err != nil { - t.Fatal(err) - } + hasher := NewHasher(false) + defer hasher.returnToPool() n := &leaf{key: generateRandBytes(2), value: generateRandBytes(3)} - expected, err := n.encode() + expected, err := hasher.encode(n) if err != nil { t.Fatal(err) } diff --git a/lib/trie/node.go b/lib/trie/node.go index eb7383e939..41bfb0cccf 100644 --- a/lib/trie/node.go +++ b/lib/trie/node.go @@ -53,7 +53,6 @@ import ( // node is the interface for trie methods type node interface { encodeAndHash() ([]byte, []byte, error) - encode() ([]byte, error) decode(r io.Reader, h byte) error isDirty() bool setDirty(dirty bool) @@ -220,28 +219,13 @@ func (b *branch) setKey(key []byte) { b.key = key } -// Encode is the high-level function wrapping the encoding for different node types -// encoding has the following format: -// NodeHeader | Extra partial key length | Partial Key | Value -func encode(n node) ([]byte, error) { - switch n := n.(type) { - case *branch: - return n.encode() - case *leaf: - return n.encode() - case nil: - return []byte{0}, nil - } - - return nil, nil -} - func (b *branch) encodeAndHash() ([]byte, []byte, error) { if !b.dirty && b.encoding != nil && b.hash != nil { return b.encoding, b.hash, nil } - enc, err := b.encode() + hasher := NewHasher(false) + enc, err := hasher.encodeBranch(b) if err != nil { return nil, nil, err } @@ -262,59 +246,13 @@ func (b *branch) encodeAndHash() ([]byte, []byte, error) { return enc, hash[:], nil } -// Encode encodes a branch with the encoding specified at the top of this package -func (b *branch) encode() ([]byte, error) { - if !b.dirty && b.encoding != nil { - return b.encoding, nil - } - - encoding, err := b.header() - if err != nil { - return nil, err - } - - encoding = append(encoding, nibblesToKeyLE(b.key)...) - encoding = append(encoding, common.Uint16ToBytes(b.childrenBitmap())...) - - if b.value != nil { - buffer := bytes.Buffer{} - se := scale.Encoder{Writer: &buffer} - _, err = se.Encode(b.value) - if err != nil { - return encoding, err - } - encoding = append(encoding, buffer.Bytes()...) - } - - for _, child := range b.children { - if child != nil { - hasher, err := NewHasher() - if err != nil { - return nil, err - } - - encChild, err := hasher.Hash(child) - if err != nil { - return encoding, err - } - - scEncChild, err := scale.Encode(encChild) - if err != nil { - return encoding, err - } - encoding = append(encoding, scEncChild...) - } - } - - return encoding, nil -} - func (l *leaf) encodeAndHash() ([]byte, []byte, error) { if !l.isDirty() && l.encoding != nil && l.hash != nil { return l.encoding, l.hash, nil } + hasher := NewHasher(false) + enc, err := hasher.encodeLeaf(l) - enc, err := l.encode() if err != nil { return nil, nil, err } @@ -335,30 +273,6 @@ func (l *leaf) encodeAndHash() ([]byte, []byte, error) { return enc, hash[:], nil } -// Encode encodes a leaf with the encoding specified at the top of this package -func (l *leaf) encode() ([]byte, error) { - if !l.dirty && l.encoding != nil { - return l.encoding, nil - } - - encoding, err := l.header() - if err != nil { - return nil, err - } - - encoding = append(encoding, nibblesToKeyLE(l.key)...) - - buffer := bytes.Buffer{} - se := scale.Encoder{Writer: &buffer} - _, err = se.Encode(l.value) - if err != nil { - return encoding, err - } - encoding = append(encoding, buffer.Bytes()...) - l.encoding = encoding - return encoding, nil -} - func decodeBytes(in []byte) (node, error) { r := &bytes.Buffer{} _, err := r.Write(in) diff --git a/lib/trie/node_test.go b/lib/trie/node_test.go index 36ff01e470..b894bde450 100644 --- a/lib/trie/node_test.go +++ b/lib/trie/node_test.go @@ -171,10 +171,8 @@ func TestBranchEncode(t *testing.T) { for _, child := range b.children { if child != nil { - hasher, e := NewHasher() - if e != nil { - t.Fatal(e) - } + hasher := NewHasher(false) + defer hasher.returnToPool() encChild, er := hasher.Hash(child) if er != nil { t.Errorf("Fail when encoding branch child: %s", er) @@ -183,7 +181,9 @@ func TestBranchEncode(t *testing.T) { } } - res, err := b.encode() + hasher := NewHasher(false) + defer hasher.returnToPool() + res, err := hasher.encodeBranch(b) if !bytes.Equal(res, expected) { t.Errorf("Fail when encoding node: got %x expected %x", res, expected) } else if err != nil { @@ -216,7 +216,9 @@ func TestLeafEncode(t *testing.T) { expected = append(expected, buf.Bytes()...) - res, err := l.encode() + hasher := NewHasher(false) + defer hasher.returnToPool() + res, err := hasher.encodeLeaf(l) if !bytes.Equal(res, expected) { t.Errorf("Fail when encoding node: got %x expected %x", res, expected) } else if err != nil { @@ -238,7 +240,9 @@ func TestEncodeRoot(t *testing.T) { t.Errorf("Fail to get key %x with value %x: got %x", test.key, test.value, val) } - _, err := encode(trie.root) + hasher := NewHasher(false) + defer hasher.returnToPool() + _, err := hasher.encode(trie.root) if err != nil { t.Errorf("Fail to encode trie root: %s", err) } @@ -263,8 +267,10 @@ func TestBranchDecode(t *testing.T) { {key: byteArray(573), children: [16]node{}, value: []byte{0x01}}, } + hasher := NewHasher(false) + defer hasher.returnToPool() for _, test := range tests { - enc, err := test.encode() + enc, err := hasher.encodeBranch(test) require.NoError(t, err) res := new(branch) @@ -292,8 +298,10 @@ func TestLeafDecode(t *testing.T) { {key: byteArray(573), value: []byte{0x01}, dirty: true}, } + hasher := NewHasher(false) + defer hasher.returnToPool() for _, test := range tests { - enc, err := test.encode() + enc, err := hasher.encodeLeaf(test) require.NoError(t, err) res := new(leaf) @@ -329,8 +337,10 @@ func TestDecode(t *testing.T) { &leaf{key: byteArray(573), value: []byte{0x01}}, } + hasher := NewHasher(false) + defer hasher.returnToPool() for _, test := range tests { - enc, err := test.encode() + enc, err := hasher.encode(test) require.NoError(t, err) r := &bytes.Buffer{} diff --git a/lib/trie/print.go b/lib/trie/print.go index 9311395539..73df501d43 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -38,12 +38,14 @@ func (t *Trie) String() string { func (t *Trie) string(tree gotree.Tree, curr node, idx int) { switch c := curr.(type) { case *branch: - c.encoding, _ = c.encode() + hasher := NewHasher(false) + defer hasher.returnToPool() + c.encoding, _ = hasher.encode(c) var bstr string if len(c.encoding) > 1024 { bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", idx, c.String(), common.MustBlake2bHash(c.encoding), c.generation) } else { - bstr = fmt.Sprintf("idx=%d %s enc=%x gen=%d", idx, c.String(), c.encoding, c.generation) + bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.encoding, c.generation) } sub := tree.Add(bstr) for i, child := range c.children { @@ -52,12 +54,14 @@ func (t *Trie) string(tree gotree.Tree, curr node, idx int) { } } case *leaf: - c.encoding, _ = c.encode() + hasher := NewHasher(false) + defer hasher.returnToPool() + c.encoding, _ = hasher.encode(c) var bstr string if len(c.encoding) > 1024 { bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", idx, c.String(), common.MustBlake2bHash(c.encoding), c.generation) } else { - bstr = fmt.Sprintf("idx=%d %s enc=%x gen=%d", idx, c.String(), c.encoding, c.generation) + bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.encoding, c.generation) } tree.Add(bstr) default: diff --git a/lib/trie/test_utils.go b/lib/trie/test_utils.go index bde2effb08..80e262244f 100644 --- a/lib/trie/test_utils.go +++ b/lib/trie/test_utils.go @@ -25,7 +25,7 @@ func (t *Test) Value() []byte { } // GenerateRandomTests returns an array of random Tests -func GenerateRandomTests(t *testing.T, size int) []Test { +func GenerateRandomTests(t testing.TB, size int) []Test { rt := make([]Test, size) kv := make(map[string][]byte) @@ -38,7 +38,7 @@ func GenerateRandomTests(t *testing.T, size int) []Test { return rt } -func generateRandomTest(t *testing.T, kv map[string][]byte) Test { +func generateRandomTest(t testing.TB, kv map[string][]byte) Test { r := *rand.New(rand.NewSource(rand.Int63())) //nolint test := Test{} diff --git a/lib/trie/trie.go b/lib/trie/trie.go index e163234da8..fb3963486d 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -33,6 +33,7 @@ type Trie struct { root node childTries map[common.Hash]*Trie // Used to store the child tries. deletedKeys []common.Hash + parallel bool } // NewEmptyTrie creates a trie with a nil root @@ -47,6 +48,7 @@ func NewTrie(root node) *Trie { childTries: make(map[common.Hash]*Trie), generation: 0, // Initially zero but increases after every snapshot. deletedKeys: make([]common.Hash, 0), + parallel: true, } } @@ -107,7 +109,9 @@ func (t *Trie) RootNode() node { //nolint // EncodeRoot returns the encoded root of the trie func (t *Trie) EncodeRoot() ([]byte, error) { - return encode(t.RootNode()) + h := NewHasher(t.parallel) + defer h.returnToPool() + return h.encode(t.RootNode()) } // MustHash returns the hashed root of the trie. It panics if it fails to hash the root node. diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index 17838f9094..53f4d4c8b7 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -24,6 +24,7 @@ import ( "math/rand" "os" "path/filepath" + "runtime" "sort" "strconv" "strings" @@ -1137,3 +1138,49 @@ func TestNextKey_Random(t *testing.T) { } } } + +func TestRootHashNonParallel(t *testing.T) { + rt := GenerateRandomTests(t, 1000000) + trie := NewEmptyTrie() + for i := range rt { + test := &rt[i] + trie.Put(test.key, test.value) + } + + t.Run("Non Parallel Hash", func(t *testing.T) { + trie.parallel = false + _, err := trie.Hash() + require.NoError(t, err) + PrintMemUsage() + }) +} + +func TestRootHashParallel(t *testing.T) { + rt := GenerateRandomTests(t, 1000000) + trie := NewEmptyTrie() + for i := range rt { + test := &rt[i] + trie.Put(test.key, test.value) + } + + t.Run("Parallel Hash", func(t *testing.T) { + trie.parallel = true + _, err := trie.Hash() + require.NoError(t, err) + PrintMemUsage() + }) +} + +func PrintMemUsage() { + var m runtime.MemStats + runtime.ReadMemStats(&m) + // For info on each, see: https://golang.org/pkg/runtime/#MemStats + fmt.Printf("Alloc = %v MiB", bToMb(m.Alloc)) + fmt.Printf("\tTotalAlloc = %v MiB", bToMb(m.TotalAlloc)) + fmt.Printf("\tSys = %v MiB", bToMb(m.Sys)) + fmt.Printf("\tNumGC = %v\n", m.NumGC) +} + +func bToMb(b uint64) uint64 { + return b / 1024 / 1024 +}