Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

util/singleflight: add DoChanContext #12003

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
94 changes: 94 additions & 0 deletions util/singleflight/singleflight.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ package singleflight // import "tailscale.com/util/singleflight"

import (
"bytes"
"context"
"errors"
"fmt"
"runtime"
"runtime/debug"
"sync"
"sync/atomic"
)

// errGoexit indicates the runtime.Goexit was called in
Expand Down Expand Up @@ -69,6 +71,11 @@ type call[V any] struct {
// not written after the WaitGroup is done.
dups int
chans []chan<- Result[V]

// These fields are only written when the call is being created, and
// only in the DoChanContext method.
cancel context.CancelFunc
ctxWaiters atomic.Int64
}

// Group represents a class of work and forms a namespace in
Expand Down Expand Up @@ -143,6 +150,93 @@ func (g *Group[K, V]) DoChan(key K, fn func() (V, error)) <-chan Result[V] {
return ch
}

// DoChanContext is like [Group.DoChan], but supports context cancelation. The
// context passed to the fn function is a context that is canceled only when
// there are no callers waiting on a result (i.e. all callers have canceled
// their contexts).
//
// The context that is passed to the fn function is not derived from any of the
// input contexts, so context values will not be propagated. If context values
// are needed, they must be propagated explicitly.
//
// The returned channel will not be closed. The Result.Err field is set to the
// context error if the context is canceled.
func (g *Group[K, V]) DoChanContext(ctx context.Context, key K, fn func(context.Context) (V, error)) <-chan Result[V] {
andrew-d marked this conversation as resolved.
Show resolved Hide resolved
ch := make(chan Result[V], 1)
g.mu.Lock()
if g.m == nil {
g.m = make(map[K]*call[V])
}
c, ok := g.m[key]
if ok {
// Call already in progress; add to the waiters list and then
// release the mutex.
c.dups++
c.ctxWaiters.Add(1)
c.chans = append(c.chans, ch)
g.mu.Unlock()
} else {
// The call hasn't been started yet; we need to start it.
//
// Create a context that is not canceled when the parent context is,
// but otherwise propagates all values.
callCtx, callCancel := context.WithCancel(context.Background())
c = &call[V]{
chans: []chan<- Result[V]{ch},
cancel: callCancel,
}
c.wg.Add(1)
c.ctxWaiters.Add(1) // one caller waiting
g.m[key] = c
g.mu.Unlock()

andrew-d marked this conversation as resolved.
Show resolved Hide resolved
// Wrap our function to provide the context.
go g.doCall(c, key, func() (V, error) {
return fn(callCtx)
})
}

// Instead of returning the channel directly, we need to track
// when the call finishes so we can handle context cancelation.
// Do so by creating an final channel that gets the
// result and hooking that up to the wait function.
final := make(chan Result[V], 1)
go g.waitCtx(ctx, c, ch, final)
return final
}

// waitCtx will wait on the provided call to finish, or the context to be done.
// If the context is done, and this is the last waiter, then the context
// provided to the underlying function will be canceled.
func (g *Group[K, V]) waitCtx(ctx context.Context, c *call[V], result <-chan Result[V], output chan<- Result[V]) {
var res Result[V]
select {
case <-ctx.Done():
case res = <-result:
}
Copy link
Member

@dsnet dsnet Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the context was already canceled when a result becomes available, should one take precedence?

If we want cancelation to take precedence, we could do:

select {
case <-ctx.Done():
case res = <-result:
}

...

if err := ctx.Err(); err != nil {
	res = Result[V]{Err: err}
}
output <- res

Thus, context errors take precedence over results.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think preferring context cancelation makes sense; if the caller gets a value, it's likely that it'll just end up continuing until a later point and then checking cancelation and aborting, which is essentially wasted work.


// Decrement the caller count, and if we're the last one, cancel the
// context we created. Do this in all cases, error and otherwise, so we
// don't leak goroutines.
//
// Also wait on the call to finish, so we know that the call has
// finished executing after the last caller has returned.
if c.ctxWaiters.Add(-1) == 0 {
c.cancel()
c.wg.Wait()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have 10 concurrent waiters, I find it odd that 9 return immediately, but then the 10th must wait, but maybe that's the right semantic?

}

// Ensure that context cancelation takes precedence over a value being
// available by checking ctx.Err() before sending the result to the
// caller. The select above will nondeterministically pick a case if a
// result is available and the ctx.Done channel is closed, so we check
// again here.
if err := ctx.Err(); err != nil {
res = Result[V]{Err: err}
}
output <- res
}

