Skip to content

Commit

Permalink
Added InboundMiddlewares and initialize fx.Context
Browse files Browse the repository at this point in the history
  • Loading branch information
anuptalwalkar committed Dec 15, 2016
1 parent 2c8f6be commit 4e55169
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 17 deletions.
21 changes: 21 additions & 0 deletions modules/rpc/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"go.uber.org/thriftrw/wire"
"go.uber.org/yarpc"
"go.uber.org/yarpc/encoding/thrift"
"go.uber.org/yarpc/transport"
)

// UnaryHandler is a wrapper for YARPC thrift.UnaryHandler
Expand Down Expand Up @@ -57,6 +58,7 @@ func (f OnewayHandlerFunc) HandleOneway(ctx fx.Context, reqMeta yarpc.ReqMeta, b
}

// WrapUnary wraps the unary handler and returns implementation of thrift.UnaryHandler for yarpc calls
// TODO(anup): fix wrapUnary signature to remove host once update yarpc plugin is imported in the repo
func WrapUnary(host service.Host, unaryHandlerFunc UnaryHandlerFunc) thrift.UnaryHandler {
return &unaryHandlerWrapper{
Host: host,
Expand All @@ -76,6 +78,7 @@ func (hw *unaryHandlerWrapper) Handle(ctx context.Context, reqMeta yarpc.ReqMeta
}

// WrapOneway wraps the oneway handler and returns implementation of thrift.OnewayHandler for yarpc calls
// TODO(anup): fix wrapOneway signature to remove host once update yarpc plugin is imported in the repo
func WrapOneway(host service.Host, onewayHandlerFunc OnewayHandlerFunc) thrift.OnewayHandler {
return &onewayHandlerWrapper{
Host: host,
Expand All @@ -93,3 +96,21 @@ func (hw *onewayHandlerWrapper) HandleOneway(ctx context.Context, reqMeta yarpc.
fxctx := fx.NewContext(ctx, hw.Host)
return hw.OnewayHandlerFunc.HandleOneway(fxctx, reqMeta, body)
}

type fxContextUnaryInboundMiddleware struct {
service.Host
}

func (f fxContextUnaryInboundMiddleware) Handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter, handler transport.UnaryHandler) error {
fxctx := fx.NewContext(ctx, f.Host)
return handler.Handle(fxctx, req, resw)
}

type fxContextOnewayInboundMiddleware struct {
service.Host
}

func (f fxContextOnewayInboundMiddleware) HandleOneway(ctx context.Context, req *transport.Request, handler transport.OnewayHandler) error {
fxctx := fx.NewContext(ctx, f.Host)
return handler.HandleOneway(fxctx, req)
}
31 changes: 31 additions & 0 deletions modules/rpc/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"go.uber.org/thriftrw/wire"
"go.uber.org/yarpc"
"go.uber.org/yarpc/encoding/thrift"
"go.uber.org/yarpc/transport"
)

type fakeEnveloper struct {
Expand Down Expand Up @@ -88,3 +89,33 @@ func TestWrapOneway_error(t *testing.T) {
err := handlerFunc.HandleOneway(context.Background(), nil, wire.Value{})
assert.Error(t, err)
}

func TestUnaryInboundMiddleware_fxContext(t *testing.T) {
u := fxContextUnaryInboundMiddleware{
Host: service.NullHost(),
}
err := u.Handle(context.Background(), &transport.Request{}, nil, &fakeUnaryHandler{})
assert.Equal(t, "dummy", err.Error())
}

type fakeUnaryHandler struct{}

func (_m *fakeUnaryHandler) Handle(ctx context.Context, _param1 *transport.Request, _param2 transport.ResponseWriter) error {
// TODO(anup): improve type assertion and context upgrading to fx.Context
return errors.New(ctx.(fx.Context).Name())
}

func TestOnewayInboundMiddleware_fxContext(t *testing.T) {
u := fxContextOnewayInboundMiddleware{
Host: service.NullHost(),
}
err := u.HandleOneway(context.Background(), &transport.Request{}, &fakeOnewayHandler{})
assert.Equal(t, "dummy", err.Error())
}

type fakeOnewayHandler struct{}

func (_m *fakeOnewayHandler) HandleOneway(ctx context.Context, p *transport.Request) error {
// TODO(anup): improve type assertion and context upgrading to fx.Context
return errors.New(ctx.(fx.Context).Name())
}
3 changes: 2 additions & 1 deletion modules/rpc/thrift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ func testInitRunModule(t *testing.T, mod service.Module, mci service.ModuleCreat

func mch() service.ModuleCreateInfo {
return service.ModuleCreateInfo{
Host: service.NullHost(),
Items: make(map[string]interface{}),
Host: service.NullHost(),
}
}

Expand Down
30 changes: 21 additions & 9 deletions modules/rpc/yarpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ import (
// YarpcModule is an implementation of a core module using YARPC
type YarpcModule struct {
modules.ModuleBase
rpc yarpc.Dispatcher
register registerServiceFunc
config yarpcConfig
log ulog.Log
stateMu sync.RWMutex
inboundMiddlewares []transport.UnaryInboundMiddleware
rpc yarpc.Dispatcher
register registerServiceFunc
config yarpcConfig
log ulog.Log
stateMu sync.RWMutex
unaryInboundMiddlewares []transport.UnaryInboundMiddleware
onewayInboundMiddlewares []transport.OnewayInboundMiddleware
}

var (
Expand Down Expand Up @@ -85,6 +86,14 @@ func newYarpcModule(
config: *cfg,
}

options = append(options,
WithUnaryInboundMiddleware(fxContextUnaryInboundMiddleware{
Host: mi.Host,
}),
WithOnewayInboundMiddleware(fxContextOnewayInboundMiddleware{
Host: mi.Host,
}))

module.log = ulog.Logger().With("moduleName", name)
for _, opt := range options {
if err := opt(&mi); err != nil {
Expand All @@ -97,7 +106,8 @@ func newYarpcModule(
// found values, update module
module.config = *cfg

module.inboundMiddlewares = inboundMiddlewaresFromCreateInfo(mi)
module.unaryInboundMiddlewares = unaryInboundMiddlewaresFromCreateInfo(mi)
module.onewayInboundMiddlewares = onewayInboundMiddlewaresFromCreateInfo(mi)

return module, err
}
Expand All @@ -119,15 +129,17 @@ func (m *YarpcModule) Start(readyCh chan<- struct{}) <-chan error {
return ret
}

interceptor := yarpc.UnaryInboundMiddleware(m.inboundMiddlewares...)
unaryInterceptor := yarpc.UnaryInboundMiddleware(m.unaryInboundMiddlewares...)
onewayInterceptor := yarpc.OnewayInboundMiddleware(m.onewayInboundMiddlewares...)

m.rpc, err = _dispatcherFn(m.Host(), yarpc.Config{
Name: m.config.AdvertiseName,
Inbounds: []transport.Inbound{
tch.NewInbound(channel, tch.ListenAddr(m.config.Bind)),
},
InboundMiddleware: yarpc.InboundMiddleware{
Unary: interceptor,
Unary: unaryInterceptor,
Oneway: onewayInterceptor,
},
Tracer: m.Tracer(),
})
Expand Down
31 changes: 26 additions & 5 deletions modules/rpc/yarpc_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,47 @@ import (
)

const (
_interceptorKey = "yarpcUnaryInboundMiddleware"
_unaryInterceptorKey = "yarpcUnaryInboundMiddleware"
_onewayInterceptorKey = "yarpcOnewayInboundMiddleware"
)

// WithUnaryInboundMiddleware adds custom YARPC inboundMiddlewares to the module
func WithUnaryInboundMiddleware(i ...transport.UnaryInboundMiddleware) modules.Option {
return func(mci *service.ModuleCreateInfo) error {
inboundMiddlewares := inboundMiddlewaresFromCreateInfo(*mci)
inboundMiddlewares := unaryInboundMiddlewaresFromCreateInfo(*mci)
inboundMiddlewares = append(inboundMiddlewares, i...)
mci.Items[_interceptorKey] = inboundMiddlewares
mci.Items[_unaryInterceptorKey] = inboundMiddlewares

return nil
}
}

func inboundMiddlewaresFromCreateInfo(mci service.ModuleCreateInfo) []transport.UnaryInboundMiddleware {
items, ok := mci.Items[_interceptorKey]
// WithOnewayInboundMiddleware adds custom YARPC inboundMid dlewares to the module
func WithOnewayInboundMiddleware(i ...transport.OnewayInboundMiddleware) modules.Option {
return func(mci *service.ModuleCreateInfo) error {
inboundMiddlewares := onewayInboundMiddlewaresFromCreateInfo(*mci)
inboundMiddlewares = append(inboundMiddlewares, i...)
mci.Items[_onewayInterceptorKey] = inboundMiddlewares
return nil
}
}

func unaryInboundMiddlewaresFromCreateInfo(mci service.ModuleCreateInfo) []transport.UnaryInboundMiddleware {
items, ok := mci.Items[_unaryInterceptorKey]
if !ok {
return nil
}

// Intentionally panic if programmer adds non-interceptor slice to the data
return items.([]transport.UnaryInboundMiddleware)
}

func onewayInboundMiddlewaresFromCreateInfo(mci service.ModuleCreateInfo) []transport.OnewayInboundMiddleware {
items, ok := mci.Items[_onewayInterceptorKey]
if !ok {
return nil
}

// Intentionally panic if programmer adds non-interceptor slice to the data
return items.([]transport.OnewayInboundMiddleware)
}
24 changes: 22 additions & 2 deletions modules/rpc/yarpc_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,37 @@ func TestWithUnaryInboundMiddleware_OK(t *testing.T) {
}

require.NoError(t, opt(mc))
assert.Equal(t, 1, len(inboundMiddlewaresFromCreateInfo(*mc)))
assert.Equal(t, 1, len(unaryInboundMiddlewaresFromCreateInfo(*mc)))
}

func TestWithOnewayInboundMiddleware_OK(t *testing.T) {
opt := WithOnewayInboundMiddleware(transport.NopOnewayInboundMiddleware)
mc := &service.ModuleCreateInfo{
Items: make(map[string]interface{}),
}
require.NoError(t, opt(mc))
assert.Equal(t, 1, len(onewayInboundMiddlewaresFromCreateInfo(*mc)))
}

func TestWithUnaryInboundMiddleware_PanicsBadData(t *testing.T) {
opt := WithUnaryInboundMiddleware(transport.NopUnaryInboundMiddleware)
mc := &service.ModuleCreateInfo{
Items: map[string]interface{}{
_interceptorKey: "foo",
_unaryInterceptorKey: "foo",
},
}
assert.Panics(t, func() {
opt(mc)
})
}

func TestWithOnewayInboundMiddleware_PanicsBadData(t *testing.T) {
opt := WithOnewayInboundMiddleware(transport.NopOnewayInboundMiddleware)
mc := &service.ModuleCreateInfo{
Items: map[string]interface{}{
_onewayInterceptorKey: "foo",
},
}
assert.Panics(t, func() {
opt(mc)
})
Expand Down

0 comments on commit 4e55169

Please sign in to comment.