From 3651a544f849eda19919dee1f18d4693f8f6b1c6 Mon Sep 17 00:00:00 2001 From: Noam Nelke Date: Wed, 27 Mar 2019 16:54:46 +0200 Subject: [PATCH] extract common cache with usability improvements (#6) --- cache/cache.go | 142 ++++++++++ cache/cache_test.go | 79 ++++++ cache/cachingpolicies.go | 19 ++ cache/cachingpolicies_test.go | 103 ++++++++ cache/layerfactories.go | 13 + cache/slice.go | 35 +++ iterators.go | 37 ++- merkle.go | 87 +----- merkle_test.go | 157 ++++------- utils.go => position.go | 26 +- position_test.go | 22 ++ proving.go | 181 ++++++++----- proving_test.go | 479 ++++++++++++++++++++++++---------- treebuilder.go | 63 +++++ treecache.go | 137 ---------- treecache_test.go | 212 --------------- validation.go | 16 +- validation_test.go | 34 ++- 18 files changed, 1082 insertions(+), 760 deletions(-) create mode 100644 cache/cache.go create mode 100644 cache/cache_test.go create mode 100644 cache/cachingpolicies.go create mode 100644 cache/cachingpolicies_test.go create mode 100644 cache/layerfactories.go create mode 100644 cache/slice.go rename utils.go => position.go (62%) create mode 100644 position_test.go create mode 100644 treebuilder.go delete mode 100644 treecache.go delete mode 100644 treecache_test.go diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..01248cc --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,142 @@ +package cache + +import ( + "errors" + "fmt" + "math" +) + +const NodeSize = 32 + +type Writer struct { + *cache +} + +func NewWriter(shouldCacheLayer CachingPolicy, generateLayer LayerFactory) *Writer { + return &Writer{ + cache: &cache{ + layers: make(map[uint]LayerReadWriter), + generateLayer: generateLayer, + shouldCacheLayer: shouldCacheLayer, + }, + } +} + +func (c *Writer) SetLayer(layerHeight uint, rw LayerReadWriter) { + c.layers[layerHeight] = rw +} + +func (c *Writer) GetLayerWriter(layerHeight uint) LayerWriter { + layerReadWriter, found := c.layers[layerHeight] + if !found && c.shouldCacheLayer(layerHeight) { + layerReadWriter = c.generateLayer(layerHeight) + c.layers[layerHeight] = layerReadWriter + } + return layerReadWriter +} + +func (c *Writer) SetHash(hashFunc func(lChild, rChild []byte) []byte) { + c.hash = hashFunc +} + +func (c *Writer) GetReader() (*Reader, error) { + err := c.validateStructure() + if err != nil { + return nil, err + } + return &Reader{c.cache}, nil +} + +type Reader struct { + *cache +} + +func (c *Reader) GetLayerReader(layerHeight uint) LayerReader { + return c.layers[layerHeight] +} + +func (c *Reader) GetHashFunc() func(lChild, rChild []byte) []byte { + return c.hash +} + +type cache struct { + layers map[uint]LayerReadWriter + hash func(lChild, rChild []byte) []byte + shouldCacheLayer CachingPolicy + generateLayer LayerFactory +} + +func (c *cache) validateStructure() error { + // Verify we got the base layer. + if _, found := c.layers[0]; !found { + return errors.New("reader for base layer must be included") + } + width := c.layers[0].Width() + if width == 0 { + return errors.New("base layer cannot be empty") + } + height := RootHeightFromWidth(width) + for i := uint(0); i < height; i++ { + if _, found := c.layers[i]; found && c.layers[i].Width() != width { + return fmt.Errorf("reader at layer %d has width %d instead of %d", i, c.layers[i].Width(), width) + } + width >>= 1 + } + return nil +} + +type CachingPolicy func(layerHeight uint) (shouldCacheLayer bool) + +type LayerFactory func(layerHeight uint) LayerReadWriter + +// LayerReadWriter is a combined reader-writer. Note that the Seek() method only belongs to the LayerReader interface +// and does not affect the LayerWriter. +type LayerReadWriter interface { + LayerReader + LayerWriter +} + +type LayerReader interface { + Seek(index uint64) error + ReadNext() ([]byte, error) + Width() uint64 +} + +type LayerWriter interface { + Append(p []byte) (n int, err error) +} + +func RootHeightFromWidth(width uint64) uint { + return uint(math.Ceil(math.Log2(float64(width)))) +} + +//func (c *cache) Print(bottom, top int) { +// for i := top; i >= bottom; i-- { +// print("| ") +// sliceReadWriter, ok := c.layers[uint(i)].(*SliceReadWriter) +// if !ok { +// println("-- layer is not a SliceReadWriter --") +// continue +// } +// for _, n := range sliceReadWriter.slice { +// printSpaces(numSpaces(i)) +// fmt.Print(hex.EncodeToString(n[:2])) +// printSpaces(numSpaces(i)) +// } +// println(" |") +// } +//} +// +//func numSpaces(n int) int { +// res := 1 +// for i := 0; i < n; i++ { +// res += 3 * (1 << uint(i)) +// } +// return res +//} +// +//func printSpaces(n int) { +// for i := 0; i < n; i++ { +// print(" ") +// } +//} diff --git a/cache/cache_test.go b/cache/cache_test.go new file mode 100644 index 0000000..c2cac5a --- /dev/null +++ b/cache/cache_test.go @@ -0,0 +1,79 @@ +package cache + +import ( + "errors" + "github.com/stretchr/testify/require" + "testing" +) + +var someError = errors.New("some error") + +type widthReader struct{ width uint64 } + +func (r widthReader) Seek(index uint64) error { return nil } +func (r widthReader) ReadNext() ([]byte, error) { return nil, someError } +func (r widthReader) Width() uint64 { return r.width } +func (r widthReader) Append(p []byte) (n int, err error) { panic("implement me") } + +func TestCache_ValidateStructure(t *testing.T) { + r := require.New(t) + var readers map[uint]LayerReadWriter + + treeCache := &cache{layers: readers} + err := treeCache.validateStructure() + + r.Error(err) + r.Equal("reader for base layer must be included", err.Error()) +} + +func TestCache_ValidateStructure2(t *testing.T) { + r := require.New(t) + readers := make(map[uint]LayerReadWriter) + + treeCache := &cache{layers: readers} + err := treeCache.validateStructure() + + r.Error(err) + r.Equal("reader for base layer must be included", err.Error()) +} + +func TestCache_ValidateStructureSuccess(t *testing.T) { + r := require.New(t) + readers := make(map[uint]LayerReadWriter) + + readers[0] = widthReader{width: 4} + readers[1] = widthReader{width: 2} + readers[2] = widthReader{width: 1} + treeCache := &cache{layers: readers} + err := treeCache.validateStructure() + + r.NoError(err) +} + +func TestCache_ValidateStructureFail(t *testing.T) { + r := require.New(t) + readers := make(map[uint]LayerReadWriter) + + readers[0] = widthReader{width: 3} + readers[1] = widthReader{width: 2} + readers[2] = widthReader{width: 1} + treeCache := &cache{layers: readers} + err := treeCache.validateStructure() + + r.Error(err) + r.Equal("reader at layer 1 has width 2 instead of 1", err.Error()) +} + +func TestCache_ValidateStructureFail2(t *testing.T) { + r := require.New(t) + readers := make(map[uint]LayerReadWriter) + + readers[0] = widthReader{width: 4} + readers[1] = widthReader{width: 1} + readers[2] = widthReader{width: 1} + treeCache := &cache{layers: readers} + err := treeCache.validateStructure() + + r.Error(err) + r.Equal("reader at layer 1 has width 1 instead of 2", err.Error()) +} diff --git a/cache/cachingpolicies.go b/cache/cachingpolicies.go new file mode 100644 index 0000000..e96b7e5 --- /dev/null +++ b/cache/cachingpolicies.go @@ -0,0 +1,19 @@ +package cache + +func MinHeightPolicy(minHeight uint) CachingPolicy { + return func(layerHeight uint) (shouldCacheLayer bool) { + return layerHeight >= minHeight + } +} + +func SpecificLayersPolicy(layersToCache map[uint]bool) CachingPolicy { + return func(layerHeight uint) (shouldCacheLayer bool) { + return layersToCache[layerHeight] + } +} + +func Combine(first, second CachingPolicy) CachingPolicy { + return func(layerHeight uint) (shouldCacheLayer bool) { + return first(layerHeight) || second(layerHeight) + } +} diff --git a/cache/cachingpolicies_test.go b/cache/cachingpolicies_test.go new file mode 100644 index 0000000..70397f1 --- /dev/null +++ b/cache/cachingpolicies_test.go @@ -0,0 +1,103 @@ +package cache + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestMakeMemoryReadWriterFactory(t *testing.T) { + r := require.New(t) + cacheWriter := NewWriter(MinHeightPolicy(2), MakeSliceReadWriterFactory()) + cacheWriter.SetLayer(0, widthReader{width: 1}) + + cacheReader, err := cacheWriter.GetReader() + r.NoError(err) + + reader := cacheReader.GetLayerReader(1) + r.Nil(reader) + reader = cacheReader.GetLayerReader(2) + r.Nil(reader) + reader = cacheReader.GetLayerReader(3) + r.Nil(reader) + + writer := cacheWriter.GetLayerWriter(1) + r.Nil(writer) + writer = cacheWriter.GetLayerWriter(2) + r.NotNil(writer) + writer = cacheWriter.GetLayerWriter(3) + r.NotNil(writer) + + cacheReader, err = cacheWriter.GetReader() + r.NoError(err) + + reader = cacheReader.GetLayerReader(1) + r.Nil(reader) + reader = cacheReader.GetLayerReader(2) + r.NotNil(reader) + reader = cacheReader.GetLayerReader(3) + r.NotNil(reader) +} + +func TestMakeMemoryReadWriterFactoryForLayers(t *testing.T) { + r := require.New(t) + cacheWriter := NewWriter(SpecificLayersPolicy(map[uint]bool{1: true, 3: true}), MakeSliceReadWriterFactory()) + cacheWriter.SetLayer(0, widthReader{width: 1}) + + cacheReader, err := cacheWriter.GetReader() + r.NoError(err) + + reader := cacheReader.GetLayerReader(1) + r.Nil(reader) + reader = cacheReader.GetLayerReader(2) + r.Nil(reader) + reader = cacheReader.GetLayerReader(3) + r.Nil(reader) + + writer := cacheWriter.GetLayerWriter(1) + r.NotNil(writer) + writer = cacheWriter.GetLayerWriter(2) + r.Nil(writer) + writer = cacheWriter.GetLayerWriter(3) + r.NotNil(writer) + + cacheReader, err = cacheWriter.GetReader() + r.NoError(err) + + reader = cacheReader.GetLayerReader(1) + r.NotNil(reader) + reader = cacheReader.GetLayerReader(2) + r.Nil(reader) + reader = cacheReader.GetLayerReader(3) + r.NotNil(reader) +} + +func TestMakeSpecificLayerFactory(t *testing.T) { + r := require.New(t) + readWriter := &SliceReadWriter{} + cacheWriter := NewWriter( + SpecificLayersPolicy(map[uint]bool{1: true}), + MakeSpecificLayersFactory(map[uint]LayerReadWriter{1: readWriter}), + ) + cacheWriter.SetLayer(0, widthReader{width: 1}) + + cacheReader, err := cacheWriter.GetReader() + r.NoError(err) + + reader := cacheReader.GetLayerReader(1) + r.Nil(reader) + reader = cacheReader.GetLayerReader(2) + r.Nil(reader) + + writer := cacheWriter.GetLayerWriter(1) + r.Equal(readWriter, writer) + writer = cacheWriter.GetLayerWriter(2) + r.Nil(writer) + + cacheReader, err = cacheWriter.GetReader() + r.NoError(err) + + reader = cacheReader.GetLayerReader(1) + r.Equal(readWriter, reader) + reader = cacheReader.GetLayerReader(2) + r.Nil(reader) +} diff --git a/cache/layerfactories.go b/cache/layerfactories.go new file mode 100644 index 0000000..10c6edc --- /dev/null +++ b/cache/layerfactories.go @@ -0,0 +1,13 @@ +package cache + +func MakeSliceReadWriterFactory() LayerFactory { + return func(layerHeight uint) LayerReadWriter { + return &SliceReadWriter{} + } +} + +func MakeSpecificLayersFactory(readWriters map[uint]LayerReadWriter) LayerFactory { + return func(layerHeight uint) LayerReadWriter { + return readWriters[layerHeight] + } +} diff --git a/cache/slice.go b/cache/slice.go new file mode 100644 index 0000000..0c32c1b --- /dev/null +++ b/cache/slice.go @@ -0,0 +1,35 @@ +package cache + +import "io" + +type SliceReadWriter struct { + slice [][]byte + position uint64 +} + +func (s *SliceReadWriter) Width() uint64 { + return uint64(len(s.slice)) +} + +func (s *SliceReadWriter) Seek(index uint64) error { + if index >= uint64(len(s.slice)) { + return io.EOF + } + s.position = index + return nil +} + +func (s *SliceReadWriter) ReadNext() ([]byte, error) { + if s.position >= uint64(len(s.slice)) { + return nil, io.EOF + } + value := make([]byte, NodeSize) + copy(value, s.slice[s.position]) + s.position++ + return value, nil +} + +func (s *SliceReadWriter) Append(p []byte) (n int, err error) { + s.slice = append(s.slice, p) + return len(p), nil +} diff --git a/iterators.go b/iterators.go index e62fa4c..7ceccb1 100644 --- a/iterators.go +++ b/iterators.go @@ -1,13 +1,42 @@ package merkle -import "errors" +import ( + "errors" + "sort" +) var noMoreItems = errors.New("no more items") +type set map[uint64]bool + +func (s set) asSortedSlice() []uint64 { + var ret []uint64 + for key, value := range s { + if value { + ret = append(ret, key) + } + } + sort.Slice(ret, func(i, j int) bool { return ret[i] < ret[j] }) + return ret +} + +func setOf(members ...uint64) set { + ret := make(set) + for _, member := range members { + ret[member] = true + } + return ret +} + type positionsIterator struct { s []uint64 } +func newPositionsIterator(positions set) *positionsIterator { + s := positions.asSortedSlice() + return &positionsIterator{s: s} +} + func (it *positionsIterator) peek() (pos position, found bool) { if len(it.s) == 0 { return position{}, false @@ -17,10 +46,10 @@ func (it *positionsIterator) peek() (pos position, found bool) { } // batchPop returns the indices of all positions up to endIndex. -func (it *positionsIterator) batchPop(endIndex uint64) []uint64 { - var res []uint64 +func (it *positionsIterator) batchPop(endIndex uint64) set { + res := make(set) for len(it.s) > 0 && it.s[0] < endIndex { - res = append(res, it.s[0]) + res[it.s[0]] = true it.s = it.s[1:] } return res diff --git a/merkle.go b/merkle.go index 6073588..7b78129 100644 --- a/merkle.go +++ b/merkle.go @@ -2,11 +2,14 @@ package merkle import ( "errors" + "github.com/spacemeshos/merkle-tree/cache" "github.com/spacemeshos/sha256-simd" - "io" - "sort" ) +const NodeSize = cache.NodeSize + +type HashFunc func(lChild, rChild []byte) []byte + var emptyNode node // PaddingValue is used for padding unbalanced trees. This value should not be permitted at the leaf layer to @@ -31,17 +34,17 @@ type layer struct { height uint parking node // This is where we park a node until its sibling is processed and we can calculate their parent. next *layer - cache io.Writer + cache cache.LayerWriter } // ensureNextLayerExists creates the next layer if it doesn't exist. -func (l *layer) ensureNextLayerExists(cache map[uint]io.Writer) { +func (l *layer) ensureNextLayerExists(cacheWriter *cache.Writer) { if l.next == nil { - l.next = newLayer(l.height+1, cache[(l.height + 1)]) + l.next = newLayer(l.height+1, cacheWriter.GetLayerWriter(l.height+1)) } } -func newLayer(height uint, cache io.Writer) *layer { +func newLayer(height uint, cache cache.LayerWriter) *layer { return &layer{height: height, cache: cache} } @@ -50,10 +53,8 @@ type sparseBoolStack struct { currentIndex uint64 } -func newSparseBoolStack(trueIndices []uint64) *sparseBoolStack { - sorted := make([]uint64, len(trueIndices)) - copy(sorted, trueIndices) - sort.Slice(sorted, func(i, j int) bool { return sorted[i] < sorted[j] }) +func newSparseBoolStack(trueIndices set) *sparseBoolStack { + sorted := trueIndices.asSortedSlice() return &sparseBoolStack{sortedTrueIndices: sorted} } @@ -69,8 +70,6 @@ func (s *sparseBoolStack) Pop() bool { return ret } -type HashFunc func(lChild, rChild []byte) []byte - // Tree calculates a merkle tree root. It can optionally calculate a proof, or partial tree, for leaves defined in // advance. Leaves are appended to the tree incrementally. It uses O(log(n)) memory to calculate the root and // O(k*log(n)) (k being the number of leaves to prove) memory to calculate proofs. @@ -81,7 +80,7 @@ type Tree struct { hash HashFunc proof [][]byte leavesToProve *sparseBoolStack - cache map[uint]io.Writer + cacheWriter *cache.Writer minHeight uint } @@ -100,7 +99,7 @@ func (t *Tree) AddLeaf(value []byte) error { for { // Writing the node to its layer cache, if applicable. if l.cache != nil { - _, err := l.cache.Write(n.value) + _, err := l.cache.Append(n.value) if err != nil { lastCachingError = errors.New("error while caching: " + err.Error()) } @@ -129,20 +128,13 @@ func (t *Tree) AddLeaf(value []byte) error { l.parking.value = nil n = parent - l.ensureNextLayerExists(t.cache) + l.ensureNextLayerExists(t.cacheWriter) l = l.next } } return lastCachingError } -func nextOrEmptyLayer(l *layer) *layer { - if l.next != nil { - return l.next - } - return &layer{height: l.height + 1} -} - // Root returns the root of the tree. // If the tree is unbalanced (num. of leaves is not a power of 2) it will perform padding on-the-fly. func (t *Tree) Root() []byte { @@ -230,57 +222,6 @@ func (t *Tree) calcParent(lChild, rChild node) node { } } -type TreeBuilder struct { - hash HashFunc - leavesToProves []uint64 - cache map[uint]io.Writer - minHeight uint -} - -func NewTreeBuilder(hash HashFunc) TreeBuilder { - return TreeBuilder{hash: hash} -} - -func (tb TreeBuilder) Build() *Tree { - if tb.cache == nil { - tb.cache = make(map[uint]io.Writer) - } - return &Tree{ - baseLayer: newLayer(0, tb.cache[0]), - hash: tb.hash, - leavesToProve: newSparseBoolStack(tb.leavesToProves), - cache: tb.cache, - minHeight: tb.minHeight, - } -} - -func (tb TreeBuilder) WithLeavesToProve(leavesToProves []uint64) TreeBuilder { - tb.leavesToProves = leavesToProves - return tb -} - -func (tb TreeBuilder) WithCache(cache map[uint]io.Writer) TreeBuilder { - tb.cache = cache - return tb -} - -func (tb TreeBuilder) WithMinHeight(minHeight uint) TreeBuilder { - tb.minHeight = minHeight - return tb -} - -func NewTree(hash HashFunc) *Tree { - return NewTreeBuilder(hash).Build() -} - -func NewProvingTree(hash HashFunc, leavesToProves []uint64) *Tree { - return NewTreeBuilder(hash).WithLeavesToProve(leavesToProves).Build() -} - -func NewCachingTree(hash HashFunc, cache map[uint]io.Writer) *Tree { - return NewTreeBuilder(hash).WithCache(cache).Build() -} - func GetSha256Parent(lChild, rChild []byte) []byte { res := sha256.Sum256(append(lChild, rChild...)) return res[:] diff --git a/merkle_test.go b/merkle_test.go index c143e56..7931a11 100644 --- a/merkle_test.go +++ b/merkle_test.go @@ -3,9 +3,8 @@ package merkle import ( "encoding/binary" "encoding/hex" - "fmt" + "github.com/spacemeshos/merkle-tree/cache" "github.com/stretchr/testify/require" - "io" "testing" "time" ) @@ -25,7 +24,7 @@ import ( func TestNewTree(t *testing.T) { r := require.New(t) - tree := NewTree(GetSha256Parent) + tree := NewTree() for i := uint64(0); i < 8; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) r.NoError(err) @@ -47,7 +46,7 @@ func concatLeaves(lChild, rChild []byte) []byte { func TestNewTreeWithMinHeightEqual(t *testing.T) { r := require.New(t) - tree := NewTreeBuilder(concatLeaves).WithMinHeight(3).Build() + tree := NewTreeBuilder().WithHashFunc(concatLeaves).WithMinHeight(3).Build() for i := uint64(0); i < 8; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) r.NoError(err) @@ -59,7 +58,7 @@ func TestNewTreeWithMinHeightEqual(t *testing.T) { func TestNewTreeWithMinHeightGreater(t *testing.T) { r := require.New(t) - tree := NewTreeBuilder(concatLeaves).WithMinHeight(4).Build() + tree := NewTreeBuilder().WithHashFunc(concatLeaves).WithMinHeight(4).Build() for i := uint64(0); i < 8; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) r.NoError(err) @@ -72,7 +71,7 @@ func TestNewTreeWithMinHeightGreater(t *testing.T) { func TestNewTreeWithMinHeightGreater2(t *testing.T) { r := require.New(t) - tree := NewTreeBuilder(concatLeaves).WithMinHeight(5).Build() + tree := NewTreeBuilder().WithHashFunc(concatLeaves).WithMinHeight(5).Build() for i := uint64(0); i < 8; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) r.NoError(err) @@ -85,7 +84,7 @@ func TestNewTreeWithMinHeightGreater2(t *testing.T) { func TestNewTreeUnbalanced(t *testing.T) { r := require.New(t) - tree := NewTree(GetSha256Parent) + tree := NewTree() for i := uint64(0); i < 9; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) r.NoError(err) @@ -97,7 +96,7 @@ func TestNewTreeUnbalanced(t *testing.T) { func TestNewTreeUnbalanced2(t *testing.T) { r := require.New(t) - tree := NewTree(GetSha256Parent) + tree := NewTree() for i := uint64(0); i < 10; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) r.NoError(err) @@ -109,7 +108,7 @@ func TestNewTreeUnbalanced2(t *testing.T) { func TestNewTreeUnbalanced3(t *testing.T) { r := require.New(t) - tree := NewTree(GetSha256Parent) + tree := NewTree() for i := uint64(0); i < 15; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) r.NoError(err) @@ -122,15 +121,13 @@ func TestNewTreeUnbalanced3(t *testing.T) { func TestNewTreeUnbalancedProof(t *testing.T) { r := require.New(t) - leavesToProve := []uint64{0, 4, 7} + leavesToProve := setOf(0, 4, 7) - sliceWriters := make(map[uint]*sliceReadWriter) - for i := uint(0); i < 5; i++ { - sliceWriters[i] = &sliceReadWriter{} - } - tree := NewTreeBuilder(GetSha256Parent). + cacheWriter := cache.NewWriter(cache.MinHeightPolicy(0), cache.MakeSliceReadWriterFactory()) + + tree := NewTreeBuilder(). WithLeavesToProve(leavesToProve). - WithCache(WritersFromSliceReadWriters(sliceWriters)). + WithCacheWriter(cacheWriter). Build() for i := uint64(0); i < 10; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) @@ -140,11 +137,17 @@ func TestNewTreeUnbalancedProof(t *testing.T) { root := tree.Root() r.Equal(expectedRoot, root) - r.Len(sliceWriters[0].slice, 10) - r.Len(sliceWriters[1].slice, 5) - r.Len(sliceWriters[2].slice, 2) - r.Len(sliceWriters[3].slice, 1) - r.NotEqual(sliceWriters[3].slice[0], expectedRoot) + cacheReader, err := cacheWriter.GetReader() + r.NoError(err) + + r.Equal(uint64(10), cacheReader.GetLayerReader(0).Width()) + r.Equal(uint64(5), cacheReader.GetLayerReader(1).Width()) + r.Equal(uint64(2), cacheReader.GetLayerReader(2).Width()) + r.Equal(uint64(1), cacheReader.GetLayerReader(3).Width()) + + cacheRoot, err := cacheReader.GetLayerReader(3).ReadNext() + r.NoError(err) + r.NotEqual(cacheRoot, expectedRoot) expectedProof := make([][]byte, 5) expectedProof[0], _ = NewNodeFromHex("0100000000000000000000000000000000000000000000000000000000000000") @@ -160,7 +163,7 @@ func TestNewTreeUnbalancedProof(t *testing.T) { func BenchmarkNewTree(b *testing.B) { var size uint64 = 1 << 28 - tree := NewTree(GetSha256Parent) + tree := NewTree() for i := uint64(0); i < size; i++ { _ = tree.AddLeaf(NewNodeFromUint64(i)) } @@ -176,7 +179,7 @@ func BenchmarkNewTree(b *testing.B) { func BenchmarkNewTreeSmall(b *testing.B) { var size uint64 = 1 << 23 start := time.Now() - tree := NewTree(GetSha256Parent) + tree := NewTree() for i := uint64(0); i < size; i++ { _ = tree.AddLeaf(NewNodeFromUint64(i)) } @@ -188,10 +191,7 @@ func BenchmarkNewTreeSmall(b *testing.B) { func BenchmarkNewTreeNoHashing(b *testing.B) { var size uint64 = 1 << 28 - tree := NewTree(func(leftChild, rightChild []byte) []byte { - arr := [32]byte{} - return arr[:] - }) + tree := NewTree() for i := uint64(0); i < size; i++ { _ = tree.AddLeaf(NewNodeFromUint64(i)) } @@ -216,7 +216,7 @@ func BenchmarkNewTreeNoHashing(b *testing.B) { func TestNewProvingTree(t *testing.T) { r := require.New(t) - tree := NewProvingTree(GetSha256Parent, []uint64{4}) + tree := NewProvingTree(setOf(4)) for i := uint64(0); i < 8; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) r.NoError(err) @@ -243,7 +243,7 @@ func TestNewProvingTree(t *testing.T) { func TestNewProvingTreeMultiProof(t *testing.T) { r := require.New(t) - tree := NewProvingTree(GetSha256Parent, []uint64{1, 4}) + tree := NewProvingTree(setOf(1, 4)) for i := uint64(0); i < 8; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) r.NoError(err) @@ -271,7 +271,7 @@ func TestNewProvingTreeMultiProof(t *testing.T) { func TestNewProvingTreeMultiProof2(t *testing.T) { r := require.New(t) - tree := NewProvingTree(GetSha256Parent, []uint64{0, 1, 4}) + tree := NewProvingTree(setOf(0, 1, 4)) for i := uint64(0); i < 8; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) r.NoError(err) @@ -308,45 +308,10 @@ func NewNodeFromHex(s string) ([]byte, error) { // Caching tests: -type sliceReadWriter struct { - slice [][]byte - position uint64 -} - -func (s *sliceReadWriter) Width() uint64 { - return uint64(len(s.slice)) -} - -func (s *sliceReadWriter) Seek(index uint64) error { - if index >= uint64(len(s.slice)) { - return io.EOF - } - s.position = index - return nil -} - -func (s *sliceReadWriter) ReadNext() ([]byte, error) { - if s.position >= uint64(len(s.slice)) { - return nil, io.EOF - } - value := make([]byte, NodeSize) - copy(value, s.slice[s.position]) - s.position++ - return value, nil -} - -func (s *sliceReadWriter) Write(p []byte) (n int, err error) { - s.slice = append(s.slice, p) - return len(p), nil -} - func TestNewCachingTree(t *testing.T) { r := require.New(t) - sliceWriters := make(map[uint]*sliceReadWriter) - for i := uint(0); i < 4; i++ { - sliceWriters[i] = &sliceReadWriter{} - } - tree := NewCachingTree(GetSha256Parent, WritersFromSliceReadWriters(sliceWriters)) + cacheWriter := cache.NewWriter(cache.MinHeightPolicy(0), cache.MakeSliceReadWriterFactory()) + tree := NewCachingTree(cacheWriter) for i := uint64(0); i < 8; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) r.NoError(err) @@ -355,49 +320,25 @@ func TestNewCachingTree(t *testing.T) { root := tree.Root() r.Equal(expectedRoot, root) - r.Len(sliceWriters[0].slice, 8) - r.Len(sliceWriters[1].slice, 4) - r.Len(sliceWriters[2].slice, 2) - r.Len(sliceWriters[3].slice, 1) - r.Equal(sliceWriters[3].slice[0], expectedRoot) + cacheReader, err := cacheWriter.GetReader() + r.NoError(err) - // printCache(0, 3, sliceWriters) -} + r.Equal(uint64(8), cacheReader.GetLayerReader(0).Width()) + r.Equal(uint64(4), cacheReader.GetLayerReader(1).Width()) + r.Equal(uint64(2), cacheReader.GetLayerReader(2).Width()) + r.Equal(uint64(1), cacheReader.GetLayerReader(3).Width()) + cacheRoot, err := cacheReader.GetLayerReader(3).ReadNext() + r.NoError(err) + r.Equal(cacheRoot, expectedRoot) -func printCache(bottom, top int, sliceWriters map[uint]*sliceReadWriter) { - for i := top; i >= bottom; i-- { - print("| ") - for _, n := range sliceWriters[uint(i)].slice { - printSpaces(numSpaces(i)) - fmt.Print(hex.EncodeToString(n[:2])) - printSpaces(numSpaces(i)) - } - println(" |") - } -} - -func numSpaces(n int) int { - res := 1 - for i := 0; i < n; i++ { - res += 3 * (1 << uint(i)) - } - return res -} - -func printSpaces(n int) { - for i := 0; i < n; i++ { - print(" ") - } + //cacheWriter.Print(0 , 3) } func BenchmarkNewCachingTreeSmall(b *testing.B) { var size uint64 = 1 << 23 - cache := make(map[uint]io.Writer) - for i := uint(7); i < 23; i++ { - cache[i] = &sliceReadWriter{} - } + cacheWriter := cache.NewWriter(cache.MinHeightPolicy(7), cache.MakeSliceReadWriterFactory()) start := time.Now() - tree := NewCachingTree(GetSha256Parent, cache) + tree := NewCachingTree(cacheWriter) for i := uint64(0); i < size; i++ { _ = tree.AddLeaf(NewNodeFromUint64(i)) } @@ -410,19 +351,19 @@ func BenchmarkNewCachingTreeSmall(b *testing.B) { func TestSparseBoolStack(t *testing.T) { r := require.New(t) - allFalse := newSparseBoolStack([]uint64{}) + allFalse := newSparseBoolStack(make(set)) for i := 0; i < 1000; i++ { r.False(allFalse.Pop()) } - allTrue := newSparseBoolStack([]uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + allTrue := newSparseBoolStack(setOf(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)) for i := 0; i < 10; i++ { r.True(allTrue.Pop()) } - rounds := make([]uint64, 0, 100) - for i := 0; i < 1000; i += 10 { - rounds = append(rounds, uint64(i)) + rounds := make(set) + for i := uint64(0); i < 1000; i += 10 { + rounds[i] = true } roundsTrue := newSparseBoolStack(rounds) for i := 0; i < 1000; i++ { diff --git a/utils.go b/position.go similarity index 62% rename from utils.go rename to position.go index f6e49a5..b79c589 100644 --- a/utils.go +++ b/position.go @@ -1,9 +1,6 @@ package merkle -import ( - "fmt" - "math" -) +import "fmt" type position struct { index uint64 @@ -46,6 +43,23 @@ func (p position) leftChild() position { } } -func rootHeightFromWidth(width uint64) uint { - return uint(math.Ceil(math.Log2(float64(width)))) +type positionsStack struct { + positions []position +} + +func (s *positionsStack) Push(v position) { + s.positions = append(s.positions, v) +} + +// Check the top of the stack for equality and pop the element if it's equal. +func (s *positionsStack) PopIfEqual(p position) bool { + l := len(s.positions) + if l == 0 { + return false + } + if s.positions[l-1] == p { + s.positions = s.positions[:l-1] + return true + } + return false } diff --git a/position_test.go b/position_test.go new file mode 100644 index 0000000..ded60b4 --- /dev/null +++ b/position_test.go @@ -0,0 +1,22 @@ +package merkle + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestPosition_isAncestorOf(t *testing.T) { + lower := position{ + index: 0, + height: 0, + } + + higher := position{ + index: 0, + height: 1, + } + + isAncestor := lower.isAncestorOf(higher) + + require.False(t, isAncestor) +} diff --git a/proving.go b/proving.go index b8cbe8f..642808f 100644 --- a/proving.go +++ b/proving.go @@ -2,33 +2,20 @@ package merkle import ( "errors" + "github.com/spacemeshos/merkle-tree/cache" "io" ) -const NodeSize = 32 - -type NodeReader interface { - Seek(index uint64) error - ReadNext() ([]byte, error) - Width() uint64 -} +var ErrMissingValueAtBaseLayer = errors.New("reader for base layer must be included") func GenerateProof( - provenLeafIndices []uint64, - readers map[uint]NodeReader, - hash HashFunc, -) ([][]byte, error) { - - var proof [][]byte + provenLeafIndices map[uint64]bool, + treeCache *cache.Reader, +) (sortedProvenLeafIndices []uint64, provenLeaves, proofNodes [][]byte, err error) { - provenLeafIndexIt := &positionsIterator{s: provenLeafIndices} + provenLeafIndexIt := newPositionsIterator(provenLeafIndices) skipPositions := &positionsStack{} - rootHeight := rootHeightFromWidth(readers[0].Width()) - - cache, err := NewTreeCache(readers, hash) - if err != nil { - return nil, err - } + rootHeight := cache.RootHeightFromWidth(treeCache.GetLayerReader(0).Width()) for { // Process proven leaves: @@ -40,18 +27,19 @@ func GenerateProof( } // Get indices for the bottom left corner of the subtree and its root, as well as the bottom layer's width. - currentPos, subtreeStart, width := cache.subtreeDefinition(nextProvenLeafPos) + currentPos, subtreeStart, width := subtreeDefinition(treeCache, nextProvenLeafPos) // Prepare list of leaves to prove in the subtree. leavesToProve := provenLeafIndexIt.batchPop(subtreeStart.index + width) - additionalProof, err := calcSubtreeProof(cache, hash, leavesToProve, subtreeStart, width) + additionalProof, additionalLeaves, err := calcSubtreeProof(treeCache, leavesToProve, subtreeStart, width) if err != nil { - return nil, err + return nil, nil, nil, err } - proof = append(proof, additionalProof...) + proofNodes = append(proofNodes, additionalProof...) + provenLeaves = append(provenLeaves, additionalLeaves...) - for ; currentPos.height < rootHeight; currentPos = currentPos.parent() { // Traverse cache: + for ; currentPos.height < rootHeight; currentPos = currentPos.parent() { // Traverse treeCache: // Check if we're revisiting a node. If we've descended into a subtree and just got back, we shouldn't add // the sibling to the proof and instead move on to the parent. @@ -67,49 +55,50 @@ func GenerateProof( break } - currentVal, err := cache.GetNode(currentPos.sibling()) + currentVal, err := GetNode(treeCache, currentPos.sibling()) if err != nil { - return nil, err + return nil, nil, nil, err } - proof = append(proof, currentVal) + proofNodes = append(proofNodes, currentVal) } } - return proof, nil + return set(provenLeafIndices).asSortedSlice(), provenLeaves, proofNodes, nil } -func calcSubtreeProof(cache *TreeCache, hash HashFunc, leavesToProve []uint64, subtreeStart position, width uint64) ( - [][]byte, error) { +func calcSubtreeProof(c *cache.Reader, leavesToProve set, subtreeStart position, width uint64) ( + additionalProof, additionalLeaves [][]byte, err error) { // By subtracting subtreeStart.index we get the index relative to the subtree. - relativeLeavesToProve := make([]uint64, len(leavesToProve)) - for i, leafIndex := range leavesToProve { - relativeLeavesToProve[i] = leafIndex - subtreeStart.index + relativeLeavesToProve := make(set) + for leafIndex, prove := range leavesToProve { + relativeLeavesToProve[leafIndex-subtreeStart.index] = prove } // Prepare leaf reader to read subtree leaves. - reader := cache.LeafReader() - err := reader.Seek(subtreeStart.index) + reader := c.GetLayerReader(0) + err = reader.Seek(subtreeStart.index) if err != nil { - return nil, errors.New("while preparing to traverse subtree: " + err.Error()) + return nil, nil, errors.New("while preparing to traverse subtree: " + err.Error()) } - _, additionalProof, err := traverseSubtree(reader, width, hash, relativeLeavesToProve, nil) + _, additionalProof, additionalLeaves, err = traverseSubtree(reader, width, c.GetHashFunc(), relativeLeavesToProve, nil) if err != nil { - return nil, errors.New("while traversing subtree: " + err.Error()) + return nil, nil, errors.New("while traversing subtree: " + err.Error()) } - return additionalProof, err + return additionalProof, additionalLeaves, err } -func traverseSubtree(leafReader NodeReader, width uint64, hash HashFunc, leavesToProve []uint64, - externalPadding []byte) (root []byte, proof [][]byte, err error) { +func traverseSubtree(leafReader cache.LayerReader, width uint64, hash HashFunc, leavesToProve set, + externalPadding []byte) (root []byte, proof, provenLeaves [][]byte, err error) { shouldUseExternalPadding := externalPadding != nil - t := NewTreeBuilder(hash). + t := NewTreeBuilder(). + WithHashFunc(hash). WithLeavesToProve(leavesToProve). - WithMinHeight(rootHeightFromWidth(width)). // This ensures the correct size tree, even if padding is needed. + WithMinHeight(cache.RootHeightFromWidth(width)). // This ensures the correct size tree, even if padding is needed. Build() for i := uint64(0); i < width; i++ { leaf, err := leafReader.ReadNext() @@ -121,34 +110,106 @@ func traverseSubtree(leafReader NodeReader, width uint64, hash HashFunc, leavesT leaf = externalPadding shouldUseExternalPadding = false } else if err != nil { - return nil, nil, errors.New("while reading a leaf: " + err.Error()) + return nil, nil, nil, errors.New("while reading a leaf: " + err.Error()) } err = t.AddLeaf(leaf) if err != nil { - return nil, nil, errors.New("while adding a leaf: " + err.Error()) + return nil, nil, nil, errors.New("while adding a leaf: " + err.Error()) + } + if leavesToProve[i] { + provenLeaves = append(provenLeaves, leaf) } } root, proof = t.RootAndProof() - return root, proof, nil + return root, proof, provenLeaves, nil } -type positionsStack struct { - positions []position +// GetNode reads the node at the requested position from the cache or calculates it if not available. +func GetNode(c *cache.Reader, nodePos position) ([]byte, error) { + // Get the cache reader for the requested node's layer. + reader := c.GetLayerReader(nodePos.height) + + // If the cache wasn't found, we calculate the minimal subtree that will get us the required node. + if reader == nil { + return calcNode(c, nodePos) + } + + err := reader.Seek(nodePos.index) + if err == io.EOF { + return calcNode(c, nodePos) + } + if err != nil { + return nil, errors.New("while seeking to position " + nodePos.String() + " in cache: " + err.Error()) + } + currentVal, err := reader.ReadNext() + if err != nil { + return nil, errors.New("while reading from cache: " + err.Error()) + } + return currentVal, nil } -func (s *positionsStack) Push(v position) { - s.positions = append(s.positions, v) +func calcNode(c *cache.Reader, nodePos position) ([]byte, error) { + var subtreeStart position + var reader cache.LayerReader + + if nodePos.height == 0 { + return nil, ErrMissingValueAtBaseLayer + } + + // Find the next cached layer below the current one. + for subtreeStart = nodePos; reader == nil; { + subtreeStart = subtreeStart.leftChild() + reader = c.GetLayerReader(subtreeStart.height) + } + + // Prepare the reader for traversing the subtree. + err := reader.Seek(subtreeStart.index) + if err == io.EOF { + return PaddingValue.value, nil + } + if err != nil { + return nil, errors.New("while seeking to position " + subtreeStart.String() + " in cache: " + err.Error()) + } + + var paddingValue []byte + width := uint64(1) << (nodePos.height - subtreeStart.height) + if reader.Width() < subtreeStart.index+width { + paddingPos := position{ + index: reader.Width(), + height: subtreeStart.height, + } + paddingValue, err = calcNode(c, paddingPos) + if err == ErrMissingValueAtBaseLayer { + paddingValue = PaddingValue.value + } else if err != nil { + return nil, errors.New("while calculating ephemeral node at position " + paddingPos.String() + ": " + err.Error()) + } + } + + // Traverse the subtree. + currentVal, _, _, err := traverseSubtree(reader, width, c.GetHashFunc(), nil, paddingValue) + if err != nil { + return nil, errors.New("while traversing subtree for root: " + err.Error()) + } + return currentVal, nil } -// Check the top of the stack for equality and pop the element if it's equal. -func (s *positionsStack) PopIfEqual(p position) bool { - l := len(s.positions) - if l == 0 { - return false +// subtreeDefinition returns the definition (firstLeaf and root positions, width) for the minimal subtree whose +// base layer includes p and where the root is on a cached layer. If no cached layer exists above the base layer, the +// subtree will reach the root of the original tree. +func subtreeDefinition(c *cache.Reader, p position) (root, firstLeaf position, width uint64) { + // maxRootHeight represents the max height of the tree, based on the width of base layer. This is used to prevent an + // infinite loop. + maxRootHeight := cache.RootHeightFromWidth(c.GetLayerReader(p.height).Width()) + for root = p.parent(); root.height < maxRootHeight; root = root.parent() { + if layer := c.GetLayerReader(root.height); layer != nil { + break + } } - if s.positions[l-1] == p { - s.positions = s.positions[:l-1] - return true + subtreeHeight := root.height - p.height + firstLeaf = position{ + index: root.index << subtreeHeight, + height: p.height, } - return false + return root, firstLeaf, 1 << subtreeHeight } diff --git a/proving_test.go b/proving_test.go index 5777019..868dbec 100644 --- a/proving_test.go +++ b/proving_test.go @@ -2,6 +2,8 @@ package merkle import ( "encoding/hex" + "errors" + "github.com/spacemeshos/merkle-tree/cache" "github.com/stretchr/testify/require" "io" "testing" @@ -23,14 +25,13 @@ import ( func TestGenerateProof(t *testing.T) { r := require.New(t) - sliceReadWriters := make(map[uint]*sliceReadWriter) - sliceReadWriters[0] = &sliceReadWriter{} - sliceReadWriters[1] = &sliceReadWriter{} - sliceReadWriters[2] = &sliceReadWriter{} - leavesToProve := []uint64{0, 4, 7} - - tree := NewTreeBuilder(GetSha256Parent). - WithCache(WritersFromSliceReadWriters(sliceReadWriters)). + + leavesToProve := setOf(0, 4, 7) + + cacheWriter := cache.NewWriter(cache.MinHeightPolicy(0), cache.MakeSliceReadWriterFactory()) + + tree := NewTreeBuilder(). + WithCacheWriter(cacheWriter). WithLeavesToProve(leavesToProve). Build() for i := uint64(0); i < 8; i++ { @@ -41,34 +42,46 @@ func TestGenerateProof(t *testing.T) { root := tree.Root() r.Equal(expectedRoot, root) - r.Len(sliceReadWriters[0].slice, 8) - r.Len(sliceReadWriters[1].slice, 4) - r.Len(sliceReadWriters[2].slice, 2) + cacheReader, err := cacheWriter.GetReader() + r.NoError(err) - var proof, expectedProof nodes - var err error - proof, err = GenerateProof(leavesToProve, NodeReadersFromSliceReadWriters(sliceReadWriters), GetSha256Parent) + r.Equal(uint64(8), cacheReader.GetLayerReader(0).Width()) + r.Equal(uint64(4), cacheReader.GetLayerReader(1).Width()) + r.Equal(uint64(2), cacheReader.GetLayerReader(2).Width()) + + var leaves, proof, expectedProof nodes + sortedIndices, leaves, proof, err := GenerateProof(leavesToProve, cacheReader) r.NoError(err) expectedProof = tree.Proof() - r.EqualValues(expectedProof, proof, "actual") + r.EqualValues(expectedProof, proof) + + var expectedLeaves nodes + for _, i := range leavesToProve.asSortedSlice() { + expectedLeaves = append(expectedLeaves, NewNodeFromUint64(i)) + } + r.EqualValues(expectedLeaves, leaves) + r.EqualValues([]uint64{0, 4, 7}, sortedIndices) } func BenchmarkGenerateProof(b *testing.B) { const treeHeight = 23 r := require.New(b) - sliceReadWriters := make(map[uint]*sliceReadWriter) - sliceReadWriters[0] = &sliceReadWriter{} - for i := 7; i < treeHeight; i++ { - sliceReadWriters[uint(i)] = &sliceReadWriter{} - } - var leavesToProve []uint64 + + leavesToProve := make(set) + + cacheWriter := cache.NewWriter( + cache.Combine( + cache.MinHeightPolicy(7), + cache.SpecificLayersPolicy(map[uint]bool{0: true})), + cache.MakeSliceReadWriterFactory()) + for i := 0; i < 20; i++ { - leavesToProve = append(leavesToProve, uint64(i)*400000) + leavesToProve[uint64(i)*400000] = true } - tree := NewTreeBuilder(GetSha256Parent). - WithCache(WritersFromSliceReadWriters(sliceReadWriters)). + tree := NewTreeBuilder(). + WithCacheWriter(cacheWriter). WithLeavesToProve(leavesToProve). Build() for i := uint64(0); i < 1< in cache: some error", err.Error()) + r.Nil(node) + } -func NodeReadersFromSliceReadWriters(sliceReadWriters map[uint]*sliceReadWriter) map[uint]NodeReader { - nodeReaders := make(map[uint]NodeReader, len(sliceReadWriters)) - for k, v := range sliceReadWriters { - nodeReaders[k] = v - } - return nodeReaders +func TestGetNode2(t *testing.T) { + r := require.New(t) + cacheWriter := cache.NewWriter(nil, nil) + cacheWriter.SetLayer(0, readErrorReader{}) + + cacheReader, err := cacheWriter.GetReader() + r.NoError(err) + nodePos := position{} + node, err := GetNode(cacheReader, nodePos) + + r.Error(err) + r.Equal("while reading from cache: some error", err.Error()) + r.Nil(node) +} + +func TestGetNode3(t *testing.T) { + r := require.New(t) + cacheWriter := cache.NewWriter(nil, nil) + cacheWriter.SetLayer(0, seekErrorReader{}) + cacheWriter.SetLayer(1, seekEOFReader{}) + + cacheReader, err := cacheWriter.GetReader() + r.NoError(err) + nodePos := position{height: 1} + node, err := GetNode(cacheReader, nodePos) + + r.Error(err) + r.Equal("while seeking to position in cache: some error", err.Error()) + r.Nil(node) +} + +func TestGetNode4(t *testing.T) { + r := require.New(t) + cacheWriter := cache.NewWriter(nil, nil) + cacheWriter.SetLayer(0, seekErrorReader{}) + cacheWriter.SetLayer(1, widthReader{width: 1}) + cacheWriter.SetLayer(2, seekEOFReader{}) + + cacheReader, err := cacheWriter.GetReader() + r.NoError(err) + nodePos := position{height: 2} + node, err := GetNode(cacheReader, nodePos) + + r.Error(err) + r.Equal("while calculating ephemeral node at position : while seeking to position in cache: some error", err.Error()) + r.Nil(node) +} + +func TestGetNode5(t *testing.T) { + r := require.New(t) + cacheWriter := cache.NewWriter(nil, nil) + cacheWriter.SetLayer(0, widthReader{width: 2}) + cacheWriter.SetLayer(1, seekEOFReader{}) + + cacheReader, err := cacheWriter.GetReader() + r.NoError(err) + nodePos := position{height: 1} + node, err := GetNode(cacheReader, nodePos) + + r.Error(err) + r.Equal("while traversing subtree for root: while reading a leaf: some error", err.Error()) + r.Nil(node) +} + +func TestCache_ValidateStructure(t *testing.T) { + r := require.New(t) + cacheWriter := cache.NewWriter(nil, nil) + cacheReader, err := cacheWriter.GetReader() + + r.Error(err) + r.Equal("reader for base layer must be included", err.Error()) + r.Nil(cacheReader) } diff --git a/treebuilder.go b/treebuilder.go new file mode 100644 index 0000000..6a69cd8 --- /dev/null +++ b/treebuilder.go @@ -0,0 +1,63 @@ +package merkle + +import "github.com/spacemeshos/merkle-tree/cache" + +type TreeBuilder struct { + hash HashFunc + leavesToProves set + cacheWriter *cache.Writer + minHeight uint +} + +func NewTreeBuilder() TreeBuilder { + return TreeBuilder{} +} + +func (tb TreeBuilder) Build() *Tree { + if tb.hash == nil { + tb.hash = GetSha256Parent + } + if tb.cacheWriter == nil { + tb.cacheWriter = cache.NewWriter(cache.SpecificLayersPolicy(map[uint]bool{}), nil) + } + tb.cacheWriter.SetHash(tb.hash) + return &Tree{ + baseLayer: newLayer(0, tb.cacheWriter.GetLayerWriter(0)), + hash: tb.hash, + leavesToProve: newSparseBoolStack(tb.leavesToProves), + cacheWriter: tb.cacheWriter, + minHeight: tb.minHeight, + } +} + +func (tb TreeBuilder) WithHashFunc(hash HashFunc) TreeBuilder { + tb.hash = hash + return tb +} + +func (tb TreeBuilder) WithLeavesToProve(leavesToProves map[uint64]bool) TreeBuilder { + tb.leavesToProves = leavesToProves + return tb +} + +func (tb TreeBuilder) WithCacheWriter(cacheWriter *cache.Writer) TreeBuilder { + tb.cacheWriter = cacheWriter + return tb +} + +func (tb TreeBuilder) WithMinHeight(minHeight uint) TreeBuilder { + tb.minHeight = minHeight + return tb +} + +func NewTree() *Tree { + return NewTreeBuilder().Build() +} + +func NewProvingTree(leavesToProves map[uint64]bool) *Tree { + return NewTreeBuilder().WithLeavesToProve(leavesToProves).Build() +} + +func NewCachingTree(cacheWriter *cache.Writer) *Tree { + return NewTreeBuilder().WithCacheWriter(cacheWriter).Build() +} diff --git a/treecache.go b/treecache.go deleted file mode 100644 index 27b4f57..0000000 --- a/treecache.go +++ /dev/null @@ -1,137 +0,0 @@ -package merkle - -import ( - "errors" - "fmt" - "io" -) - -var ErrMissingValueAtBaseLayer = errors.New("missing value at base layer, returned PaddingValue") - -type TreeCache struct { - readers map[uint]NodeReader - hash HashFunc -} - -func NewTreeCache(readers map[uint]NodeReader, hash HashFunc) (*TreeCache, error) { - err := validateCacheStructure(readers) - if err != nil { - return nil, err - } - - return &TreeCache{ - readers: readers, - hash: hash, - }, nil -} - -func validateCacheStructure(readers map[uint]NodeReader) error { - // Verify we got the base layer. - if _, found := readers[0]; !found { - return errors.New("reader for base layer must be included") - } - width := readers[0].Width() - height := rootHeightFromWidth(width) - for i := uint(0); i < height; i++ { - if _, found := readers[i]; found && readers[i].Width() != width { - return fmt.Errorf("reader at layer %d has width %d instead of %d", i, readers[i].Width(), width) - } - width >>= 1 - } - return nil -} - -// GetNode reads the node at the requested position from the cache or calculates it if not available. -func (c *TreeCache) GetNode(nodePos position) ([]byte, error) { - // Get the cache reader for the requested node's layer. - reader, found := c.readers[nodePos.height] - - // If the cache wasn't found, we calculate the minimal subtree that will get us the required node. - if !found { - return c.calcNode(nodePos) - } - - err := reader.Seek(nodePos.index) - if err == io.EOF { - return c.calcNode(nodePos) - } - if err != nil { - return nil, errors.New("while seeking to position " + nodePos.String() + " in cache: " + err.Error()) - } - currentVal, err := reader.ReadNext() - if err != nil { - return nil, errors.New("while reading from cache: " + err.Error()) - } - return currentVal, nil -} - -func (c *TreeCache) calcNode(nodePos position) ([]byte, error) { - var subtreeStart position - var found bool - var reader NodeReader - - if nodePos.height == 0 { - return nil, ErrMissingValueAtBaseLayer - } - - // Find the next cached layer below the current one. - for subtreeStart = nodePos; !found; { - subtreeStart = subtreeStart.leftChild() - reader, found = c.readers[subtreeStart.height] - } - - // Prepare the reader for traversing the subtree. - err := reader.Seek(subtreeStart.index) - if err == io.EOF { - return PaddingValue.value, nil - } - if err != nil { - return nil, errors.New("while seeking to position " + subtreeStart.String() + " in cache: " + err.Error()) - } - - var paddingValue []byte - width := uint64(1) << (nodePos.height - subtreeStart.height) - if reader.Width() < subtreeStart.index+width { - paddingPos := position{ - index: reader.Width(), - height: subtreeStart.height, - } - paddingValue, err = c.calcNode(paddingPos) - if err == ErrMissingValueAtBaseLayer { - paddingValue = PaddingValue.value - } else if err != nil { - return nil, errors.New("while calculating ephemeral node at position " + paddingPos.String() + ": " + err.Error()) - } - } - - // Traverse the subtree. - currentVal, _, err := traverseSubtree(reader, width, c.hash, nil, paddingValue) - if err != nil { - return nil, errors.New("while traversing subtree for root: " + err.Error()) - } - return currentVal, nil -} - -// subtreeDefinition returns the definition (firstLeaf and root positions, width) for the minimal subtree whose -// base layer includes p and where the root is on a cached layer. If no cached layer exists above the base layer, the -// subtree will reach the root of the original tree. -func (c *TreeCache) subtreeDefinition(p position) (root, firstLeaf position, width uint64) { - // maxRootHeight represents the max height of the tree, based on the width of base layer. This is used to prevent an - // infinite loop. - maxRootHeight := rootHeightFromWidth(c.readers[p.height].Width()) - for root = p.parent(); root.height < maxRootHeight; root = root.parent() { - if _, found := c.readers[root.height]; found { - break - } - } - subtreeHeight := root.height - p.height - firstLeaf = position{ - index: root.index << subtreeHeight, - height: p.height, - } - return root, firstLeaf, 1 << subtreeHeight -} - -func (c *TreeCache) LeafReader() NodeReader { - return c.readers[0] -} diff --git a/treecache_test.go b/treecache_test.go deleted file mode 100644 index 1e5c8e6..0000000 --- a/treecache_test.go +++ /dev/null @@ -1,212 +0,0 @@ -package merkle - -import ( - "errors" - "github.com/stretchr/testify/require" - "io" - "testing" -) - -func TestNewTreeCache(t *testing.T) { - -} - -var someError = errors.New("some error") - -type seekErrorReader struct{} - -func (seekErrorReader) Seek(index uint64) error { return someError } -func (seekErrorReader) ReadNext() ([]byte, error) { panic("implement me") } -func (seekErrorReader) Width() uint64 { return 3 } - -type readErrorReader struct{} - -func (readErrorReader) Seek(index uint64) error { return nil } -func (readErrorReader) ReadNext() ([]byte, error) { return nil, someError } -func (readErrorReader) Width() uint64 { return 8 } - -type seekEOFReader struct{} - -func (seekEOFReader) Seek(index uint64) error { return io.EOF } -func (seekEOFReader) ReadNext() ([]byte, error) { panic("implement me") } -func (seekEOFReader) Width() uint64 { return 1 } - -type widthReader struct{ width uint64 } - -func (r widthReader) Seek(index uint64) error { return nil } -func (r widthReader) ReadNext() ([]byte, error) { return nil, someError } -func (r widthReader) Width() uint64 { return r.width } - -func TestNewTreeCacheErrors(t *testing.T) { - r := require.New(t) - var readers map[uint]NodeReader - - treeCache, err := NewTreeCache(readers, GetSha256Parent) - - r.Error(err) - r.Equal("reader for base layer must be included", err.Error()) - r.Nil(treeCache) -} - -func TestNewTreeCache2(t *testing.T) { - r := require.New(t) - readers := make(map[uint]NodeReader) - - treeCache, err := NewTreeCache(readers, GetSha256Parent) - - r.Error(err) - r.Equal("reader for base layer must be included", err.Error()) - r.Nil(treeCache) -} - -func TestNewTreeCache3(t *testing.T) { - r := require.New(t) - readers := make(map[uint]NodeReader) - - readers[0] = seekErrorReader{} - treeCache, err := NewTreeCache(readers, GetSha256Parent) - - r.NoError(err) - r.NotNil(treeCache) - - nodePos := position{} - node, err := treeCache.GetNode(nodePos) - - r.Error(err) - r.Equal("while seeking to position in cache: some error", err.Error()) - r.Nil(node) - -} - -func TestNewTreeCache4(t *testing.T) { - r := require.New(t) - readers := make(map[uint]NodeReader) - - readers[0] = readErrorReader{} - treeCache, err := NewTreeCache(readers, GetSha256Parent) - - r.NoError(err) - r.NotNil(treeCache) - - nodePos := position{} - node, err := treeCache.GetNode(nodePos) - - r.Error(err) - r.Equal("while reading from cache: some error", err.Error()) - r.Nil(node) -} - -func TestNewTreeCache5(t *testing.T) { - r := require.New(t) - readers := make(map[uint]NodeReader) - - readers[0] = seekErrorReader{} - readers[1] = seekEOFReader{} - treeCache, err := NewTreeCache(readers, GetSha256Parent) - - r.NoError(err) - r.NotNil(treeCache) - - nodePos := position{height: 1} - node, err := treeCache.GetNode(nodePos) - - r.Error(err) - r.Equal("while seeking to position in cache: some error", err.Error()) - r.Nil(node) -} - -func TestNewTreeCache6(t *testing.T) { - r := require.New(t) - readers := make(map[uint]NodeReader) - - readers[0] = seekErrorReader{} - readers[1] = widthReader{width: 1} - readers[2] = seekEOFReader{} - treeCache, err := NewTreeCache(readers, GetSha256Parent) - - r.NoError(err) - r.NotNil(treeCache) - - nodePos := position{height: 2} - node, err := treeCache.GetNode(nodePos) - - r.Error(err) - r.Equal("while calculating ephemeral node at position : while seeking to position in cache: some error", err.Error()) - r.Nil(node) -} - -func TestNewTreeCache7(t *testing.T) { - r := require.New(t) - readers := make(map[uint]NodeReader) - - readers[0] = widthReader{width: 2} - readers[1] = seekEOFReader{} - treeCache, err := NewTreeCache(readers, GetSha256Parent) - - r.NoError(err) - r.NotNil(treeCache) - - nodePos := position{height: 1} - node, err := treeCache.GetNode(nodePos) - - r.Error(err) - r.Equal("while traversing subtree for root: while reading a leaf: some error", err.Error()) - r.Nil(node) -} - -func TestNewTreeCacheStructureSuccess(t *testing.T) { - r := require.New(t) - readers := make(map[uint]NodeReader) - - readers[0] = widthReader{width: 4} - readers[1] = widthReader{width: 2} - readers[2] = widthReader{width: 1} - treeCache, err := NewTreeCache(readers, GetSha256Parent) - - r.NoError(err) - r.NotNil(treeCache) -} - -func TestNewTreeCacheStructureFail(t *testing.T) { - r := require.New(t) - readers := make(map[uint]NodeReader) - - readers[0] = widthReader{width: 3} - readers[1] = widthReader{width: 2} - readers[2] = widthReader{width: 1} - treeCache, err := NewTreeCache(readers, GetSha256Parent) - - r.Error(err) - r.Equal("reader at layer 1 has width 2 instead of 1", err.Error()) - r.Nil(treeCache) -} - -func TestNewTreeCacheStructureFail2(t *testing.T) { - r := require.New(t) - readers := make(map[uint]NodeReader) - - readers[0] = widthReader{width: 4} - readers[1] = widthReader{width: 1} - readers[2] = widthReader{width: 1} - treeCache, err := NewTreeCache(readers, GetSha256Parent) - - r.Error(err) - r.Equal("reader at layer 1 has width 1 instead of 2", err.Error()) - r.Nil(treeCache) -} - -func TestPosition_isAncestorOf(t *testing.T) { - lower := position{ - index: 0, - height: 0, - } - - higher := position{ - index: 0, - height: 1, - } - - isAncestor := lower.isAncestorOf(higher) - - require.False(t, isAncestor) -} diff --git a/validation.go b/validation.go index 884be9d..4f5bd50 100644 --- a/validation.go +++ b/validation.go @@ -2,7 +2,9 @@ package merkle import ( "bytes" + "errors" "fmt" + "sort" ) const MaxUint = ^uint(0) @@ -19,18 +21,24 @@ func ValidatePartialTree(leafIndices []uint64, leaves, proof [][]byte, expectedR return bytes.Equal(root, expectedRoot), err } -func newValidator(leafIndices []uint64, leaves, proof [][]byte, hash HashFunc) (validator, error) { +func newValidator(leafIndices []uint64, leaves, proof [][]byte, hash HashFunc) (*validator, error) { if len(leafIndices) != len(leaves) { - return validator{}, fmt.Errorf("number of leaves (%d) must equal number of indices (%d)", len(leaves), + return nil, fmt.Errorf("number of leaves (%d) must equal number of indices (%d)", len(leaves), len(leafIndices)) } if len(leaves) == 0 { - return validator{}, fmt.Errorf("at least one leaf is required for validation") + return nil, errors.New("at least one leaf is required for validation") + } + if !sort.SliceIsSorted(leafIndices, func(i, j int) bool { return leafIndices[i] < leafIndices[j] }) { + return nil, errors.New("leafIndices are not sorted") + } + if len(setOf(leafIndices...)) != len(leafIndices) { + return nil, errors.New("leafIndices contain duplicates") } proofNodes := &proofIterator{proof} leafIt := &leafIterator{leafIndices, leaves} - return validator{leaves: leafIt, proofNodes: proofNodes, hash: hash}, nil + return &validator{leaves: leafIt, proofNodes: proofNodes, hash: hash}, nil } type validator struct { diff --git a/validation_test.go b/validation_test.go index 9bc11f6..1c33969 100644 --- a/validation_test.go +++ b/validation_test.go @@ -26,7 +26,7 @@ func TestValidatePartialTreeForRealz(t *testing.T) { leafIndices := []uint64{4} leaves := [][]byte{NewNodeFromUint64(4)} - tree := NewProvingTree(GetSha256Parent, leafIndices) + tree := NewProvingTree(setOf(leafIndices...)) for i := uint64(0); i < 8; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) req.NoError(err) @@ -53,7 +53,7 @@ func TestValidatePartialTreeMulti(t *testing.T) { NewNodeFromUint64(1), NewNodeFromUint64(4), } - tree := NewProvingTree(GetSha256Parent, leafIndices) + tree := NewProvingTree(setOf(leafIndices...)) for i := uint64(0); i < 8; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) req.NoError(err) @@ -81,7 +81,7 @@ func TestValidatePartialTreeMulti2(t *testing.T) { NewNodeFromUint64(1), NewNodeFromUint64(4), } - tree := NewProvingTree(GetSha256Parent, leafIndices) + tree := NewProvingTree(setOf(leafIndices...)) for i := uint64(0); i < 8; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) req.NoError(err) @@ -109,7 +109,7 @@ func TestValidatePartialTreeMultiUnbalanced(t *testing.T) { NewNodeFromUint64(4), NewNodeFromUint64(7), } - tree := NewProvingTree(GetSha256Parent, leafIndices) + tree := NewProvingTree(setOf(leafIndices...)) for i := uint64(0); i < 10; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) req.NoError(err) @@ -140,7 +140,7 @@ func TestValidatePartialTreeMultiUnbalanced2(t *testing.T) { NewNodeFromUint64(7), NewNodeFromUint64(9), } - tree := NewProvingTree(GetSha256Parent, leafIndices) + tree := NewProvingTree(setOf(leafIndices...)) for i := uint64(0); i < 10; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) req.NoError(err) @@ -168,7 +168,7 @@ func TestValidatePartialTreeUnbalanced(t *testing.T) { leaves := [][]byte{ NewNodeFromUint64(9), } - tree := NewProvingTree(GetSha256Parent, leafIndices) + tree := NewProvingTree(setOf(leafIndices...)) for i := uint64(0); i < 10; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) req.NoError(err) @@ -197,7 +197,7 @@ func BenchmarkValidatePartialTree(b *testing.B) { for _, i := range leafIndices { leaves = append(leaves, NewNodeFromUint64(i)) } - tree := NewProvingTree(GetSha256Parent, leafIndices) + tree := NewProvingTree(setOf(leafIndices...)) for i := uint64(0); i < 1<<23; i++ { err := tree.AddLeaf(NewNodeFromUint64(i)) req.NoError(err) @@ -231,12 +231,24 @@ func TestValidatePartialTreeErrors(t *testing.T) { } root, _ := NewNodeFromHex("2657509b700c67b205c5196ee9a231e0fe567f1dae4a15bb52c0de813d65677a") valid, err := ValidatePartialTree(leafIndices, leaves, proof, root, GetSha256Parent) - req.Error(err) - req.False(valid, "Proof should be valid, but isn't") + req.EqualError(err, "number of leaves (1) must equal number of indices (2)") + req.False(valid) valid, err = ValidatePartialTree([]uint64{}, [][]byte{}, proof, root, GetSha256Parent) - req.Error(err) - req.False(valid, "Proof should be valid, but isn't") + req.EqualError(err, "at least one leaf is required for validation") + req.False(valid) + + leafIndices = []uint64{5, 3} + leaves = [][]byte{NewNodeFromUint64(5), NewNodeFromUint64(3)} + valid, err = ValidatePartialTree(leafIndices, leaves, proof, root, GetSha256Parent) + req.EqualError(err, "leafIndices are not sorted") + req.False(valid) + + leafIndices = []uint64{3, 3} + leaves = [][]byte{NewNodeFromUint64(5), NewNodeFromUint64(3)} + valid, err = ValidatePartialTree(leafIndices, leaves, proof, root, GetSha256Parent) + req.EqualError(err, "leafIndices contain duplicates") + req.False(valid) } func TestValidator_calcRoot(t *testing.T) {