Skip to content

Commit

Permalink
Manage transport lifecycle from Dispatcher (#529)
Browse files Browse the repository at this point in the history
  • Loading branch information
kriskowal committed Dec 7, 2016
1 parent 5efac87 commit c91f0f0
Show file tree
Hide file tree
Showing 13 changed files with 223 additions and 17 deletions.
105 changes: 89 additions & 16 deletions dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ func NewDispatcher(cfg Config) Dispatcher {
Registrar: transport.NewMapRegistry(cfg.Name),
inbounds: cfg.Inbounds,
outbounds: convertOutbounds(cfg.Outbounds, cfg.OutboundMiddleware),
transports: collectTransports(cfg.Inbounds, cfg.Outbounds),
InboundMiddleware: cfg.InboundMiddleware,
}
}
Expand Down Expand Up @@ -141,6 +142,37 @@ func convertOutbounds(outbounds Outbounds, middleware OutboundMiddleware) Outbou
return convertedOutbounds
}

// collectTransports iterates over all inbounds and outbounds and collects all
// of their unique underlying transports. Multiple inbounds and outbounds may
// share a transport, and we only want the dispatcher to manage their lifecycle
// once.
func collectTransports(inbounds Inbounds, outbounds Outbounds) []transport.Transport {
// Collect all unique transports from inbounds and outbounds.
transports := make(map[transport.Transport]struct{})
for _, inbound := range inbounds {
for _, transport := range inbound.Transports() {
transports[transport] = struct{}{}
}
}
for _, outbound := range outbounds {
if unary := outbound.Unary; unary != nil {
for _, transport := range unary.Transports() {
transports[transport] = struct{}{}
}
}
if oneway := outbound.Oneway; oneway != nil {
for _, transport := range oneway.Transports() {
transports[transport] = struct{}{}
}
}
}
keys := make([]transport.Transport, 0, len(transports))
for key := range transports {
keys = append(keys, key)
}
return keys
}

