Skip to content

Commit

Permalink
pass metadata along with stream creation command
Browse files Browse the repository at this point in the history
  • Loading branch information
trueinsider committed Mar 2, 2019
1 parent f4f6ca3 commit 035af7b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 8 deletions.
10 changes: 6 additions & 4 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
}

// OpenStream is used to create a new stream
func (s *Session) OpenStream() (*Stream, error) {
func (s *Session) OpenStream(metadata ...byte) (*Stream, error) {
if s.IsClosed() {
return nil, errors.New(errBrokenPipe)
}
Expand All @@ -101,9 +101,11 @@ func (s *Session) OpenStream() (*Stream, error) {
}
s.nextStreamIDLock.Unlock()

stream := newStream(sid, s.config.MaxFrameSize, s)
stream := newStream(sid, metadata, s.config.MaxFrameSize, s)

if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil {
frame := newFrame(cmdSYN, sid)
frame.data = metadata
if _, err := s.writeFrame(frame); err != nil {
return nil, errors.Wrap(err, "writeFrame")
}

Expand Down Expand Up @@ -247,7 +249,7 @@ func (s *Session) recvLoop() {
case cmdSYN:
s.streamLock.Lock()
if _, ok := s.streams[f.sid]; !ok {
stream := newStream(f.sid, s.config.MaxFrameSize, s)
stream := newStream(f.sid, f.data, s.config.MaxFrameSize, s)
s.streams[f.sid] = stream
select {
case s.chAccepts <- stream:
Expand Down
22 changes: 19 additions & 3 deletions session_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package smux

import (
"bytes"
crand "crypto/rand"
"encoding/binary"
"fmt"
Expand All @@ -16,7 +17,7 @@ import (
// setupServer starts new server listening on a random localhost port and
// returns address of the server, function to stop the server, new client
// connection to this server or an error.
func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) {
func setupServer(tb testing.TB, metadata ...byte) (addr string, stopfunc func(), client net.Conn, err error) {
ln, err := net.Listen("tcp", "localhost:0")
if err != nil {
return "", nil, nil, err
Expand All @@ -27,7 +28,7 @@ func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn,
tb.Error(err)
return
}
go handleConnection(conn)
go handleConnection(tb, conn, metadata...)
}()
addr = ln.Addr().String()
conn, err := net.Dial("tcp", addr)
Expand All @@ -38,10 +39,13 @@ func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn,
return ln.Addr().String(), func() { ln.Close() }, conn, nil
}

func handleConnection(conn net.Conn) {
func handleConnection(tb testing.TB, conn net.Conn, metadata ...byte) {
session, _ := Server(conn, nil)
for {
if stream, err := session.AcceptStream(); err == nil {
if !bytes.Equal(metadata, stream.Metadata()) {
tb.Fatal("metadata mimatch")
}
go func(s io.ReadWriteCloser) {
buf := make([]byte, 65536)
for {
Expand All @@ -58,6 +62,18 @@ func handleConnection(conn net.Conn) {
}
}

func TestMetadata(t *testing.T) {
metadata := []byte("hello, world")
_, stop, cli, err := setupServer(t, metadata...)
if err != nil {
t.Fatal(err)
}
defer stop()
session, _ := Client(cli, nil)
session.OpenStream(metadata...)
session.Close()
}

func TestEcho(t *testing.T) {
_, stop, cli, err := setupServer(t)
if err != nil {
Expand Down
9 changes: 8 additions & 1 deletion stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
// Stream implements net.Conn
type Stream struct {
id uint32
metadata []byte
rstflag int32
sess *Session
buffer bytes.Buffer
Expand All @@ -27,9 +28,10 @@ type Stream struct {
}

// newStream initiates a Stream struct
func newStream(id uint32, frameSize int, sess *Session) *Stream {
func newStream(id uint32, metadata []byte, frameSize int, sess *Session) *Stream {
s := new(Stream)
s.id = id
s.metadata = metadata
s.chReadEvent = make(chan struct{}, 1)
s.frameSize = frameSize
s.sess = sess
Expand All @@ -42,6 +44,11 @@ func (s *Stream) ID() uint32 {
return s.id
}

// Metadata returns stream metadata which was provided when opening stream.
func (s *Stream) Metadata() []byte {
return s.metadata
}

// Read implements net.Conn
func (s *Stream) Read(b []byte) (n int, err error) {
if len(b) == 0 {
Expand Down

0 comments on commit 035af7b

Please sign in to comment.