diff --git a/store/cache/accessor_cache.go b/store/cache/accessor_cache.go new file mode 100644 index 0000000000..90bda41db6 --- /dev/null +++ b/store/cache/accessor_cache.go @@ -0,0 +1,236 @@ +package cache + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + lru "github.com/hashicorp/golang-lru/v2" + + eds "github.com/celestiaorg/celestia-node/share/new_eds" +) + +const defaultCloseTimeout = time.Minute + +var _ Cache = (*AccessorCache)(nil) + +// AccessorCache implements the Cache interface using an LRU cache backend. +type AccessorCache struct { + // The name is a prefix that will be used for cache metrics if they are enabled. + name string + // stripedLocks prevents simultaneous RW access to the accessor cache. Instead + // of using only one lock or one lock per uint64, we stripe the uint64s across 256 locks. 256 is + // chosen because it 0-255 is the range of values we get looking at the last byte of the uint64. + stripedLocks [256]*sync.RWMutex + // Caches the accessor for a given uint64 for accessor read affinity, i.e., further reads will + // likely be from the same accessor. Maps (Datahash -> accessor). + cache *lru.Cache[uint64, *accessor] + + metrics *metrics +} + +// accessor is the value stored in Cache. It implements the eds.AccessorStreamer interface. It has a +// reference counted so that it can be removed from the cache only when all references are released. +type accessor struct { + eds.AccessorStreamer + + lock sync.Mutex + done chan struct{} + refs atomic.Int32 + isClosed bool +} + +func NewAccessorCache(name string, cacheSize int) (*AccessorCache, error) { + bc := &AccessorCache{ + name: name, + stripedLocks: [256]*sync.RWMutex{}, + } + + for i := range bc.stripedLocks { + bc.stripedLocks[i] = &sync.RWMutex{} + } + // Instantiate the Accessor Cache. + bslru, err := lru.NewWithEvict[uint64, *accessor](cacheSize, bc.evictFn()) + if err != nil { + return nil, fmt.Errorf("creating accessor cache %s: %w", name, err) + } + bc.cache = bslru + return bc, nil +} + +// evictFn will be invoked when an item is evicted from the cache. +func (bc *AccessorCache) evictFn() func(uint64, *accessor) { + return func(_ uint64, ac *accessor) { + // we don't want to block cache on close and can release accessor from cache early, while it is + // being closed in parallel routine + go func() { + err := ac.close() + if err != nil { + bc.metrics.observeEvicted(true) + log.Errorf("couldn't close accessor after cache eviction: %s", err) + return + } + bc.metrics.observeEvicted(false) + }() + } +} + +// Get retrieves the accessor for a given uint64 from the Cache. If the Accessor is not in +// the Cache, it returns an ErrCacheMiss. +func (bc *AccessorCache) Get(height uint64) (eds.AccessorStreamer, error) { + lk := bc.getLock(height) + lk.RLock() + defer lk.RUnlock() + + ac, ok := bc.cache.Get(height) + if !ok { + bc.metrics.observeGet(false) + return nil, ErrCacheMiss + } + + bc.metrics.observeGet(true) + return newRefCloser(ac) +} + +// GetOrLoad attempts to get an item from the cache, and if not found, invokes +// the provided loader function to load it. +func (bc *AccessorCache) GetOrLoad( + ctx context.Context, + height uint64, + loader OpenAccessorFn, +) (eds.AccessorStreamer, error) { + lk := bc.getLock(height) + lk.Lock() + defer lk.Unlock() + + ac, ok := bc.cache.Get(height) + if ok { + // return accessor, only if it is not closed yet + accessorWithRef, err := newRefCloser(ac) + if err == nil { + bc.metrics.observeGet(true) + return accessorWithRef, nil + } + } + + // accessor not found in cache or closed, so load new one using loader + f, err := loader(ctx) + if err != nil { + return nil, fmt.Errorf("unable to load accessor: %w", err) + } + + ac = &accessor{AccessorStreamer: f} + // Create a new accessor first to increment the reference count in it, so it cannot get evicted + // from the inner lru cache before it is used. + rc, err := newRefCloser(ac) + if err != nil { + return nil, err + } + bc.cache.Add(height, ac) + return rc, nil +} + +// Remove removes the Accessor for a given uint64 from the cache. +func (bc *AccessorCache) Remove(height uint64) error { + lk := bc.getLock(height) + lk.RLock() + ac, ok := bc.cache.Get(height) + lk.RUnlock() + if !ok { + // item is not in cache + return nil + } + if err := ac.close(); err != nil { + return err + } + // The cache will call evictFn on removal, where accessor close will be called. + bc.cache.Remove(height) + return nil +} + +// EnableMetrics enables metrics for the cache. +func (bc *AccessorCache) EnableMetrics() error { + var err error + bc.metrics, err = newMetrics(bc) + return err +} + +func (s *accessor) addRef() error { + s.lock.Lock() + defer s.lock.Unlock() + if s.isClosed { + // item is already closed and soon will be removed after all refs are released + return ErrCacheMiss + } + if s.refs.Add(1) == 1 { + // there were no refs previously and done channel was closed, reopen it by recreating + s.done = make(chan struct{}) + } + return nil +} + +func (s *accessor) removeRef() { + s.lock.Lock() + defer s.lock.Unlock() + if s.refs.Add(-1) <= 0 { + close(s.done) + } +} + +// close closes the accessor and removes it from the cache if it is not closed yet. It will block +// until all references are released or timeout is reached. +func (s *accessor) close() error { + s.lock.Lock() + if s.isClosed { + s.lock.Unlock() + // accessor will be closed by another goroutine + return nil + } + s.isClosed = true + done := s.done + s.lock.Unlock() + + select { + case <-done: + case <-time.After(defaultCloseTimeout): + return fmt.Errorf("closing accessor, some readers didn't close the accessor within timeout,"+ + " amount left: %v", s.refs.Load()) + } + if err := s.AccessorStreamer.Close(); err != nil { + return fmt.Errorf("closing accessor: %w", err) + } + return nil +} + +// refCloser exists for reference counting protection on accessor. It ensures that a caller can't +// decrement it more than once. +type refCloser struct { + *accessor + closeFn func() +} + +// newRefCloser creates new refCloser +func newRefCloser(abs *accessor) (*refCloser, error) { + if err := abs.addRef(); err != nil { + return nil, err + } + + var closeOnce sync.Once + return &refCloser{ + accessor: abs, + closeFn: func() { + closeOnce.Do(abs.removeRef) + }, + }, nil +} + +func (c *refCloser) Close() error { + c.closeFn() + return nil +} + +func (bc *AccessorCache) getLock(k uint64) *sync.RWMutex { + return bc.stripedLocks[byte(k%256)] +} diff --git a/store/cache/accessor_cache_test.go b/store/cache/accessor_cache_test.go new file mode 100644 index 0000000000..c3777f548b --- /dev/null +++ b/store/cache/accessor_cache_test.go @@ -0,0 +1,341 @@ +package cache + +import ( + "bytes" + "context" + "errors" + "io" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + eds "github.com/celestiaorg/celestia-node/share/new_eds" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +func TestAccessorCache(t *testing.T) { + t.Run("add / get item from cache", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + cache, err := NewAccessorCache("test", 1) + require.NoError(t, err) + + // add accessor to the cache + height := uint64(1) + mock := &mockAccessor{ + data: []byte("test_data"), + } + loaded, err := cache.GetOrLoad(ctx, height, func(ctx context.Context) (eds.AccessorStreamer, error) { + return mock, nil + }) + require.NoError(t, err) + reader, err := loaded.Reader() + require.NoError(t, err) + data, err := io.ReadAll(reader) + require.NoError(t, err) + require.Equal(t, mock.data, data) + err = loaded.Close() + require.NoError(t, err) + + // check if item exists + got, err := cache.Get(height) + require.NoError(t, err) + reader, err = got.Reader() + require.NoError(t, err) + data, err = io.ReadAll(reader) + require.NoError(t, err) + require.Equal(t, mock.data, data) + err = got.Close() + require.NoError(t, err) + }) + + t.Run("get reader from accessor", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + cache, err := NewAccessorCache("test", 1) + require.NoError(t, err) + + // add accessor to the cache + height := uint64(1) + mock := &mockAccessor{} + accessor, err := cache.GetOrLoad(ctx, height, func(ctx context.Context) (eds.AccessorStreamer, error) { + return mock, nil + }) + require.NoError(t, err) + + // check if item exists + _, err = cache.Get(height) + require.NoError(t, err) + + // try to get reader + _, err = accessor.Reader() + require.NoError(t, err) + }) + + t.Run("remove an item", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + cache, err := NewAccessorCache("test", 1) + require.NoError(t, err) + + // add accessor to the cache + height := uint64(1) + mock := &mockAccessor{} + ac, err := cache.GetOrLoad(ctx, height, func(ctx context.Context) (eds.AccessorStreamer, error) { + return mock, nil + }) + require.NoError(t, err) + err = ac.Close() + require.NoError(t, err) + + err = cache.Remove(height) + require.NoError(t, err) + + // accessor should be closed on removal + mock.checkClosed(t, true) + + // check if item exists + _, err = cache.Get(height) + require.ErrorIs(t, err, ErrCacheMiss) + }) + + t.Run("successive reads should read the same data", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + cache, err := NewAccessorCache("test", 1) + require.NoError(t, err) + + // add accessor to the cache + height := uint64(1) + mock := &mockAccessor{data: []byte("test")} + accessor, err := cache.GetOrLoad(ctx, height, func(ctx context.Context) (eds.AccessorStreamer, error) { + return mock, nil + }) + require.NoError(t, err) + + reader, err := accessor.Reader() + require.NoError(t, err) + data, err := io.ReadAll(reader) + require.NoError(t, err) + require.Equal(t, mock.data, data) + + for i := 0; i < 2; i++ { + accessor, err = cache.Get(height) + require.NoError(t, err) + reader, err := accessor.Reader() + require.NoError(t, err) + data, err := io.ReadAll(reader) + require.NoError(t, err) + require.Equal(t, mock.data, data) + } + }) + + t.Run("removed by eviction", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + cache, err := NewAccessorCache("test", 1) + require.NoError(t, err) + + // add accessor to the cache + height := uint64(1) + mock := &mockAccessor{} + ac1, err := cache.GetOrLoad(ctx, height, func(ctx context.Context) (eds.AccessorStreamer, error) { + return mock, nil + }) + require.NoError(t, err) + err = ac1.Close() + require.NoError(t, err) + + // add second item + height2 := uint64(2) + ac2, err := cache.GetOrLoad(ctx, height2, func(ctx context.Context) (eds.AccessorStreamer, error) { + return mock, nil + }) + require.NoError(t, err) + err = ac2.Close() + require.NoError(t, err) + + // accessor should be closed on removal by eviction + mock.checkClosed(t, true) + + // first item should be evicted from cache + _, err = cache.Get(height) + require.ErrorIs(t, err, ErrCacheMiss) + }) + + t.Run("close on accessor is not closing underlying accessor", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + cache, err := NewAccessorCache("test", 1) + require.NoError(t, err) + + // add accessor to the cache + height := uint64(1) + mock := &mockAccessor{} + _, err = cache.GetOrLoad(ctx, height, func(ctx context.Context) (eds.AccessorStreamer, error) { + return mock, nil + }) + require.NoError(t, err) + + // check if item exists + accessor, err := cache.Get(height) + require.NoError(t, err) + require.NotNil(t, accessor) + + // close on returned accessor should not close inner accessor + err = accessor.Close() + require.NoError(t, err) + + // check that close was not performed on inner accessor + mock.checkClosed(t, false) + }) + + t.Run("close on accessor should wait all readers to finish", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + cache, err := NewAccessorCache("test", 1) + require.NoError(t, err) + + // add accessor to the cache + height := uint64(1) + mock := &mockAccessor{} + accessor1, err := cache.GetOrLoad(ctx, height, func(ctx context.Context) (eds.AccessorStreamer, error) { + return mock, nil + }) + require.NoError(t, err) + + // create second readers + accessor2, err := cache.Get(height) + require.NoError(t, err) + + // initialize close + done := make(chan struct{}) + go func() { + err := cache.Remove(height) + require.NoError(t, err) + close(done) + }() + + // close on first reader and check that it is not enough to release the inner accessor + err = accessor1.Close() + require.NoError(t, err) + mock.checkClosed(t, false) + + // second close from same reader should not release accessor either + err = accessor1.Close() + require.NoError(t, err) + mock.checkClosed(t, false) + + // reads for item that is being evicted should result in ErrCacheMiss + _, err = cache.Get(height) + require.ErrorIs(t, err, ErrCacheMiss) + + // close second reader and wait for accessor to be closed + err = accessor2.Close() + require.NoError(t, err) + // wait until close is performed on accessor + select { + case <-done: + case <-ctx.Done(): + t.Fatal("timeout reached") + } + + // item will be removed + mock.checkClosed(t, true) + }) + + t.Run("slow reader should not block eviction", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + cache, err := NewAccessorCache("test", 1) + require.NoError(t, err) + + // add accessor to the cache + height1 := uint64(1) + mock1 := &mockAccessor{} + accessor1, err := cache.GetOrLoad(ctx, height1, func(ctx context.Context) (eds.AccessorStreamer, error) { + return mock1, nil + }) + require.NoError(t, err) + + // add second accessor, to trigger eviction of the first one + height2 := uint64(2) + mock2 := &mockAccessor{} + accessor2, err := cache.GetOrLoad(ctx, height2, func(ctx context.Context) (eds.AccessorStreamer, error) { + return mock2, nil + }) + require.NoError(t, err) + + // first accessor should be evicted from cache + _, err = cache.Get(height1) + require.ErrorIs(t, err, ErrCacheMiss) + + // first accessor should not be closed before all refs are released by Close() is calls. + mock1.checkClosed(t, false) + + // after Close() is called on first accessor, it is free to get closed + err = accessor1.Close() + require.NoError(t, err) + mock1.checkClosed(t, true) + + // after Close called on second accessor, it should stay in cache (not closed) + err = accessor2.Close() + require.NoError(t, err) + mock2.checkClosed(t, false) + }) +} + +type mockAccessor struct { + m sync.Mutex + data []byte + isClosed bool +} + +func (m *mockAccessor) Size(context.Context) int { + panic("implement me") +} + +func (m *mockAccessor) Sample(context.Context, int, int) (shwap.Sample, error) { + panic("implement me") +} + +func (m *mockAccessor) AxisHalf(context.Context, rsmt2d.Axis, int) (eds.AxisHalf, error) { + panic("implement me") +} + +func (m *mockAccessor) RowNamespaceData(context.Context, share.Namespace, int) (shwap.RowNamespaceData, error) { + panic("implement me") +} + +func (m *mockAccessor) Shares(context.Context) ([]share.Share, error) { + panic("implement me") +} + +func (m *mockAccessor) Reader() (io.Reader, error) { + m.m.Lock() + defer m.m.Unlock() + return bytes.NewBuffer(m.data), nil +} + +func (m *mockAccessor) Close() error { + m.m.Lock() + defer m.m.Unlock() + if m.isClosed { + return errors.New("already closed") + } + m.isClosed = true + return nil +} + +func (m *mockAccessor) checkClosed(t *testing.T, expected bool) { + // item will be removed async in background, give it some time to settle + time.Sleep(time.Millisecond * 100) + m.m.Lock() + defer m.m.Unlock() + require.Equal(t, expected, m.isClosed) +} diff --git a/store/cache/cache.go b/store/cache/cache.go new file mode 100644 index 0000000000..7bdf247612 --- /dev/null +++ b/store/cache/cache.go @@ -0,0 +1,36 @@ +package cache + +import ( + "context" + "errors" + + logging "github.com/ipfs/go-log/v2" + "go.opentelemetry.io/otel" + + eds "github.com/celestiaorg/celestia-node/share/new_eds" +) + +var ( + log = logging.Logger("store/cache") + meter = otel.Meter("store_cache") +) + +var ErrCacheMiss = errors.New("accessor not found in cache") + +type OpenAccessorFn func(context.Context) (eds.AccessorStreamer, error) + +// Cache is an interface that defines the basic Cache operations. +type Cache interface { + // Get returns the eds.AccessorStreamer for the given height. + Get(height uint64) (eds.AccessorStreamer, error) + + // GetOrLoad attempts to get an item from the Cache and, if not found, invokes + // the provided loader function to load it into the Cache. + GetOrLoad(ctx context.Context, height uint64, open OpenAccessorFn) (eds.AccessorStreamer, error) + + // Remove removes an item from Cache. + Remove(height uint64) error + + // EnableMetrics enables metrics in Cache + EnableMetrics() error +} diff --git a/store/cache/doublecache.go b/store/cache/doublecache.go new file mode 100644 index 0000000000..1e1230f703 --- /dev/null +++ b/store/cache/doublecache.go @@ -0,0 +1,51 @@ +package cache + +import ( + "errors" + + eds "github.com/celestiaorg/celestia-node/share/new_eds" +) + +// DoubleCache represents a Cache that looks into multiple caches one by one. +type DoubleCache struct { + first, second Cache +} + +// NewDoubleCache creates a new DoubleCache with the provided caches. +func NewDoubleCache(first, second Cache) *DoubleCache { + return &DoubleCache{ + first: first, + second: second, + } +} + +// Get looks for an item in all the caches one by one and returns the Cache found item. +func (mc *DoubleCache) Get(height uint64) (eds.AccessorStreamer, error) { + accessor, err := mc.first.Get(height) + if err == nil { + return accessor, nil + } + return mc.second.Get(height) +} + +// Remove removes an item from all underlying caches +func (mc *DoubleCache) Remove(height uint64) error { + err1 := mc.first.Remove(height) + err2 := mc.second.Remove(height) + return errors.Join(err1, err2) +} + +func (mc *DoubleCache) First() Cache { + return mc.first +} + +func (mc *DoubleCache) Second() Cache { + return mc.second +} + +func (mc *DoubleCache) EnableMetrics() error { + if err := mc.first.EnableMetrics(); err != nil { + return err + } + return mc.second.EnableMetrics() +} diff --git a/store/cache/metrics.go b/store/cache/metrics.go new file mode 100644 index 0000000000..289ae80a96 --- /dev/null +++ b/store/cache/metrics.go @@ -0,0 +1,71 @@ +package cache + +import ( + "context" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +const ( + cacheFoundKey = "found" + failedKey = "failed" +) + +type metrics struct { + getCounter metric.Int64Counter + evictedCounter metric.Int64Counter + reg metric.Registration +} + +func newMetrics(bc *AccessorCache) (*metrics, error) { + metricsPrefix := "eds_cache" + bc.name + + evictedCounter, err := meter.Int64Counter(metricsPrefix+"_evicted_counter", + metric.WithDescription("eds cache evicted event counter")) + if err != nil { + return nil, err + } + + getCounter, err := meter.Int64Counter(metricsPrefix+"_get_counter", + metric.WithDescription("eds cache get event counter")) + if err != nil { + return nil, err + } + + cacheSize, err := meter.Int64ObservableGauge(metricsPrefix+"_size", + metric.WithDescription("total amount of items in cache"), + ) + if err != nil { + return nil, err + } + + callback := func(_ context.Context, observer metric.Observer) error { + observer.ObserveInt64(cacheSize, int64(bc.cache.Len())) + return nil + } + reg, err := meter.RegisterCallback(callback, cacheSize) + + return &metrics{ + getCounter: getCounter, + evictedCounter: evictedCounter, + reg: reg, + }, err +} + +func (m *metrics) observeEvicted(failed bool) { + if m == nil { + return + } + m.evictedCounter.Add(context.Background(), 1, + metric.WithAttributes( + attribute.Bool(failedKey, failed))) +} + +func (m *metrics) observeGet(found bool) { + if m == nil { + return + } + m.getCounter.Add(context.Background(), 1, metric.WithAttributes( + attribute.Bool(cacheFoundKey, found))) +} diff --git a/store/cache/noop.go b/store/cache/noop.go new file mode 100644 index 0000000000..f1a2936cdb --- /dev/null +++ b/store/cache/noop.go @@ -0,0 +1,72 @@ +package cache + +import ( + "context" + "io" + + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/share" + eds "github.com/celestiaorg/celestia-node/share/new_eds" + "github.com/celestiaorg/celestia-node/share/shwap" +) + +var _ Cache = (*NoopCache)(nil) + +// NoopCache implements noop version of Cache interface +type NoopCache struct{} + +func (n NoopCache) Get(uint64) (eds.AccessorStreamer, error) { + return nil, ErrCacheMiss +} + +func (n NoopCache) GetOrLoad(ctx context.Context, _ uint64, loader OpenAccessorFn) (eds.AccessorStreamer, error) { + return loader(ctx) +} + +func (n NoopCache) Remove(uint64) error { + return nil +} + +func (n NoopCache) EnableMetrics() error { + return nil +} + +var _ eds.AccessorStreamer = NoopFile{} + +// NoopFile implements noop version of eds.AccessorStreamer interface +type NoopFile struct{} + +func (n NoopFile) Reader() (io.Reader, error) { + return noopReader{}, nil +} + +func (n NoopFile) Size(context.Context) int { + return 0 +} + +func (n NoopFile) Sample(context.Context, int, int) (shwap.Sample, error) { + return shwap.Sample{}, nil +} + +func (n NoopFile) AxisHalf(context.Context, rsmt2d.Axis, int) (eds.AxisHalf, error) { + return eds.AxisHalf{}, nil +} + +func (n NoopFile) RowNamespaceData(context.Context, share.Namespace, int) (shwap.RowNamespaceData, error) { + return shwap.RowNamespaceData{}, nil +} + +func (n NoopFile) Shares(context.Context) ([]share.Share, error) { + return []share.Share{}, nil +} + +func (n NoopFile) Close() error { + return nil +} + +type noopReader struct{} + +func (n noopReader) Read([]byte) (int, error) { + return 0, nil +}