// dispatcher is the standard RPC implementation.
//
// It allows use of multiple Inbounds and Outbounds together.
Expand All @@ -149,8 +181,9 @@ type dispatcher struct {

Name string

inbounds Inbounds
outbounds Outbounds
inbounds Inbounds
outbounds Outbounds
transports []transport.Transport

InboundMiddleware InboundMiddleware
}
Expand All @@ -170,9 +203,10 @@ func (d dispatcher) ClientConfig(service string) transport.ClientConfig {

func (d dispatcher) Start() error {
var (
mu sync.Mutex
startedInbounds []transport.Inbound
startedOutbounds []transport.Outbound
mu sync.Mutex
startedTransports []transport.Transport
startedInbounds []transport.Inbound
startedOutbounds []transport.Outbound
)

startInbound := func(i transport.Inbound) func() error {
Expand Down Expand Up @@ -205,6 +239,41 @@ func (d dispatcher) Start() error {
}
}

startTransport := func(t transport.Transport) func() error {
return func() error {
if err := t.Start(); err != nil {
return err
}

mu.Lock()
startedTransports = append(startedTransports, t)
mu.Unlock()
return nil
}
}

abort := func(errs []error) error {
// Failed to start so stop everything that was started.
wait := intsync.ErrorWaiter{}
for _, i := range startedInbounds {
wait.Submit(i.Stop)
}
for _, o := range startedOutbounds {
wait.Submit(o.Stop)
}
for _, t := range startedTransports {
wait.Submit(t.Stop)
}

if newErrors := wait.Wait(); len(newErrors) > 0 {
errs = append(errs, newErrors...)
}

return errors.ErrorGroup(errs)
}

// Start inbounds and outbounds in parallel

var wait intsync.ErrorWaiter
for _, i := range d.inbounds {
i.SetRegistry(d)
Expand All @@ -217,25 +286,25 @@ func (d dispatcher) Start() error {
wait.Submit(startOutbound(o.Oneway))
}

// Synchronize
errs := wait.Wait()
if len(errs) == 0 {
return nil
if len(errs) != 0 {
return abort(errs)
}

// Failed to start so stop everything that was started.
// Start transports
wait = intsync.ErrorWaiter{}
for _, i := range startedInbounds {
wait.Submit(i.Stop)
}
for _, o := range startedOutbounds {
wait.Submit(o.Stop)
for _, t := range d.transports {
wait.Submit(startTransport(t))
}

if newErrors := wait.Wait(); len(newErrors) > 0 {
errs = append(errs, newErrors...)
// Synchronize
errs = wait.Wait()
if len(errs) != 0 {
return abort(errs)
}

return errors.ErrorGroup(errs)
return nil
}

func (d dispatcher) Register(rs []transport.Registrant) {
Expand Down Expand Up @@ -277,6 +346,10 @@ func (d dispatcher) Stop() error {
}
}

for _, t := range d.transports {
wait.Submit(t.Stop)
}

if errs := wait.Wait(); len(errs) > 0 {
return errors.ErrorGroup(errs)
}
Expand Down
10 changes: 10 additions & 0 deletions dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ func TestStartStopFailures(t *testing.T) {
inbounds := make(Inbounds, 10)
for i := range inbounds {
in := transporttest.NewMockInbound(mockCtrl)
in.EXPECT().Transports()
in.EXPECT().SetRegistry(gomock.Any())
in.EXPECT().Start().Return(nil)
in.EXPECT().Stop().Return(nil)
Expand All @@ -122,6 +123,7 @@ func TestStartStopFailures(t *testing.T) {
outbounds := make(Outbounds, 10)
for i := 0; i < 10; i++ {
out := transporttest.NewMockUnaryOutbound(mockCtrl)
out.EXPECT().Transports()
out.EXPECT().Start().Return(nil)
out.EXPECT().Stop().Return(nil)
outbounds[fmt.Sprintf("service-%v", i)] =
Expand All @@ -138,6 +140,7 @@ func TestStartStopFailures(t *testing.T) {
inbounds := make(Inbounds, 10)
for i := range inbounds {
in := transporttest.NewMockInbound(mockCtrl)
in.EXPECT().Transports()
in.EXPECT().SetRegistry(gomock.Any())
if i == 6 {
in.EXPECT().Start().Return(errors.New("great sadness"))
Expand All @@ -153,6 +156,7 @@ func TestStartStopFailures(t *testing.T) {
outbounds := make(Outbounds, 10)
for i := 0; i < 10; i++ {
out := transporttest.NewMockUnaryOutbound(mockCtrl)
out.EXPECT().Transports()
out.EXPECT().Start().Return(nil)
out.EXPECT().Stop().Return(nil)
outbounds[fmt.Sprintf("service-%v", i)] =
Expand All @@ -170,6 +174,7 @@ func TestStartStopFailures(t *testing.T) {
inbounds := make(Inbounds, 10)
for i := range inbounds {
in := transporttest.NewMockInbound(mockCtrl)
in.EXPECT().Transports()
in.EXPECT().SetRegistry(gomock.Any())
in.EXPECT().Start().Return(nil)
if i == 7 {
Expand All @@ -185,6 +190,7 @@ func TestStartStopFailures(t *testing.T) {
outbounds := make(Outbounds, 10)
for i := 0; i < 10; i++ {
out := transporttest.NewMockUnaryOutbound(mockCtrl)
out.EXPECT().Transports()
out.EXPECT().Start().Return(nil)
out.EXPECT().Stop().Return(nil)
outbounds[fmt.Sprintf("service-%v", i)] =
Expand All @@ -202,6 +208,7 @@ func TestStartStopFailures(t *testing.T) {
inbounds := make(Inbounds, 10)
for i := range inbounds {
in := transporttest.NewMockInbound(mockCtrl)
in.EXPECT().Transports()
in.EXPECT().SetRegistry(gomock.Any())
in.EXPECT().Start().Return(nil)
in.EXPECT().Stop().Return(nil)
Expand All @@ -213,6 +220,7 @@ func TestStartStopFailures(t *testing.T) {
outbounds := make(Outbounds, 10)
for i := 0; i < 10; i++ {
out := transporttest.NewMockUnaryOutbound(mockCtrl)
out.EXPECT().Transports()
if i == 5 {
out.EXPECT().Start().Return(errors.New("something went wrong"))
} else {
Expand All @@ -235,6 +243,7 @@ func TestStartStopFailures(t *testing.T) {
inbounds := make(Inbounds, 10)
for i := range inbounds {
in := transporttest.NewMockInbound(mockCtrl)
in.EXPECT().Transports()
in.EXPECT().SetRegistry(gomock.Any())
in.EXPECT().Start().Return(nil)
in.EXPECT().Stop().Return(nil)
Expand All @@ -246,6 +255,7 @@ func TestStartStopFailures(t *testing.T) {
outbounds := make(Outbounds, 10)
for i := 0; i < 10; i++ {
out := transporttest.NewMockUnaryOutbound(mockCtrl)
out.EXPECT().Transports()
out.EXPECT().Start().Return(nil)
if i == 7 {
out.EXPECT().Stop().Return(errors.New("something went wrong"))
Expand Down
8 changes: 8 additions & 0 deletions internal/outboundmiddleware/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ type unaryChainExec struct {
Final transport.UnaryOutbound
}

func (x unaryChainExec) Transports() []transport.Transport {
return x.Final.Transports()
}

func (x unaryChainExec) Start() error {
return x.Final.Start()
}
Expand Down Expand Up @@ -99,6 +103,10 @@ type onewayChainExec struct {
Final transport.OnewayOutbound
}

func (x onewayChainExec) Transports() []transport.Transport {
return x.Final.Transports()
}

func (x onewayChainExec) Start() error {
return x.Final.Start()
}
Expand Down
6 changes: 6 additions & 0 deletions transport/http/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ func (i *Inbound) SetRegistry(registry transport.Registry) {
i.registry = registry
}

// Transports returns the inbound's HTTP transport.
func (i *Inbound) Transports() []transport.Transport {
// TODO factor out transport and return it here.
return []transport.Transport{}
}

// Start starts the inbound with a given service detail, opening a listening
// socket.
func (i *Inbound) Start() error {
Expand Down
6 changes: 6 additions & 0 deletions transport/http/outbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ func (o *Outbound) WithTracer(tracer opentracing.Tracer) *Outbound {
return o
}

// Transports returns the outbound's HTTP transport.
func (o *Outbound) Transports() []transport.Transport {
// TODO factor out transport and return it here.
return []transport.Transport{}
}

// Start the HTTP outbound
func (o *Outbound) Start() error {
o.started.Swap(true)
Expand Down
8 changes: 7 additions & 1 deletion transport/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@ package transport
// calls.
type Inbound interface {
// SetRegistry configures the inbound to dispatch requests through a
// registry.
// registry, typically called by a Dispatcher with its Registrar of handled
// procedures.
SetRegistry(Registry)

// Transport returns any transports that the inbound uses, so they can be
// collected for lifecycle management, typically by a Dispatcher.
// An inbound may submit zero or more transports.
Transports() []Transport

// Starts accepting new requests.
//
// The inbound must have a configured registry.
Expand Down
9 changes: 9 additions & 0 deletions transport/outbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ import "context"

// Outbound is the common interface for all outbounds
type Outbound interface {
// Transports returns the transports that used by this outbound, so they
// can be collected for lifecycle management, typically by a Dispatcher.
//
// Though most outbounds only use a single transport, composite outbounds
// may use multiple transport protocols, particularly for shadowing traffic
// across multiple transport protocols during a transport protocol
// migration.
Transports() []Transport

// Sets up the outbound to start making calls.
//
// This MUST block until the outbound is ready to start sending requests.
Expand Down
8 changes: 8 additions & 0 deletions transport/outboundmiddleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ type unaryOutboundWithMiddleware struct {
f UnaryOutboundMiddleware
}

func (fo unaryOutboundWithMiddleware) Transports() []Transport {
return fo.o.Transports()
}

func (fo unaryOutboundWithMiddleware) Start() error {
return fo.o.Start()
}
Expand Down Expand Up @@ -135,6 +139,10 @@ type onewayOutboundWithMiddleware struct {
f OnewayOutboundMiddleware
}

func (fo onewayOutboundWithMiddleware) Transports() []Transport {
return fo.o.Transports()
}

func (fo onewayOutboundWithMiddleware) Start() error {
return fo.o.Start()
}
Expand Down
6 changes: 6 additions & 0 deletions transport/tchannel/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ func (i *Inbound) Channel() Channel {
return i.ch
}

// Transports returns the underlying Transport for this Inbound.
func (i *Inbound) Transports() []transport.Transport {
// TODO factor out transport and return it here.
return []transport.Transport{}
}

// Start starts the TChannel inbound transport. This immediately opens a listen
// socket.
func (i *Inbound) Start() error {
Expand Down
6 changes: 6 additions & 0 deletions transport/tchannel/outbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ func (o *Outbound) WithTracer(tracer opentracing.Tracer) *Outbound {
return o
}

// Transports returns the underlying TChannel Transport for this outbound.
func (o *Outbound) Transports() []transport.Transport {
// TODO factor out transport and return it here.
return []transport.Transport{}
}

// Start starts the TChannel outbound.
func (o *Outbound) Start() error {
// TODO: Should we create the connection to HostPort (if specified) here or
Expand Down
Loading

0 comments on commit c91f0f0

Please sign in to comment.