diff --git a/internal/protocol/stream_test.go b/internal/protocol/stream_test.go index a8d4654cbc4..4209f8a0c7f 100644 --- a/internal/protocol/stream_test.go +++ b/internal/protocol/stream_test.go @@ -56,5 +56,15 @@ var _ = Describe("Stream ID", func() { Expect(StreamNum(100).StreamID(StreamTypeUni, PerspectiveClient)).To(Equal(StreamID(398))) Expect(StreamNum(100).StreamID(StreamTypeUni, PerspectiveServer)).To(Equal(StreamID(399))) }) + + It("has the right value for MaxStreamCount", func() { + const maxStreamID = StreamID(1<<62 - 1) + for _, dir := range []StreamType{StreamTypeUni, StreamTypeBidi} { + for _, pers := range []Perspective{PerspectiveClient, PerspectiveServer} { + Expect(MaxStreamCount.StreamID(dir, pers)).To(BeNumerically("<=", maxStreamID)) + Expect((MaxStreamCount + 1).StreamID(dir, pers)).To(BeNumerically(">", maxStreamID)) + } + } + }) }) }) diff --git a/internal/wire/max_streams_frame.go b/internal/wire/max_streams_frame.go index 63e506c4902..8157e77ce72 100644 --- a/internal/wire/max_streams_frame.go +++ b/internal/wire/max_streams_frame.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -31,6 +32,9 @@ func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStream return nil, err } f.MaxStreamNum = protocol.StreamNum(streamID) + if f.MaxStreamNum > protocol.MaxStreamCount { + return nil, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum) + } return f, nil } diff --git a/internal/wire/max_streams_frame_test.go b/internal/wire/max_streams_frame_test.go index 35eb6e903f1..3f75dfbfa62 100644 --- a/internal/wire/max_streams_frame_test.go +++ b/internal/wire/max_streams_frame_test.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -43,6 +44,33 @@ var _ = Describe("MAX_STREAMS frame", func() { Expect(err).To(HaveOccurred()) } }) + + for _, t := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { + streamType := t + + It("accepts a frame containing the maximum stream count", func() { + f := &MaxStreamsFrame{ + Type: streamType, + MaxStreamNum: protocol.MaxStreamCount, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + frame, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("errors when receiving a too large stream count", func() { + f := &MaxStreamsFrame{ + Type: streamType, + MaxStreamNum: protocol.MaxStreamCount + 1, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + _, err := parseMaxStreamsFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) + }) + } }) Context("writing", func() { diff --git a/internal/wire/streams_blocked_frame.go b/internal/wire/streams_blocked_frame.go index c46e87b479f..42b455bfe30 100644 --- a/internal/wire/streams_blocked_frame.go +++ b/internal/wire/streams_blocked_frame.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "fmt" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -31,7 +32,9 @@ func parseStreamsBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*Strea return nil, err } f.StreamLimit = protocol.StreamNum(streamLimit) - + if f.StreamLimit > protocol.MaxStreamCount { + return nil, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit) + } return f, nil } diff --git a/internal/wire/streams_blocked_frame_test.go b/internal/wire/streams_blocked_frame_test.go index 9c942fd6989..97820a2f178 100644 --- a/internal/wire/streams_blocked_frame_test.go +++ b/internal/wire/streams_blocked_frame_test.go @@ -2,6 +2,7 @@ package wire import ( "bytes" + "fmt" "io" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -44,6 +45,33 @@ var _ = Describe("STREAMS_BLOCKED frame", func() { Expect(err).To(MatchError(io.EOF)) } }) + + for _, t := range []protocol.StreamType{protocol.StreamTypeUni, protocol.StreamTypeBidi} { + streamType := t + + It("accepts a frame containing the maximum stream count", func() { + f := &StreamsBlockedFrame{ + Type: streamType, + StreamLimit: protocol.MaxStreamCount, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + frame, err := parseStreamsBlockedFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + Expect(frame).To(Equal(f)) + }) + + It("errors when receiving a too large stream count", func() { + f := &StreamsBlockedFrame{ + Type: streamType, + StreamLimit: protocol.MaxStreamCount + 1, + } + b := &bytes.Buffer{} + Expect(f.Write(b, protocol.VersionWhatever)).To(Succeed()) + _, err := parseStreamsBlockedFrame(bytes.NewReader(b.Bytes()), protocol.VersionWhatever) + Expect(err).To(MatchError(fmt.Sprintf("%d exceeds the maximum stream count", protocol.MaxStreamCount+1))) + }) + } }) Context("writing", func() { diff --git a/streams_map.go b/streams_map.go index 6559adae546..99d427c884a 100644 --- a/streams_map.go +++ b/streams_map.go @@ -214,9 +214,6 @@ func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, err } func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) error { - if f.MaxStreamNum > protocol.MaxStreamCount { - return qerr.StreamLimitError - } switch f.Type { case protocol.StreamTypeUni: m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum) diff --git a/streams_map_test.go b/streams_map_test.go index 9a468200f94..d37b2658f78 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -401,13 +401,6 @@ var _ = Describe("Streams Map", func() { _, err = m.OpenUniStream() expectTooManyStreamsError(err) }) - - It("rejects MAX_STREAMS frames with too large values", func() { - Expect(m.HandleMaxStreamsFrame(&wire.MaxStreamsFrame{ - Type: protocol.StreamTypeBidi, - MaxStreamNum: protocol.MaxStreamCount + 1, - })).To(MatchError(qerr.StreamLimitError)) - }) }) Context("sending MAX_STREAMS frames", func() {