Skip to content

Commit

Permalink
clean up and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yujieli-temporal committed Aug 1, 2023
1 parent 31f963d commit 3ddbfb2
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 28 deletions.
59 changes: 31 additions & 28 deletions common/cache/lru.go
Expand Up @@ -274,63 +274,53 @@ func (c *lru) putInternal(key interface{}, value interface{}, allowUpdate bool)
if c.maxSize == 0 {
return nil, nil
}
entrySize := getSize(value)
if entrySize > c.maxSize {
newEntrySize := getSize(value)
if newEntrySize > c.maxSize {
return nil, ErrCacheItemTooLarge
}

c.mut.Lock()
defer c.mut.Unlock()

c.currSize += entrySize
c.tryEvictUntilEnoughSpace()
// If there is still not enough space, remove the new entry size from the current size and return an error
if c.currSize > c.maxSize {
c.currSize -= entrySize
c.tryEvictUntilEnoughSpace(newEntrySize)
// after evicting, check if the new entry can fit in the cache
if c.currSize+newEntrySize > c.maxSize {
return nil, ErrCacheFull
}

elt := c.byKey[key]
// If the entry exists, check if it has expired or update the value
if elt != nil {
existingEntry := elt.Value.(*entryImpl)
if c.isEntryExpired(existingEntry, time.Now().UTC()) {
// Entry has expired
c.deleteInternal(elt)
} else {
if !c.isEntryExpired(existingEntry, time.Now().UTC()) {
existing := existingEntry.value
if allowUpdate {
c.currSize -= existingEntry.Size()
existingEntry.value = value
if c.ttl != 0 {
existingEntry.createTime = time.Now().UTC()
}
c.currSize += newEntrySize
c.updateEntryTTL(existingEntry)
}

c.updateEntryRefCount(existingEntry)
c.byAccess.MoveToFront(elt)
if c.pin {
existingEntry.refCount++
}
return existing, nil
}

// Entry has expired
c.deleteInternal(elt)
}

entry := &entryImpl{
key: key,
value: value,
size: entrySize,
}

if c.pin {
entry.refCount++
}

if c.ttl != 0 {
entry.createTime = c.timeSource.Now().UTC()
size: newEntrySize,
}

c.updateEntryTTL(entry)
c.updateEntryRefCount(entry)
element := c.byAccess.PushFront(entry)
c.byKey[key] = element
c.currSize += newEntrySize
return nil, nil
}

Expand All @@ -341,9 +331,10 @@ func (c *lru) deleteInternal(element *list.Element) {
}

// tryEvictUntilEnoughSpace try to evict entries until there is enough space for the new entry
func (c *lru) tryEvictUntilEnoughSpace() {
func (c *lru) tryEvictUntilEnoughSpace(newEntrySize int) {
element := c.byAccess.Back()
for c.currSize > c.maxSize && element != nil {
// currSize will be updated within deleteInternal
for c.currSize+newEntrySize > c.maxSize && element != nil {
entry := element.Value.(*entryImpl)
if entry.refCount == 0 {
c.deleteInternal(element)
Expand All @@ -358,3 +349,15 @@ func (c *lru) tryEvictUntilEnoughSpace() {
func (c *lru) isEntryExpired(entry *entryImpl, currentTime time.Time) bool {
return entry.refCount == 0 && !entry.createTime.IsZero() && currentTime.After(entry.createTime.Add(c.ttl))
}

func (c *lru) updateEntryTTL(entry *entryImpl) {
if c.ttl != 0 {
entry.createTime = c.timeSource.Now().UTC()
}
}

func (c *lru) updateEntryRefCount(entry *entryImpl) {
if c.pin {
entry.refCount++
}
}
23 changes: 23 additions & 0 deletions common/cache/lru_test.go
Expand Up @@ -71,19 +71,23 @@ func TestLRU(t *testing.T) {

cache.Put("A", "Foo2")
assert.Equal(t, "Foo2", cache.Get("A"))
assert.Equal(t, 4, cache.Size())

cache.Put("E", "Epsi")
assert.Equal(t, "Epsi", cache.Get("E"))
assert.Equal(t, "Foo2", cache.Get("A"))
assert.Nil(t, cache.Get("B")) // Oldest, should be evicted
assert.Equal(t, 4, cache.Size())

// Access C, D is now LRU
cache.Get("C")
cache.Put("F", "Felp")
assert.Nil(t, cache.Get("D"))
assert.Equal(t, 4, cache.Size())

cache.Delete("A")
assert.Nil(t, cache.Get("A"))
assert.Equal(t, 3, cache.Size())
}

func TestGenerics(t *testing.T) {
Expand Down Expand Up @@ -210,13 +214,16 @@ func TestTTLWithPin(t *testing.T) {
_, err := cache.PutIfNotExist("A", t)
assert.NoError(t, err)
assert.Equal(t, t, cache.Get("A"))
assert.Equal(t, 1, cache.Size())
timeSource.Advance(time.Millisecond * 100)
assert.Equal(t, t, cache.Get("A"))
assert.Equal(t, 1, cache.Size())
// release 3 time since put if not exist also increase the counter
cache.Release("A")
cache.Release("A")
cache.Release("A")
assert.Nil(t, cache.Get("A"))
assert.Equal(t, 0, cache.Size())
}

func TestMaxSizeWithPin_MidItem(t *testing.T) {
Expand All @@ -231,31 +238,38 @@ func TestMaxSizeWithPin_MidItem(t *testing.T) {

_, err := cache.PutIfNotExist("A", t)
assert.NoError(t, err)
assert.Equal(t, 1, cache.Size())

_, err = cache.PutIfNotExist("B", t)
assert.NoError(t, err)
assert.Equal(t, 2, cache.Size())

_, err = cache.PutIfNotExist("C", t)
assert.Error(t, err)
assert.Equal(t, 2, cache.Size())

assert.Equal(t, t, cache.Get("A"))
cache.Release("A") // get will also increase the ref count
assert.Equal(t, t, cache.Get("B"))
cache.Release("B") // get will also increase the ref count
assert.Equal(t, 2, cache.Size())

cache.Release("B") // B's ref count is 0
_, err = cache.PutIfNotExist("C", t)
assert.NoError(t, err)
assert.Equal(t, t, cache.Get("C"))
cache.Release("C") // get will also increase the ref count
assert.Equal(t, 2, cache.Size())

cache.Release("A") // A's ref count is 0
cache.Release("C") // C's ref count is 0
assert.Equal(t, 2, cache.Size())

timeSource.Advance(time.Millisecond * 100)
assert.Nil(t, cache.Get("A"))
assert.Nil(t, cache.Get("B"))
assert.Nil(t, cache.Get("C"))
assert.Equal(t, 0, cache.Size())
}

func TestMaxSizeWithPin_LastItem(t *testing.T) {
Expand All @@ -270,31 +284,38 @@ func TestMaxSizeWithPin_LastItem(t *testing.T) {

_, err := cache.PutIfNotExist("A", t)
assert.NoError(t, err)
assert.Equal(t, 1, cache.Size())

_, err = cache.PutIfNotExist("B", t)
assert.NoError(t, err)
assert.Equal(t, 2, cache.Size())

_, err = cache.PutIfNotExist("C", t)
assert.Error(t, err)
assert.Equal(t, 2, cache.Size())

assert.Equal(t, t, cache.Get("A"))
cache.Release("A") // get will also increase the ref count
assert.Equal(t, t, cache.Get("B"))
cache.Release("B") // get will also increase the ref count
assert.Equal(t, 2, cache.Size())

cache.Release("A") // A's ref count is 0
_, err = cache.PutIfNotExist("C", t)
assert.NoError(t, err)
assert.Equal(t, t, cache.Get("C"))
cache.Release("C") // get will also increase the ref count
assert.Equal(t, 2, cache.Size())

cache.Release("B") // B's ref count is 0
cache.Release("C") // C's ref count is 0
assert.Equal(t, 2, cache.Size())

timeSource.Advance(time.Millisecond * 100)
assert.Nil(t, cache.Get("A"))
assert.Nil(t, cache.Get("B"))
assert.Nil(t, cache.Get("C"))
assert.Equal(t, 0, cache.Size())
}

func TestIterator(t *testing.T) {
Expand Down Expand Up @@ -359,10 +380,12 @@ func TestCache_ItemSizeTooLarge(t *testing.T) {

res := cache.Put(uuid.New(), &testEntryWithCacheSize{maxTotalBytes})
assert.Equal(t, res, nil)
assert.Equal(t, 10, cache.Size())

res, err := cache.PutIfNotExist(uuid.New(), &testEntryWithCacheSize{maxTotalBytes + 1})
assert.Equal(t, err, ErrCacheItemTooLarge)
assert.Equal(t, res, nil)
assert.Equal(t, 10, cache.Size())

}

Expand Down

0 comments on commit 3ddbfb2

Please sign in to comment.