/
contextwaitgroup.go
85 lines (79 loc) · 1.25 KB
/
contextwaitgroup.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
package util
import "context"
type ContextWaitGroup struct {
count int
ctx context.Context
nCh chan int
wCh chan chan struct{}
notify []chan struct{}
}
func NewContextWaitGroup(ctx context.Context) *ContextWaitGroup {
wg := &ContextWaitGroup{
ctx: ctx,
nCh: make(chan int),
wCh: make(chan chan struct{}),
}
go wg.loop()
return wg
}
func (c *ContextWaitGroup) loop() {
done := func() {
for _, ch := range c.notify {
close(ch)
}
c.notify = nil
}
defer done()
mainLoop:
for {
select {
case ch := <-c.wCh:
if c.count == 0 {
close(ch)
continue
}
c.notify = append(c.notify, ch)
case <-c.ctx.Done():
break mainLoop
case n := <-c.nCh:
c.count += n
if c.count == 0 {
done()
continue
}
if c.count < 0 {
panic("Done() called too many times")
}
}
}
cleanup:
for {
select {
case <-c.nCh:
case ch := <-c.wCh:
close(ch)
default:
break cleanup
}
}
}
func (c *ContextWaitGroup) Add(n int) {
if c.ctx.Err() != nil {
return
}
c.nCh <- n
}
func (c *ContextWaitGroup) WaitCh() <-chan struct{} {
ch := make(chan struct{})
c.wCh <- ch
return ch
}
func (c *ContextWaitGroup) Done() {
if c.ctx.Err() != nil {
return
}
c.nCh <- -1
}
func (c *ContextWaitGroup) Wait() {
<-c.WaitCh()
}