/
cachemap.go
103 lines (84 loc) · 2.02 KB
/
cachemap.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package dagql
import (
"context"
"fmt"
"sync"
"github.com/opencontainers/go-digest"
)
type cacheMap[K comparable, T any] struct {
l sync.Mutex
calls map[K]*cache[T]
}
type cache[T any] struct {
wg sync.WaitGroup
val T
err error
}
// NewCache creates a new cache map suitable for assigning on a Server or
// multiple Servers.
func NewCache() Cache {
return newCacheMap[digest.Digest, Typed]()
}
func newCacheMap[K comparable, T any]() *cacheMap[K, T] {
return &cacheMap[K, T]{
calls: map[K]*cache[T]{},
}
}
type cacheMapContextKey[K comparable, T any] struct {
key K
m *cacheMap[K, T]
}
var ErrCacheMapRecursiveCall = fmt.Errorf("recursive call detected")
func (m *cacheMap[K, T]) Set(key K, val T) {
m.l.Lock()
m.calls[key] = &cache[T]{
val: val,
}
m.l.Unlock()
}
func (m *cacheMap[K, T]) GetOrInitialize(ctx context.Context, key K, fn func(ctx context.Context) (T, error)) (T, error) {
return m.GetOrInitializeOnHit(ctx, key, fn, func(T, error) {})
}
func (m *cacheMap[K, T]) GetOrInitializeOnHit(ctx context.Context, key K, fn func(ctx context.Context) (T, error), onHit func(T, error)) (T, error) {
if v := ctx.Value(cacheMapContextKey[K, T]{key: key, m: m}); v != nil {
var zero T
return zero, ErrCacheMapRecursiveCall
}
m.l.Lock()
if c, ok := m.calls[key]; ok {
m.l.Unlock()
c.wg.Wait()
if onHit != nil {
onHit(c.val, c.err)
}
return c.val, c.err
}
c := &cache[T]{}
c.wg.Add(1)
m.calls[key] = c
m.l.Unlock()
ctx = context.WithValue(ctx, cacheMapContextKey[K, T]{key: key, m: m}, struct{}{})
c.val, c.err = fn(ctx)
c.wg.Done()
if c.err != nil {
m.l.Lock()
delete(m.calls, key)
m.l.Unlock()
}
return c.val, c.err
}
func (m *cacheMap[K, T]) Get(ctx context.Context, key K) (T, error) {
if v := ctx.Value(cacheMapContextKey[K, T]{key: key, m: m}); v != nil {
var zero T
return zero, ErrCacheMapRecursiveCall
}
m.l.Lock()
if c, ok := m.calls[key]; ok {
m.l.Unlock()
c.wg.Wait()
return c.val, c.err
}
m.l.Unlock()
var zero T
return zero, fmt.Errorf("key not found")
}