diff --git a/util/singleflight/singleflight.go b/util/singleflight/singleflight.go index a1c9e47667c2b..b4ffac8745930 100644 --- a/util/singleflight/singleflight.go +++ b/util/singleflight/singleflight.go @@ -20,11 +20,13 @@ package singleflight // import "tailscale.com/util/singleflight" import ( "bytes" + "context" "errors" "fmt" "runtime" "runtime/debug" "sync" + "sync/atomic" ) // errGoexit indicates the runtime.Goexit was called in @@ -69,6 +71,11 @@ type call[V any] struct { // not written after the WaitGroup is done. dups int chans []chan<- Result[V] + + // These fields are only written when the call is being created, and + // only in the DoChanContext method. + cancel context.CancelFunc + ctxWaiters atomic.Int64 } // Group represents a class of work and forms a namespace in @@ -143,6 +150,86 @@ func (g *Group[K, V]) DoChan(key K, fn func() (V, error)) <-chan Result[V] { return ch } +// DoChanContext is like DoChan, but supports context cancelation. The context +// passed to the fn function is a context that is canceled only when there are +// no callers waiting on a result (i.e. all callers have canceled their +// contexts). +func (g *Group[K, V]) DoChanContext(ctx context.Context, key K, fn func(context.Context) (V, error)) <-chan Result[V] { + ch := make(chan Result[V], 1) + g.mu.Lock() + if g.m == nil { + g.m = make(map[K]*call[V]) + } + if c, ok := g.m[key]; ok { + c.dups++ + c.ctxWaiters.Add(1) + c.chans = append(c.chans, ch) + g.mu.Unlock() + + // Instead of returning the channel directly, we need to track + // when the call finishes so we can handle context cancelation. + // Do so by creating an final channel that gets the + // result and hooking that up to the wait function. + final := make(chan Result[V], 1) + go g.waitCtx(ctx, c, ch, final) + return final + } + // Create a context that is not canceled when the parent context is, + // but otherwise propagates all values. + callCtx, callCancel := context.WithCancel(context.WithoutCancel(ctx)) + c := &call[V]{ + chans: []chan<- Result[V]{ch}, + cancel: callCancel, + } + c.wg.Add(1) + c.ctxWaiters.Add(1) // one caller waiting + g.m[key] = c + g.mu.Unlock() + + // Wrap our function to provide the context. + go g.doCall(c, key, func() (V, error) { + return fn(callCtx) + }) + + final := make(chan Result[V], 1) + go g.waitCtx(ctx, c, ch, final) + return final +} + +// waitCtx will wait on the provided call to finish, or the context to be done. +// If the context is done, and this is the last waiter, then the context +// provided to the underlying function will be canceled. +func (g *Group[K, V]) waitCtx(ctx context.Context, c *call[V], result <-chan Result[V], output chan<- Result[V]) { + var ( + res Result[V] + err error + ) + select { + case <-ctx.Done(): + err = ctx.Err() + case res = <-result: + } + + // Decrement the caller count, and if we're the last one, cancel the + // context we created. Do this in all cases, error and otherwise, so we + // don't leak goroutines. + // + // Also wait on the call to finish, so we know that the call has + // finished executing after the last caller has returned. + if c.ctxWaiters.Add(-1) == 0 { + c.cancel() + c.wg.Wait() + } + + // Send the result to the caller; if the error was non-nil, we send + // that, otherwise we send the result we got. + if err != nil { + output <- Result[V]{Err: err} + } else { + output <- res + } +} + // doCall handles the single call for a key. func (g *Group[K, V]) doCall(c *call[V], key K, fn func() (V, error)) { normalReturn := false diff --git a/util/singleflight/singleflight_test.go b/util/singleflight/singleflight_test.go index b98fae8850d9e..e0476f9b06ec0 100644 --- a/util/singleflight/singleflight_test.go +++ b/util/singleflight/singleflight_test.go @@ -9,6 +9,7 @@ package singleflight import ( "bytes" + "context" "errors" "fmt" "os" @@ -321,3 +322,153 @@ func TestPanicDoSharedByDoChan(t *testing.T) { t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in Do") } } + +func TestDoChanContext(t *testing.T) { + t.Run("Basic", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var g Group[string, int] + ch := g.DoChanContext(ctx, "key", func(_ context.Context) (int, error) { + return 1, nil + }) + ret := <-ch + assertOKResult(t, ret, 1) + }) + + t.Run("PropagatesValues", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + key := new(int) + const value = "hello world" + + ctx = context.WithValue(ctx, key, value) + + var g Group[string, int] + ch := g.DoChanContext(ctx, "key", func(_ context.Context) (int, error) { + gotVal, ok := ctx.Value(key).(string) + if !ok { + t.Fatalf("expected value to be present in context") + } + if gotVal != value { + t.Fatalf("unexpected value; got %q, want %q", gotVal, value) + } + return 1, nil + }) + ret := <-ch + assertOKResult(t, ret, 1) + }) + + t.Run("NoCancelWhenWaiters", func(t *testing.T) { + testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer testCancel() + + trigger := make(chan struct{}) + + ctx1, cancel1 := context.WithCancel(context.Background()) + defer cancel1() + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + fn := func(ctx context.Context) (int, error) { + select { + case <-ctx.Done(): + return 0, ctx.Err() + case <-trigger: + return 1234, nil + } + } + + // Create two waiters, then cancel the first before we trigger + // the function to return a value. This shouldn't result in a + // context canceled error. + var g Group[string, int] + ch1 := g.DoChanContext(ctx1, "key", fn) + ch2 := g.DoChanContext(ctx2, "key", fn) + + cancel1() + + // The first channel, now that it's canceled, should return a + // context canceled error. + select { + case res := <-ch1: + if !errors.Is(res.Err, context.Canceled) { + t.Errorf("unexpected error; got %v, want context.Canceled", res.Err) + } + case <-testCtx.Done(): + t.Fatal("test timed out") + } + + // Actually return + close(trigger) + res := <-ch2 + assertOKResult(t, res, 1234) + }) + + t.Run("AllCancel", func(t *testing.T) { + for _, n := range []int{1, 2, 10} { + n := n + t.Run(fmt.Sprintf("NumWaiters=%d", n), func(t *testing.T) { + testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer testCancel() + + trigger := make(chan struct{}) + defer close(trigger) + + fn := func(ctx context.Context) (int, error) { + select { + case <-ctx.Done(): + return 0, ctx.Err() + case <-trigger: + t.Fatal("unexpected trigger; want all callers to cancel") + return 0, errors.New("unexpected trigger") + } + } + + // Launch N goroutines that all wait on the same key. + var ( + g Group[string, int] + chs []<-chan Result[int] + cancels []context.CancelFunc + ) + for i := 0; i < n; i++ { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cancels = append(cancels, cancel) + + ch := g.DoChanContext(ctx, "key", fn) + chs = append(chs, ch) + } + + // Now that everything is waiting, cancel all the contexts. + for _, cancel := range cancels { + cancel() + } + + // Wait for a result from each channel. They + // should all return an error showing a context + // cancel. + for _, ch := range chs { + select { + case res := <-ch: + if !errors.Is(res.Err, context.Canceled) { + t.Errorf("unexpected error; got %v, want context.Canceled", res.Err) + } + case <-testCtx.Done(): + t.Fatal("test timed out") + } + } + }) + } + }) +} + +func assertOKResult[V comparable](t testing.TB, res Result[V], want V) { + if res.Err != nil { + t.Fatalf("unexpected error: %v", res.Err) + } + if res.Val != want { + t.Fatalf("unexpected value; got %v, want %v", res.Val, want) + } +}