From 04ef8a032ffb98b217def82e973d8811113dc65e Mon Sep 17 00:00:00 2001 From: Sebastien Binet Date: Tue, 21 Jan 2020 16:31:03 +0100 Subject: [PATCH] zmq4: first stab at a Proxy impl Fixes go-zeromq/zmq4#65. --- proxy.go | 144 ++++++++++++++++++++ proxy_test.go | 336 +++++++++++++++++++++++++++++++++++++++++++++++ security_test.go | 4 +- socket.go | 2 +- socket_test.go | 69 +++++----- zall_test.go | 14 ++ 6 files changed, 530 insertions(+), 39 deletions(-) create mode 100644 proxy.go create mode 100644 proxy_test.go create mode 100644 zall_test.go diff --git a/proxy.go b/proxy.go new file mode 100644 index 0000000..e952b7e --- /dev/null +++ b/proxy.go @@ -0,0 +1,144 @@ +// Copyright 2020 The go-zeromq Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zmq4 + +import ( + "context" + "log" + + "golang.org/x/sync/errgroup" + "golang.org/x/xerrors" +) + +type Proxy struct { + ctx context.Context // life-line of proxy + grp *errgroup.Group + cmds chan proxyCmd +} + +type proxyCmd byte + +const ( + proxyStats proxyCmd = iota + proxyPause + proxyResume + proxyKill +) + +// NewProxy creates a new proxy value. +// It proxies messages received on the frontend to the backend (and vice versa) +// If capture is not nil, messages proxied are also sent on that socket. +func NewProxy(ctx context.Context, front, back, capture Socket) *Proxy { + grp, ctx := errgroup.WithContext(ctx) + proxy := Proxy{ + ctx: ctx, + grp: grp, + cmds: make(chan proxyCmd), + } + proxy.run(front, back, capture) + return &proxy +} + +func (p *Proxy) Pause() { p.cmds <- proxyPause } +func (p *Proxy) Stats() { p.cmds <- proxyStats } +func (p *Proxy) Resume() { p.cmds <- proxyResume } +func (p *Proxy) Kill() { p.cmds <- proxyKill } + +// Run runs the proxy loop. +func (p *Proxy) Run() error { + return p.grp.Wait() +} + +func (p *Proxy) run(front, back, capture Socket) { + canRecv := func(sck Socket) bool { + switch sck.Type() { + case Push: + return false + default: + return true + } + } + + canSend := func(sck Socket) bool { + switch sck.Type() { + case Pull: + return false + default: + return true + } + } + + type Pipe struct { + name string + dst Socket + src Socket + } + + var ( + quit = make(chan struct{}) + pipes = []Pipe{ + { + name: "backend", + dst: back, + src: front, + }, + { + name: "frontend", + dst: front, + src: back, + }, + } + ) + + for i := range pipes { + pipe := pipes[i] + if pipe.src == nil || !canRecv(pipe.src) { + continue + } + p.grp.Go(func() error { + canSend := canSend(pipe.dst) + for { + msg, err := pipe.src.Recv() + select { + case <-p.ctx.Done(): + return p.ctx.Err() + case <-quit: + return nil + default: + if canSend { + err = pipe.dst.Send(msg) + if err != nil { + log.Printf("could not forward to %s: %+v", pipe.name, err) + continue + } + } + if err == nil && capture != nil && len(msg.Frames) != 0 { + _ = capture.Send(msg) + } + } + } + }) + } + + p.grp.Go(func() error { + for { + select { + case <-p.ctx.Done(): + return p.ctx.Err() + case cmd := <-p.cmds: + switch cmd { + case proxyPause, proxyResume, proxyStats: + // TODO + case proxyKill: + close(quit) + return nil + default: + // API error. panic. + panic(xerrors.Errorf("invalid control socket command: %v", cmd)) + } + } + } + }) +} diff --git a/proxy_test.go b/proxy_test.go new file mode 100644 index 0000000..c71d4c0 --- /dev/null +++ b/proxy_test.go @@ -0,0 +1,336 @@ +// Copyright 2020 The go-zeromq Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zmq4_test + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/go-zeromq/zmq4" + "golang.org/x/sync/errgroup" + "golang.org/x/xerrors" +) + +func TestProxy(t *testing.T) { + bkg := context.Background() + ctx, timeout := context.WithTimeout(bkg, 20*time.Second) + defer timeout() + + var ( + frontIn = zmq4.NewPush(ctx, zmq4.WithLogger(zmq4.Devnull)) + front = zmq4.NewPull(ctx, zmq4.WithLogger(zmq4.Devnull)) + back = zmq4.NewPush(ctx, zmq4.WithLogger(zmq4.Devnull)) + backOut = zmq4.NewPull(ctx, zmq4.WithLogger(zmq4.Devnull)) + capt = zmq4.NewPush(ctx, zmq4.WithLogger(zmq4.Devnull)) + captOut = zmq4.NewPull(ctx, zmq4.WithLogger(zmq4.Devnull)) + + proxy *zmq4.Proxy + + epFront = "ipc://proxy-front" + epBack = "ipc://proxy-back" + epCapt = "ipc://proxy-capt" + + wg1 sync.WaitGroup // all sockets ready + wg2 sync.WaitGroup // proxy setup + wg3 sync.WaitGroup // all messages received + wg4 sync.WaitGroup // all capture messages received + wg5 sync.WaitGroup // terminate sent + wg6 sync.WaitGroup // all sockets done + ) + + wg1.Add(6) // number of sockets + wg2.Add(1) // proxy ready + wg3.Add(1) // messages received at backout + wg4.Add(1) // capture messages received at capt-out + wg5.Add(1) // terminate + wg6.Add(6) // number of sockets + + cleanUp(epFront) + cleanUp(epBack) + cleanUp(epCapt) + + var ( + msgs = []zmq4.Msg{ + zmq4.NewMsgFrom([]byte("msg1")), + zmq4.NewMsgFrom([]byte("msg2")), + zmq4.NewMsgFrom([]byte("msg3")), + zmq4.NewMsgFrom([]byte("msg4")), + } + ) + + grp, ctx := errgroup.WithContext(ctx) + grp.Go(func() error { + defer frontIn.Close() + err := frontIn.Dial(epFront) + if err != nil { + return xerrors.Errorf("front-in could not dial %q: %w", epFront, err) + } + + wg1.Done() + t.Logf("front-in ready") + wg1.Wait() // sockets + wg2.Wait() // proxy + + for _, msg := range msgs { + t.Logf("front-in sending %v...", msg) + err = frontIn.Send(msg) + if err != nil { + return xerrors.Errorf("could not send front-in %q: %w", msg, err) + } + t.Logf("front-in sending %v... [done]", msg) + } + + wg3.Wait() // all messages received + wg4.Wait() // all capture messages received + t.Logf("front-in waiting for terminate signal") + wg5.Wait() // terminate + + wg6.Done() // all sockets done + wg6.Wait() + return nil + }) + + grp.Go(func() error { + defer front.Close() + err := front.Listen(epFront) + if err != nil { + return xerrors.Errorf("front could not listen %q: %w", epFront, err) + } + + wg1.Done() + t.Logf("front ready") + wg1.Wait() // sockets + wg2.Wait() // proxy + wg3.Wait() // all messages received + wg4.Wait() // all capture messages received + t.Logf("front waiting for terminate signal") + wg5.Wait() // terminate + + wg6.Done() // all sockets done + wg6.Wait() + return nil + }) + + grp.Go(func() error { + defer back.Close() + err := back.Listen(epBack) + if err != nil { + return xerrors.Errorf("back could not listen %q: %w", epBack, err) + } + + wg1.Done() + t.Logf("back ready") + wg1.Wait() // sockets + wg2.Wait() // proxy + wg3.Wait() // all messages received + wg4.Wait() // all capture messages received + t.Logf("back waiting for terminate signal") + wg5.Wait() // terminate + + wg6.Done() // all sockets done + wg6.Wait() + return nil + }) + + grp.Go(func() error { + defer backOut.Close() + err := backOut.Dial(epBack) + if err != nil { + return xerrors.Errorf("back-out could not dial %q: %w", epBack, err) + } + + wg1.Done() + t.Logf("back-out ready") + wg1.Wait() // sockets + wg2.Wait() // proxy + + for _, want := range msgs { + t.Logf("back-out recving %v...", want) + msg, err := backOut.Recv() + if err != nil { + return xerrors.Errorf("back-out could not recv: %w", err) + } + if msg.String() != want.String() { + return xerrors.Errorf("invalid message: got=%v, want=%v", msg, want) + } + t.Logf("back-out recving %v... [done]", msg) + } + + wg3.Done() // all messages received + wg3.Wait() // all messages received + wg4.Wait() // all capture messages received + t.Logf("back-out waiting for terminate signal") + wg5.Wait() // terminate + + wg6.Done() // all sockets done + wg6.Wait() + return nil + }) + + grp.Go(func() error { + defer captOut.Close() + err := captOut.Listen(epCapt) + if err != nil { + return xerrors.Errorf("capt-out could not listen %q: %w", epCapt, err) + } + + wg1.Done() + t.Logf("capt-out ready") + wg1.Wait() // sockets + wg2.Wait() // proxy + wg3.Wait() // all messages received + + for _, want := range msgs { + t.Logf("capt-out recving %v...", want) + msg, err := captOut.Recv() + if err != nil { + return xerrors.Errorf("capt-out could not recv msg: %w", err) + } + if msg.String() != want.String() { + return xerrors.Errorf("capt-out: invalid message: got=%v, want=%v", msg, want) + } + t.Logf("capt-out recving %v... [done]", msg) + } + + wg4.Done() // all capture messages received + wg4.Wait() // all capture messages received + t.Logf("capt-out waiting for terminate signal") + wg5.Wait() // terminate + + wg6.Done() // all sockets done + wg6.Wait() + return nil + }) + + grp.Go(func() error { + defer capt.Close() + err := capt.Dial(epCapt) + if err != nil { + return xerrors.Errorf("capt could not dial %q: %w", epCapt, err) + } + + wg1.Done() + t.Logf("capt ready") + wg1.Wait() // sockets + wg2.Wait() // proxy + wg3.Wait() // all messages received + wg4.Wait() // all capture messages received + t.Logf("capt waiting for terminate signal") + wg5.Wait() // terminate + + wg6.Done() // all sockets done + wg6.Wait() + return nil + }) + + grp.Go(func() error { + t.Logf("ctrl ready") + wg1.Wait() // sockets + wg2.Wait() // proxy + for _, cmd := range []struct { + name string + fct func() + }{ + {"pause", proxy.Pause}, + {"resume", proxy.Resume}, + {"stats", proxy.Stats}, + } { + t.Logf("ctrl sending %v...", cmd.name) + cmd.fct() + t.Logf("ctrl sending %v... [done]", cmd.name) + } + wg3.Wait() // all messages received + wg4.Wait() // all capture messages received + + t.Logf("ctrl sending kill...") + proxy.Kill() + t.Logf("ctrl sending kill... [done]") + + wg5.Done() + t.Logf("ctrl waiting for terminate signal") + wg5.Wait() // terminate + + wg6.Wait() + return nil + }) + + grp.Go(func() error { + wg1.Wait() // sockets ready + proxy = zmq4.NewProxy(ctx, front, back, capt) + t.Logf("proxy ready") + wg2.Done() + err := proxy.Run() + t.Logf("proxy done: err=%+v", err) + return err + }) + + if err := grp.Wait(); err != nil { + t.Fatalf("error: %+v", err) + } + + if err := ctx.Err(); err != nil && err != context.Canceled { + t.Fatalf("error: %+v", err) + } +} + +func TestProxyStop(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + var ( + epFront = "ipc://proxy-stop-front" + epBack = "ipc://proxy-stop-back" + + frontIn = zmq4.NewPush(ctx, zmq4.WithLogger(zmq4.Devnull)) + front = zmq4.NewPull(ctx, zmq4.WithLogger(zmq4.Devnull)) + back = zmq4.NewPush(ctx, zmq4.WithLogger(zmq4.Devnull)) + backOut = zmq4.NewPull(ctx, zmq4.WithLogger(zmq4.Devnull)) + ) + + cleanUp(epFront) + cleanUp(epBack) + + defer front.Close() + defer back.Close() + + if err := front.Listen(epFront); err != nil { + t.Fatalf("could not listen: %+v", err) + } + + if err := frontIn.Dial(epFront); err != nil { + t.Fatalf("could not dial: %+v", err) + } + + if err := back.Listen(epBack); err != nil { + t.Fatalf("could not listen: %+v", err) + } + + if err := backOut.Dial(epBack); err != nil { + t.Fatalf("could not dial: %+v", err) + } + + var errc = make(chan error) + go func() { + errc <- zmq4.NewProxy(ctx, front, back, nil).Run() + }() + + go func() { + _ = frontIn.Send(zmq4.NewMsgString("msg1")) + }() + go func() { + _, _ = backOut.Recv() + }() + cancel() + + err := <-errc + if err != context.Canceled { + t.Fatalf("error: %+v", err) + } + + if err := ctx.Err(); err != nil && err != context.Canceled { + t.Fatalf("error: %+v", err) + } +} diff --git a/security_test.go b/security_test.go index 8ec76af..581afe9 100644 --- a/security_test.go +++ b/security_test.go @@ -56,10 +56,10 @@ func TestNullHandshakeReqRep(t *testing.T) { ep := "ipc://ipc-req-rep-null-sec" cleanUp(ep) - req := NewReq(ctx, WithSecurity(sec)) + req := NewReq(ctx, WithSecurity(sec), WithLogger(Devnull)) defer req.Close() - rep := NewRep(ctx, WithSecurity(sec)) + rep := NewRep(ctx, WithSecurity(sec), WithLogger(Devnull)) defer rep.Close() grp, ctx := errgroup.WithContext(ctx) diff --git a/socket.go b/socket.go index 80a0b13..ccdb55b 100644 --- a/socket.go +++ b/socket.go @@ -189,7 +189,7 @@ func (sck *socket) accept() { conn, err := sck.listener.Accept() if err != nil { // FIXME(sbinet): maybe bubble up this error to application code? - sck.log.Printf("error accepting connection from %q: %+v", sck.ep, err) + //sck.log.Printf("error accepting connection from %q: %+v", sck.ep, err) continue } diff --git a/socket_test.go b/socket_test.go index 34e6b0a..bc2029c 100644 --- a/socket_test.go +++ b/socket_test.go @@ -7,8 +7,6 @@ package zmq4_test import ( "context" "io" - "io/ioutil" - "log" "net" "testing" "time" @@ -68,7 +66,6 @@ func TestConnPairs(t *testing.T) { t.Parallel() bkg := context.Background() - msg := log.New(ioutil.Discard, "zmq4: ", 0) for _, tc := range []struct { name string @@ -78,69 +75,69 @@ func TestConnPairs(t *testing.T) { }{ { name: "pair", - srv: zmq4.NewPair(bkg, zmq4.WithLogger(msg)), - wrong: zmq4.NewSub(bkg, zmq4.WithLogger(msg)), - cli: zmq4.NewPair(bkg, zmq4.WithLogger(msg)), + srv: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)), + wrong: zmq4.NewSub(bkg, zmq4.WithLogger(zmq4.Devnull)), + cli: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)), }, { name: "pub", - srv: zmq4.NewPub(bkg, zmq4.WithLogger(msg)), - wrong: zmq4.NewPair(bkg, zmq4.WithLogger(msg)), - cli: zmq4.NewSub(bkg, zmq4.WithLogger(msg)), + srv: zmq4.NewPub(bkg, zmq4.WithLogger(zmq4.Devnull)), + wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)), + cli: zmq4.NewSub(bkg, zmq4.WithLogger(zmq4.Devnull)), }, { name: "sub", - srv: zmq4.NewSub(bkg, zmq4.WithLogger(msg)), - wrong: zmq4.NewPair(bkg, zmq4.WithLogger(msg)), - cli: zmq4.NewPub(bkg, zmq4.WithLogger(msg)), + srv: zmq4.NewSub(bkg, zmq4.WithLogger(zmq4.Devnull)), + wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)), + cli: zmq4.NewPub(bkg, zmq4.WithLogger(zmq4.Devnull)), }, { name: "req", - srv: zmq4.NewReq(bkg, zmq4.WithLogger(msg)), - wrong: zmq4.NewPair(bkg, zmq4.WithLogger(msg)), - cli: zmq4.NewRep(bkg, zmq4.WithLogger(msg)), + srv: zmq4.NewReq(bkg, zmq4.WithLogger(zmq4.Devnull)), + wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)), + cli: zmq4.NewRep(bkg, zmq4.WithLogger(zmq4.Devnull)), }, { name: "rep", - srv: zmq4.NewRep(bkg, zmq4.WithLogger(msg)), - wrong: zmq4.NewPair(bkg, zmq4.WithLogger(msg)), - cli: zmq4.NewReq(bkg, zmq4.WithLogger(msg)), + srv: zmq4.NewRep(bkg, zmq4.WithLogger(zmq4.Devnull)), + wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)), + cli: zmq4.NewReq(bkg, zmq4.WithLogger(zmq4.Devnull)), }, { name: "dealer", - srv: zmq4.NewDealer(bkg, zmq4.WithLogger(msg)), - wrong: zmq4.NewPair(bkg, zmq4.WithLogger(msg)), - cli: zmq4.NewRouter(bkg, zmq4.WithLogger(msg)), + srv: zmq4.NewDealer(bkg, zmq4.WithLogger(zmq4.Devnull)), + wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)), + cli: zmq4.NewRouter(bkg, zmq4.WithLogger(zmq4.Devnull)), }, { name: "router", - srv: zmq4.NewRouter(bkg, zmq4.WithLogger(msg)), - wrong: zmq4.NewPair(bkg, zmq4.WithLogger(msg)), - cli: zmq4.NewDealer(bkg, zmq4.WithLogger(msg)), + srv: zmq4.NewRouter(bkg, zmq4.WithLogger(zmq4.Devnull)), + wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)), + cli: zmq4.NewDealer(bkg, zmq4.WithLogger(zmq4.Devnull)), }, { name: "pull", - srv: zmq4.NewPull(bkg, zmq4.WithLogger(msg)), - wrong: zmq4.NewPair(bkg, zmq4.WithLogger(msg)), - cli: zmq4.NewPush(bkg, zmq4.WithLogger(msg)), + srv: zmq4.NewPull(bkg, zmq4.WithLogger(zmq4.Devnull)), + wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)), + cli: zmq4.NewPush(bkg, zmq4.WithLogger(zmq4.Devnull)), }, { name: "push", - srv: zmq4.NewPush(bkg, zmq4.WithLogger(msg)), - wrong: zmq4.NewPair(bkg, zmq4.WithLogger(msg)), - cli: zmq4.NewPull(bkg, zmq4.WithLogger(msg)), + srv: zmq4.NewPush(bkg, zmq4.WithLogger(zmq4.Devnull)), + wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)), + cli: zmq4.NewPull(bkg, zmq4.WithLogger(zmq4.Devnull)), }, { name: "xpub", - srv: zmq4.NewXPub(bkg, zmq4.WithLogger(msg)), - wrong: zmq4.NewPair(bkg, zmq4.WithLogger(msg)), - cli: zmq4.NewXSub(bkg, zmq4.WithLogger(msg)), + srv: zmq4.NewXPub(bkg, zmq4.WithLogger(zmq4.Devnull)), + wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)), + cli: zmq4.NewXSub(bkg, zmq4.WithLogger(zmq4.Devnull)), }, { name: "xsub", - srv: zmq4.NewXSub(bkg, zmq4.WithLogger(msg)), - wrong: zmq4.NewPair(bkg, zmq4.WithLogger(msg)), - cli: zmq4.NewXPub(bkg, zmq4.WithLogger(msg)), + srv: zmq4.NewXSub(bkg, zmq4.WithLogger(zmq4.Devnull)), + wrong: zmq4.NewPair(bkg, zmq4.WithLogger(zmq4.Devnull)), + cli: zmq4.NewXPub(bkg, zmq4.WithLogger(zmq4.Devnull)), }, } { t.Run(tc.name, func(t *testing.T) { diff --git a/zall_test.go b/zall_test.go new file mode 100644 index 0000000..a322bdd --- /dev/null +++ b/zall_test.go @@ -0,0 +1,14 @@ +// Copyright 2020 The go-zeromq Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zmq4 + +import ( + "io/ioutil" + "log" +) + +var ( + Devnull = log.New(ioutil.Discard, "zmq4: ", 0) +)