Skip to content

Commit

Permalink
delay completion of the receive stream until the reset error was read (
Browse files Browse the repository at this point in the history
…#4460)

* delay completion of the receive stream until the reset error was read

* fix handling of CancelRead after receiving a RESET_STREAM
  • Loading branch information
marten-seemann committed Apr 26, 2024
1 parent bff131e commit 4b87539
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 74 deletions.
136 changes: 81 additions & 55 deletions receive_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,14 @@ type receiveStream struct {
readPosInFrame int
currentFrameIsLast bool // is the currentFrame the last frame on this stream

finRead bool // set once we read a frame with a Fin
// Set once we read the io.EOF or the cancellation error.
// Note that for local cancellations, this doesn't necessarily mean that we know the final offset yet.
errorRead bool
completed bool // set once we've called streamSender.onStreamCompleted
cancelledRemotely bool
cancelledLocally bool
cancelErr *StreamError
closeForShutdownErr error
cancelReadErr error
resetRemotelyErr *StreamError

readChan chan struct{}
readOnce chan struct{} // cap: 1, to protect against concurrent use of Read
Expand Down Expand Up @@ -83,7 +87,8 @@ func (s *receiveStream) Read(p []byte) (int, error) {
defer func() { <-s.readOnce }()

s.mutex.Lock()
completed, n, err := s.readImpl(p)
n, err := s.readImpl(p)
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
Expand All @@ -92,18 +97,38 @@ func (s *receiveStream) Read(p []byte) (int, error) {
return n, err
}

func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, error) {
if s.finRead {
return false, 0, io.EOF
func (s *receiveStream) isNewlyCompleted() bool {
if s.completed {
return false
}
// We need to know the final offset (either via FIN or RESET_STREAM) for flow control accounting.
if s.finalOffset == protocol.MaxByteCount {
return false
}
// We're done with the stream if it was cancelled locally...
if s.cancelledLocally {
s.completed = true
return true
}
if s.cancelReadErr != nil {
return false, 0, s.cancelReadErr
// ... or if the error (either io.EOF or the reset error) was read
if s.errorRead {
s.completed = true
return true
}
return false
}

func (s *receiveStream) readImpl(p []byte) (int, error) {
if s.currentFrameIsLast && s.currentFrame == nil {
s.errorRead = true
return 0, io.EOF
}
if s.resetRemotelyErr != nil {
return false, 0, s.resetRemotelyErr
if s.cancelledRemotely || s.cancelledLocally {
s.errorRead = true
return 0, s.cancelErr
}
if s.closeForShutdownErr != nil {
return false, 0, s.closeForShutdownErr
return 0, s.closeForShutdownErr
}

var bytesRead int
Expand All @@ -113,25 +138,23 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
s.dequeueNextFrame()
}
if s.currentFrame == nil && bytesRead > 0 {
return false, bytesRead, s.closeForShutdownErr
return bytesRead, s.closeForShutdownErr
}

for {
// Stop waiting on errors
if s.closeForShutdownErr != nil {
return false, bytesRead, s.closeForShutdownErr
return bytesRead, s.closeForShutdownErr
}
if s.cancelReadErr != nil {
return false, bytesRead, s.cancelReadErr
}
if s.resetRemotelyErr != nil {
return false, bytesRead, s.resetRemotelyErr
if s.cancelledRemotely || s.cancelledLocally {
s.errorRead = true
return 0, s.cancelErr
}

deadline := s.deadline
if !deadline.IsZero() {
if !time.Now().Before(deadline) {
return false, bytesRead, errDeadline
return bytesRead, errDeadline
}
if deadlineTimer == nil {
deadlineTimer = utils.NewTimer()
Expand Down Expand Up @@ -161,10 +184,10 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
}

if bytesRead > len(p) {
return false, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
}
if s.readPosInFrame > len(s.currentFrame) {
return false, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
}

m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:])
Expand All @@ -173,20 +196,20 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err

// when a RESET_STREAM was received, the flow controller was already
// informed about the final byteOffset for this stream
if s.resetRemotelyErr == nil {
if !s.cancelledRemotely {
s.flowController.AddBytesRead(protocol.ByteCount(m))
}

if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast {
s.finRead = true
s.currentFrame = nil
if s.currentFrameDone != nil {
s.currentFrameDone()
}
return true, bytesRead, io.EOF
s.errorRead = true
return bytesRead, io.EOF
}
}
return false, bytesRead, nil
return bytesRead, nil
}

func (s *receiveStream) dequeueNextFrame() {
Expand All @@ -202,7 +225,8 @@ func (s *receiveStream) dequeueNextFrame() {

func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
s.mutex.Lock()
completed := s.cancelReadImpl(errorCode)
s.cancelReadImpl(errorCode)
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
Expand All @@ -211,23 +235,26 @@ func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
}
}

func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) bool /* completed */ {
if s.finRead || s.cancelReadErr != nil || s.resetRemotelyErr != nil {
return false
func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) {
if s.cancelledLocally { // duplicate call to CancelRead
return
}
s.cancelReadErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.cancelledLocally = true
if s.errorRead || s.cancelledRemotely {
return
}
s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.signalRead()
s.sender.queueControlFrame(&wire.StopSendingFrame{
StreamID: s.streamID,
ErrorCode: errorCode,
})
// We're done with this stream if the final offset was already received.
return s.finalOffset != protocol.MaxByteCount
}

func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
s.mutex.Lock()
completed, err := s.handleStreamFrameImpl(frame)
err := s.handleStreamFrameImpl(frame)
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
Expand All @@ -237,59 +264,58 @@ func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
return err
}

func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) (bool /* completed */, error) {
func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) error {
maxOffset := frame.Offset + frame.DataLen()
if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin); err != nil {
return false, err
return err
}
var newlyRcvdFinalOffset bool
if frame.Fin {
newlyRcvdFinalOffset = s.finalOffset == protocol.MaxByteCount
s.finalOffset = maxOffset
}
if s.cancelReadErr != nil {
return newlyRcvdFinalOffset, nil
if s.cancelledLocally {
return nil
}
if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil {
return false, err
return err
}
s.signalRead()
return false, nil
return nil
}

func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) error {
s.mutex.Lock()
completed, err := s.handleResetStreamFrameImpl(frame)
err := s.handleResetStreamFrameImpl(frame)
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
s.flowController.Abandon()
s.sender.onStreamCompleted(s.streamID)
}
return err
}

func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) (bool /*completed */, error) {
func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) error {
if s.closeForShutdownErr != nil {
return false, nil
return nil
}
if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true); err != nil {
return false, err
return err
}
newlyRcvdFinalOffset := s.finalOffset == protocol.MaxByteCount
s.finalOffset = frame.FinalSize

// ignore duplicate RESET_STREAM frames for this stream (after checking their final offset)
if s.resetRemotelyErr != nil {
return false, nil
if s.cancelledRemotely {
return nil
}
s.resetRemotelyErr = &StreamError{
StreamID: s.streamID,
ErrorCode: frame.ErrorCode,
Remote: true,
s.flowController.Abandon()
// don't save the error if the RESET_STREAM frames was received after CancelRead was called
if s.cancelledLocally {
return nil
}
s.cancelledRemotely = true
s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: frame.ErrorCode, Remote: true}
s.signalRead()
return newlyRcvdFinalOffset, nil
return nil
}

func (s *receiveStream) SetReadDeadline(t time.Time) error {
Expand Down
65 changes: 46 additions & 19 deletions receive_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ var _ = Describe("Receive Stream", func() {

It("returns an error when Read is called after the deadline", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), false).AnyTimes()
f := &wire.StreamFrame{Data: []byte("foobar")}
err := str.handleStreamFrame(f)
Expect(err).ToNot(HaveOccurred())
Expect(str.handleStreamFrame(&wire.StreamFrame{Data: []byte("foobar")})).To(Succeed())
str.SetReadDeadline(time.Now().Add(-time.Second))
b := make([]byte, 6)
n, err := strWithTimeout.Read(b)
Expand Down Expand Up @@ -534,34 +532,46 @@ var _ = Describe("Receive Stream", func() {
Fin: true,
})).To(Succeed())
mockSender.EXPECT().onStreamCompleted(streamID)
_, err := strWithTimeout.Read(make([]byte, 100))
n, err := strWithTimeout.Read(make([]byte, 100))
Expect(err).To(MatchError(io.EOF))
Expect(n).To(Equal(6))
str.CancelRead(1234)
})

