Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions merkle.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type node struct {
}

func (n node) IsEmpty() bool {
return n.value == nil
return len(n.value) == 0
}

// layer is a layer in the merkle tree.
Expand Down Expand Up @@ -98,6 +98,7 @@ type Tree struct {
leavesToProve *sparseBoolStack
cacheWriter CacheWriter
minHeight uint
parentBuf []byte
}

// AddLeaf incorporates a new leaf to the state of the tree. It updates the state required to eventually determine the
Expand All @@ -108,7 +109,6 @@ func (t *Tree) AddLeaf(value []byte) error {
OnProvenPath: t.leavesToProve.Pop(),
}
l := t.baseLayer
var parent, lChild, rChild node
var lastCachingError error

// Loop through the layers, starting from the base layer.
Expand All @@ -124,26 +124,29 @@ func (t *Tree) AddLeaf(value []byte) error {
// If no node is pending, then this node is a left sibling,
// pending for its right sibling before its parent can be calculated.
if l.parking.IsEmpty() {
l.parking = n
// Copy the byte slice as we will keep it for a while.
l.parking.value = append(l.parking.value[:0], n.value...)
l.parking.OnProvenPath = n.OnProvenPath
break
} else {
// This node is a right sibling.
lChild, rChild = l.parking, n
parent = t.calcParent(lChild, rChild)
lChild, rChild := l.parking, n

// A given node is required in the proof if and only if its parent is an ancestor
// of a leaf whose membership in the tree is being proven, but the given node isn't.
if parent.OnProvenPath {
if !lChild.OnProvenPath {
t.proof = append(t.proof, lChild.value)
}
if !rChild.OnProvenPath {
t.proof = append(t.proof, rChild.value)
}
if rChild.OnProvenPath && !lChild.OnProvenPath {
copy := append([]byte(nil), lChild.value...)
t.proof = append(t.proof, copy)
}
if lChild.OnProvenPath && !rChild.OnProvenPath {
copy := append([]byte(nil), rChild.value...)
t.proof = append(t.proof, copy)
}

n = t.calcParent(t.parentBuf[:0], lChild, rChild)
t.parentBuf = n.value

l.parking.value = nil
n = parent
l.parking.value = l.parking.value[:0]
err := l.ensureNextLayerExists(t.cacheWriter)
if err != nil {
return err
Expand Down Expand Up @@ -264,18 +267,21 @@ func (t *Tree) calcEphemeralParent(parking, ephemeralNode node) (parent, lChild,
default: // both are empty
return EmptyNode, EmptyNode, EmptyNode
}
return t.calcParent(lChild, rChild), lChild, rChild
return t.calcParent(nil, lChild, rChild), lChild, rChild
}

// calcParent returns the parent node of two child nodes.
func (t *Tree) calcParent(lChild, rChild node) node {
// calcParent calculates the parent node of two child nodes.
// The buf can be used to reuse memory for hashing.
func (t *Tree) calcParent(buf []byte, lChild, rChild node) node {
return node{
value: t.hash(lChild.value, rChild.value),
value: t.hash(buf, lChild.value, rChild.value),
OnProvenPath: lChild.OnProvenPath || rChild.OnProvenPath,
}
}

func GetSha256Parent(lChild, rChild []byte) []byte {
res := sha256.Sum256(append(lChild, rChild...))
return res[:]
func GetSha256Parent(buf, lChild, rChild []byte) []byte {
hasher := sha256.New()
hasher.Write(lChild)
hasher.Write(rChild)
return hasher.Sum(buf)
}
45 changes: 38 additions & 7 deletions merkle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func TestNewTree(t *testing.T) {
r.Equal(expectedRoot, root)
}

func concatLeaves(lChild, rChild []byte) []byte {
func concatLeaves(_, lChild, rChild []byte) []byte {
if len(lChild) == NodeSize {
lChild = lChild[:1]
}
Expand Down Expand Up @@ -192,8 +192,7 @@ func TestNewTreeUnbalancedProof(t *testing.T) {
expectedProof[3], _ = NewNodeFromHex("0600000000000000000000000000000000000000000000000000000000000000")
expectedProof[4], _ = NewNodeFromHex("bc68417a8495de6e22d95b980fca5a1183f29eff0e2a9b7ddde91ed5bcbea952")

var proof nodes
proof = tree.Proof()
proof := tree.Proof()
r.EqualValues(expectedProof, proof)
}

Expand Down Expand Up @@ -314,6 +313,38 @@ func TestNewProvingTreeMultiProof(t *testing.T) {
***************************************************/
}

// TestNewProvingTreeMultiProofReuseLeafBytes verifies if the user of Tree
// can safely reuse the memory passed into Tree::AddLeaf.
func TestNewProvingTreeMultiProofReuseLeafBytes(t *testing.T) {
r := require.New(t)
tree, err := NewProvingTree(setOf(1, 4))
r.NoError(err)
var leaf [32]byte
for i := uint64(0); i < 8; i++ {
binary.LittleEndian.PutUint64(leaf[:], i)
r.NoError(tree.AddLeaf(leaf[:]))
}
expectedRoot, _ := NewNodeFromHex("89a0f1577268cc19b0a39c7a69f804fd140640c699585eb635ebb03c06154cce")
root := tree.Root()
r.Equal(expectedRoot, root)

expectedProof := make([][]byte, 4)
expectedProof[0], _ = NewNodeFromHex("0000000000000000000000000000000000000000000000000000000000000000")
expectedProof[1], _ = NewNodeFromHex("0094579cfc7b716038d416a311465309bea202baa922b224a7b08f01599642fb")
expectedProof[2], _ = NewNodeFromHex("0500000000000000000000000000000000000000000000000000000000000000")
expectedProof[3], _ = NewNodeFromHex("fa670379e5c2212ed93ff09769622f81f98a91e1ec8fb114d607dd25220b9088")

proof := tree.Proof()
r.EqualValues(expectedProof, proof)

/***************************************************
| 89a0 |
| ba94 633b |
| cb59 .0094. bd50 .fa67. |
| .0000.=0100= 0200 0300 =0400=.0500. 0600 0700 |
***************************************************/
}

func TestNewProvingTreeMultiProof2(t *testing.T) {
r := require.New(t)
tree, err := NewProvingTree(setOf(0, 1, 4))
Expand Down Expand Up @@ -442,7 +473,7 @@ func TestTree_GetParkedNodes(t *testing.T) {

r.NoError(tree.AddLeaf([]byte{1}))
r.EqualValues(
[][]byte{nil, decode(r, "b413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2")},
[][]byte{{}, decode(r, "b413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2")},
tree.GetParkedNodes(nil))

r.NoError(tree.AddLeaf([]byte{2}))
Expand All @@ -452,7 +483,7 @@ func TestTree_GetParkedNodes(t *testing.T) {

r.NoError(tree.AddLeaf([]byte{3}))
r.EqualValues(
[][]byte{nil, nil, decode(r, "7699a4fdd6b8b6908a344f73b8f05c8e1400f7253f544602c442ff5c65504b24")},
[][]byte{{}, {}, decode(r, "7699a4fdd6b8b6908a344f73b8f05c8e1400f7253f544602c442ff5c65504b24")},
tree.GetParkedNodes(nil))
}

Expand All @@ -463,7 +494,7 @@ func TestTree_SetParkedNodes(t *testing.T) {
r.NoError(err)
r.NoError(tree.SetParkedNodes([][]byte{{0}}))
r.NoError(tree.AddLeaf([]byte{1}))
parkedNodes := [][]byte{nil, decode(r, "b413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2")}
parkedNodes := [][]byte{{}, decode(r, "b413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2")}
r.EqualValues(parkedNodes, tree.GetParkedNodes(nil))

tree, err = NewTreeBuilder().Build()
Expand All @@ -477,7 +508,7 @@ func TestTree_SetParkedNodes(t *testing.T) {
r.NoError(err)
r.NoError(tree.SetParkedNodes(parkedNodes))
r.NoError(tree.AddLeaf([]byte{3}))
parkedNodes = [][]byte{nil, nil, decode(r, "7699a4fdd6b8b6908a344f73b8f05c8e1400f7253f544602c442ff5c65504b24")}
parkedNodes = [][]byte{{}, {}, decode(r, "7699a4fdd6b8b6908a344f73b8f05c8e1400f7253f544602c442ff5c65504b24")}
r.EqualValues(parkedNodes, tree.GetParkedNodes(nil))
}

Expand Down
2 changes: 1 addition & 1 deletion shared/types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package shared

type HashFunc func(lChild, rChild []byte) []byte
type HashFunc func(buf, lChild, rChild []byte) []byte

// LayerReadWriter is a combined reader-writer. Note that the Seek() method only belongs to the LayerReader interface
// and does not affect the LayerWriter.
Expand Down
2 changes: 1 addition & 1 deletion validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (v *Validator) CalcRoot(stopAtLayer uint) ([]byte, []ParkingSnapshot, error
subTreeSnapshots = nil
}
}
activeNode = v.Hash(lChild, rChild)
activeNode = v.Hash(nil, lChild, rChild)
activePos = activePos.parent()
}
return activeNode, parkingSnapshots, nil
Expand Down