Skip to content

Commit

Permalink
Migrating PR 153 to V2 (#234)
Browse files Browse the repository at this point in the history
Co-authored-by: Henrique Oelze <henrique.oelze@wildlifestudios.com>
  • Loading branch information
henriqueoelze and Henrique Oelze committed Aug 11, 2021
1 parent 56d2c1f commit c190e9b
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 32 deletions.
8 changes: 4 additions & 4 deletions defaultpipelines/default_struct_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ type DefaultValidator struct {
// based on the struct tags the parameter has.
// This function has the pipeline.Handler signature so
// it is possible to use it as a pipeline function
func (v *DefaultValidator) Validate(ctx context.Context, in interface{}) (interface{}, error) {
func (v *DefaultValidator) Validate(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
if in == nil {
return in, nil
return ctx, in, nil
}

v.lazyinit()
if err := v.validate.Struct(in); err != nil {
return nil, err
return ctx, nil, err
}

return in, nil
return ctx, in, nil
}

func (v *DefaultValidator) lazyinit() {
Expand Down
4 changes: 2 additions & 2 deletions defaultpipelines/default_struct_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ func TestDefaultValidator(t *testing.T) {
t.Run(tname, func(t *testing.T) {
var err error
if tbl.s == nil {
_, err = validator.Validate(context.Background(), nil)
_, _, err = validator.Validate(context.Background(), nil)
} else {
_, err = validator.Validate(context.Background(), tbl.s)
_, _, err = validator.Validate(context.Background(), tbl.s)
}

if tbl.shouldFail {
Expand Down
2 changes: 1 addition & 1 deletion defaultpipelines/struct_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
//
// The default struct validator used by pitaya is https://github.com/go-playground/validator.
type StructValidator interface {
Validate(context.Context, interface{}) (interface{}, error)
Validate(context.Context, interface{}) (context.Context, interface{}, error)
}

// StructValidatorInstance holds the default validator
Expand Down
10 changes: 5 additions & 5 deletions pipeline/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
type (
// HandlerTempl is a function that has the same signature as a handler and will
// be called before or after handler methods
HandlerTempl func(ctx context.Context, in interface{}) (out interface{}, err error)
HandlerTempl func(ctx context.Context, in interface{}) (c context.Context, out interface{}, err error)

// AfterHandlerTempl is a function for the after handler, receives both the handler response
// and the error returned
Expand Down Expand Up @@ -71,19 +71,19 @@ func NewAfterChannel() *AfterChannel {
}

// ExecuteBeforePipeline calls registered handlers
func (p *Channel) ExecuteBeforePipeline(ctx context.Context, data interface{}) (interface{}, error) {
func (p *Channel) ExecuteBeforePipeline(ctx context.Context, data interface{}) (context.Context, interface{}, error) {
var err error
res := data
if len(p.Handlers) > 0 {
for _, h := range p.Handlers {
res, err = h(ctx, res)
ctx, res, err = h(ctx, res)
if err != nil {
logger.Log.Debugf("pitaya/handler: broken pipeline: %s", err.Error())
return res, err
return ctx, res, err
}
}
}
return res, nil
return ctx, res, nil
}

// ExecuteAfterPipeline calls registered handlers
Expand Down
12 changes: 6 additions & 6 deletions pipeline/pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ import (
)

var (
handler1 = func(ctx context.Context, in interface{}) (interface{}, error) {
return in, errors.New("ohno")
handler1 = func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
return ctx, in, errors.New("ohno")
}
handler2 = func(ctx context.Context, in interface{}) (interface{}, error) {
return nil, nil
handler2 = func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
return ctx, nil, nil
}
p = &Channel{}
)
Expand All @@ -43,7 +43,7 @@ func TestPushFront(t *testing.T) {
p.PushFront(handler2)
defer p.Clear()

_, err := p.Handlers[0](nil, nil)
_, _, err := p.Handlers[0](nil, nil)
assert.Nil(t, nil, err)
}

Expand All @@ -52,7 +52,7 @@ func TestPushBack(t *testing.T) {
p.PushBack(handler2)
defer p.Clear()

_, err := p.Handlers[0](nil, nil)
_, _, err := p.Handlers[0](nil, nil)
assert.EqualError(t, errors.New("ohno"), err.Error())
}

Expand Down
2 changes: 1 addition & 1 deletion service/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ func (h *HandlerService) localProcess(ctx context.Context, a agent.Agent, route
func (h *HandlerService) DumpServices() {
handlers := h.handlerPool.GetHandlers()
for name := range handlers {
logger.Log.Infof("registered handler %s, isRawArg: %s", name, handlers[name].IsRawArg)
logger.Log.Infof("registered handler %s, isRawArg: %v", name, handlers[name].IsRawArg)
}
}

Expand Down
3 changes: 2 additions & 1 deletion service/handler_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ func (h *HandlerPool) ProcessHandlerMessage(
return nil, e.NewError(err, e.ErrBadRequestCode)
}

if arg, err = handlerHooks.BeforeHandler.ExecuteBeforePipeline(ctx, arg); err != nil {
ctx, arg, err = handlerHooks.BeforeHandler.ExecuteBeforePipeline(ctx, arg)
if err != nil {
return nil, err
}

Expand Down
4 changes: 2 additions & 2 deletions service/handler_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ func TestProcessHandlerMessageBrokenBeforePipeline(t *testing.T) {
handlerPool := NewHandlerPool()
handlerPool.handlers[rt.Short()] = &component.Handler{}
expected := errors.New("oh noes")
before := func(ctx context.Context, in interface{}) (interface{}, error) {
return nil, expected
before := func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
return ctx, nil, expected
}
beforeHandler := pipeline.NewChannel()
beforeHandler.PushFront(before)
Expand Down
3 changes: 2 additions & 1 deletion service/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ func processHandlerMessage(
return nil, e.NewError(err, e.ErrBadRequestCode)
}

if arg, err = handlerHooks.BeforeHandler.ExecuteBeforePipeline(ctx, arg); err != nil {
ctx, arg, err = handlerHooks.BeforeHandler.ExecuteBeforePipeline(ctx, arg)
if err != nil {
return nil, err
}

Expand Down
18 changes: 9 additions & 9 deletions service/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func TestGetMsgType(t *testing.T) {
func TestExecuteBeforePipelineEmpty(t *testing.T) {
expected := []byte("ok")
beforeHandler := pipeline.NewChannel()
res, err := beforeHandler.ExecuteBeforePipeline(nil, expected)
_, res, err := beforeHandler.ExecuteBeforePipeline(nil, expected)
assert.NoError(t, err)
assert.Equal(t, expected, res)
}
Expand All @@ -194,39 +194,39 @@ func TestExecuteBeforePipelineSuccess(t *testing.T) {
data := []byte("ok")
expected1 := []byte("oh noes 1")
expected2 := []byte("oh noes 2")
before1 := func(ctx context.Context, in interface{}) (interface{}, error) {
before1 := func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
assert.Equal(t, c, ctx)
assert.Equal(t, data, in)
return expected1, nil
return ctx, expected1, nil
}
before2 := func(ctx context.Context, in interface{}) (interface{}, error) {
before2 := func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
assert.Equal(t, c, ctx)
assert.Equal(t, expected1, in)
return expected2, nil
return ctx, expected2, nil
}

beforeHandler := pipeline.NewChannel()
beforeHandler.PushBack(before1)
beforeHandler.PushBack(before2)
defer beforeHandler.Clear()

res, err := beforeHandler.ExecuteBeforePipeline(c, data)
_, res, err := beforeHandler.ExecuteBeforePipeline(c, data)
assert.NoError(t, err)
assert.Equal(t, expected2, res)
}

func TestExecuteBeforePipelineError(t *testing.T) {
c := context.Background()
expected := errors.New("oh noes")
before := func(ctx context.Context, in interface{}) (interface{}, error) {
before := func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
assert.Equal(t, c, ctx)
return nil, expected
return ctx, nil, expected
}
beforeHandler := pipeline.NewChannel()
beforeHandler.PushFront(before)
defer beforeHandler.Clear()

_, err := beforeHandler.ExecuteBeforePipeline(c, []byte("ok"))
_, _, err := beforeHandler.ExecuteBeforePipeline(c, []byte("ok"))
assert.Equal(t, expected, err)
}

Expand Down

0 comments on commit c190e9b

Please sign in to comment.