From 0a25e16561ce69a71722d28e6ecc171c0d734b6b Mon Sep 17 00:00:00 2001 From: Jeff Lindsay Date: Tue, 13 Apr 2021 16:03:17 -0500 Subject: [PATCH 1/2] golang/session: initial implementation and basic test --- .gitignore | 1 + golang/api.go | 44 ++++++ golang/session/channel.go | 251 ++++++++++++++++++++++++++++++++ golang/session/session.go | 215 +++++++++++++++++++++++++++ golang/session/session_test.go | 82 +++++++++++ golang/session/util.go | 8 + golang/session/util_buffer.go | 93 ++++++++++++ golang/session/util_chanlist.go | 61 ++++++++ golang/session/util_window.go | 78 ++++++++++ 9 files changed, 833 insertions(+) create mode 100644 .gitignore create mode 100644 golang/api.go create mode 100644 golang/session/channel.go create mode 100644 golang/session/session.go create mode 100644 golang/session/session_test.go create mode 100644 golang/session/util.go create mode 100644 golang/session/util_buffer.go create mode 100644 golang/session/util_chanlist.go create mode 100644 golang/session/util_window.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..30404ce --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +TODO \ No newline at end of file diff --git a/golang/api.go b/golang/api.go new file mode 100644 index 0000000..0a098f1 --- /dev/null +++ b/golang/api.go @@ -0,0 +1,44 @@ +package mux + +import ( + "context" + "net" +) + +type Session interface { + Context() context.Context + Close() error + Open() (Channel, error) + Accept() (Channel, error) + LocalAddr() net.Addr + RemoteAddr() net.Addr + Wait() error +} + +// A Channel is an ordered, reliable, flow-controlled, duplex stream +// that is multiplexed over a qmux connection. +type Channel interface { + Context() context.Context + + // Read reads up to len(data) bytes from the channel. + Read(data []byte) (int, error) + + // Write writes len(data) bytes to the channel. + Write(data []byte) (int, error) + + // Close signals end of channel use. No data may be sent after this + // call. + Close() error + + // CloseWrite signals the end of sending in-band + // data. The other side may still send data + CloseWrite() error + + ID() uint32 +} + +type Listener interface { + Close() error + Addr() net.Addr + Accept() (Session, error) +} diff --git a/golang/session/channel.go b/golang/session/channel.go new file mode 100644 index 0000000..32dceb7 --- /dev/null +++ b/golang/session/channel.go @@ -0,0 +1,251 @@ +package session + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + + "github.com/progrium/qmux/golang/codec" +) + +type channelDirection uint8 + +const ( + channelInbound channelDirection = iota + channelOutbound +) + +// channel is an implementation of the Channel interface that works +// with the session class. +type channel struct { + ctx context.Context + + // R/O after creation + localId, remoteId uint32 + + // maxIncomingPayload and maxRemotePayload are the maximum + // payload sizes of normal and extended data packets for + // receiving and sending, respectively. The wire packet will + // be 9 or 13 bytes larger (excluding encryption overhead). + maxIncomingPayload uint32 + maxRemotePayload uint32 + + session *session + + // direction contains either channelOutbound, for channels created + // locally, or channelInbound, for channels created by the peer. + direction channelDirection + + // Pending internal channel messages. + msg chan codec.Message + + sentEOF bool + + // thread-safe data + remoteWin window + pending *buffer + + // windowMu protects myWindow, the flow-control window. + windowMu sync.Mutex + myWindow uint32 + + // writeMu serializes calls to session.conn.Write() and + // protects sentClose and packetPool. This mutex must be + // different from windowMu, as writePacket can block if there + // is a key exchange pending. + writeMu sync.Mutex + sentClose bool + + // packet buffer for writing + packetBuf []byte +} + +func (ch *channel) ID() uint32 { + return ch.localId +} + +func (ch *channel) Context() context.Context { + return ch.ctx +} + +func (ch *channel) CloseWrite() error { + ch.sentEOF = true + return ch.send(codec.EOFMessage{ + ChannelID: ch.remoteId}) +} + +func (ch *channel) Close() error { + return ch.send(codec.CloseMessage{ + ChannelID: ch.remoteId}) +} + +// Write writes len(data) bytes to the channel. +func (ch *channel) Write(data []byte) (n int, err error) { + if ch.sentEOF { + return 0, io.EOF + } + + for len(data) > 0 { + space := min(ch.maxRemotePayload, len(data)) + if space, err = ch.remoteWin.reserve(space); err != nil { + return n, err + } + + toSend := data[:space] + + if err = ch.session.enc.Encode(codec.DataMessage{ + ChannelID: ch.remoteId, + Length: uint32(len(toSend)), + Data: toSend, + }); err != nil { + return n, err + } + + n += len(toSend) + data = data[len(toSend):] + } + + return n, err +} + +// Read reads up to len(data) bytes from the channel. +func (c *channel) Read(data []byte) (n int, err error) { + n, err = c.pending.Read(data) + + if n > 0 { + err = c.adjustWindow(uint32(n)) + // sendWindowAdjust can return io.EOF if the remote + // peer has closed the connection, however we want to + // defer forwarding io.EOF to the caller of Read until + // the buffer has been drained. + if n > 0 && err == io.EOF { + err = nil + } + } + return n, err +} + +// writePacket sends a packet. If the packet is a channel close, it updates +// sentClose. This method takes the lock c.writeMu. +func (ch *channel) send(msg interface{}) error { + ch.writeMu.Lock() + defer ch.writeMu.Unlock() + + if ch.sentClose { + return io.EOF + } + + if _, ok := msg.(codec.CloseMessage); ok { + ch.sentClose = true + } + + return ch.session.enc.Encode(msg) +} + +func (c *channel) adjustWindow(n uint32) error { + c.windowMu.Lock() + // Since myWindow is managed on our side, and can never exceed + // the initial window setting, we don't worry about overflow. + c.myWindow += uint32(n) + c.windowMu.Unlock() + return c.send(codec.WindowAdjustMessage{ + ChannelID: c.remoteId, + AdditionalBytes: uint32(n), + }) +} + +func (c *channel) close() { + c.pending.eof() + close(c.msg) + c.writeMu.Lock() + // This is not necessary for a normal channel teardown, but if + // there was another error, it is. + c.sentClose = true + c.writeMu.Unlock() + // Unblock writers. + c.remoteWin.close() +} + +// responseMessageReceived is called when a success or failure message is +// received on a channel to check that such a message is reasonable for the +// given channel. +func (ch *channel) responseMessageReceived() error { + if ch.direction == channelInbound { + return errors.New("qmux: channel response message received on inbound channel") + } + return nil +} + +func (ch *channel) handle(msg codec.Message) error { + switch m := msg.(type) { + case *codec.DataMessage: + return ch.handleData(m) + + case *codec.CloseMessage: + ch.send(codec.CloseMessage{ + ChannelID: ch.remoteId, + }) + ch.session.chanList.remove(ch.localId) + ch.close() + return nil + + case *codec.EOFMessage: + ch.pending.eof() + return nil + + case *codec.WindowAdjustMessage: + if !ch.remoteWin.add(m.AdditionalBytes) { + return fmt.Errorf("qmux: invalid window update for %d bytes", m.AdditionalBytes) + } + return nil + + case *codec.OpenConfirmMessage: + if err := ch.responseMessageReceived(); err != nil { + return err + } + if m.MaxPacketSize < minPacketLength || m.MaxPacketSize > maxPacketLength { + return fmt.Errorf("qmux: invalid MaxPacketSize %d from peer", m.MaxPacketSize) + } + ch.remoteId = m.SenderID + ch.maxRemotePayload = m.MaxPacketSize + ch.remoteWin.add(m.WindowSize) + ch.msg <- m + return nil + + case *codec.OpenFailureMessage: + if err := ch.responseMessageReceived(); err != nil { + return err + } + ch.session.chanList.remove(m.ChannelID) + ch.msg <- m + return nil + + default: + return fmt.Errorf("qmux: invalid channel message %v", msg) + } +} + +func (ch *channel) handleData(msg *codec.DataMessage) error { + if msg.Length > ch.maxIncomingPayload { + // TODO(hanwen): should send Disconnect? + return errors.New("qmux: incoming packet exceeds maximum payload size") + } + + if msg.Length != uint32(len(msg.Data)) { + return errors.New("qmux: wrong packet length") + } + + ch.windowMu.Lock() + if ch.myWindow < msg.Length { + ch.windowMu.Unlock() + // TODO(hanwen): should send Disconnect with reason? + return errors.New("qmux: remote side wrote too much") + } + ch.myWindow -= msg.Length + ch.windowMu.Unlock() + + ch.pending.write(msg.Data) + return nil +} diff --git a/golang/session/session.go b/golang/session/session.go new file mode 100644 index 0000000..01d9a7b --- /dev/null +++ b/golang/session/session.go @@ -0,0 +1,215 @@ +package session + +import ( + "context" + "fmt" + "io" + "net" + "sync" + + mux "github.com/progrium/qmux/golang" + "github.com/progrium/qmux/golang/codec" +) + +const ( + minPacketLength = 9 + maxPacketLength = 1 << 31 + + // channelMaxPacket contains the maximum number of bytes that will be + // sent in a single packet. As per RFC 4253, section 6.1, 32k is also + // the minimum. + channelMaxPacket = 1 << 15 + // We follow OpenSSH here. + channelWindowSize = 64 * channelMaxPacket + + // chanSize sets the amount of buffering qmux connections. This is + // primarily for testing: setting chanSize=0 uncovers deadlocks more + // quickly. + chanSize = 16 +) + +type session struct { + ctx context.Context + conn io.ReadWriteCloser + chanList chanList + + enc *codec.Encoder + dec *codec.Decoder + + incomingChannels chan mux.Channel + + errCond *sync.Cond + err error + closeCh chan bool +} + +// NewSession returns a session that runs over the given connection. +func New(ctx context.Context, rwc io.ReadWriteCloser) mux.Session { + if rwc == nil { + return nil + } + s := &session{ + ctx: ctx, + conn: rwc, + enc: codec.NewEncoder(rwc), + dec: codec.NewDecoder(rwc), + incomingChannels: make(chan mux.Channel, chanSize), + errCond: sync.NewCond(new(sync.Mutex)), + closeCh: make(chan bool, 1), + } + go s.loop() + return s +} + +func (s *session) Context() context.Context { + return s.ctx +} + +func (s *session) Close() error { + s.conn.Close() + return nil +} + +func (s *session) LocalAddr() net.Addr { + if conn, ok := s.conn.(net.Conn); ok { + return conn.LocalAddr() + } + return nil +} + +func (s *session) RemoteAddr() net.Addr { + if conn, ok := s.conn.(net.Conn); ok { + return conn.RemoteAddr() + } + return nil +} + +func (s *session) Wait() error { + s.errCond.L.Lock() + defer s.errCond.L.Unlock() + for s.err == nil { + s.errCond.Wait() + } + return s.err +} + +func (s *session) Accept() (mux.Channel, error) { + // TODO: context cancel + select { + case ch := <-s.incomingChannels: + return ch, nil + case <-s.closeCh: + return nil, io.EOF + } +} + +func (s *session) Open() (mux.Channel, error) { + ch := s.newChannel(channelOutbound) + ch.maxIncomingPayload = channelMaxPacket + + if err := s.enc.Encode(codec.OpenMessage{ + WindowSize: ch.myWindow, + MaxPacketSize: ch.maxIncomingPayload, + SenderID: ch.localId, + }); err != nil { + return nil, err + } + + // TODO: timeout? context cancel? + m := <-ch.msg + if m == nil { + return nil, fmt.Errorf("qmux: channel closed early during open") + } + switch msg := m.(type) { + case *codec.OpenConfirmMessage: + return ch, nil + + case *codec.OpenFailureMessage: + return nil, fmt.Errorf("qmux: channel open failed on remote side") + + default: + return nil, fmt.Errorf("qmux: unexpected packet in response to channel open: %v", msg) + } +} + +func (s *session) newChannel(direction channelDirection) *channel { + ch := &channel{ + ctx: s.ctx, + remoteWin: window{Cond: sync.NewCond(new(sync.Mutex))}, + myWindow: channelWindowSize, + pending: newBuffer(), + direction: direction, + msg: make(chan codec.Message, chanSize), + session: s, + packetBuf: make([]byte, 0), + } + ch.localId = s.chanList.add(ch) + return ch +} + +// loop runs the connection machine. It will process packets until an +// error is encountered. To synchronize on loop exit, use session.Wait. +func (s *session) loop() { + var err error + for err == nil { + err = s.onePacket() + } + + for _, ch := range s.chanList.dropAll() { + ch.close() + } + + s.conn.Close() + s.closeCh <- true + + s.errCond.L.Lock() + s.err = err + s.errCond.Broadcast() + s.errCond.L.Unlock() +} + +// onePacket reads and processes one packet. +func (s *session) onePacket() error { + var err error + var msg codec.Message + + msg, err = s.dec.Decode() + if err != nil { + return err + } + + id, isChan := msg.Channel() + if !isChan { + return s.handleOpen(msg.(*codec.OpenMessage)) + } + + ch := s.chanList.getChan(id) + if ch == nil { + return fmt.Errorf("qmux: invalid channel %d", id) + } + + return ch.handle(msg) +} + +// handleChannelOpen schedules a channel to be Accept()ed. +func (s *session) handleOpen(msg *codec.OpenMessage) error { + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > maxPacketLength { + return s.enc.Encode(codec.OpenFailureMessage{ + ChannelID: msg.SenderID, + }) + } + + c := s.newChannel(channelInbound) + c.remoteId = msg.SenderID + c.maxRemotePayload = msg.MaxPacketSize + c.remoteWin.add(msg.WindowSize) + c.maxIncomingPayload = channelMaxPacket + s.incomingChannels <- c + + return s.enc.Encode(codec.OpenConfirmMessage{ + ChannelID: c.remoteId, + SenderID: c.localId, + WindowSize: c.myWindow, + MaxPacketSize: c.maxIncomingPayload, + }) +} diff --git a/golang/session/session_test.go b/golang/session/session_test.go new file mode 100644 index 0000000..2777744 --- /dev/null +++ b/golang/session/session_test.go @@ -0,0 +1,82 @@ +package session + +import ( + "bytes" + "context" + "io/ioutil" + "net" + "testing" + + mux "github.com/progrium/qmux/golang" +) + +func fatal(err error, t *testing.T) { + if err != nil { + t.Fatal(err) + } +} + +func TestQmux(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + fatal(err, t) + defer l.Close() + + go func() { + conn, err := l.Accept() + fatal(err, t) + defer conn.Close() + + sess := New(context.Background(), conn) + + ch, err := sess.Open() + fatal(err, t) + b, err := ioutil.ReadAll(ch) + fatal(err, t) + ch.Close() // should already be closed by other end + + ch, err = sess.Accept() + _, err = ch.Write(b) + fatal(err, t) + err = ch.CloseWrite() + fatal(err, t) + + err = sess.Close() + fatal(err, t) + }() + + conn, err := net.Dial("tcp", l.Addr().String()) + fatal(err, t) + defer conn.Close() + + sess := New(context.Background(), conn) + + var ch mux.Channel + t.Run("session accept", func(t *testing.T) { + ch, err = sess.Accept() + fatal(err, t) + }) + + t.Run("channel write", func(t *testing.T) { + _, err = ch.Write([]byte("Hello world")) + fatal(err, t) + err = ch.Close() + fatal(err, t) + }) + + t.Run("session open", func(t *testing.T) { + ch, err = sess.Open() + fatal(err, t) + }) + + var b []byte + t.Run("channel read", func(t *testing.T) { + b, err = ioutil.ReadAll(ch) + fatal(err, t) + ch.Close() // should already be closed by other end + }) + + if !bytes.Equal(b, []byte("Hello world")) { + t.Fatalf("unexpected bytes: %s", b) + } + +} diff --git a/golang/session/util.go b/golang/session/util.go new file mode 100644 index 0000000..9fbfe56 --- /dev/null +++ b/golang/session/util.go @@ -0,0 +1,8 @@ +package session + +func min(a uint32, b int) uint32 { + if a < uint32(b) { + return a + } + return uint32(b) +} diff --git a/golang/session/util_buffer.go b/golang/session/util_buffer.go new file mode 100644 index 0000000..ae8b8e8 --- /dev/null +++ b/golang/session/util_buffer.go @@ -0,0 +1,93 @@ +package session + +import ( + "io" + "sync" +) + +// buffer provides a linked list buffer for data exchange +// between producer and consumer. Theoretically the buffer is +// of unlimited capacity as it does no allocation of its own. +type buffer struct { + // protects concurrent access to head, tail and closed + *sync.Cond + + head *element // the buffer that will be read first + tail *element // the buffer that will be read last + + closed bool +} + +// An element represents a single link in a linked list. +type element struct { + buf []byte + next *element +} + +// newBuffer returns an empty buffer that is not closed. +func newBuffer() *buffer { + e := new(element) + b := &buffer{ + Cond: sync.NewCond(new(sync.Mutex)), + head: e, + tail: e, + } + return b +} + +// write makes buf available for Read to receive. +// buf must not be modified after the call to write. +func (b *buffer) write(buf []byte) { + b.Cond.L.Lock() + e := &element{buf: buf} + b.tail.next = e + b.tail = e + b.Cond.Signal() + b.Cond.L.Unlock() +} + +// eof closes the buffer. Reads from the buffer once all +// the data has been consumed will receive io.EOF. +func (b *buffer) eof() { + b.Cond.L.Lock() + b.closed = true + b.Cond.Signal() + b.Cond.L.Unlock() +} + +// Read reads data from the internal buffer in buf. Reads will block +// if no data is available, or until the buffer is closed. +func (b *buffer) Read(buf []byte) (n int, err error) { + b.Cond.L.Lock() + defer b.Cond.L.Unlock() + + for len(buf) > 0 { + // if there is data in b.head, copy it + if len(b.head.buf) > 0 { + r := copy(buf, b.head.buf) + buf, b.head.buf = buf[r:], b.head.buf[r:] + n += r + continue + } + // if there is a next buffer, make it the head + if len(b.head.buf) == 0 && b.head != b.tail { + b.head = b.head.next + continue + } + + // if at least one byte has been copied, return + if n > 0 { + break + } + + // if nothing was read, and there is nothing outstanding + // check to see if the buffer is closed. + if b.closed { + err = io.EOF + break + } + // out of buffers, wait for producer + b.Cond.Wait() + } + return +} diff --git a/golang/session/util_chanlist.go b/golang/session/util_chanlist.go new file mode 100644 index 0000000..121e624 --- /dev/null +++ b/golang/session/util_chanlist.go @@ -0,0 +1,61 @@ +package session + +import "sync" + +// chanList is a thread safe channel list. +type chanList struct { + // protects concurrent access to chans + sync.Mutex + + // chans are indexed by the local id of the channel, which the + // other side should send in the PeersId field. + chans []*channel +} + +// Assigns a channel ID to the given channel. +func (c *chanList) add(ch *channel) uint32 { + c.Lock() + defer c.Unlock() + for i := range c.chans { + if c.chans[i] == nil { + c.chans[i] = ch + return uint32(i) + } + } + c.chans = append(c.chans, ch) + return uint32(len(c.chans) - 1) +} + +// getChan returns the channel for the given ID. +func (c *chanList) getChan(id uint32) *channel { + c.Lock() + defer c.Unlock() + if id < uint32(len(c.chans)) { + return c.chans[id] + } + return nil +} + +func (c *chanList) remove(id uint32) { + c.Lock() + if id < uint32(len(c.chans)) { + c.chans[id] = nil + } + c.Unlock() +} + +// dropAll forgets all channels it knows, returning them in a slice. +func (c *chanList) dropAll() []*channel { + c.Lock() + defer c.Unlock() + var r []*channel + + for _, ch := range c.chans { + if ch == nil { + continue + } + r = append(r, ch) + } + c.chans = nil + return r +} diff --git a/golang/session/util_window.go b/golang/session/util_window.go new file mode 100644 index 0000000..721cd6e --- /dev/null +++ b/golang/session/util_window.go @@ -0,0 +1,78 @@ +package session + +import ( + "io" + "sync" +) + +// window represents the buffer available to clients +// wishing to write to a channel. +type window struct { + *sync.Cond + win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1 + writeWaiters int + closed bool +} + +// add adds win to the amount of window available +// for consumers. +func (w *window) add(win uint32) bool { + // a zero sized window adjust is a noop. + if win == 0 { + return true + } + w.L.Lock() + if w.win+win < win { + w.L.Unlock() + return false + } + w.win += win + // It is unusual that multiple goroutines would be attempting to reserve + // window space, but not guaranteed. Use broadcast to notify all waiters + // that additional window is available. + w.Broadcast() + w.L.Unlock() + return true +} + +// close sets the window to closed, so all reservations fail +// immediately. +func (w *window) close() { + w.L.Lock() + w.closed = true + w.Broadcast() + w.L.Unlock() +} + +// reserve reserves win from the available window capacity. +// If no capacity remains, reserve will block. reserve may +// return less than requested. +func (w *window) reserve(win uint32) (uint32, error) { + var err error + w.L.Lock() + w.writeWaiters++ + w.Broadcast() + for w.win == 0 && !w.closed { + w.Wait() + } + w.writeWaiters-- + if w.win < win { + win = w.win + } + w.win -= win + if w.closed { + err = io.EOF + } + w.L.Unlock() + return win, err +} + +// waitWriterBlocked waits until some goroutine is blocked for further +// writes. It is used in tests only. +func (w *window) waitWriterBlocked() { + w.Cond.L.Lock() + for w.writeWaiters == 0 { + w.Cond.Wait() + } + w.Cond.L.Unlock() +} From 249b077968f212d3e6b83ed2c4999bcb8e6d9e5d Mon Sep 17 00:00:00 2001 From: Jeff Lindsay Date: Sat, 17 Apr 2021 08:16:57 -0500 Subject: [PATCH 2/2] golang: remove unnecessary stateful contexts, but add context to Open for canceling --- golang/api.go | 22 ++++------------- golang/session/channel.go | 6 ----- golang/session/session.go | 43 ++++++++++------------------------ golang/session/session_test.go | 8 +++---- 4 files changed, 20 insertions(+), 59 deletions(-) diff --git a/golang/api.go b/golang/api.go index 0a098f1..bbb25b9 100644 --- a/golang/api.go +++ b/golang/api.go @@ -1,25 +1,17 @@ package mux -import ( - "context" - "net" -) +import "context" type Session interface { - Context() context.Context Close() error - Open() (Channel, error) + Open(ctx context.Context) (Channel, error) Accept() (Channel, error) - LocalAddr() net.Addr - RemoteAddr() net.Addr Wait() error } // A Channel is an ordered, reliable, flow-controlled, duplex stream // that is multiplexed over a qmux connection. type Channel interface { - Context() context.Context - // Read reads up to len(data) bytes from the channel. Read(data []byte) (int, error) @@ -30,15 +22,9 @@ type Channel interface { // call. Close() error - // CloseWrite signals the end of sending in-band - // data. The other side may still send data + // CloseWrite signals the end of sending data. + // The other side may still send data CloseWrite() error ID() uint32 } - -type Listener interface { - Close() error - Addr() net.Addr - Accept() (Session, error) -} diff --git a/golang/session/channel.go b/golang/session/channel.go index 32dceb7..6ef07b7 100644 --- a/golang/session/channel.go +++ b/golang/session/channel.go @@ -1,7 +1,6 @@ package session import ( - "context" "errors" "fmt" "io" @@ -20,7 +19,6 @@ const ( // channel is an implementation of the Channel interface that works // with the session class. type channel struct { - ctx context.Context // R/O after creation localId, remoteId uint32 @@ -66,10 +64,6 @@ func (ch *channel) ID() uint32 { return ch.localId } -func (ch *channel) Context() context.Context { - return ch.ctx -} - func (ch *channel) CloseWrite() error { ch.sentEOF = true return ch.send(codec.EOFMessage{ diff --git a/golang/session/session.go b/golang/session/session.go index 01d9a7b..24c4fe2 100644 --- a/golang/session/session.go +++ b/golang/session/session.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "net" "sync" mux "github.com/progrium/qmux/golang" @@ -29,7 +28,6 @@ const ( ) type session struct { - ctx context.Context conn io.ReadWriteCloser chanList chanList @@ -44,12 +42,11 @@ type session struct { } // NewSession returns a session that runs over the given connection. -func New(ctx context.Context, rwc io.ReadWriteCloser) mux.Session { +func New(rwc io.ReadWriteCloser) mux.Session { if rwc == nil { return nil } s := &session{ - ctx: ctx, conn: rwc, enc: codec.NewEncoder(rwc), dec: codec.NewDecoder(rwc), @@ -61,29 +58,11 @@ func New(ctx context.Context, rwc io.ReadWriteCloser) mux.Session { return s } -func (s *session) Context() context.Context { - return s.ctx -} - func (s *session) Close() error { s.conn.Close() return nil } -func (s *session) LocalAddr() net.Addr { - if conn, ok := s.conn.(net.Conn); ok { - return conn.LocalAddr() - } - return nil -} - -func (s *session) RemoteAddr() net.Addr { - if conn, ok := s.conn.(net.Conn); ok { - return conn.RemoteAddr() - } - return nil -} - func (s *session) Wait() error { s.errCond.L.Lock() defer s.errCond.L.Unlock() @@ -94,7 +73,6 @@ func (s *session) Wait() error { } func (s *session) Accept() (mux.Channel, error) { - // TODO: context cancel select { case ch := <-s.incomingChannels: return ch, nil @@ -103,7 +81,7 @@ func (s *session) Accept() (mux.Channel, error) { } } -func (s *session) Open() (mux.Channel, error) { +func (s *session) Open(ctx context.Context) (mux.Channel, error) { ch := s.newChannel(channelOutbound) ch.maxIncomingPayload = channelMaxPacket @@ -115,18 +93,22 @@ func (s *session) Open() (mux.Channel, error) { return nil, err } - // TODO: timeout? context cancel? - m := <-ch.msg - if m == nil { - return nil, fmt.Errorf("qmux: channel closed early during open") + var m codec.Message + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case m = <-ch.msg: + if m == nil { + return nil, fmt.Errorf("qmux: channel closed early during open") + } } + switch msg := m.(type) { case *codec.OpenConfirmMessage: return ch, nil - case *codec.OpenFailureMessage: return nil, fmt.Errorf("qmux: channel open failed on remote side") - default: return nil, fmt.Errorf("qmux: unexpected packet in response to channel open: %v", msg) } @@ -134,7 +116,6 @@ func (s *session) Open() (mux.Channel, error) { func (s *session) newChannel(direction channelDirection) *channel { ch := &channel{ - ctx: s.ctx, remoteWin: window{Cond: sync.NewCond(new(sync.Mutex))}, myWindow: channelWindowSize, pending: newBuffer(), diff --git a/golang/session/session_test.go b/golang/session/session_test.go index 2777744..b2445c0 100644 --- a/golang/session/session_test.go +++ b/golang/session/session_test.go @@ -26,9 +26,9 @@ func TestQmux(t *testing.T) { fatal(err, t) defer conn.Close() - sess := New(context.Background(), conn) + sess := New(conn) - ch, err := sess.Open() + ch, err := sess.Open(context.Background()) fatal(err, t) b, err := ioutil.ReadAll(ch) fatal(err, t) @@ -48,7 +48,7 @@ func TestQmux(t *testing.T) { fatal(err, t) defer conn.Close() - sess := New(context.Background(), conn) + sess := New(conn) var ch mux.Channel t.Run("session accept", func(t *testing.T) { @@ -64,7 +64,7 @@ func TestQmux(t *testing.T) { }) t.Run("session open", func(t *testing.T) { - ch, err = sess.Open() + ch, err = sess.Open(context.Background()) fatal(err, t) })