Skip to content

Commit

Permalink
Change task.Group.Wait to accept context instead of deadline duration
Browse files Browse the repository at this point in the history
  • Loading branch information
perkon committed Oct 22, 2019
1 parent faf1c55 commit ac27d4e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 29 deletions.
28 changes: 6 additions & 22 deletions task/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,8 @@
package task

import (
"errors"
"context"
"sync"
"time"
)

var (
// ErrDeadline is returned when the group wait deadline is exceeded.
ErrDeadline = errors.New("deadline exceeded")
)

// Group is used to wait for a group of tasks to finish.
Expand Down Expand Up @@ -73,21 +67,15 @@ func (g *Group) Go(tasks ...TaskFunc) {

// Wait until all tasks are stopped.
// Returns the first encountered error if any.
// If deadline is exceeded all tasks are canceled and the returned error is the deadline error.
// IsDeadlineError func can be used to check if the tasks were canceled due a deadline.
// Passing 0 for deadline means there will be no deadline, and Wait is blocked until all of the
// tasks are finished.
func (g *Group) Wait(deadline time.Duration) error {
if deadline > 0 {
timer := time.NewTimer(deadline)
defer timer.Stop()

// If the context is done all tasks are canceled and the context error is returned.
func (g *Group) Wait(ctx context.Context) error {
if ctx != context.TODO() {
go func() {
select {
case <-g.cancelCh:
return
case <-timer.C:
g.cancelWithError(ErrDeadline)
case <-ctx.Done():
g.cancelWithError(ctx.Err())
}
}()
}
Expand Down Expand Up @@ -124,7 +112,3 @@ func (g *Group) Cancel() {

g.unsafeCancel()
}

func IsDeadlineError(err error) bool {
return err == ErrDeadline
}
17 changes: 10 additions & 7 deletions task/group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
package task_test

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/sumup-oss/go-pkgs/task"
)

Expand Down Expand Up @@ -64,7 +65,7 @@ func TestGroup_Go(t *testing.T) {
foo.RunUntil <- nil
bar.RunUntil <- nil

err := group.Wait(time.Hour)
err := group.Wait(context.Background())
assert.NoError(t, err)

assert.Equal(t, 1, foo.RunCount)
Expand All @@ -89,7 +90,7 @@ func TestGroup_Go(t *testing.T) {
foo.RunUntil <- assert.AnError
}()

err := group.Wait(time.Hour)
err := group.Wait(context.Background())
assert.EqualError(t, err, assert.AnError.Error())

assert.Equal(t, 1, foo.RunCount)
Expand All @@ -111,8 +112,10 @@ func TestGroup_Go(t *testing.T) {
<-foo.RunReady
<-bar.RunReady

err := group.Wait(1)
assert.True(t, task.IsDeadlineError(err))
ctx, cancel := context.WithTimeout(context.Background(), 1)
defer cancel()
err := group.Wait(ctx)
assert.Equal(t, context.DeadlineExceeded, err)

assert.Equal(t, 1, foo.RunCount)
assert.Equal(t, 1, bar.RunCount)
Expand All @@ -130,7 +133,7 @@ func TestGroup_Go(t *testing.T) {
group.Cancel()
group.Go(foo.Run, bar.Run)

err := group.Wait(time.Hour)
err := group.Wait(context.Background())
assert.NoError(t, err)

assert.Equal(t, 0, foo.RunCount)
Expand All @@ -157,7 +160,7 @@ func TestGroup_Cancel(t *testing.T) {
group.Cancel()
}()

err := group.Wait(0)
err := group.Wait(context.Background())
assert.NoError(t, err)

assert.Equal(t, 1, foo.RunCount)
Expand Down

0 comments on commit ac27d4e

Please sign in to comment.