forked from RichardKnop/machinery
-
Notifications
You must be signed in to change notification settings - Fork 0
/
worker.go
437 lines (368 loc) · 12.9 KB
/
worker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
package machinery
import (
"errors"
"fmt"
"net/url"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/opentracing/opentracing-go"
"github.com/wrhb123/machinery/backends/amqp"
"github.com/wrhb123/machinery/brokers/errs"
"github.com/wrhb123/machinery/log"
"github.com/wrhb123/machinery/retry"
"github.com/wrhb123/machinery/tasks"
"github.com/wrhb123/machinery/tracing"
)
// Worker represents a single worker process
type Worker struct {
server *Server
ConsumerTag string
Concurrency int
Queue string
errorHandler func(err error)
preTaskHandler func(*tasks.Signature)
postTaskHandler func(*tasks.Signature)
preConsumeHandler func(*Worker) bool
}
var (
// ErrWorkerQuitGracefully is return when worker quit gracefully
ErrWorkerQuitGracefully = errors.New("Worker quit gracefully")
// ErrWorkerQuitGracefully is return when worker quit abruptly
ErrWorkerQuitAbruptly = errors.New("Worker quit abruptly")
)
// Launch starts a new worker process. The worker subscribes
// to the default queue and processes incoming registered tasks
func (worker *Worker) Launch() error {
errorsChan := make(chan error)
worker.LaunchAsync(errorsChan)
return <-errorsChan
}
// LaunchAsync is a non blocking version of Launch
func (worker *Worker) LaunchAsync(errorsChan chan<- error) {
cnf := worker.server.GetConfig()
broker := worker.server.GetBroker()
// Log some useful information about worker configuration
log.INFO.Printf("Launching a worker with the following settings:")
log.INFO.Printf("- Broker: %s", RedactURL(cnf.Broker))
if worker.Queue == "" {
log.INFO.Printf("- DefaultQueue: %s", cnf.DefaultQueue)
} else {
log.INFO.Printf("- CustomQueue: %s", worker.Queue)
}
log.INFO.Printf("- ResultBackend: %s", RedactURL(cnf.ResultBackend))
if cnf.AMQP != nil {
log.INFO.Printf("- AMQP: %s", cnf.AMQP.Exchange)
log.INFO.Printf(" - Exchange: %s", cnf.AMQP.Exchange)
log.INFO.Printf(" - ExchangeType: %s", cnf.AMQP.ExchangeType)
log.INFO.Printf(" - BindingKey: %s", cnf.AMQP.BindingKey)
log.INFO.Printf(" - PrefetchCount: %d", cnf.AMQP.PrefetchCount)
}
var signalWG sync.WaitGroup
// Goroutine to start broker consumption and handle retries when broker connection dies
go func() {
for {
retry, err := broker.StartConsuming(worker.ConsumerTag, worker.Concurrency, worker)
if retry {
if worker.errorHandler != nil {
worker.errorHandler(err)
} else {
log.WARNING.Printf("Broker failed with error: %s", err)
}
} else {
signalWG.Wait()
errorsChan <- err // stop the goroutine
return
}
}
}()
if !cnf.NoUnixSignals {
sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
var signalsReceived uint
// Goroutine Handle SIGINT and SIGTERM signals
go func() {
for s := range sig {
log.WARNING.Printf("Signal received: %v", s)
signalsReceived++
if signalsReceived < 2 {
// After first Ctrl+C start quitting the worker gracefully
log.WARNING.Print("Waiting for running tasks to finish before shutting down")
signalWG.Add(1)
go func() {
worker.Quit()
errorsChan <- ErrWorkerQuitGracefully
signalWG.Done()
}()
} else {
// Abort the program when user hits Ctrl+C second time in a row
errorsChan <- ErrWorkerQuitAbruptly
}
}
}()
}
}
// CustomQueue returns Custom Queue of the running worker process
func (worker *Worker) CustomQueue() string {
return worker.Queue
}
// Quit tears down the running worker process
func (worker *Worker) Quit() {
worker.server.GetBroker().StopConsuming()
}
// Process handles received tasks and triggers success/error callbacks
func (worker *Worker) Process(signature *tasks.Signature) error {
// If the task is not registered with this worker, do not continue
// but only return nil as we do not want to restart the worker process
if !worker.server.IsTaskRegistered(signature.Name) {
return nil
}
taskFunc, err := worker.server.GetRegisteredTask(signature.Name)
if err != nil {
return nil
}
// Update task state to RECEIVED
if err = worker.server.GetBackend().SetStateReceived(signature); err != nil {
return fmt.Errorf("Set state to 'received' for task %s returned error: %s", signature.UUID, err)
}
// Prepare task for processing
task, err := tasks.NewWithSignature(taskFunc, signature)
// if this failed, it means the task is malformed, probably has invalid
// signature, go directly to task failed without checking whether to retry
if err != nil {
worker.taskFailed(signature, err)
return err
}
// try to extract trace span from headers and add it to the function context
// so it can be used inside the function if it has context.Context as the first
// argument. Start a new span if it isn't found.
taskSpan := tracing.StartSpanFromHeaders(signature.Headers, signature.Name)
tracing.AnnotateSpanWithSignatureInfo(taskSpan, signature)
task.Context = opentracing.ContextWithSpan(task.Context, taskSpan)
// Update task state to STARTED
if err = worker.server.GetBackend().SetStateStarted(signature); err != nil {
return fmt.Errorf("Set state to 'started' for task %s returned error: %s", signature.UUID, err)
}
//Run handler before the task is called
if worker.preTaskHandler != nil {
worker.preTaskHandler(signature)
}
//Defer run handler for the end of the task
if worker.postTaskHandler != nil {
defer worker.postTaskHandler(signature)
}
// Call the task
results, err := task.Call()
if err != nil {
// If a tasks.ErrRetryTaskLater was returned from the task,
// retry the task after specified duration
retriableErr, ok := interface{}(err).(tasks.ErrRetryTaskLater)
if ok {
return worker.retryTaskIn(signature, retriableErr.RetryIn())
}
// Otherwise, execute default retry logic based on signature.RetryCount
// and signature.RetryTimeout values
if signature.RetryCount > 0 {
return worker.taskRetry(signature)
}
return worker.taskFailed(signature, err)
}
return worker.taskSucceeded(signature, results)
}
// retryTask decrements RetryCount counter and republishes the task to the queue
func (worker *Worker) taskRetry(signature *tasks.Signature) error {
// Update task state to RETRY
if err := worker.server.GetBackend().SetStateRetry(signature); err != nil {
return fmt.Errorf("Set state to 'retry' for task %s returned error: %s", signature.UUID, err)
}
// Decrement the retry counter, when it reaches 0, we won't retry again
signature.RetryCount--
// Increase retry timeout
signature.RetryTimeout = retry.FibonacciNext(signature.RetryTimeout)
// Delay task by signature.RetryTimeout seconds
eta := time.Now().UTC().Add(time.Second * time.Duration(signature.RetryTimeout))
signature.ETA = &eta
log.WARNING.Printf("Task %s failed. Going to retry in %d seconds.", signature.UUID, signature.RetryTimeout)
// Send the task back to the queue
_, err := worker.server.SendTask(signature)
return err
}
// taskRetryIn republishes the task to the queue with ETA of now + retryIn.Seconds()
func (worker *Worker) retryTaskIn(signature *tasks.Signature, retryIn time.Duration) error {
// Update task state to RETRY
if err := worker.server.GetBackend().SetStateRetry(signature); err != nil {
return fmt.Errorf("Set state to 'retry' for task %s returned error: %s", signature.UUID, err)
}
// Delay task by retryIn duration
eta := time.Now().UTC().Add(retryIn)
signature.ETA = &eta
log.WARNING.Printf("Task %s failed. Going to retry in %.0f seconds.", signature.UUID, retryIn.Seconds())
// Send the task back to the queue
_, err := worker.server.SendTask(signature)
return err
}
// taskSucceeded updates the task state and triggers success callbacks or a
// chord callback if this was the last task of a group with a chord callback
func (worker *Worker) taskSucceeded(signature *tasks.Signature, taskResults []*tasks.TaskResult) error {
// Update task state to SUCCESS
if err := worker.server.GetBackend().SetStateSuccess(signature, taskResults); err != nil {
return fmt.Errorf("Set state to 'success' for task %s returned error: %s", signature.UUID, err)
}
// Log human readable results of the processed task
var debugResults = "[]"
results, err := tasks.ReflectTaskResults(taskResults)
if err != nil {
log.WARNING.Print(err)
} else {
debugResults = tasks.HumanReadableResults(results)
}
log.DEBUG.Printf("Processed task %s. Results = %s", signature.UUID, debugResults)
// Trigger success callbacks
for _, successTask := range signature.OnSuccess {
if signature.Immutable == false {
// Pass results of the task to success callbacks
for _, taskResult := range taskResults {
successTask.Args = append(successTask.Args, tasks.Arg{
Type: taskResult.Type,
Value: taskResult.Value,
})
}
}
worker.server.SendTask(successTask)
}
// If the task was not part of a group, just return
if signature.GroupUUID == "" {
return nil
}
// There is no chord callback, just return
if signature.ChordCallback == nil {
return nil
}
// Check if all task in the group has completed
groupCompleted, err := worker.server.GetBackend().GroupCompleted(
signature.GroupUUID,
signature.GroupTaskCount,
)
if err != nil {
return fmt.Errorf("Completed check for group %s returned error: %s", signature.GroupUUID, err)
}
// If the group has not yet completed, just return
if !groupCompleted {
return nil
}
// Defer purging of group meta queue if we are using AMQP backend
if worker.hasAMQPBackend() {
defer worker.server.GetBackend().PurgeGroupMeta(signature.GroupUUID)
}
// Trigger chord callback
shouldTrigger, err := worker.server.GetBackend().TriggerChord(signature.GroupUUID)
if err != nil {
return fmt.Errorf("Triggering chord for group %s returned error: %s", signature.GroupUUID, err)
}
// Chord has already been triggered
if !shouldTrigger {
return nil
}
// Get task states
taskStates, err := worker.server.GetBackend().GroupTaskStates(
signature.GroupUUID,
signature.GroupTaskCount,
)
if err != nil {
log.ERROR.Printf(
"Failed to get tasks states for group:[%s]. Task count:[%d]. The chord may not be triggered. Error:[%s]",
signature.GroupUUID,
signature.GroupTaskCount,
err,
)
return nil
}
// Append group tasks' return values to chord task if it's not immutable
for _, taskState := range taskStates {
if !taskState.IsSuccess() {
return nil
}
if signature.ChordCallback.Immutable == false {
// Pass results of the task to the chord callback
for _, taskResult := range taskState.Results {
signature.ChordCallback.Args = append(signature.ChordCallback.Args, tasks.Arg{
Type: taskResult.Type,
Value: taskResult.Value,
})
}
}
}
// Send the chord task
_, err = worker.server.SendTask(signature.ChordCallback)
if err != nil {
return err
}
return nil
}
// taskFailed updates the task state and triggers error callbacks
func (worker *Worker) taskFailed(signature *tasks.Signature, taskErr error) error {
// Update task state to FAILURE
if err := worker.server.GetBackend().SetStateFailure(signature, taskErr.Error()); err != nil {
return fmt.Errorf("Set state to 'failure' for task %s returned error: %s", signature.UUID, err)
}
if worker.errorHandler != nil {
worker.errorHandler(taskErr)
} else {
log.ERROR.Printf("Failed processing task %s. Error = %v", signature.UUID, taskErr)
}
// Trigger error callbacks
for _, errorTask := range signature.OnError {
// Pass error as a first argument to error callbacks
args := append([]tasks.Arg{{
Type: "string",
Value: taskErr.Error(),
}}, errorTask.Args...)
errorTask.Args = args
worker.server.SendTask(errorTask)
}
if signature.StopTaskDeletionOnError {
return errs.ErrStopTaskDeletion
}
return nil
}
// Returns true if the worker uses AMQP backend
func (worker *Worker) hasAMQPBackend() bool {
_, ok := worker.server.GetBackend().(*amqp.Backend)
return ok
}
// SetErrorHandler sets a custom error handler for task errors
// A default behavior is just to log the error after all the retry attempts fail
func (worker *Worker) SetErrorHandler(handler func(err error)) {
worker.errorHandler = handler
}
// SetPreTaskHandler sets a custom handler func before a job is started
func (worker *Worker) SetPreTaskHandler(handler func(*tasks.Signature)) {
worker.preTaskHandler = handler
}
// SetPostTaskHandler sets a custom handler for the end of a job
func (worker *Worker) SetPostTaskHandler(handler func(*tasks.Signature)) {
worker.postTaskHandler = handler
}
// SetPreConsumeHandler sets a custom handler for the end of a job
func (worker *Worker) SetPreConsumeHandler(handler func(*Worker) bool) {
worker.preConsumeHandler = handler
}
// GetServer returns server
func (worker *Worker) GetServer() *Server {
return worker.server
}
func (worker *Worker) PreConsumeHandler() bool {
if worker.preConsumeHandler == nil {
return true
}
return worker.preConsumeHandler(worker)
}
func RedactURL(urlString string) string {
u, err := url.Parse(urlString)
if err != nil {
return urlString
}
return fmt.Sprintf("%s://%s", u.Scheme, u.Host)
}