From f8a021f26fc32fc7e6cae56afd963ef7d45ce81d Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sun, 24 Mar 2024 08:34:20 +0400 Subject: [PATCH] hashsync: convert from chunked streams to normal streams --- hashsync/handler.go | 271 ++++++++++++++------------------------- hashsync/handler_test.go | 165 ++++++++---------------- hashsync/interface.go | 2 +- hashsync/rangesync.go | 9 +- 4 files changed, 151 insertions(+), 296 deletions(-) diff --git a/hashsync/handler.go b/hashsync/handler.go index 779cb83ac2..5e51632bf4 100644 --- a/hashsync/handler.go +++ b/hashsync/handler.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "time" "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" @@ -38,6 +37,7 @@ func (m *decodedItemBatchMessage) Keys() []Ordered { } return r } + func (m *decodedItemBatchMessage) Values() []any { r := make([]any, len(m.ContentValues)) for n, v := range m.ContentValues { @@ -80,136 +80,86 @@ func decodeItemBatchMessage(m *ItemBatchMessage, newValue NewValueFunc) (*decode type conduitState int type wireConduit struct { - i server.Interactor - pendingMsgs []SyncMessage - initReqBuf *bytes.Buffer - newValue NewValueFunc + stream io.ReadWriter + initReqBuf *bytes.Buffer + newValue NewValueFunc // rmmePrint bool } var _ Conduit = &wireConduit{} -func (c *wireConduit) reset() { - c.pendingMsgs = nil -} - -// receive receives a single frame from the Interactor and decodes one -// or more SyncMessages from it. The frames contain just one message -// except for the initial frame which may contain multiple messages -// b/c of the way Server handles the initial request -func (c *wireConduit) receive() (msgs []SyncMessage, err error) { - data, err := c.i.Receive() - if err != nil { - return nil, err - } - if len(data) == 0 { - return nil, errors.New("zero length sync message") +// NextMessage implements Conduit. +func (c *wireConduit) NextMessage() (SyncMessage, error) { + var b [1]byte + if _, err := io.ReadFull(c.stream, b[:]); err != nil { + if !errors.Is(err, io.EOF) { + return nil, err + } + return nil, nil } - b := bytes.NewBuffer(data) - for { - code, err := b.ReadByte() + mtype := MessageType(b[0]) + // fmt.Fprintf(os.Stderr, "QQQQQ: wireConduit: receive message type %s\n", mtype) + switch mtype { + case MessageTypeDone: + return &DoneMessage{}, nil + case MessageTypeEndRound: + return &EndRoundMessage{}, nil + case MessageTypeItemBatch: + var m ItemBatchMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + dm, err := decodeItemBatchMessage(&m, c.newValue) if err != nil { - if !errors.Is(err, io.EOF) { - // this shouldn't really happen - return nil, err - } - // fmt.Fprintf(os.Stderr, "QQQQQ: wireConduit: decoded msgs: %#v\n", msgs) - return msgs, nil + return nil, err } - mtype := MessageType(code) - // fmt.Fprintf(os.Stderr, "QQQQQ: wireConduit: receive message type %s\n", mtype) - switch mtype { - case MessageTypeDone: - msgs = append(msgs, &DoneMessage{}) - case MessageTypeEndRound: - msgs = append(msgs, &EndRoundMessage{}) - case MessageTypeItemBatch: - var m ItemBatchMessage - if _, err := codec.DecodeFrom(b, &m); err != nil { - return nil, err - } - dm, err := decodeItemBatchMessage(&m, c.newValue) - if err != nil { - return nil, err - } - msgs = append(msgs, dm) - case MessageTypeEmptySet: - msgs = append(msgs, &EmptySetMessage{}) - case MessageTypeEmptyRange: - var m EmptyRangeMessage - if _, err := codec.DecodeFrom(b, &m); err != nil { - return nil, err - } - msgs = append(msgs, &m) - case MessageTypeFingerprint: - var m FingerprintMessage - if _, err := codec.DecodeFrom(b, &m); err != nil { - return nil, err - } - msgs = append(msgs, &m) - case MessageTypeRangeContents: - var m RangeContentsMessage - if _, err := codec.DecodeFrom(b, &m); err != nil { - return nil, err - } - msgs = append(msgs, &m) - case MessageTypeQuery: - var m QueryMessage - if _, err := codec.DecodeFrom(b, &m); err != nil { - return nil, err - } - msgs = append(msgs, &m) - default: - return nil, fmt.Errorf("invalid message code %02x", code) + return dm, nil + case MessageTypeEmptySet: + return &EmptySetMessage{}, nil + case MessageTypeEmptyRange: + var m EmptyRangeMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err } + return &m, nil + case MessageTypeFingerprint: + var m FingerprintMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil + case MessageTypeRangeContents: + var m RangeContentsMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil + case MessageTypeQuery: + var m QueryMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil + default: + return nil, fmt.Errorf("invalid message code %02x", b[0]) } } func (c *wireConduit) send(m sendable) error { - // fmt.Fprintf(os.Stderr, "QQQQQ: wireConduit: sending %s m %#v\n", m.Type(), m) - msg := []byte{byte(m.Type())} - // if c.rmmePrint { - // fmt.Fprintf(os.Stderr, "QQQQQ: send: %s\n", SyncMessageToString(m)) - // } - encoded, err := codec.Encode(m) - if err != nil { - return fmt.Errorf("error encoding %T: %w", m, err) - } - msg = append(msg, encoded...) + var stream io.Writer if c.initReqBuf != nil { - c.initReqBuf.Write(msg) + stream = c.initReqBuf + } else if c.stream == nil { + panic("BUG: wireConduit: no stream") } else { - if err := c.i.Send(msg); err != nil { - return err - } - } - return nil -} - -// NextMessage implements Conduit. -func (c *wireConduit) NextMessage() (SyncMessage, error) { - if len(c.pendingMsgs) != 0 { - m := c.pendingMsgs[0] - c.pendingMsgs = c.pendingMsgs[1:] - // if c.rmmePrint { - // fmt.Fprintf(os.Stderr, "QQQQQ: recv: %s\n", SyncMessageToString(m)) - // } - return m, nil + stream = c.stream } - - msgs, err := c.receive() - if err != nil { - return nil, err - } - if len(msgs) == 0 { - return nil, nil + b := []byte{byte(m.Type())} + if _, err := stream.Write(b); err != nil { + return err } - - c.pendingMsgs = msgs[1:] - // if c.rmmePrint { - // fmt.Fprintf(os.Stderr, "QQQQQ: recv: %s\n", SyncMessageToString(msgs[0])) - // } - return msgs[0], nil + _, err := codec.EncodeTo(stream, m) + return err } func (c *wireConduit) SendFingerprint(x, y Ordered, fingerprint any, count int) error { @@ -289,37 +239,32 @@ func (c *wireConduit) withInitialRequest(toCall func(Conduit) error) ([]byte, er return c.initReqBuf.Bytes(), nil } -func makeHandler(rsr *RangeSetReconciler, c *wireConduit, done chan struct{}) server.InteractiveHandler { - return func(ctx context.Context, i server.Interactor) (time.Duration, error) { - defer func() { - if done != nil { - close(done) - } - }() - c.i = i - for { - c.reset() - // Process() will receive all items and messages from the peer - syncDone, err := rsr.Process(c) - if err != nil { - // do not close done if we're returning an - // error, as the channel will be closed in the - // error handler func - done = nil - return 0, err - } else if syncDone { - return 0, nil - } +func (c *wireConduit) handleStream(stream io.ReadWriter, rsr *RangeSetReconciler) error { + c.stream = stream + for { + // Process() will receive all items and messages from the peer + syncDone, err := rsr.Process(c) + if err != nil { + return err + } else if syncDone { + return nil } } } -func MakeServerHandler(is ItemStore, opts ...Option) server.InteractiveHandler { - return func(ctx context.Context, i server.Interactor) (time.Duration, error) { +func MakeServerHandler(is ItemStore, opts ...Option) server.StreamHandler { + return func(ctx context.Context, req []byte, stream io.ReadWriter) error { c := wireConduit{newValue: is.New} rsr := NewRangeSetReconciler(is, opts...) - h := makeHandler(rsr, &c, nil) - return h(ctx, i) + s := struct { + io.Reader + io.Writer + }{ + // prepend the received request to data being read + Reader: io.MultiReader(bytes.NewBuffer(req), stream), + Writer: stream, + } + return c.handleStream(s, rsr) } } @@ -349,22 +294,9 @@ func syncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, x, if err != nil { return err } - done := make(chan struct{}, 1) - h := makeHandler(rsr, &c, done) - var reqErr error - if err = r.InteractiveRequest(ctx, peer, initReq, h, func(err error) { - reqErr = err - close(done) - }); err != nil { - return err - } - select { - case <-ctx.Done(): - <-done - return ctx.Err() - case <-done: - return reqErr - } + return r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { + return c.handleStream(stream, rsr) + }) } func Probe(ctx context.Context, r requester, peer p2p.Peer, opts ...Option) (fp any, count int, err error) { @@ -394,33 +326,16 @@ func boundedProbe(ctx context.Context, r requester, peer p2p.Peer, x, y *types.H if err != nil { return nil, 0, err } - done := make(chan struct{}, 2) - h := func(ctx context.Context, i server.Interactor) (time.Duration, error) { - defer func() { - done <- struct{}{} - }() - c.i = i + err = r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { + c.stream = stream var err error fp, count, err = rsr.HandleProbeResponse(&c) - return 0, err - } - var reqErr error - if err = r.InteractiveRequest(ctx, peer, initReq, h, func(err error) { - reqErr = err - done <- struct{}{} - }); err != nil { + return err + }) + if err != nil { return nil, 0, err } - select { - case <-ctx.Done(): - <-done - return nil, 0, ctx.Err() - case <-done: - if reqErr != nil { - return nil, 0, reqErr - } - return fp, count, nil - } + return fp, count, nil } // TODO: request duration diff --git a/hashsync/handler_test.go b/hashsync/handler_test.go index 6cf55f37b4..2d98903425 100644 --- a/hashsync/handler_test.go +++ b/hashsync/handler_test.go @@ -1,10 +1,11 @@ package hashsync import ( + "bytes" "context" "fmt" + "io" "slices" - "sync/atomic" "testing" "time" @@ -18,67 +19,14 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/server" ) -type fakeMessage struct { - data []byte - error string -} - -type fakeInteractor struct { - fr *fakeRequester - ctx context.Context - sendCh chan fakeMessage - recvCh chan fakeMessage -} - -func (i *fakeInteractor) Send(data []byte) error { - // fmt.Fprintf(os.Stderr, "%p: send %q\n", i, data) - select { - case i.sendCh <- fakeMessage{data: data}: - atomic.AddUint32(&i.fr.bytesSent, uint32(len(data))) - return nil - case <-i.ctx.Done(): - return i.ctx.Err() - } -} - -func (i *fakeInteractor) SendError(err error) error { - // fmt.Fprintf(os.Stderr, "%p: send error %q\n", i, err) - select { - case i.sendCh <- fakeMessage{error: err.Error()}: - atomic.AddUint32(&i.fr.bytesSent, uint32(len(err.Error()))) - return nil - case <-i.ctx.Done(): - return i.ctx.Err() - } -} - -func (i *fakeInteractor) Receive() ([]byte, error) { - // fmt.Fprintf(os.Stderr, "%p: receive\n", i) - var m fakeMessage - select { - case m = <-i.recvCh: - case <-i.ctx.Done(): - return nil, i.ctx.Err() - } - // fmt.Fprintf(os.Stderr, "%p: received %#v\n", i, m) - if m.error != "" { - atomic.AddUint32(&i.fr.bytesReceived, uint32(len(m.error))) - return nil, fmt.Errorf("%w: %s", server.RemoteError, m.error) - } - atomic.AddUint32(&i.fr.bytesReceived, uint32(len(m.data))) - return m.data, nil -} - type incomingRequest struct { - sendCh chan fakeMessage - recvCh chan fakeMessage + initialRequest []byte + stream io.ReadWriter } -var _ server.Interactor = &fakeInteractor{} - type fakeRequester struct { id p2p.Peer - handler server.ServerHandler + handler server.StreamHandler peers map[p2p.Peer]*fakeRequester reqCh chan incomingRequest bytesSent uint32 @@ -87,7 +35,7 @@ type fakeRequester struct { var _ requester = &fakeRequester{} -func newFakeRequester(id p2p.Peer, handler server.ServerHandler, peers ...requester) *fakeRequester { +func newFakeRequester(id p2p.Peer, handler server.StreamHandler, peers ...requester) *fakeRequester { fr := &fakeRequester{ id: id, handler: handler, @@ -112,13 +60,9 @@ func (fr *fakeRequester) Run(ctx context.Context) error { return nil case req = <-fr.reqCh: } - i := &fakeInteractor{ - fr: fr, - ctx: ctx, - sendCh: req.sendCh, - recvCh: req.recvCh, + if err := fr.handler(ctx, req.initialRequest, req.stream); err != nil { + panic("handler error: " + err.Error()) } - fr.handler.Handle(ctx, i) } } @@ -126,45 +70,41 @@ func (fr *fakeRequester) request( ctx context.Context, pid p2p.Peer, initialRequest []byte, - handler server.InteractiveHandler, + callback server.StreamRequestCallback, ) error { p, found := fr.peers[pid] if !found { return fmt.Errorf("bad peer %q", pid) } - i := &fakeInteractor{ - fr: fr, - ctx: ctx, - sendCh: make(chan fakeMessage, 1), - recvCh: make(chan fakeMessage), + r, w := io.Pipe() + defer r.Close() + defer w.Close() + stream := struct { + io.Reader + io.Writer + }{ + Reader: r, + Writer: w, } - i.sendCh <- fakeMessage{data: initialRequest} select { case p.reqCh <- incomingRequest{ - sendCh: i.recvCh, - recvCh: i.sendCh, + initialRequest: initialRequest, + stream: stream, }: case <-ctx.Done(): return ctx.Err() } - _, err := handler(ctx, i) - return err + return callback(ctx, stream) } -func (fr *fakeRequester) InteractiveRequest( +func (fr *fakeRequester) StreamRequest( ctx context.Context, pid p2p.Peer, initialRequest []byte, - handler server.InteractiveHandler, - failure func(error), + callback server.StreamRequestCallback, + extraProtocols ...string, ) error { - go func() { - err := fr.request(ctx, pid, initialRequest, handler) - if err != nil { - failure(err) - } - }() - return nil + return fr.request(ctx, pid, initialRequest, callback) } type sliceIterator struct { @@ -235,18 +175,14 @@ type fakeRound struct { } func (r *fakeRound) handleMessages(t *testing.T, c Conduit) error { - // fmt.Fprintf(os.Stderr, "fakeRound %q: handleMessages\n", r.name) var msgs []SyncMessage for { msg, err := c.NextMessage() if err != nil { - // fmt.Fprintf(os.Stderr, "fakeRound %q: error getting message: %v\n", r.name, err) return fmt.Errorf("NextMessage(): %w", err) } else if msg == nil { - // fmt.Fprintf(os.Stderr, "fakeRound %q: consumed all messages\n", r.name) break } - // fmt.Fprintf(os.Stderr, "fakeRound %q: got message %#v\n", r.name, msg) msgs = append(msgs, msg) if msg.Type() == MessageTypeDone || msg.Type() == MessageTypeEndRound { break @@ -268,25 +204,35 @@ func (r *fakeRound) handleConversation(t *testing.T, c *wireConduit) error { return nil } -func makeTestHandler(t *testing.T, c *wireConduit, newValue NewValueFunc, done chan struct{}, rounds []fakeRound) server.InteractiveHandler { - return func(ctx context.Context, i server.Interactor) (time.Duration, error) { - defer func() { - if done != nil { - close(done) - } - }() +func makeTestStreamHandler(t *testing.T, c *wireConduit, newValue NewValueFunc, rounds []fakeRound) server.StreamHandler { + cbk := makeTestRequestCallback(t, c, newValue, rounds) + return func(ctx context.Context, initialRequest []byte, stream io.ReadWriter) error { + t.Logf("init request bytes: %d", len(initialRequest)) + s := struct { + io.Reader + io.Writer + }{ + // prepend the received request to data being read + Reader: io.MultiReader(bytes.NewBuffer(initialRequest), stream), + Writer: stream, + } + return cbk(ctx, s) + } +} + +func makeTestRequestCallback(t *testing.T, c *wireConduit, newValue NewValueFunc, rounds []fakeRound) server.StreamRequestCallback { + return func(ctx context.Context, stream io.ReadWriter) error { if c == nil { - c = &wireConduit{i: i, newValue: newValue} + c = &wireConduit{stream: stream, newValue: newValue} } else { - c.i = i + c.stream = stream } for _, round := range rounds { if err := round.handleConversation(t, c); err != nil { - done = nil - return 0, err + return err } } - return 0, nil + return nil } } @@ -296,7 +242,7 @@ func TestWireConduit(t *testing.T) { hs[n] = types.RandomHash() } fp := types.Hash12(hs[2][:12]) - srvHandler := makeTestHandler(t, nil, func() any { return new(fakeValue) }, nil, []fakeRound{ + srvHandler := makeTestStreamHandler(t, nil, func() any { return new(fakeValue) }, []fakeRound{ { name: "server got 1st request", expectMsgs: []SyncMessage{ @@ -369,8 +315,7 @@ func TestWireConduit(t *testing.T) { return c.SendEndRound() }) require.NoError(t, err) - done := make(chan struct{}) - clientHandler := makeTestHandler(t, &c, c.newValue, done, []fakeRound{ + clientCbk := makeTestRequestCallback(t, &c, c.newValue, []fakeRound{ { name: "client got 1st response", expectMsgs: []SyncMessage{ @@ -410,15 +355,11 @@ func TestWireConduit(t *testing.T) { }, }, }) - err = client.InteractiveRequest(context.Background(), "srv", initReq, clientHandler, func(err error) { - t.Errorf("fail handler called: %v", err) - close(done) - }) + err = client.StreamRequest(context.Background(), "srv", initReq, clientCbk) require.NoError(t, err) - <-done } -type getRequesterFunc func(name string, handler server.InteractiveHandler, peers ...requester) (requester, p2p.Peer) +type getRequesterFunc func(name string, handler server.StreamHandler, peers ...requester) (requester, p2p.Peer) func withClientServer( storeA, storeB ItemStore, @@ -443,7 +384,7 @@ func withClientServer( } func fakeRequesterGetter(t *testing.T) getRequesterFunc { - return func(name string, handler server.InteractiveHandler, peers ...requester) (requester, p2p.Peer) { + return func(name string, handler server.StreamHandler, peers ...requester) (requester, p2p.Peer) { pid := p2p.Peer(name) return newFakeRequester(pid, handler, peers...), pid } @@ -457,7 +398,7 @@ func p2pRequesterGetter(t *testing.T) getRequesterFunc { server.WithTimeout(10 * time.Second), server.WithLog(logtest.New(t)), } - return func(name string, handler server.InteractiveHandler, peers ...requester) (requester, p2p.Peer) { + return func(name string, handler server.StreamHandler, peers ...requester) (requester, p2p.Peer) { if len(peers) == 0 { return server.New(mesh.Hosts()[0], proto, handler, opts...), mesh.Hosts()[0].ID() } diff --git a/hashsync/interface.go b/hashsync/interface.go index 4b4c21d5f3..61de430f26 100644 --- a/hashsync/interface.go +++ b/hashsync/interface.go @@ -9,5 +9,5 @@ import ( type requester interface { Run(context.Context) error - InteractiveRequest(context.Context, p2p.Peer, []byte, server.InteractiveHandler, func(error)) error + StreamRequest(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error } diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index e9e7483ec5..9319904b40 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -80,11 +80,10 @@ type NewValueFunc func() any // Conduit handles receiving and sending peer messages type Conduit interface { - // NextMessage returns the next SyncMessage, or nil if there - // are no more SyncMessages. NextMessage is only called after - // a NextItem call indicates that there are no more items. - // NextMessage will not be called after any of Send...() - // methods is invoked + // NextMessage returns the next SyncMessage, or nil if there are no more + // SyncMessages for this session. NextMessage is only called after a NextItem call + // indicates that there are no more items. NextMessage should not be called after + // any of Send...() methods is invoked NextMessage() (SyncMessage, error) // SendFingerprint sends range fingerprint to the peer. // Count must be > 0