diff --git a/pipeline/pipeline.go b/pipeline/pipeline.go index 2849f3c9..7509ad87 100644 --- a/pipeline/pipeline.go +++ b/pipeline/pipeline.go @@ -51,3 +51,8 @@ func (p *pipelineChannel) PushFront(h Handler) { func (p *pipelineChannel) PushBack(h Handler) { p.Handlers = append(p.Handlers, h) } + +// Clear should not be used after pitaya running +func (p *pipelineChannel) Clear() { + p.Handlers = make([]Handler, 0) +} diff --git a/pipeline/pipeline_test.go b/pipeline/pipeline_test.go index 21310318..6a397ef8 100644 --- a/pipeline/pipeline_test.go +++ b/pipeline/pipeline_test.go @@ -41,6 +41,8 @@ var ( func TestPushFront(t *testing.T) { p.PushFront(handler1) p.PushFront(handler2) + defer p.Clear() + _, err := p.Handlers[0](nil, nil) assert.Nil(t, nil, err) } @@ -48,6 +50,16 @@ func TestPushFront(t *testing.T) { func TestPushBack(t *testing.T) { p.PushFront(handler1) p.PushBack(handler2) + defer p.Clear() + _, err := p.Handlers[0](nil, nil) assert.EqualError(t, errors.New("ohno"), err.Error()) } + +func TestClear(t *testing.T) { + p.PushFront(handler1) + p.PushBack(handler2) + assert.Len(t, p.Handlers, 2) + p.Clear() + assert.Len(t, p.Handlers, 0) +} diff --git a/service/util.go b/service/util.go index 7b62f4e7..c9422939 100644 --- a/service/util.go +++ b/service/util.go @@ -134,7 +134,6 @@ func serializeReturn(ser serialize.Serializer, ret interface{}) ([]byte, error) return res, nil } -// TODO: should this be here in utils? func processHandlerMessage( rt *route.Route, serializer serialize.Serializer, @@ -174,6 +173,7 @@ func processHandlerMessage( if arg != nil { args = append(args, reflect.ValueOf(arg)) } + resp, err := util.Pcall(h.Method, args) if err != nil { return nil, err diff --git a/service/util_test.go b/service/util_test.go index 10293b27..9d372584 100644 --- a/service/util_test.go +++ b/service/util_test.go @@ -59,6 +59,15 @@ type SomeStruct struct { func (t *TestType) HandlerNil(*session.Session) {} func (t *TestType) HandlerRaw(s *session.Session, msg []byte) {} func (t *TestType) HandlerPointer(s *session.Session, ss *SomeStruct) {} +func (t *TestType) HandlerPointerRaw(s *session.Session, ss *SomeStruct) ([]byte, error) { + return []byte("ok"), nil +} +func (t *TestType) HandlerPointerStruct(s *session.Session, ss *SomeStruct) (*SomeStruct, error) { + return &SomeStruct{A: 1, B: "ok"}, nil +} +func (t *TestType) HandlerPointerErr(s *session.Session, ss *SomeStruct) ([]byte, error) { + return nil, errors.New("HandlerPointerErr") +} func TestMain(m *testing.M) { setup() @@ -230,6 +239,7 @@ func TestExecuteBeforePipelineSuccess(t *testing.T) { } pipeline.BeforeHandler.PushBack(before1) pipeline.BeforeHandler.PushBack(before2) + defer pipeline.BeforeHandler.Clear() res, err := executeBeforePipeline(ss, data) assert.NoError(t, err) @@ -244,6 +254,7 @@ func TestExecuteBeforePipelineError(t *testing.T) { return nil, expected } pipeline.BeforeHandler.PushFront(before) + defer pipeline.BeforeHandler.Clear() _, err := executeBeforePipeline(ss, []byte("ok")) assert.Equal(t, expected, err) @@ -272,6 +283,7 @@ func TestExecuteAfterPipelineSuccess(t *testing.T) { } pipeline.AfterHandler.PushBack(after1) pipeline.AfterHandler.PushBack(after2) + defer pipeline.AfterHandler.Clear() res := executeAfterPipeline(ss, nil, []byte("ok")) assert.Equal(t, expected2, res) @@ -288,6 +300,7 @@ func TestExecuteAfterPipelineError(t *testing.T) { return nil, errors.New("oh noes") } pipeline.AfterHandler.PushFront(after) + defer pipeline.AfterHandler.Clear() expected := []byte("error") mockSerializer.EXPECT().Marshal(gomock.Any()).Return(expected, nil) @@ -329,3 +342,127 @@ func TestSerializeReturn(t *testing.T) { }) } } + +func TestProcessHandlerMessage(t *testing.T) { + tObj := &TestType{} + + m, ok := reflect.TypeOf(tObj).MethodByName("HandlerPointerRaw") + assert.True(t, ok) + assert.NotNil(t, m) + rt := route.NewRoute("", uuid.New().String(), uuid.New().String()) + handlers[rt.Short()] = &component.Handler{Receiver: reflect.ValueOf(tObj), Method: m, Type: m.Type.In(2)} + + m, ok = reflect.TypeOf(tObj).MethodByName("HandlerPointerErr") + assert.True(t, ok) + assert.NotNil(t, m) + rtErr := route.NewRoute("", uuid.New().String(), uuid.New().String()) + handlers[rtErr.Short()] = &component.Handler{Receiver: reflect.ValueOf(tObj), Method: m, Type: m.Type.In(2)} + + m, ok = reflect.TypeOf(tObj).MethodByName("HandlerPointerStruct") + assert.True(t, ok) + assert.NotNil(t, m) + rtSt := route.NewRoute("", uuid.New().String(), uuid.New().String()) + handlers[rtSt.Short()] = &component.Handler{Receiver: reflect.ValueOf(tObj), Method: m, Type: m.Type.In(2)} + defer func() { handlers = make(map[string]*component.Handler, 0) }() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ss := session.New(nil, false) + cs := reflect.ValueOf(ss) + + tables := []struct { + name string + route *route.Route + errSerReturn error + errSerialize error + outSerialize interface{} + handlerType message.Type + msgType interface{} + remote bool + out []byte + err error + }{ + {"invalid_route", route.NewRoute("", "no", "no"), nil, nil, nil, message.Request, nil, false, nil, errors.New("pitaya/handler: no.no not found")}, + {"invalid_msg_type", rt, nil, nil, nil, message.Request, nil, false, nil, errors.New("invalid message type provided")}, + {"request_on_notify", rt, nil, nil, nil, message.Notify, message.Request, false, nil, errors.New("tried to request a notify route")}, + {"failed_handle_args_unmarshal", rt, nil, errors.New("some error"), &SomeStruct{}, message.Request, message.Request, false, nil, errors.New("some error")}, + {"failed_pcall", rtErr, nil, nil, &SomeStruct{A: 1, B: "ok"}, message.Request, message.Request, false, nil, errors.New("HandlerPointerErr")}, + {"failed_serialize_return", rtSt, errors.New("ser ret error"), nil, &SomeStruct{A: 1, B: "ok"}, message.Request, message.Request, false, []byte("failed"), nil}, + {"ok", rt, nil, nil, &SomeStruct{}, message.Request, message.Request, false, []byte("ok"), nil}, + {"notify_on_request", rt, nil, nil, &SomeStruct{}, message.Request, message.Notify, false, []byte("ok"), nil}, + {"remote_notify", rt, nil, nil, &SomeStruct{}, message.Notify, message.Notify, true, []byte("ack"), nil}, + } + + for _, table := range tables { + t.Run(table.name, func(t *testing.T) { + handlers[rt.Short()].MessageType = table.handlerType + mockSerializer := mocks.NewMockSerializer(ctrl) + if table.outSerialize != nil { + mockSerializer.EXPECT().Unmarshal(gomock.Any(), gomock.Any()).Return(table.errSerialize).Do( + func(p []byte, arg interface{}) { + arg = table.outSerialize + }) + + if table.errSerReturn != nil { + mockSerializer.EXPECT().Marshal(gomock.Any()).Return(table.out, table.errSerReturn) + mockSerializer.EXPECT().Marshal(gomock.Any()).Return(table.out, nil) + } + } + out, err := processHandlerMessage(table.route, mockSerializer, cs, ss, nil, table.msgType, table.remote) + assert.Equal(t, table.out, out) + assert.Equal(t, table.err, err) + }) + } +} + +func TestProcessHandlerMessageBrokenBeforePipeline(t *testing.T) { + rt := route.NewRoute("", uuid.New().String(), uuid.New().String()) + handlers[rt.Short()] = &component.Handler{} + defer func() { delete(handlers, rt.Short()) }() + expected := errors.New("oh noes") + before := func(s *session.Session, in []byte) ([]byte, error) { + return nil, expected + } + pipeline.BeforeHandler.PushFront(before) + defer pipeline.BeforeHandler.Clear() + + ss := session.New(nil, false) + cs := reflect.ValueOf(ss) + out, err := processHandlerMessage(rt, nil, cs, ss, nil, message.Request, false) + assert.Nil(t, out) + assert.Equal(t, expected, err) +} + +func TestProcessHandlerMessageBrokenAfterPipeline(t *testing.T) { + tObj := &TestType{} + m, ok := reflect.TypeOf(tObj).MethodByName("HandlerPointerRaw") + assert.True(t, ok) + assert.NotNil(t, m) + rt := route.NewRoute("", uuid.New().String(), uuid.New().String()) + handlers[rt.Short()] = &component.Handler{Receiver: reflect.ValueOf(tObj), Method: m, Type: m.Type.In(2)} + defer func() { delete(handlers, rt.Short()) }() + + after := func(s *session.Session, in []byte) ([]byte, error) { + return nil, errors.New("oh noes") + } + pipeline.AfterHandler.PushFront(after) + defer pipeline.AfterHandler.Clear() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ss := session.New(nil, false) + cs := reflect.ValueOf(ss) + mockSerializer := mocks.NewMockSerializer(ctrl) + mockSerializer.EXPECT().Unmarshal(gomock.Any(), gomock.Any()).Return(nil).Do( + func(p []byte, arg interface{}) { + arg = &SomeStruct{} + }) + expected := []byte("oops") + mockSerializer.EXPECT().Marshal(gomock.Any()).Return(expected, nil) + + out, err := processHandlerMessage(rt, mockSerializer, cs, ss, nil, message.Request, false) + assert.Equal(t, expected, out) + assert.NoError(t, err) +}