Skip to content

Commit

Permalink
adding CeleryTask interface in the flow
Browse files Browse the repository at this point in the history
  • Loading branch information
shicky committed Sep 11, 2016
1 parent 4745b0b commit 847801b
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 12 deletions.
6 changes: 3 additions & 3 deletions backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestGetResult(t *testing.T) {

// value must be float64 for testing due to json limitation
value := reflect.ValueOf(rand.Float64())
resultMessage := getResultMessage(&value)
resultMessage := getReflectionResultMessage(&value)
defer releaseResultMessage(resultMessage)
messageBytes, err := json.Marshal(resultMessage)
if err != nil {
Expand All @@ -51,7 +51,7 @@ func TestSetResult(t *testing.T) {
backend := NewRedisCeleryBackend("localhost:6379", "")
taskID := uuid.NewV4().String()
value := reflect.ValueOf(rand.Float64())
resultMessage := getResultMessage(&value)
resultMessage := getReflectionResultMessage(&value)
releaseResultMessage(resultMessage)
// set result
err := backend.SetResult(taskID, resultMessage)
Expand Down Expand Up @@ -82,7 +82,7 @@ func TestSetGetResult(t *testing.T) {
for _, backend := range getBackends() {
taskID := uuid.NewV4().String()
value := reflect.ValueOf(rand.Float64())
resultMessage := getResultMessage(&value)
resultMessage := getReflectionResultMessage(&value)
defer releaseResultMessage(resultMessage)
// set result
err := backend.SetResult(taskID, resultMessage)
Expand Down
8 changes: 8 additions & 0 deletions gocelery.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ func (cc *CeleryClient) Delay(task string, args ...interface{}) (*AsyncResult, e
}, nil
}

// CeleryTask is an interface that represents actual task
// Passing CeleryTask interface instead of function pointer
// avoids reflection and may have performance gain.
// ResultMessage must be obtained using GetResultMessage()
type CeleryTask interface {
RunTask() (*ResultMessage, error)
}

// AsyncResult is pending result
type AsyncResult struct {
taskID string
Expand Down
6 changes: 5 additions & 1 deletion message.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ func allocResultMessage() interface{} {
}
}

func getResultMessage(val *reflect.Value) *ResultMessage {
func GetResultMessage() *ResultMessage {
return resultMessagePool.Get().(*ResultMessage)
}

func getReflectionResultMessage(val *reflect.Value) *ResultMessage {
msg := resultMessagePool.Get().(*ResultMessage)
msg.Result = GetRealValue(val)
return msg
Expand Down
26 changes: 20 additions & 6 deletions worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,15 @@ func (w *CeleryWorker) StartWorker() {
log.Printf("WORKER %s task message received: %v\n", workerID, taskMessage)

// run task
val, err := w.RunTask(taskMessage)
resultMsg, err := w.RunTask(taskMessage)
if err != nil {
log.Println(err)
continue
}
defer releaseResultMessage(resultMsg)

// push result to backend
resultMessage := getResultMessage(val)
defer releaseResultMessage(resultMessage)
err = w.backend.SetResult(taskMessage.ID, resultMessage)
err = w.backend.SetResult(taskMessage.ID, resultMsg)
if err != nil {
log.Println(err)
continue
Expand Down Expand Up @@ -100,12 +99,26 @@ func (w *CeleryWorker) GetTask(name string) interface{} {
}

// RunTask runs celery task
func (w *CeleryWorker) RunTask(message *TaskMessage) (*reflect.Value, error) {
func (w *CeleryWorker) RunTask(message *TaskMessage) (*ResultMessage, error) {

// get task
task := w.GetTask(message.Task)
if task == nil {
return nil, fmt.Errorf("task %s is not registered", message.Task)
}

// convert to task interface
taskInterface, ok := task.(CeleryTask)
if ok {
return taskInterface.RunTask()
}

// use reflection to execute function ptr
taskFunc := reflect.ValueOf(task)
return runTaskFunc(&taskFunc, message)
}

func runTaskFunc(taskFunc *reflect.Value, message *TaskMessage) (*ResultMessage, error) {

// check number of arguments
numArgs := taskFunc.Type().NumIn()
Expand All @@ -132,5 +145,6 @@ func (w *CeleryWorker) RunTask(message *TaskMessage) (*reflect.Value, error) {
if len(res) == 0 {
return nil, nil
}
return &res[0], nil
//defer releaseResultMessage(resultMessage)
return getReflectionResultMessage(&res[0]), nil
}
5 changes: 3 additions & 2 deletions worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,13 @@ func TestRunTask(t *testing.T) {
Retries: 1,
ETA: "",
}
reflectVal, err := celeryWorker.RunTask(taskMessage)
resultMsg, err := celeryWorker.RunTask(taskMessage)
if err != nil {
t.Errorf("failed to run celery task %v: %v", taskMessage, err)
}

reflectRes := reflectVal.Int()
reflectRes := resultMsg.Result.(int64)

// check result
if int64(res) != reflectRes {
t.Errorf("reflect result %v is different from normal result %v", reflectRes, res)
Expand Down

0 comments on commit 847801b

Please sign in to comment.