diff --git a/util/singleflight/singleflight.go b/util/singleflight/singleflight.go index a1c9e47667c2b..9df47448b70ab 100644 --- a/util/singleflight/singleflight.go +++ b/util/singleflight/singleflight.go @@ -20,11 +20,13 @@ package singleflight // import "tailscale.com/util/singleflight" import ( "bytes" + "context" "errors" "fmt" "runtime" "runtime/debug" "sync" + "sync/atomic" ) // errGoexit indicates the runtime.Goexit was called in @@ -69,6 +71,11 @@ type call[V any] struct { // not written after the WaitGroup is done. dups int chans []chan<- Result[V] + + // These fields are only written when the call is being created, and + // only in the DoChanContext method. + cancel context.CancelFunc + ctxWaiters atomic.Int64 } // Group represents a class of work and forms a namespace in @@ -143,6 +150,93 @@ func (g *Group[K, V]) DoChan(key K, fn func() (V, error)) <-chan Result[V] { return ch } +// DoChanContext is like [Group.DoChan], but supports context cancelation. The +// context passed to the fn function is a context that is canceled only when +// there are no callers waiting on a result (i.e. all callers have canceled +// their contexts). +// +// The context that is passed to the fn function is not derived from any of the +// input contexts, so context values will not be propagated. If context values +// are needed, they must be propagated explicitly. +// +// The returned channel will not be closed. The Result.Err field is set to the +// context error if the context is canceled. +func (g *Group[K, V]) DoChanContext(ctx context.Context, key K, fn func(context.Context) (V, error)) <-chan Result[V] { + ch := make(chan Result[V], 1) + g.mu.Lock() + if g.m == nil { + g.m = make(map[K]*call[V]) + } + c, ok := g.m[key] + if ok { + // Call already in progress; add to the waiters list and then + // release the mutex. + c.dups++ + c.ctxWaiters.Add(1) + c.chans = append(c.chans, ch) + g.mu.Unlock() + } else { + // The call hasn't been started yet; we need to start it. + // + // Create a context that is not canceled when the parent context is, + // but otherwise propagates all values. + callCtx, callCancel := context.WithCancel(context.Background()) + c = &call[V]{ + chans: []chan<- Result[V]{ch}, + cancel: callCancel, + } + c.wg.Add(1) + c.ctxWaiters.Add(1) // one caller waiting + g.m[key] = c + g.mu.Unlock() + + // Wrap our function to provide the context. + go g.doCall(c, key, func() (V, error) { + return fn(callCtx) + }) + } + + // Instead of returning the channel directly, we need to track + // when the call finishes so we can handle context cancelation. + // Do so by creating an final channel that gets the + // result and hooking that up to the wait function. + final := make(chan Result[V], 1) + go g.waitCtx(ctx, c, ch, final) + return final +} + +// waitCtx will wait on the provided call to finish, or the context to be done. +// If the context is done, and this is the last waiter, then the context +// provided to the underlying function will be canceled. +func (g *Group[K, V]) waitCtx(ctx context.Context, c *call[V], result <-chan Result[V], output chan<- Result[V]) { + var res Result[V] + select { + case <-ctx.Done(): + case res = <-result: + } + + // Decrement the caller count, and if we're the last one, cancel the + // context we created. Do this in all cases, error and otherwise, so we + // don't leak goroutines. + // + // Also wait on the call to finish, so we know that the call has + // finished executing after the last caller has returned. + if c.ctxWaiters.Add(-1) == 0 { + c.cancel() + c.wg.Wait() + } + + // Ensure that context cancelation takes precedence over a value being + // available by checking ctx.Err() before sending the result to the + // caller. The select above will nondeterministically pick a case if a + // result is available and the ctx.Done channel is closed, so we check + // again here. + if err := ctx.Err(); err != nil { + res = Result[V]{Err: err} + } + output <- res +} + // doCall handles the single call for a key. func (g *Group[K, V]) doCall(c *call[V], key K, fn func() (V, error)) { normalReturn := false diff --git a/util/singleflight/singleflight_test.go b/util/singleflight/singleflight_test.go index b98fae8850d9e..031922736fab6 100644 --- a/util/singleflight/singleflight_test.go +++ b/util/singleflight/singleflight_test.go @@ -9,6 +9,7 @@ package singleflight import ( "bytes" + "context" "errors" "fmt" "os" @@ -321,3 +322,155 @@ func TestPanicDoSharedByDoChan(t *testing.T) { t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in Do") } } + +func TestDoChanContext(t *testing.T) { + t.Run("Basic", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var g Group[string, int] + ch := g.DoChanContext(ctx, "key", func(_ context.Context) (int, error) { + return 1, nil + }) + ret := <-ch + assertOKResult(t, ret, 1) + }) + + t.Run("DoesNotPropagateValues", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + key := new(int) + const value = "hello world" + + ctx = context.WithValue(ctx, key, value) + + var g Group[string, int] + ch := g.DoChanContext(ctx, "foobar", func(ctx context.Context) (int, error) { + if _, ok := ctx.Value(key).(string); ok { + t.Error("expected no value, but was present in context") + } + return 1, nil + }) + ret := <-ch + assertOKResult(t, ret, 1) + }) + + t.Run("NoCancelWhenWaiters", func(t *testing.T) { + testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer testCancel() + + trigger := make(chan struct{}) + + ctx1, cancel1 := context.WithCancel(context.Background()) + defer cancel1() + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + fn := func(ctx context.Context) (int, error) { + select { + case <-ctx.Done(): + return 0, ctx.Err() + case <-trigger: + return 1234, nil + } + } + + // Create two waiters, then cancel the first before we trigger + // the function to return a value. This shouldn't result in a + // context canceled error. + var g Group[string, int] + ch1 := g.DoChanContext(ctx1, "key", fn) + ch2 := g.DoChanContext(ctx2, "key", fn) + + cancel1() + + // The first channel, now that it's canceled, should return a + // context canceled error. + select { + case res := <-ch1: + if !errors.Is(res.Err, context.Canceled) { + t.Errorf("unexpected error; got %v, want context.Canceled", res.Err) + } + case <-testCtx.Done(): + t.Fatal("test timed out") + } + + // Actually return + close(trigger) + res := <-ch2 + assertOKResult(t, res, 1234) + }) + + t.Run("AllCancel", func(t *testing.T) { + for _, n := range []int{1, 2, 10, 20} { + t.Run(fmt.Sprintf("NumWaiters=%d", n), func(t *testing.T) { + testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer testCancel() + + trigger := make(chan struct{}) + defer close(trigger) + + fn := func(ctx context.Context) (int, error) { + select { + case <-ctx.Done(): + return 0, ctx.Err() + case <-trigger: + t.Error("unexpected trigger; want all callers to cancel") + return 0, errors.New("unexpected trigger") + } + } + + // Launch N goroutines that all wait on the same key. + var ( + g Group[string, int] + chs []<-chan Result[int] + cancels []context.CancelFunc + ) + for i := range n { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cancels = append(cancels, cancel) + + ch := g.DoChanContext(ctx, "key", fn) + chs = append(chs, ch) + + // Every third goroutine should cancel + // immediately, which better tests the + // cancel logic. + if i%3 == 0 { + cancel() + } + } + + // Now that everything is waiting, cancel all the contexts. + for _, cancel := range cancels { + cancel() + } + + // Wait for a result from each channel. They + // should all return an error showing a context + // cancel. + for _, ch := range chs { + select { + case res := <-ch: + if !errors.Is(res.Err, context.Canceled) { + t.Errorf("unexpected error; got %v, want context.Canceled", res.Err) + } + case <-testCtx.Done(): + t.Fatal("test timed out") + } + } + }) + } + }) +} + +func assertOKResult[V comparable](t testing.TB, res Result[V], want V) { + if res.Err != nil { + t.Fatalf("unexpected error: %v", res.Err) + } + if res.Val != want { + t.Fatalf("unexpected value; got %v, want %v", res.Val, want) + } +}