From 0a6805b18476e016d7346a8b4876458e82b409d9 Mon Sep 17 00:00:00 2001 From: Nikolai Perevozchikov Date: Mon, 2 Sep 2019 11:34:10 +0900 Subject: [PATCH] pass metadata along with stream creation command --- session.go | 19 +++++++++++++++---- session_test.go | 22 +++++++++++++++++++--- stream.go | 13 ++++++++++--- 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/session.go b/session.go index 5f6bdb4..bb24e4a 100644 --- a/session.go +++ b/session.go @@ -102,7 +102,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { } // OpenStream is used to create a new stream -func (s *Session) OpenStream() (*Stream, error) { +func (s *Session) OpenStream(metadata ...byte) (*Stream, error) { if s.IsClosed() { return nil, errors.WithStack(io.ErrClosedPipe) } @@ -123,9 +123,11 @@ func (s *Session) OpenStream() (*Stream, error) { } s.nextStreamIDLock.Unlock() - stream := newStream(sid, s.config.MaxFrameSize, s) + stream := newStream(sid, metadata, s.config.MaxFrameSize, s) - if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil { + frame := newFrame(cmdSYN, sid) + frame.data = metadata + if _, err := s.writeFrame(frame); err != nil { return nil, errors.WithStack(err) } @@ -307,7 +309,16 @@ func (s *Session) recvLoop() { case cmdSYN: s.streamLock.Lock() if _, ok := s.streams[sid]; !ok { - stream := newStream(sid, s.config.MaxFrameSize, s) + var newbuf []byte + if hdr.Length() > 0 { + newbuf = defaultAllocator.Get(int(hdr.Length())) + if _, err := io.ReadFull(s.conn, newbuf); err != nil { + s.notifyReadError(errors.WithStack(err)) + s.streamLock.Unlock() + return + } + } + stream := newStream(sid, append([]byte(nil), newbuf...), s.config.MaxFrameSize, s) s.streams[sid] = stream select { case s.chAccepts <- stream: diff --git a/session_test.go b/session_test.go index dde6ac6..ea51d52 100644 --- a/session_test.go +++ b/session_test.go @@ -1,6 +1,7 @@ package smux import ( + "bytes" crand "crypto/rand" "encoding/binary" "fmt" @@ -25,7 +26,7 @@ func init() { // setupServer starts new server listening on a random localhost port and // returns address of the server, function to stop the server, new client // connection to this server or an error. -func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) { +func setupServer(tb testing.TB, metadata ...byte) (addr string, stopfunc func(), client net.Conn, err error) { ln, err := net.Listen("tcp", "localhost:0") if err != nil { return "", nil, nil, err @@ -35,7 +36,7 @@ func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, if err != nil { return } - go handleConnection(conn) + go handleConnection(tb, conn, metadata...) }() addr = ln.Addr().String() conn, err := net.Dial("tcp", addr) @@ -46,10 +47,13 @@ func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, return ln.Addr().String(), func() { ln.Close() }, conn, nil } -func handleConnection(conn net.Conn) { +func handleConnection(tb testing.TB, conn net.Conn, metadata ...byte) { session, _ := Server(conn, nil) for { if stream, err := session.AcceptStream(); err == nil { + if !bytes.Equal(metadata, stream.Metadata()) { + tb.Fatal("metadata mismatch") + } go func(s io.ReadWriteCloser) { buf := make([]byte, 65536) for { @@ -66,6 +70,18 @@ func handleConnection(conn net.Conn) { } } +func TestMetadata(t *testing.T) { + metadata := []byte("hello, world") + _, stop, cli, err := setupServer(t, metadata...) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + session.OpenStream(metadata...) + session.Close() +} + func TestEcho(t *testing.T) { _, stop, cli, err := setupServer(t) if err != nil { diff --git a/stream.go b/stream.go index 8d7bbfa..32a7cc4 100644 --- a/stream.go +++ b/stream.go @@ -12,8 +12,9 @@ import ( // Stream implements net.Conn type Stream struct { - id uint32 - sess *Session + id uint32 + metadata []byte + sess *Session buffers [][]byte heads [][]byte // slice heads kept for recycle @@ -38,9 +39,10 @@ type Stream struct { } // newStream initiates a Stream struct -func newStream(id uint32, frameSize int, sess *Session) *Stream { +func newStream(id uint32, metadata []byte, frameSize int, sess *Session) *Stream { s := new(Stream) s.id = id + s.metadata = metadata s.chReadEvent = make(chan struct{}, 1) s.frameSize = frameSize s.sess = sess @@ -54,6 +56,11 @@ func (s *Stream) ID() uint32 { return s.id } +// Metadata returns stream metadata which was provided when opening stream. +func (s *Stream) Metadata() []byte { + return s.metadata +} + // Read implements net.Conn func (s *Stream) Read(b []byte) (n int, err error) { if len(b) == 0 {