// doCall handles the single call for a key.
func (g *Group[K, V]) doCall(c *call[V], key K, fn func() (V, error)) {
normalReturn := false
Expand Down
153 changes: 153 additions & 0 deletions util/singleflight/singleflight_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package singleflight

import (
"bytes"
"context"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -321,3 +322,155 @@ func TestPanicDoSharedByDoChan(t *testing.T) {
t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in Do")
}
}

func TestDoChanContext(t *testing.T) {
t.Run("Basic", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var g Group[string, int]
ch := g.DoChanContext(ctx, "key", func(_ context.Context) (int, error) {
return 1, nil
})
ret := <-ch
assertOKResult(t, ret, 1)
})

t.Run("DoesNotPropagateValues", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

key := new(int)
const value = "hello world"

ctx = context.WithValue(ctx, key, value)

var g Group[string, int]
ch := g.DoChanContext(ctx, "foobar", func(ctx context.Context) (int, error) {
if _, ok := ctx.Value(key).(string); ok {
t.Error("expected no value, but was present in context")
}
return 1, nil
})
ret := <-ch
assertOKResult(t, ret, 1)
})

t.Run("NoCancelWhenWaiters", func(t *testing.T) {
testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer testCancel()

trigger := make(chan struct{})

ctx1, cancel1 := context.WithCancel(context.Background())
defer cancel1()
ctx2, cancel2 := context.WithCancel(context.Background())
defer cancel2()

fn := func(ctx context.Context) (int, error) {
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-trigger:
return 1234, nil
}
}

// Create two waiters, then cancel the first before we trigger
// the function to return a value. This shouldn't result in a
// context canceled error.
var g Group[string, int]
ch1 := g.DoChanContext(ctx1, "key", fn)
ch2 := g.DoChanContext(ctx2, "key", fn)

cancel1()

// The first channel, now that it's canceled, should return a
// context canceled error.
select {
case res := <-ch1:
if !errors.Is(res.Err, context.Canceled) {
t.Errorf("unexpected error; got %v, want context.Canceled", res.Err)
}
case <-testCtx.Done():
t.Fatal("test timed out")
}

// Actually return
close(trigger)
res := <-ch2
assertOKResult(t, res, 1234)
})

t.Run("AllCancel", func(t *testing.T) {
for _, n := range []int{1, 2, 10, 20} {
t.Run(fmt.Sprintf("NumWaiters=%d", n), func(t *testing.T) {
testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer testCancel()

trigger := make(chan struct{})
defer close(trigger)

fn := func(ctx context.Context) (int, error) {
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-trigger:
t.Error("unexpected trigger; want all callers to cancel")
return 0, errors.New("unexpected trigger")
}
}

// Launch N goroutines that all wait on the same key.
var (
g Group[string, int]
chs []<-chan Result[int]
cancels []context.CancelFunc
)
for i := range n {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cancels = append(cancels, cancel)

ch := g.DoChanContext(ctx, "key", fn)
chs = append(chs, ch)

// Every third goroutine should cancel
// immediately, which better tests the
// cancel logic.
if i%3 == 0 {
cancel()
}
}

// Now that everything is waiting, cancel all the contexts.
for _, cancel := range cancels {
cancel()
}

// Wait for a result from each channel. They
// should all return an error showing a context
// cancel.
for _, ch := range chs {
select {
case res := <-ch:
if !errors.Is(res.Err, context.Canceled) {
t.Errorf("unexpected error; got %v, want context.Canceled", res.Err)
}
case <-testCtx.Done():
t.Fatal("test timed out")
}
}
})
}
})
}

func assertOKResult[V comparable](t testing.TB, res Result[V], want V) {
if res.Err != nil {
t.Fatalf("unexpected error: %v", res.Err)
}
if res.Val != want {
t.Fatalf("unexpected value; got %v, want %v", res.Val, want)
}
}