/
each.go
67 lines (54 loc) · 1.17 KB
/
each.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
package itertools
import (
"errors"
"reflect"
"sync"
)
var (
NotFuncError = errors.New("not a func type")
InvalidArgError = errors.New("invalid argument count")
)
func CEach(iter, f interface{}) {
if err := validateEachFunction(f); err == nil {
var wg sync.WaitGroup
for p := range Iterate(iter) {
wg.Add(1)
go func(pp Pair) {
defer wg.Done()
runEach(f, p)
}(p)
}
wg.Wait()
}
}
func Each(iter, f interface{}) {
if err := validateEachFunction(f); err == nil {
for p := range Iterate(iter) {
runEach(f, p)
}
}
}
func runEach(f interface{}, p Pair) {
function := reflect.TypeOf(f)
args := []reflect.Value{}
switch function.NumIn() {
case 1:
val := reflect.ValueOf(p.Second).Convert(function.In(0))
args = []reflect.Value{val}
default:
val1 := reflect.ValueOf(p.First).Convert(function.In(0))
val2 := reflect.ValueOf(p.Second).Convert(function.In(1))
args = []reflect.Value{val1, val2}
}
reflect.ValueOf(f).Call(args)
}
func validateEachFunction(f interface{}) (err error) {
function := reflect.TypeOf(f)
if function.Kind() != reflect.Func {
err = NotFuncError
}
if function.NumIn() > 2 {
err = InvalidArgError
}
return
}