Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support before/after hooks on user remote RPCs #351

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions service/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,6 @@ func processRemoteMessage(ctx context.Context, req *protos.Request, r *RemoteSer
}

func (r *RemoteService) handleRPCUser(ctx context.Context, req *protos.Request, rt *route.Route) *protos.Response {
response := &protos.Response{}

remote, ok := r.remotes[rt.Short()]
if !ok {
logger.Log.Warnf("pitaya/remote: %s not found", rt.Short())
Expand All @@ -350,9 +348,13 @@ func (r *RemoteService) handleRPCUser(ctx context.Context, req *protos.Request,
}
return response
}
params := []reflect.Value{remote.Receiver, reflect.ValueOf(ctx)}

var ret interface{}
var arg interface{}
var err error

if remote.HasArgs {
arg, err := unmarshalRemoteArg(remote, req.GetMsg().GetData())
arg, err = unmarshalRemoteArg(remote, req.GetMsg().GetData())
if err != nil {
response := &protos.Response{
Error: &protos.Error{
Expand All @@ -362,10 +364,26 @@ func (r *RemoteService) handleRPCUser(ctx context.Context, req *protos.Request,
}
return response
}
}

ctx, arg, err = r.handlerHooks.BeforeHandler.ExecuteBeforePipeline(ctx, arg)
if err != nil {
response := &protos.Response{
Error: &protos.Error{
Code: e.ErrInternalCode,
Msg: err.Error(),
},
}
return response
}

params := []reflect.Value{remote.Receiver, reflect.ValueOf(ctx)}
if remote.HasArgs {
params = append(params, reflect.ValueOf(arg))
}

ret, err := util.Pcall(remote.Method, params)
ret, err = util.Pcall(remote.Method, params)
ret, err = r.handlerHooks.AfterHandler.ExecuteAfterPipeline(ctx, ret, err)
if err != nil {
response := &protos.Response{
Error: &protos.Error{
Expand Down Expand Up @@ -405,6 +423,7 @@ func (r *RemoteService) handleRPCUser(ctx context.Context, req *protos.Request,
}
}

response := &protos.Response{}
response.Data = b
return response
}
Expand Down
167 changes: 167 additions & 0 deletions service/remote_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package service
import (
"context"
"errors"
"fmt"
"math/rand"
"reflect"
"testing"
Expand Down Expand Up @@ -50,6 +51,8 @@ import (
sessionmocks "github.com/topfreegames/pitaya/v2/session/mocks"
)

const ctxModifiedResponse = "response"

func (m *MyComp) Remote1(ctx context.Context, ss *test.SomeStruct) (*test.SomeStruct, error) {
return &test.SomeStruct{B: "ack"}, nil
}
Expand All @@ -59,6 +62,10 @@ func (m *MyComp) Remote2(ctx context.Context) (*test.SomeStruct, error) {
}

func (m *MyComp) RemoteRes(ctx context.Context, b *test.SomeStruct) (*test.SomeStruct, error) {
ctxRes := ctx.Value(ctxModifiedResponse) // used in hook tests
if ctxRes != nil {
return ctxRes.(*test.SomeStruct), nil
}
return b, nil
}

Expand Down Expand Up @@ -368,6 +375,166 @@ func TestRemoteServiceHandleRPCUser(t *testing.T) {
}
}

func TestRemoteServiceHandleRPCUserWithHooks(t *testing.T) {
handlerPool := NewHandlerPool()

tObj := &MyComp{}
m, ok := reflect.TypeOf(tObj).MethodByName("Remote1")
assert.True(t, ok)
assert.NotNil(t, m)
rt := route.NewRoute("", uuid.New().String(), uuid.New().String())
comp := &component.Remote{Receiver: reflect.ValueOf(tObj), Method: m, HasArgs: m.Type.NumIn() > 2}

m, ok = reflect.TypeOf(tObj).MethodByName("RemoteErr")
assert.True(t, ok)
assert.NotNil(t, m)
rtErr := route.NewRoute("", uuid.New().String(), uuid.New().String())
compErr := &component.Remote{Receiver: reflect.ValueOf(tObj), Method: m, HasArgs: m.Type.NumIn() > 2}

m, ok = reflect.TypeOf(tObj).MethodByName("Remote2")
assert.True(t, ok)
assert.NotNil(t, m)
rtStr := route.NewRoute("", uuid.New().String(), uuid.New().String())
compStr := &component.Remote{Receiver: reflect.ValueOf(tObj), Method: m, HasArgs: m.Type.NumIn() > 2}

m, ok = reflect.TypeOf(tObj).MethodByName("RemoteRes")
assert.True(t, ok)
assert.NotNil(t, m)
rtRes := route.NewRoute("", uuid.New().String(), uuid.New().String())
compRes := &component.Remote{Receiver: reflect.ValueOf(tObj), Method: m, HasArgs: m.Type.NumIn() > 2, Type: reflect.TypeOf(&test.SomeStruct{B: "aa"})}

b, err := proto.Marshal(&test.SomeStruct{B: "aa"})
assert.NoError(t, err)

modifiedInput := &test.SomeStruct{B: "cc"}

modifiedResponse, err := proto.Marshal(modifiedInput)
assert.NoError(t, err)

modifiedCtx := context.WithValue(context.Background(), ctxModifiedResponse, modifiedInput)
tables := []struct {
name string
req *protos.Request
rt *route.Route
expectedOutput []byte
errSubstring string
shouldRunBeforeHook bool
shouldRunAfterHook bool
modifiedInput interface{}
modifiedCtx context.Context
modifiedInputError error
modifiedOutput interface{}
modifiedOutputError error
}{
{"remote_not_found", &protos.Request{Msg: &protos.Msg{}}, route.NewRoute("bla", "bla", "bla"), nil, "route not found", false, false, nil, nil, nil, nil, nil},
{"failed_unmarshal", &protos.Request{Msg: &protos.Msg{Data: []byte("dd")}}, rt, nil, "reflect: Call using zero Value argument", true, true, nil, nil, nil, nil, nil},
{"failed_pcall", &protos.Request{Msg: &protos.Msg{}}, rtErr, nil, "remote err", true, true, nil, nil, nil, nil, nil},
{"failed_before_hook", &protos.Request{Msg: &protos.Msg{}}, rtErr, nil, "before hook err", true, false, nil, nil, fmt.Errorf("before hook err"), nil, nil},
{"failed_pcall_modified_err", &protos.Request{Msg: &protos.Msg{}}, rtErr, nil, "remote err modified output", true, true, nil, nil, nil, nil, fmt.Errorf("remote err modified output")},
{"success_nil_response", &protos.Request{Msg: &protos.Msg{}}, rtStr, nil, "", true, true, nil, nil, nil, nil, nil},
{"success_response", &protos.Request{Msg: &protos.Msg{Data: b}}, rtRes, b, "", true, true, nil, nil, nil, nil, nil},
{"success_response_modified_ctx", &protos.Request{Msg: &protos.Msg{Data: b}}, rtRes, modifiedResponse, "", true, true, nil, modifiedCtx, nil, nil, nil},
{"success_response_modified_input", &protos.Request{Msg: &protos.Msg{Data: b}}, rtRes, modifiedResponse, "", true, true, modifiedInput, nil, nil, nil, nil},
{"success_response_modified_input_ctx", &protos.Request{Msg: &protos.Msg{Data: b}}, rtRes, modifiedResponse, "", true, true, modifiedInput, modifiedCtx, nil, nil, nil},
{"success_response_modified_output", &protos.Request{Msg: &protos.Msg{Data: b}}, rtRes, modifiedResponse, "", true, true, nil, nil, nil, modifiedInput, nil},
{"failed_after_hook", &protos.Request{Msg: &protos.Msg{Data: b}}, rtRes, nil, "after hook err", true, true, nil, nil, nil, nil, fmt.Errorf("after hook err")},
}

for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
packetEncoder := codec.NewPomeloPacketEncoder()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockSerializer := serializemocks.NewMockSerializer(ctrl)
mockSD := clustermocks.NewMockServiceDiscovery(ctrl)
mockRPCClient := clustermocks.NewMockRPCClient(ctrl)
mockRPCServer := clustermocks.NewMockRPCServer(ctrl)
messageEncoder := message.NewMessagesEncoder(false)
router := router.New()
sessionPool := session.NewSessionPool()

beforeHookInvoked := false
afterHookInvoked := false

handlerHooks := pipeline.NewHandlerHooks()
handlerHooks.BeforeHandler.PushFront(func(ctx context.Context, in interface{}) (context.Context, interface{}, error) {
if beforeHookInvoked {
assert.FailNow(t, "BeforeHandler hook invoked twice")
}
if afterHookInvoked {
assert.FailNow(t, "BeforeHandler and AfterHandler hooks running out of order")
}

var err error
if table.modifiedInput != nil {
in = table.modifiedInput
}
if table.modifiedCtx != nil {
ctx = table.modifiedCtx
}
if table.modifiedInputError != nil {
err = table.modifiedInputError
}

beforeHookInvoked = true
return ctx, in, err
})
handlerHooks.AfterHandler.PushFront(func(ctx context.Context, out interface{}, err error) (interface{}, error) {
if afterHookInvoked {
assert.FailNow(t, "AfterHandler hook invoked twice")
}
if !beforeHookInvoked {
assert.FailNow(t, "BeforeHandler and AfterHandler hooks running out of order")
}

if table.modifiedOutput != nil {
out = table.modifiedOutput
}
if table.modifiedOutputError != nil {
err = table.modifiedOutputError
}

afterHookInvoked = true
return out, err
})

svc := NewRemoteService(mockRPCClient, mockRPCServer, mockSD, packetEncoder, mockSerializer, router, messageEncoder, &cluster.Server{}, sessionPool, handlerHooks, handlerPool)

svc.remotes[rt.Short()] = comp
svc.remotes[rtErr.Short()] = compErr
svc.remotes[rtStr.Short()] = compStr
svc.remotes[rtRes.Short()] = compRes

assert.NotNil(t, svc)

assert.False(t, beforeHookInvoked, "Before hook invoked before RPC")
assert.False(t, afterHookInvoked, "After hook invoked before RPC")

res := svc.handleRPCUser(context.Background(), table.req, table.rt)

if table.shouldRunBeforeHook {
assert.True(t, beforeHookInvoked, "After hook was never invoked")
} else {
assert.False(t, beforeHookInvoked, "After hook should not have run")
}
if table.shouldRunAfterHook {
assert.True(t, afterHookInvoked, "After hook was never invoked")
} else {
assert.False(t, afterHookInvoked, "After hook should not have run")
}

assert.NoError(t, err)
if table.errSubstring != "" {
assert.Contains(t, res.Error.Msg, table.errSubstring)
} else if table.req.Msg.Data != nil {
assert.NotNil(t, res.Data)
}

assert.Equal(t, res.Data, table.expectedOutput)
})
}
}

func TestRemoteServiceHandleRPCSys(t *testing.T) {
tObj := &TestType{}
m, ok := reflect.TypeOf(tObj).MethodByName("HandlerPointerRaw")
Expand Down