diff --git a/crypto/merkle/hash.go b/crypto/merkle/hash.go index e5f731855f..fc3c6a2936 100644 --- a/crypto/merkle/hash.go +++ b/crypto/merkle/hash.go @@ -1,6 +1,8 @@ package merkle import ( + "hash" + "github.com/cometbft/cometbft/crypto/tmhash" ) @@ -20,7 +22,23 @@ func leafHash(leaf []byte) []byte { return tmhash.Sum(append(leafPrefix, leaf...)) } +// returns tmhash(0x00 || leaf) +func leafHashOpt(s hash.Hash, leaf []byte) []byte { + s.Reset() + s.Write(leafPrefix) + s.Write(leaf) + return s.Sum(nil) +} + // returns tmhash(0x01 || left || right) func innerHash(left []byte, right []byte) []byte { return tmhash.Sum(append(innerPrefix, append(left, right...)...)) } + +func innerHashOpt(s hash.Hash, left []byte, right []byte) []byte { + s.Reset() + s.Write(innerPrefix) + s.Write(left) + s.Write(right) + return s.Sum(nil) +} diff --git a/crypto/merkle/proof.go b/crypto/merkle/proof.go index 85b2db1e91..2f083499d0 100644 --- a/crypto/merkle/proof.go +++ b/crypto/merkle/proof.go @@ -50,9 +50,6 @@ func ProofsFromByteSlices(items [][]byte) (rootHash []byte, proofs []*Proof) { // Verify that the Proof proves the root hash. // Check sp.Index/sp.Total manually if needed func (sp *Proof) Verify(rootHash []byte, leaf []byte) error { - if rootHash == nil { - return fmt.Errorf("invalid root hash: cannot be nil") - } if sp.Total < 0 { return errors.New("proof total must be positive") } diff --git a/crypto/merkle/proof_key_path_test.go b/crypto/merkle/proof_key_path_test.go index 25a61af929..0d6d3354d3 100644 --- a/crypto/merkle/proof_key_path_test.go +++ b/crypto/merkle/proof_key_path_test.go @@ -36,6 +36,7 @@ func TestKeyPath(t *testing.T) { res, err := KeyPathToKeys(path.String()) require.Nil(t, err) + require.Equal(t, len(keys), len(res)) for i, key := range keys { require.Equal(t, key, res[i]) diff --git a/crypto/merkle/proof_test.go b/crypto/merkle/proof_test.go index 45d565e40c..f307380aad 100644 --- a/crypto/merkle/proof_test.go +++ b/crypto/merkle/proof_test.go @@ -173,12 +173,12 @@ func TestProofValidateBasic(t *testing.T) { } } func TestVoteProtobuf(t *testing.T) { - _, proofs := ProofsFromByteSlices([][]byte{ []byte("apple"), []byte("watermelon"), []byte("kiwi"), }) + testCases := []struct { testName string v1 *Proof diff --git a/crypto/merkle/tree.go b/crypto/merkle/tree.go index 089c2f82ee..896b67c595 100644 --- a/crypto/merkle/tree.go +++ b/crypto/merkle/tree.go @@ -1,22 +1,28 @@ package merkle import ( + "crypto/sha256" + "hash" "math/bits" ) // HashFromByteSlices computes a Merkle tree where the leaves are the byte slice, // in the provided order. It follows RFC-6962. func HashFromByteSlices(items [][]byte) []byte { + return hashFromByteSlices(sha256.New(), items) +} + +func hashFromByteSlices(sha hash.Hash, items [][]byte) []byte { switch len(items) { case 0: return emptyHash() case 1: - return leafHash(items[0]) + return leafHashOpt(sha, items[0]) default: k := getSplitPoint(int64(len(items))) - left := HashFromByteSlices(items[:k]) - right := HashFromByteSlices(items[k:]) - return innerHash(left, right) + left := hashFromByteSlices(sha, items[:k]) + right := hashFromByteSlices(sha, items[k:]) + return innerHashOpt(sha, left, right) } } @@ -61,7 +67,7 @@ func HashFromByteSlices(items [][]byte) []byte { // implementation for so little benefit. func HashFromByteSlicesIterative(input [][]byte) []byte { items := make([][]byte, len(input)) - + sha := sha256.New() for i, leaf := range input { items[i] = leafHash(leaf) } @@ -78,7 +84,7 @@ func HashFromByteSlicesIterative(input [][]byte) []byte { wp := 0 // write position for rp < size { if rp+1 < size { - items[wp] = innerHash(items[rp], items[rp+1]) + items[wp] = innerHashOpt(sha, items[rp], items[rp+1]) rp += 2 } else { items[wp] = items[rp]