-
Notifications
You must be signed in to change notification settings - Fork 568
/
util.go
143 lines (133 loc) · 4.26 KB
/
util.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package task
import (
"context"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/anypb"
taskapi "github.com/pachyderm/pachyderm/v2/src/task"
"github.com/pachyderm/pachyderm/v2/src/internal/errors"
"github.com/pachyderm/pachyderm/v2/src/internal/log"
"github.com/pachyderm/pachyderm/v2/src/internal/taskchain"
)
// DoOrdered processes tasks in parallel, but returns outputs in order via the provided callback cb.
func DoOrdered(ctx context.Context, doer Doer, inputs chan *anypb.Any, parallelism int, cb CollectFunc) error {
taskChain := taskchain.New(ctx, semaphore.NewWeighted(int64(parallelism)))
for {
select {
case input, ok := <-inputs:
if !ok {
return taskChain.Wait()
}
if err := taskChain.CreateTask(func(context.Context) (func() error, error) {
result, err := DoOne(ctx, doer, input)
if err != nil {
return nil, errors.EnsureStack(err)
}
return func() error {
return cb(-1, result, nil)
}, nil
}); err != nil {
return errors.EnsureStack(err)
}
case <-ctx.Done():
return errors.EnsureStack(context.Cause(ctx))
}
}
}
// DoOne executes one task.
// NOTE: This interface is much less performant than the stream / batch interfaces for many tasks.
// Only use this interface for development / a small number of tasks.
func DoOne(ctx context.Context, doer Doer, input *anypb.Any) (*anypb.Any, error) {
var result *anypb.Any
if err := DoBatch(ctx, doer, []*anypb.Any{input}, func(_ int64, output *anypb.Any, err error) error {
if err != nil {
return err
}
result = output
return nil
}); err != nil {
return nil, err
}
return result, nil
}
// DoBatch executes a batch of tasks.
func DoBatch(ctx context.Context, doer Doer, inputs []*anypb.Any, cb CollectFunc) error {
var eg errgroup.Group
inputChan := make(chan *anypb.Any)
eg.Go(func() error {
return errors.EnsureStack(doer.Do(ctx, inputChan, cb))
})
eg.Go(func() error {
for _, input := range inputs {
select {
case inputChan <- input:
case <-ctx.Done():
return errors.EnsureStack(context.Cause(ctx))
}
}
close(inputChan)
return nil
})
return errors.EnsureStack(eg.Wait())
}
func translateTaskState(state State) taskapi.State {
switch state {
case State_RUNNING:
return taskapi.State_RUNNING
case State_SUCCESS:
return taskapi.State_SUCCESS
case State_FAILURE:
return taskapi.State_FAILURE
}
return taskapi.State_UNKNOWN
}
// List implements the functionality for an arbitrary service's ListTask gRPC
func List(ctx context.Context, svc Service, req *taskapi.ListTaskRequest, send func(info *taskapi.TaskInfo) error) error {
var marshaler protojson.MarshalOptions
return errors.EnsureStack(svc.List(ctx, req.Group.Namespace, req.Group.Group, func(namespace, group string, data *Task, claimed bool) error {
state := translateTaskState(data.State)
if claimed {
state = taskapi.State_CLAIMED
}
var inputJSON []byte
input, err := data.Input.UnmarshalNew()
if err != nil {
// unmarshalling might fail due to the input type not being registered,
// don't let this interfere with fetching or counting tasks
log.Error(ctx, "couldn't unmarshal task input", zap.Error(err), zap.String("taskType", data.GetInput().TypeUrl), zap.String("taskID", data.GetId()))
} else {
inputJSON, err = marshaler.Marshal(input)
if err != nil {
log.Error(ctx, "couldn't marshal task input", zap.Error(err), zap.String("taskType", data.GetInput().TypeUrl), zap.String("taskID", data.GetId()))
}
}
info := &taskapi.TaskInfo{
Id: data.Id,
Group: &taskapi.Group{
Namespace: namespace,
Group: group,
},
State: state,
Reason: data.Reason,
InputType: data.Input.TypeUrl,
InputData: string(inputJSON),
}
return errors.EnsureStack(send(info))
}))
}
// Count returns the number of tasks and claims in the given namespace and group (if nonempty)
func Count(ctx context.Context, service Service, namespace, group string) (tasks int64, claims int64, retErr error) {
retErr = errors.EnsureStack(service.List(ctx, namespace, group, func(_, _ string, _ *Task, claimed bool) error {
tasks++
if claimed {
claims++
}
return nil
}))
if retErr != nil {
return 0, 0, retErr
}
return
}