It("doesn't send a STOP_SENDING frame, if the stream was already reset", func() {
gomock.InOrder(
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true),
mockFC.EXPECT().Abandon(),
)
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
mockFC.EXPECT().Abandon().MinTimes(1)
Expect(str.handleResetStreamFrame(&wire.ResetStreamFrame{
ErrorCode: 1337,
StreamID: streamID,
FinalSize: 42,
})).To(Succeed())
mockSender.EXPECT().onStreamCompleted(gomock.Any())
str.CancelRead(1234)
// check that the error indicates a remote reset
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
var streamErr *StreamError
Expect(errors.As(err, &streamErr)).To(BeTrue())
Expect(streamErr.ErrorCode).To(BeEquivalentTo(1337))
Expect(streamErr.Remote).To(BeTrue())
})

It("sends a STOP_SENDING and completes the stream after receiving the final offset", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(1000), true)
It("sends a STOP_SENDING after receiving the final offset", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), true)
Expect(str.handleStreamFrame(&wire.StreamFrame{
Offset: 1000,
Fin: true,
Data: []byte("foobar"),
Fin: true,
})).To(Succeed())
mockFC.EXPECT().Abandon()
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(streamID)
str.CancelRead(1234)
// read the error
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
})

It("completes the stream when receiving the Fin after the stream was canceled", func() {
Expand Down Expand Up @@ -649,32 +659,49 @@ var _ = Describe("Receive Stream", func() {
})

It("ignores duplicate RESET_STREAM frames", func() {
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().Abandon()
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
mockFC.EXPECT().Abandon()
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
})

It("doesn't call onStreamCompleted again when the final offset was already received via Fin", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
str.CancelRead(1234)
mockSender.EXPECT().onStreamCompleted(streamID)
mockFC.EXPECT().Abandon()
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true).Times(2)
Expect(str.handleStreamFrame(&wire.StreamFrame{
StreamID: streamID,
Offset: rst.FinalSize,
Fin: true,
})).To(Succeed())
mockFC.EXPECT().Abandon().MinTimes(1)
mockSender.EXPECT().onStreamCompleted(streamID)
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
// now read the error
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
})

It("doesn't do anything when it was closed for shutdown", func() {
str.closeForShutdown(errors.New("shutdown"))
err := str.handleResetStreamFrame(rst)
Expect(err).ToNot(HaveOccurred())
})

It("handles RESET_STREAM after CancelRead", func() {
mockFC.EXPECT().Abandon()
mockSender.EXPECT().queueControlFrame(gomock.Any())
str.CancelRead(1234)
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(42), true)
mockSender.EXPECT().onStreamCompleted(streamID)
Expect(str.handleResetStreamFrame(rst)).To(Succeed())
// check that the error indicates a local reset
n, err := str.Read([]byte{0})
Expect(err).To(HaveOccurred())
Expect(n).To(BeZero())
var streamErr *StreamError
Expect(errors.As(err, &streamErr)).To(BeTrue())
Expect(streamErr.Remote).To(BeFalse())
})
})
})

Expand Down

0 comments on commit 4b87539

Please sign in to comment.