forked from azazeal/singleflight
/
singleflight.go
81 lines (64 loc) · 2.06 KB
/
singleflight.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
// Package singleflight implements a call sharing mechanism.
package singleflight
import (
"context"
"sync"
"golang.org/x/sync/semaphore"
)
// Caller wraps the functionality of the call sharing mechanism.
//
// A Caller must not be copied after first use.
type Caller[K comparable, V any] struct {
mu sync.Mutex
calls map[K]*call[V]
}
const (
readerWeight = 1 << (30 * iota)
writerWeight
)
type call[V any] struct {
sem *semaphore.Weighted
val V
err error
}
// Call calls fn and returns the results. Concurrent callers sharing a key will also share the results of the first
// call.
//
// fn may access the key passed to Call via KeyFromContext.
func (caller *Caller[K, V]) Call(ctx context.Context, key K, fn func(context.Context) (V, error)) (V, error) {
caller.mu.Lock()
if caller.calls == nil {
caller.calls = make(map[K]*call[V])
}
// check whether an in-flight call exists for the key
if inflight, ok := caller.calls[key]; ok {
// an in-flight call exists; attach to it as a reader and return its result once available
caller.mu.Unlock()
if err := inflight.sem.Acquire(ctx, readerWeight); err != nil {
var zero V
return zero, err
}
defer inflight.sem.Release(readerWeight)
return inflight.val, inflight.err
}
// there's no in-flight call; start one
call := &call[V]{
sem: semaphore.NewWeighted(writerWeight),
}
_ = call.sem.Acquire(context.Background(), writerWeight) //nolint:contextcheck // guaranteed to succeed
caller.calls[key] = call
caller.mu.Unlock()
call.val, call.err = fn(context.WithValue(ctx, contextKeyType[K]{}, key))
// the call has finished; we're still the only active caller so we can mark
// this call as no longer taking place by deleting it from the map
caller.mu.Lock()
call.sem.Release(writerWeight)
delete(caller.calls, key)
caller.mu.Unlock()
return call.val, call.err
}
type contextKeyType[K comparable] struct{}
// KeyFromContext returns the key ctx carries. It panics in case ctx carries no key.
func (*Caller[K, V]) KeyFromContext(ctx context.Context) K {
return ctx.Value(contextKeyType[K]{}).(K)
}