diff --git a/merkle.go b/merkle.go index 4fa73cb..aca3822 100644 --- a/merkle.go +++ b/merkle.go @@ -37,7 +37,7 @@ type node struct { } func (n node) IsEmpty() bool { - return n.value == nil + return len(n.value) == 0 } // layer is a layer in the merkle tree. @@ -98,6 +98,7 @@ type Tree struct { leavesToProve *sparseBoolStack cacheWriter CacheWriter minHeight uint + parentBuf []byte } // AddLeaf incorporates a new leaf to the state of the tree. It updates the state required to eventually determine the @@ -108,7 +109,6 @@ func (t *Tree) AddLeaf(value []byte) error { OnProvenPath: t.leavesToProve.Pop(), } l := t.baseLayer - var parent, lChild, rChild node var lastCachingError error // Loop through the layers, starting from the base layer. @@ -124,26 +124,29 @@ func (t *Tree) AddLeaf(value []byte) error { // If no node is pending, then this node is a left sibling, // pending for its right sibling before its parent can be calculated. if l.parking.IsEmpty() { - l.parking = n + // Copy the byte slice as we will keep it for a while. + l.parking.value = append(l.parking.value[:0], n.value...) + l.parking.OnProvenPath = n.OnProvenPath break } else { // This node is a right sibling. - lChild, rChild = l.parking, n - parent = t.calcParent(lChild, rChild) + lChild, rChild := l.parking, n // A given node is required in the proof if and only if its parent is an ancestor // of a leaf whose membership in the tree is being proven, but the given node isn't. - if parent.OnProvenPath { - if !lChild.OnProvenPath { - t.proof = append(t.proof, lChild.value) - } - if !rChild.OnProvenPath { - t.proof = append(t.proof, rChild.value) - } + if rChild.OnProvenPath && !lChild.OnProvenPath { + copy := append([]byte(nil), lChild.value...) + t.proof = append(t.proof, copy) } + if lChild.OnProvenPath && !rChild.OnProvenPath { + copy := append([]byte(nil), rChild.value...) + t.proof = append(t.proof, copy) + } + + n = t.calcParent(t.parentBuf[:0], lChild, rChild) + t.parentBuf = n.value - l.parking.value = nil - n = parent + l.parking.value = l.parking.value[:0] err := l.ensureNextLayerExists(t.cacheWriter) if err != nil { return err @@ -264,18 +267,21 @@ func (t *Tree) calcEphemeralParent(parking, ephemeralNode node) (parent, lChild, default: // both are empty return EmptyNode, EmptyNode, EmptyNode } - return t.calcParent(lChild, rChild), lChild, rChild + return t.calcParent(nil, lChild, rChild), lChild, rChild } -// calcParent returns the parent node of two child nodes. -func (t *Tree) calcParent(lChild, rChild node) node { +// calcParent calculates the parent node of two child nodes. +// The buf can be used to reuse memory for hashing. +func (t *Tree) calcParent(buf []byte, lChild, rChild node) node { return node{ - value: t.hash(lChild.value, rChild.value), + value: t.hash(buf, lChild.value, rChild.value), OnProvenPath: lChild.OnProvenPath || rChild.OnProvenPath, } } -func GetSha256Parent(lChild, rChild []byte) []byte { - res := sha256.Sum256(append(lChild, rChild...)) - return res[:] +func GetSha256Parent(buf, lChild, rChild []byte) []byte { + hasher := sha256.New() + hasher.Write(lChild) + hasher.Write(rChild) + return hasher.Sum(buf) } diff --git a/merkle_test.go b/merkle_test.go index 4ed1e81..d50d2ef 100644 --- a/merkle_test.go +++ b/merkle_test.go @@ -63,7 +63,7 @@ func TestNewTree(t *testing.T) { r.Equal(expectedRoot, root) } -func concatLeaves(lChild, rChild []byte) []byte { +func concatLeaves(_, lChild, rChild []byte) []byte { if len(lChild) == NodeSize { lChild = lChild[:1] } @@ -192,8 +192,7 @@ func TestNewTreeUnbalancedProof(t *testing.T) { expectedProof[3], _ = NewNodeFromHex("0600000000000000000000000000000000000000000000000000000000000000") expectedProof[4], _ = NewNodeFromHex("bc68417a8495de6e22d95b980fca5a1183f29eff0e2a9b7ddde91ed5bcbea952") - var proof nodes - proof = tree.Proof() + proof := tree.Proof() r.EqualValues(expectedProof, proof) } @@ -314,6 +313,38 @@ func TestNewProvingTreeMultiProof(t *testing.T) { ***************************************************/ } +// TestNewProvingTreeMultiProofReuseLeafBytes verifies if the user of Tree +// can safely reuse the memory passed into Tree::AddLeaf. +func TestNewProvingTreeMultiProofReuseLeafBytes(t *testing.T) { + r := require.New(t) + tree, err := NewProvingTree(setOf(1, 4)) + r.NoError(err) + var leaf [32]byte + for i := uint64(0); i < 8; i++ { + binary.LittleEndian.PutUint64(leaf[:], i) + r.NoError(tree.AddLeaf(leaf[:])) + } + expectedRoot, _ := NewNodeFromHex("89a0f1577268cc19b0a39c7a69f804fd140640c699585eb635ebb03c06154cce") + root := tree.Root() + r.Equal(expectedRoot, root) + + expectedProof := make([][]byte, 4) + expectedProof[0], _ = NewNodeFromHex("0000000000000000000000000000000000000000000000000000000000000000") + expectedProof[1], _ = NewNodeFromHex("0094579cfc7b716038d416a311465309bea202baa922b224a7b08f01599642fb") + expectedProof[2], _ = NewNodeFromHex("0500000000000000000000000000000000000000000000000000000000000000") + expectedProof[3], _ = NewNodeFromHex("fa670379e5c2212ed93ff09769622f81f98a91e1ec8fb114d607dd25220b9088") + + proof := tree.Proof() + r.EqualValues(expectedProof, proof) + + /*************************************************** + | 89a0 | + | ba94 633b | + | cb59 .0094. bd50 .fa67. | + | .0000.=0100= 0200 0300 =0400=.0500. 0600 0700 | + ***************************************************/ +} + func TestNewProvingTreeMultiProof2(t *testing.T) { r := require.New(t) tree, err := NewProvingTree(setOf(0, 1, 4)) @@ -442,7 +473,7 @@ func TestTree_GetParkedNodes(t *testing.T) { r.NoError(tree.AddLeaf([]byte{1})) r.EqualValues( - [][]byte{nil, decode(r, "b413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2")}, + [][]byte{{}, decode(r, "b413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2")}, tree.GetParkedNodes(nil)) r.NoError(tree.AddLeaf([]byte{2})) @@ -452,7 +483,7 @@ func TestTree_GetParkedNodes(t *testing.T) { r.NoError(tree.AddLeaf([]byte{3})) r.EqualValues( - [][]byte{nil, nil, decode(r, "7699a4fdd6b8b6908a344f73b8f05c8e1400f7253f544602c442ff5c65504b24")}, + [][]byte{{}, {}, decode(r, "7699a4fdd6b8b6908a344f73b8f05c8e1400f7253f544602c442ff5c65504b24")}, tree.GetParkedNodes(nil)) } @@ -463,7 +494,7 @@ func TestTree_SetParkedNodes(t *testing.T) { r.NoError(err) r.NoError(tree.SetParkedNodes([][]byte{{0}})) r.NoError(tree.AddLeaf([]byte{1})) - parkedNodes := [][]byte{nil, decode(r, "b413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2")} + parkedNodes := [][]byte{{}, decode(r, "b413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2")} r.EqualValues(parkedNodes, tree.GetParkedNodes(nil)) tree, err = NewTreeBuilder().Build() @@ -477,7 +508,7 @@ func TestTree_SetParkedNodes(t *testing.T) { r.NoError(err) r.NoError(tree.SetParkedNodes(parkedNodes)) r.NoError(tree.AddLeaf([]byte{3})) - parkedNodes = [][]byte{nil, nil, decode(r, "7699a4fdd6b8b6908a344f73b8f05c8e1400f7253f544602c442ff5c65504b24")} + parkedNodes = [][]byte{{}, {}, decode(r, "7699a4fdd6b8b6908a344f73b8f05c8e1400f7253f544602c442ff5c65504b24")} r.EqualValues(parkedNodes, tree.GetParkedNodes(nil)) } diff --git a/shared/types.go b/shared/types.go index ff8d5f4..a259f9b 100644 --- a/shared/types.go +++ b/shared/types.go @@ -1,6 +1,6 @@ package shared -type HashFunc func(lChild, rChild []byte) []byte +type HashFunc func(buf, lChild, rChild []byte) []byte // LayerReadWriter is a combined reader-writer. Note that the Seek() method only belongs to the LayerReader interface // and does not affect the LayerWriter. diff --git a/validation.go b/validation.go index bb138f5..f940804 100644 --- a/validation.go +++ b/validation.go @@ -104,7 +104,7 @@ func (v *Validator) CalcRoot(stopAtLayer uint) ([]byte, []ParkingSnapshot, error subTreeSnapshots = nil } } - activeNode = v.Hash(lChild, rChild) + activeNode = v.Hash(nil, lChild, rChild) activePos = activePos.parent() } return activeNode, parkingSnapshots, nil