Skip to content

Commit

Permalink
Ensure WaitGroup.Done() is always called
Browse files Browse the repository at this point in the history
  • Loading branch information
bsdelf authored and traefiker committed Aug 26, 2019
1 parent 6fed76a commit a8c73f7
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 4 deletions.
19 changes: 15 additions & 4 deletions pkg/safe/routine.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,23 @@ func (p *Pool) GoCtx(goroutine routineCtx) {
p.routinesCtx = append(p.routinesCtx, goroutine)
p.waitGroup.Add(1)
Go(func() {
defer p.waitGroup.Done()
goroutine(p.ctx)
p.waitGroup.Done()
})
p.lock.Unlock()
}

// addGo adds a recoverable goroutine, and can be stopped with stop chan
func (p *Pool) addGo(goroutine func(stop chan bool)) {
p.lock.Lock()
newRoutine := routine{
goroutine: goroutine,
stop: make(chan bool, 1),
}
p.routines = append(p.routines, newRoutine)
p.lock.Unlock()
}

// Go starts a recoverable goroutine, and can be stopped with stop chan
func (p *Pool) Go(goroutine func(stop chan bool)) {
p.lock.Lock()
Expand All @@ -75,8 +86,8 @@ func (p *Pool) Go(goroutine func(stop chan bool)) {
p.routines = append(p.routines, newRoutine)
p.waitGroup.Add(1)
Go(func() {
defer p.waitGroup.Done()
goroutine(newRoutine.stop)
p.waitGroup.Done()
})
p.lock.Unlock()
}
Expand Down Expand Up @@ -112,16 +123,16 @@ func (p *Pool) Start() {
p.waitGroup.Add(1)
p.routines[i].stop = make(chan bool, 1)
Go(func() {
defer p.waitGroup.Done()
p.routines[i].goroutine(p.routines[i].stop)
p.waitGroup.Done()
})
}

for _, routine := range p.routinesCtx {
p.waitGroup.Add(1)
Go(func() {
defer p.waitGroup.Done()
routine(p.ctx)
p.waitGroup.Done()
})
}
}
Expand Down
67 changes: 67 additions & 0 deletions pkg/safe/routine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,73 @@ func TestPoolStartWithStopChan(t *testing.T) {
}
}

func TestPoolCleanupWithGoPanicking(t *testing.T) {
testRoutine := func(stop chan bool) {
panic("BOOM")
}

testCtxRoutine := func(ctx context.Context) {
panic("BOOM")
}

testCases := []struct {
desc string
fn func(*Pool)
}{
{
desc: "Go()",
fn: func(p *Pool) {
p.Go(testRoutine)
},
},
{
desc: "addGo() and Start()",
fn: func(p *Pool) {
p.addGo(testRoutine)
p.Start()
},
},
{
desc: "GoCtx()",
fn: func(p *Pool) {
p.GoCtx(testCtxRoutine)
},
},
{
desc: "AddGoCtx() and Start()",
fn: func(p *Pool) {
p.AddGoCtx(testCtxRoutine)
p.Start()
},
},
}

for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
p := NewPool(context.Background())

timer := time.NewTimer(500 * time.Millisecond)
defer timer.Stop()

test.fn(p)

testDone := make(chan bool, 1)
go func() {
p.Cleanup()
testDone <- true
}()

select {
case <-timer.C:
t.Fatalf("Pool.Cleanup() did not complete in time with a panicking goroutine")
case <-testDone:
return
}
})
}
}

func TestGoroutineRecover(t *testing.T) {
// if recover fails the test will panic
Go(func() {
Expand Down

0 comments on commit a8c73f7

Please sign in to comment.