Skip to content

Commit

Permalink
delay completion of the send stream until the reset error was deliver…
Browse files Browse the repository at this point in the history
…ed (#4445)

* delay completion of the send stream until the reset error was delivered

* mark the send stream completed on Close after receiving a STOP_SENDING

* fix handling of STOP_SENDING after Close
  • Loading branch information
marten-seemann committed Apr 26, 2024
1 parent 12aa638 commit bff131e
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 23 deletions.
78 changes: 59 additions & 19 deletions send_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ type sendStream struct {

finishedWriting bool // set once Close() is called
finSent bool // set when a STREAM_FRAME with FIN bit has been sent
completed bool // set when this stream has been reported to the streamSender as completed
// Set when the application knows about the cancellation.
// This can happen because the application called CancelWrite,
// or because Write returned the error (for remote cancellations).
cancellationFlagged bool
completed bool // set when this stream has been reported to the streamSender as completed

dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
nextFrame *wire.StreamFrame
Expand Down Expand Up @@ -87,23 +91,32 @@ func (s *sendStream) Write(p []byte) (int, error) {
s.writeOnce <- struct{}{}
defer func() { <-s.writeOnce }()

isNewlyCompleted, n, err := s.write(p)
if isNewlyCompleted {
s.sender.onStreamCompleted(s.streamID)
}
return n, err
}

func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error) {
s.mutex.Lock()
defer s.mutex.Unlock()

if s.finishedWriting {
return 0, fmt.Errorf("write on closed stream %d", s.streamID)
return false, 0, fmt.Errorf("write on closed stream %d", s.streamID)
}
if s.cancelWriteErr != nil {
return 0, s.cancelWriteErr
s.cancellationFlagged = true
return s.isNewlyCompleted(), 0, s.cancelWriteErr
}
if s.closeForShutdownErr != nil {
return 0, s.closeForShutdownErr
return false, 0, s.closeForShutdownErr
}
if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
return 0, errDeadline
return false, 0, errDeadline
}
if len(p) == 0 {
return 0, nil
return false, 0, nil
}

s.dataForWriting = p
Expand Down Expand Up @@ -144,7 +157,7 @@ func (s *sendStream) Write(p []byte) (int, error) {
if !deadline.IsZero() {
if !time.Now().Before(deadline) {
s.dataForWriting = nil
return bytesWritten, errDeadline
return false, bytesWritten, errDeadline
}
if deadlineTimer == nil {
deadlineTimer = utils.NewTimer()
Expand Down Expand Up @@ -179,14 +192,15 @@ func (s *sendStream) Write(p []byte) (int, error) {
}

if bytesWritten == len(p) {
return bytesWritten, nil
return false, bytesWritten, nil
}
if s.closeForShutdownErr != nil {
return bytesWritten, s.closeForShutdownErr
return false, bytesWritten, s.closeForShutdownErr
} else if s.cancelWriteErr != nil {
return bytesWritten, s.cancelWriteErr
s.cancellationFlagged = true
return s.isNewlyCompleted(), bytesWritten, s.cancelWriteErr
}
return bytesWritten, nil
return false, bytesWritten, nil
}

func (s *sendStream) canBufferStreamFrame() bool {
Expand Down Expand Up @@ -349,8 +363,24 @@ func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.By
}

func (s *sendStream) isNewlyCompleted() bool {
completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
if completed && !s.completed {
if s.completed {
return false
}
// We need to keep the stream around until all frames have been sent and acknowledged.
if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 {
return false
}
// The stream is completed if we sent the FIN.
if s.finSent {
s.completed = true
return true
}
// The stream is also completed if:
// 1. the application called CancelWrite, or
// 2. we received a STOP_SENDING, and
// * the application consumed the error via Write, or
// * the application called CLsoe
if s.cancelWriteErr != nil && (s.cancellationFlagged || s.finishedWriting) {
s.completed = true
return true
}
Expand All @@ -363,25 +393,35 @@ func (s *sendStream) Close() error {
s.mutex.Unlock()
return nil
}
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return fmt.Errorf("close called for canceled stream %d", s.streamID)
}
s.ctxCancel(nil)
s.finishedWriting = true
cancelWriteErr := s.cancelWriteErr
if cancelWriteErr != nil {
s.cancellationFlagged = true
}
completed := s.isNewlyCompleted()
s.mutex.Unlock()

if completed {
s.sender.onStreamCompleted(s.streamID)
}
if cancelWriteErr != nil {
return fmt.Errorf("close called for canceled stream %d", s.streamID)
}
s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex

s.ctxCancel(nil)
return nil
}

func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
s.cancelWriteImpl(errorCode, false)
}

// must be called after locking the mutex
func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) {
s.mutex.Lock()
if !remote {
s.cancellationFlagged = true
}
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return
Expand Down
46 changes: 42 additions & 4 deletions send_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,8 @@ var _ = Describe("Send Stream", func() {
StreamID: streamID,
ErrorCode: 101,
})
mockSender.EXPECT().onStreamCompleted(gomock.Any())

// Don't EXPECT calls to onStreamCompleted.
// The application needs to learn about the cancellation first.
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 101,
Expand All @@ -919,10 +919,10 @@ var _ = Describe("Send Stream", func() {
It("unblocks Write", func() {
mockSender.EXPECT().onHasStreamData(streamID)
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(gomock.Any())
done := make(chan struct{})
go func() {
defer GinkgoRecover()
mockSender.EXPECT().onStreamCompleted(gomock.Any())
_, err := str.Write(getData(5000))
Expect(err).To(Equal(&StreamError{
StreamID: streamID,
Expand All @@ -941,18 +941,56 @@ var _ = Describe("Send Stream", func() {

It("doesn't allow further calls to Write", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
mockSender.EXPECT().onStreamCompleted(gomock.Any())
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 123,
})
mockSender.EXPECT().onStreamCompleted(gomock.Any())
_, err := str.Write([]byte("foobar"))
Expect(err).To(Equal(&StreamError{
StreamID: streamID,
ErrorCode: 123,
Remote: true,
}))
})

It("handles Close after STOP_SENDING", func() {
mockSender.EXPECT().queueControlFrame(gomock.Any())
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 123,
})
mockSender.EXPECT().onStreamCompleted(gomock.Any())
str.Close()
})

It("handles STOP_SENDING after sending the FIN", func() {
mockSender.EXPECT().onHasStreamData(gomock.Any())
str.Close()
_, ok, _ := str.popStreamFrame(protocol.MaxByteCount, protocol.Version1)
Expect(ok).To(BeTrue())
gomock.InOrder(
mockSender.EXPECT().queueControlFrame(gomock.Any()),
mockSender.EXPECT().onStreamCompleted(gomock.Any()),
)
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 123,
})
})

It("handles STOP_SENDING after Close, but before sending the FIN", func() {
mockSender.EXPECT().onHasStreamData(gomock.Any())
str.Close()
gomock.InOrder(
mockSender.EXPECT().queueControlFrame(gomock.Any()),
mockSender.EXPECT().onStreamCompleted(gomock.Any()),
)
str.handleStopSendingFrame(&wire.StopSendingFrame{
StreamID: streamID,
ErrorCode: 123,
})
})
})
})

Expand Down

0 comments on commit bff131e

Please sign in to comment.