Navigation Menu

Skip to content

Commit

Permalink
extract common cache with usability improvements (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
noamnelke committed Mar 27, 2019
1 parent 1a334a4 commit 3651a54
Show file tree
Hide file tree
Showing 18 changed files with 1,082 additions and 760 deletions.
142 changes: 142 additions & 0 deletions cache/cache.go
@@ -0,0 +1,142 @@
package cache

import (
"errors"
"fmt"
"math"
)

const NodeSize = 32

type Writer struct {
*cache
}

func NewWriter(shouldCacheLayer CachingPolicy, generateLayer LayerFactory) *Writer {
return &Writer{
cache: &cache{
layers: make(map[uint]LayerReadWriter),
generateLayer: generateLayer,
shouldCacheLayer: shouldCacheLayer,
},
}
}

func (c *Writer) SetLayer(layerHeight uint, rw LayerReadWriter) {
c.layers[layerHeight] = rw
}

func (c *Writer) GetLayerWriter(layerHeight uint) LayerWriter {
layerReadWriter, found := c.layers[layerHeight]
if !found && c.shouldCacheLayer(layerHeight) {
layerReadWriter = c.generateLayer(layerHeight)
c.layers[layerHeight] = layerReadWriter
}
return layerReadWriter
}

func (c *Writer) SetHash(hashFunc func(lChild, rChild []byte) []byte) {
c.hash = hashFunc
}

func (c *Writer) GetReader() (*Reader, error) {
err := c.validateStructure()
if err != nil {
return nil, err
}
return &Reader{c.cache}, nil
}

type Reader struct {
*cache
}

func (c *Reader) GetLayerReader(layerHeight uint) LayerReader {
return c.layers[layerHeight]
}

func (c *Reader) GetHashFunc() func(lChild, rChild []byte) []byte {
return c.hash
}

type cache struct {
layers map[uint]LayerReadWriter
hash func(lChild, rChild []byte) []byte
shouldCacheLayer CachingPolicy
generateLayer LayerFactory
}

func (c *cache) validateStructure() error {
// Verify we got the base layer.
if _, found := c.layers[0]; !found {
return errors.New("reader for base layer must be included")
}
width := c.layers[0].Width()
if width == 0 {
return errors.New("base layer cannot be empty")
}
height := RootHeightFromWidth(width)
for i := uint(0); i < height; i++ {
if _, found := c.layers[i]; found && c.layers[i].Width() != width {
return fmt.Errorf("reader at layer %d has width %d instead of %d", i, c.layers[i].Width(), width)
}
width >>= 1
}
return nil
}

type CachingPolicy func(layerHeight uint) (shouldCacheLayer bool)

type LayerFactory func(layerHeight uint) LayerReadWriter

// LayerReadWriter is a combined reader-writer. Note that the Seek() method only belongs to the LayerReader interface
// and does not affect the LayerWriter.
type LayerReadWriter interface {
LayerReader
LayerWriter
}

type LayerReader interface {
Seek(index uint64) error
ReadNext() ([]byte, error)
Width() uint64
}

type LayerWriter interface {
Append(p []byte) (n int, err error)
}

func RootHeightFromWidth(width uint64) uint {
return uint(math.Ceil(math.Log2(float64(width))))
}

//func (c *cache) Print(bottom, top int) {
// for i := top; i >= bottom; i-- {
// print("| ")
// sliceReadWriter, ok := c.layers[uint(i)].(*SliceReadWriter)
// if !ok {
// println("-- layer is not a SliceReadWriter --")
// continue
// }
// for _, n := range sliceReadWriter.slice {
// printSpaces(numSpaces(i))
// fmt.Print(hex.EncodeToString(n[:2]))
// printSpaces(numSpaces(i))
// }
// println(" |")
// }
//}
//
//func numSpaces(n int) int {
// res := 1
// for i := 0; i < n; i++ {
// res += 3 * (1 << uint(i))
// }
// return res
//}
//
//func printSpaces(n int) {
// for i := 0; i < n; i++ {
// print(" ")
// }
//}
79 changes: 79 additions & 0 deletions cache/cache_test.go
@@ -0,0 +1,79 @@
package cache

import (
"errors"
"github.com/stretchr/testify/require"
"testing"
)

var someError = errors.New("some error")

type widthReader struct{ width uint64 }

func (r widthReader) Seek(index uint64) error { return nil }
func (r widthReader) ReadNext() ([]byte, error) { return nil, someError }
func (r widthReader) Width() uint64 { return r.width }
func (r widthReader) Append(p []byte) (n int, err error) { panic("implement me") }

func TestCache_ValidateStructure(t *testing.T) {
r := require.New(t)
var readers map[uint]LayerReadWriter

treeCache := &cache{layers: readers}
err := treeCache.validateStructure()

r.Error(err)
r.Equal("reader for base layer must be included", err.Error())
}

func TestCache_ValidateStructure2(t *testing.T) {
r := require.New(t)
readers := make(map[uint]LayerReadWriter)

treeCache := &cache{layers: readers}
err := treeCache.validateStructure()

r.Error(err)
r.Equal("reader for base layer must be included", err.Error())
}

func TestCache_ValidateStructureSuccess(t *testing.T) {
r := require.New(t)
readers := make(map[uint]LayerReadWriter)

readers[0] = widthReader{width: 4}
readers[1] = widthReader{width: 2}
readers[2] = widthReader{width: 1}
treeCache := &cache{layers: readers}
err := treeCache.validateStructure()

r.NoError(err)
}

func TestCache_ValidateStructureFail(t *testing.T) {
r := require.New(t)
readers := make(map[uint]LayerReadWriter)

readers[0] = widthReader{width: 3}
readers[1] = widthReader{width: 2}
readers[2] = widthReader{width: 1}
treeCache := &cache{layers: readers}
err := treeCache.validateStructure()

r.Error(err)
r.Equal("reader at layer 1 has width 2 instead of 1", err.Error())
}

func TestCache_ValidateStructureFail2(t *testing.T) {
r := require.New(t)
readers := make(map[uint]LayerReadWriter)

readers[0] = widthReader{width: 4}
readers[1] = widthReader{width: 1}
readers[2] = widthReader{width: 1}
treeCache := &cache{layers: readers}
err := treeCache.validateStructure()

r.Error(err)
r.Equal("reader at layer 1 has width 1 instead of 2", err.Error())
}
19 changes: 19 additions & 0 deletions cache/cachingpolicies.go
@@ -0,0 +1,19 @@
package cache

func MinHeightPolicy(minHeight uint) CachingPolicy {
return func(layerHeight uint) (shouldCacheLayer bool) {
return layerHeight >= minHeight
}
}

func SpecificLayersPolicy(layersToCache map[uint]bool) CachingPolicy {
return func(layerHeight uint) (shouldCacheLayer bool) {
return layersToCache[layerHeight]
}
}

func Combine(first, second CachingPolicy) CachingPolicy {
return func(layerHeight uint) (shouldCacheLayer bool) {
return first(layerHeight) || second(layerHeight)
}
}
103 changes: 103 additions & 0 deletions cache/cachingpolicies_test.go
@@ -0,0 +1,103 @@
package cache

import (
"github.com/stretchr/testify/require"
"testing"
)

func TestMakeMemoryReadWriterFactory(t *testing.T) {
r := require.New(t)
cacheWriter := NewWriter(MinHeightPolicy(2), MakeSliceReadWriterFactory())
cacheWriter.SetLayer(0, widthReader{width: 1})

cacheReader, err := cacheWriter.GetReader()
r.NoError(err)

reader := cacheReader.GetLayerReader(1)
r.Nil(reader)
reader = cacheReader.GetLayerReader(2)
r.Nil(reader)
reader = cacheReader.GetLayerReader(3)
r.Nil(reader)

writer := cacheWriter.GetLayerWriter(1)
r.Nil(writer)
writer = cacheWriter.GetLayerWriter(2)
r.NotNil(writer)
writer = cacheWriter.GetLayerWriter(3)
r.NotNil(writer)

cacheReader, err = cacheWriter.GetReader()
r.NoError(err)

reader = cacheReader.GetLayerReader(1)
r.Nil(reader)
reader = cacheReader.GetLayerReader(2)
r.NotNil(reader)
reader = cacheReader.GetLayerReader(3)
r.NotNil(reader)
}

func TestMakeMemoryReadWriterFactoryForLayers(t *testing.T) {
r := require.New(t)
cacheWriter := NewWriter(SpecificLayersPolicy(map[uint]bool{1: true, 3: true}), MakeSliceReadWriterFactory())
cacheWriter.SetLayer(0, widthReader{width: 1})

cacheReader, err := cacheWriter.GetReader()
r.NoError(err)

reader := cacheReader.GetLayerReader(1)
r.Nil(reader)
reader = cacheReader.GetLayerReader(2)
r.Nil(reader)
reader = cacheReader.GetLayerReader(3)
r.Nil(reader)

writer := cacheWriter.GetLayerWriter(1)
r.NotNil(writer)
writer = cacheWriter.GetLayerWriter(2)
r.Nil(writer)
writer = cacheWriter.GetLayerWriter(3)
r.NotNil(writer)

cacheReader, err = cacheWriter.GetReader()
r.NoError(err)

reader = cacheReader.GetLayerReader(1)
r.NotNil(reader)
reader = cacheReader.GetLayerReader(2)
r.Nil(reader)
reader = cacheReader.GetLayerReader(3)
r.NotNil(reader)
}

func TestMakeSpecificLayerFactory(t *testing.T) {
r := require.New(t)
readWriter := &SliceReadWriter{}
cacheWriter := NewWriter(
SpecificLayersPolicy(map[uint]bool{1: true}),
MakeSpecificLayersFactory(map[uint]LayerReadWriter{1: readWriter}),
)
cacheWriter.SetLayer(0, widthReader{width: 1})

cacheReader, err := cacheWriter.GetReader()
r.NoError(err)

reader := cacheReader.GetLayerReader(1)
r.Nil(reader)
reader = cacheReader.GetLayerReader(2)
r.Nil(reader)

writer := cacheWriter.GetLayerWriter(1)
r.Equal(readWriter, writer)
writer = cacheWriter.GetLayerWriter(2)
r.Nil(writer)

cacheReader, err = cacheWriter.GetReader()
r.NoError(err)

reader = cacheReader.GetLayerReader(1)
r.Equal(readWriter, reader)
reader = cacheReader.GetLayerReader(2)
r.Nil(reader)
}
13 changes: 13 additions & 0 deletions cache/layerfactories.go
@@ -0,0 +1,13 @@
package cache

func MakeSliceReadWriterFactory() LayerFactory {
return func(layerHeight uint) LayerReadWriter {
return &SliceReadWriter{}
}
}

func MakeSpecificLayersFactory(readWriters map[uint]LayerReadWriter) LayerFactory {
return func(layerHeight uint) LayerReadWriter {
return readWriters[layerHeight]
}
}
35 changes: 35 additions & 0 deletions cache/slice.go
@@ -0,0 +1,35 @@
package cache

import "io"

type SliceReadWriter struct {
slice [][]byte
position uint64
}

func (s *SliceReadWriter) Width() uint64 {
return uint64(len(s.slice))
}

func (s *SliceReadWriter) Seek(index uint64) error {
if index >= uint64(len(s.slice)) {
return io.EOF
}
s.position = index
return nil
}

func (s *SliceReadWriter) ReadNext() ([]byte, error) {
if s.position >= uint64(len(s.slice)) {
return nil, io.EOF
}
value := make([]byte, NodeSize)
copy(value, s.slice[s.position])
s.position++
return value, nil
}

func (s *SliceReadWriter) Append(p []byte) (n int, err error) {
s.slice = append(s.slice, p)
return len(p), nil
}

0 comments on commit 3651a54

Please sign in to comment.