Skip to content

Commit

Permalink
pass metadata along with stream creation command
Browse files Browse the repository at this point in the history
  • Loading branch information
trueinsider committed Sep 2, 2019
1 parent 6aa95ef commit 0a6805b
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 10 deletions.
19 changes: 15 additions & 4 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
}

// OpenStream is used to create a new stream
func (s *Session) OpenStream() (*Stream, error) {
func (s *Session) OpenStream(metadata ...byte) (*Stream, error) {
if s.IsClosed() {
return nil, errors.WithStack(io.ErrClosedPipe)
}
Expand All @@ -123,9 +123,11 @@ func (s *Session) OpenStream() (*Stream, error) {
}
s.nextStreamIDLock.Unlock()

stream := newStream(sid, s.config.MaxFrameSize, s)
stream := newStream(sid, metadata, s.config.MaxFrameSize, s)

if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil {
frame := newFrame(cmdSYN, sid)
frame.data = metadata
if _, err := s.writeFrame(frame); err != nil {
return nil, errors.WithStack(err)
}

Expand Down Expand Up @@ -307,7 +309,16 @@ func (s *Session) recvLoop() {
case cmdSYN:
s.streamLock.Lock()
if _, ok := s.streams[sid]; !ok {
stream := newStream(sid, s.config.MaxFrameSize, s)
var newbuf []byte
if hdr.Length() > 0 {
newbuf = defaultAllocator.Get(int(hdr.Length()))
if _, err := io.ReadFull(s.conn, newbuf); err != nil {
s.notifyReadError(errors.WithStack(err))
s.streamLock.Unlock()
return
}
}
stream := newStream(sid, append([]byte(nil), newbuf...), s.config.MaxFrameSize, s)
s.streams[sid] = stream
select {
case s.chAccepts <- stream:
Expand Down
22 changes: 19 additions & 3 deletions session_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package smux

import (
"bytes"
crand "crypto/rand"
"encoding/binary"
"fmt"
Expand All @@ -25,7 +26,7 @@ func init() {
// setupServer starts new server listening on a random localhost port and
// returns address of the server, function to stop the server, new client
// connection to this server or an error.
func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) {
func setupServer(tb testing.TB, metadata ...byte) (addr string, stopfunc func(), client net.Conn, err error) {
ln, err := net.Listen("tcp", "localhost:0")
if err != nil {
return "", nil, nil, err
Expand All @@ -35,7 +36,7 @@ func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn,
if err != nil {
return
}
go handleConnection(conn)
go handleConnection(tb, conn, metadata...)
}()
addr = ln.Addr().String()
conn, err := net.Dial("tcp", addr)
Expand All @@ -46,10 +47,13 @@ func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn,
return ln.Addr().String(), func() { ln.Close() }, conn, nil
}

func handleConnection(conn net.Conn) {
func handleConnection(tb testing.TB, conn net.Conn, metadata ...byte) {
session, _ := Server(conn, nil)
for {
if stream, err := session.AcceptStream(); err == nil {
if !bytes.Equal(metadata, stream.Metadata()) {
tb.Fatal("metadata mismatch")
}
go func(s io.ReadWriteCloser) {
buf := make([]byte, 65536)
for {
Expand All @@ -66,6 +70,18 @@ func handleConnection(conn net.Conn) {
}
}

func TestMetadata(t *testing.T) {
metadata := []byte("hello, world")
_, stop, cli, err := setupServer(t, metadata...)
if err != nil {
t.Fatal(err)
}
defer stop()
session, _ := Client(cli, nil)
session.OpenStream(metadata...)
session.Close()
}

func TestEcho(t *testing.T) {
_, stop, cli, err := setupServer(t)
if err != nil {
Expand Down
13 changes: 10 additions & 3 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import (

// Stream implements net.Conn
type Stream struct {
id uint32
sess *Session
id uint32
metadata []byte
sess *Session

buffers [][]byte
heads [][]byte // slice heads kept for recycle
Expand All @@ -38,9 +39,10 @@ type Stream struct {
}

// newStream initiates a Stream struct
func newStream(id uint32, frameSize int, sess *Session) *Stream {
func newStream(id uint32, metadata []byte, frameSize int, sess *Session) *Stream {
s := new(Stream)
s.id = id
s.metadata = metadata
s.chReadEvent = make(chan struct{}, 1)
s.frameSize = frameSize
s.sess = sess
Expand All @@ -54,6 +56,11 @@ func (s *Stream) ID() uint32 {
return s.id
}

// Metadata returns stream metadata which was provided when opening stream.
func (s *Stream) Metadata() []byte {
return s.metadata
}

// Read implements net.Conn
func (s *Stream) Read(b []byte) (n int, err error) {
if len(b) == 0 {
Expand Down

0 comments on commit 0a6805b

Please sign in to comment.