Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrating PR 153 to V2 #234

Merged
merged 1 commit into from
Aug 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions defaultpipelines/default_struct_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ type DefaultValidator struct {
// based on the struct tags the parameter has.
// This function has the pipeline.Handler signature so
// it is possible to use it as a pipeline function
func (v *DefaultValidator) Validate(ctx context.Context, in interface{}) (interface{}, error) {
func (v *DefaultValidator) Validate(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
if in == nil {
return in, nil
return ctx, in, nil
}

v.lazyinit()
if err := v.validate.Struct(in); err != nil {
return nil, err
return ctx, nil, err
}

return in, nil
return ctx, in, nil
}

func (v *DefaultValidator) lazyinit() {
Expand Down
4 changes: 2 additions & 2 deletions defaultpipelines/default_struct_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ func TestDefaultValidator(t *testing.T) {
t.Run(tname, func(t *testing.T) {
var err error
if tbl.s == nil {
_, err = validator.Validate(context.Background(), nil)
_, _, err = validator.Validate(context.Background(), nil)
} else {
_, err = validator.Validate(context.Background(), tbl.s)
_, _, err = validator.Validate(context.Background(), tbl.s)
}

if tbl.shouldFail {
Expand Down
2 changes: 1 addition & 1 deletion defaultpipelines/struct_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
//
// The default struct validator used by pitaya is https://github.com/go-playground/validator.
type StructValidator interface {
Validate(context.Context, interface{}) (interface{}, error)
Validate(context.Context, interface{}) (context.Context, interface{}, error)
}

// StructValidatorInstance holds the default validator
Expand Down
10 changes: 5 additions & 5 deletions pipeline/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
type (
// HandlerTempl is a function that has the same signature as a handler and will
// be called before or after handler methods
HandlerTempl func(ctx context.Context, in interface{}) (out interface{}, err error)
HandlerTempl func(ctx context.Context, in interface{}) (c context.Context, out interface{}, err error)

// AfterHandlerTempl is a function for the after handler, receives both the handler response
// and the error returned
Expand Down Expand Up @@ -71,19 +71,19 @@ func NewAfterChannel() *AfterChannel {
}

// ExecuteBeforePipeline calls registered handlers
func (p *Channel) ExecuteBeforePipeline(ctx context.Context, data interface{}) (interface{}, error) {
func (p *Channel) ExecuteBeforePipeline(ctx context.Context, data interface{}) (context.Context, interface{}, error) {
var err error
res := data
if len(p.Handlers) > 0 {
for _, h := range p.Handlers {
res, err = h(ctx, res)
ctx, res, err = h(ctx, res)
if err != nil {
logger.Log.Debugf("pitaya/handler: broken pipeline: %s", err.Error())
return res, err
return ctx, res, err
}
}
}
return res, nil
return ctx, res, nil
}

// ExecuteAfterPipeline calls registered handlers
Expand Down
12 changes: 6 additions & 6 deletions pipeline/pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ import (
)

var (
handler1 = func(ctx context.Context, in interface{}) (interface{}, error) {
return in, errors.New("ohno")
handler1 = func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
return ctx, in, errors.New("ohno")
}
handler2 = func(ctx context.Context, in interface{}) (interface{}, error) {
return nil, nil
handler2 = func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
return ctx, nil, nil
}
p = &Channel{}
)
Expand All @@ -43,7 +43,7 @@ func TestPushFront(t *testing.T) {
p.PushFront(handler2)
defer p.Clear()

_, err := p.Handlers[0](nil, nil)
_, _, err := p.Handlers[0](nil, nil)
assert.Nil(t, nil, err)
}

Expand All @@ -52,7 +52,7 @@ func TestPushBack(t *testing.T) {
p.PushBack(handler2)
defer p.Clear()

_, err := p.Handlers[0](nil, nil)
_, _, err := p.Handlers[0](nil, nil)
assert.EqualError(t, errors.New("ohno"), err.Error())
}

Expand Down
2 changes: 1 addition & 1 deletion service/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ func (h *HandlerService) localProcess(ctx context.Context, a agent.Agent, route
func (h *HandlerService) DumpServices() {
handlers := h.handlerPool.GetHandlers()
for name := range handlers {
logger.Log.Infof("registered handler %s, isRawArg: %s", name, handlers[name].IsRawArg)
logger.Log.Infof("registered handler %s, isRawArg: %v", name, handlers[name].IsRawArg)
}
}

Expand Down
3 changes: 2 additions & 1 deletion service/handler_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ func (h *HandlerPool) ProcessHandlerMessage(
return nil, e.NewError(err, e.ErrBadRequestCode)
}

if arg, err = handlerHooks.BeforeHandler.ExecuteBeforePipeline(ctx, arg); err != nil {
ctx, arg, err = handlerHooks.BeforeHandler.ExecuteBeforePipeline(ctx, arg)
if err != nil {
return nil, err
}

Expand Down
4 changes: 2 additions & 2 deletions service/handler_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ func TestProcessHandlerMessageBrokenBeforePipeline(t *testing.T) {
handlerPool := NewHandlerPool()
handlerPool.handlers[rt.Short()] = &component.Handler{}
expected := errors.New("oh noes")
before := func(ctx context.Context, in interface{}) (interface{}, error) {
return nil, expected
before := func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
return ctx, nil, expected
}
beforeHandler := pipeline.NewChannel()
beforeHandler.PushFront(before)
Expand Down
3 changes: 2 additions & 1 deletion service/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ func processHandlerMessage(
return nil, e.NewError(err, e.ErrBadRequestCode)
}

if arg, err = handlerHooks.BeforeHandler.ExecuteBeforePipeline(ctx, arg); err != nil {
ctx, arg, err = handlerHooks.BeforeHandler.ExecuteBeforePipeline(ctx, arg)
if err != nil {
return nil, err
}

Expand Down
18 changes: 9 additions & 9 deletions service/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func TestGetMsgType(t *testing.T) {
func TestExecuteBeforePipelineEmpty(t *testing.T) {
expected := []byte("ok")
beforeHandler := pipeline.NewChannel()
res, err := beforeHandler.ExecuteBeforePipeline(nil, expected)
_, res, err := beforeHandler.ExecuteBeforePipeline(nil, expected)
assert.NoError(t, err)
assert.Equal(t, expected, res)
}
Expand All @@ -194,39 +194,39 @@ func TestExecuteBeforePipelineSuccess(t *testing.T) {
data := []byte("ok")
expected1 := []byte("oh noes 1")
expected2 := []byte("oh noes 2")
before1 := func(ctx context.Context, in interface{}) (interface{}, error) {
before1 := func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
assert.Equal(t, c, ctx)
assert.Equal(t, data, in)
return expected1, nil
return ctx, expected1, nil
}
before2 := func(ctx context.Context, in interface{}) (interface{}, error) {
before2 := func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
assert.Equal(t, c, ctx)
assert.Equal(t, expected1, in)
return expected2, nil
return ctx, expected2, nil
}

beforeHandler := pipeline.NewChannel()
beforeHandler.PushBack(before1)
beforeHandler.PushBack(before2)
defer beforeHandler.Clear()

res, err := beforeHandler.ExecuteBeforePipeline(c, data)
_, res, err := beforeHandler.ExecuteBeforePipeline(c, data)
assert.NoError(t, err)
assert.Equal(t, expected2, res)
}

func TestExecuteBeforePipelineError(t *testing.T) {
c := context.Background()
expected := errors.New("oh noes")
before := func(ctx context.Context, in interface{}) (interface{}, error) {
before := func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
assert.Equal(t, c, ctx)
return nil, expected
return ctx, nil, expected
}
beforeHandler := pipeline.NewChannel()
beforeHandler.PushFront(before)
defer beforeHandler.Clear()

_, err := beforeHandler.ExecuteBeforePipeline(c, []byte("ok"))
_, _, err := beforeHandler.ExecuteBeforePipeline(c, []byte("ok"))
assert.Equal(t, expected, err)
}

Expand Down