Skip to content

Drain HTTP/3 response after trailers #116319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
Expand Down Expand Up @@ -46,6 +47,9 @@ internal sealed class Http3RequestStream : IHttpStreamHeadersHandler, IAsyncDisp
/// <summary>Any trailing headers.</summary>
private List<(HeaderDescriptor name, string value)>? _trailingHeaders;

/// <summary>Response drain task after receiving trailers.</summary>
private Task? _responseDrainTask;

// When reading response content, keep track of the number of bytes left in the current data frame.
private long _responseDataPayloadRemaining;

Expand Down Expand Up @@ -84,54 +88,87 @@ public Http3RequestStream(HttpRequestMessage request, Http3Connection connection

public void Dispose()
{
if (!_disposed)
ValueTask disposeTask = DisposeAsync();

// DisposeAsync() will fire-and-forget the underlying QuicStream.DisposeAsync() Task in most cases.
// Since QuicStream.Dispose() is implemented in a sync-over-async manner, there is no point maintaining
// a separate synchronous disposal path for the cases when the QuicStream disposal needs to be awaited.
if (!disposeTask.IsCompleted)
{
_disposed = true;
AbortStream();
if (_stream.WritesClosed.IsCompleted)
{
_connection.LogExceptions(_stream.DisposeAsync().AsTask());
}
else
{
_stream.Dispose();
}
DisposeSyncHelper();
disposeTask.AsTask().GetAwaiter().GetResult();
}
}

private void RemoveFromConnectionIfDone()
public async ValueTask DisposeAsync()
{
if (_responseRecvCompleted && _requestSendCompleted)
if (_disposed)
{
_connection.RemoveStream(_stream);
return;
}
}

public async ValueTask DisposeAsync()
{
if (!_disposed)
_disposed = true;
ValueTask? disposeTask = default;
bool writesClosed = _stream.WritesClosed.IsCompleted;

if (_responseDrainTask is not null)
{
#pragma warning disable CA2012 // The ValueTask is only consumed once.
disposeTask = WaitForDrainCompletionAndDisposeAsync();
}
else
{
_disposed = true;
AbortStream();
if (_stream.WritesClosed.IsCompleted)
if (writesClosed)
{
_connection.LogExceptions(_stream.DisposeAsync().AsTask());
disposeTask = _stream.DisposeAsync();
#pragma warning restore CA2012
}
else
}

if (writesClosed)
{
Debug.Assert(disposeTask.HasValue);

// The peer has confirmed receipt of the full request, so there's no need to wait for QuicStream disposal -- it's fine to fire-and-forget the task.
if (!disposeTask.Value.IsCompleted && NetEventSource.Log.IsEnabled())
{
await _stream.DisposeAsync().ConfigureAwait(false);
_connection.LogExceptions(disposeTask.Value.AsTask());
}
DisposeSyncHelper();
}
}
else if (disposeTask.HasValue)
{
await disposeTask.Value.ConfigureAwait(false);
}
else
{
await _stream.DisposeAsync().ConfigureAwait(false);
}

private void DisposeSyncHelper()
{
_connection.RemoveStream(_stream);

_sendBuffer.Dispose();
_recvBuffer.Dispose();

if (_responseDrainTask is null || _responseDrainTask.IsCompleted)
{
// If response drain is in progress it might be still using _recvBuffer -- let WaitForDrainCompletionAndDisposeAsync() dispose it.
_recvBuffer.Dispose();
}

async ValueTask WaitForDrainCompletionAndDisposeAsync()
{
Debug.Assert(_responseDrainTask is not null);
await _responseDrainTask.ConfigureAwait(false);
AbortStream();
await _stream.DisposeAsync().ConfigureAwait(false);
_recvBuffer.Dispose();
}
}

private void RemoveFromConnectionIfDone()
{
if (_responseRecvCompleted && _requestSendCompleted)
{
_connection.RemoveStream(_stream);
}
}

public void GoAway()
Expand Down Expand Up @@ -564,10 +601,8 @@ private async ValueTask DrainContentLength0Frames(CancellationToken cancellation
_trailingHeaders = new List<(HeaderDescriptor name, string value)>();
await ReadHeadersAsync(payloadLength, cancellationToken).ConfigureAwait(false);

// Stop looping after a trailing header.
// There may be extra frames after this one, but they would all be unknown extension
// frames that can be safely ignored. Just stop reading here.
// Note: this does leave us open to a bad server sending us an out of order DATA frame.
// We do not expect more DATA frames after the trailers, but we haven't reached EOS yet.
await CheckForEndOfStreamAsync(cancellationToken).ConfigureAwait(false);
goto case null;
case null:
// Done receiving: copy over trailing headers.
Expand Down Expand Up @@ -1335,6 +1370,70 @@ private void HandleReadResponseContentException(Exception ex, CancellationToken
throw new HttpIOException(HttpRequestError.Unknown, SR.net_http_client_execution_error, new HttpRequestException(SR.net_http_client_execution_error, ex));
}

/// <summary>
/// Check for EOS and start draining the response stream if needed. This method is expected to be called after receiving trailers.
/// </summary>
private async ValueTask CheckForEndOfStreamAsync(CancellationToken cancellationToken)
{
// In most cases, we expect to read an EOS at this point.
_recvBuffer.EnsureAvailableSpace(1);
int bytesRead = await _stream.ReadAsync(_recvBuffer.AvailableMemory, cancellationToken).ConfigureAwait(false);
if (bytesRead == 0)
{
return;
}
_recvBuffer.Commit(bytesRead);
_recvBuffer.Discard(bytesRead);

// According to https://datatracker.ietf.org/doc/html/rfc9114#name-http-message-framing the server may send us frames of uknown types after the trailers.
// Start draining the stream without trying to interpret the data. Note: this does leave us open to a bad server sending us out of order frames.
HttpConnectionSettings settings = _connection.Pool.Settings;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could short-circuit this with checking ReadsClosed. Also I'd prefer this to finish synchronously in the most common scenario.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReadsClosed.IsCompleted is never true at the moment we finish reading the trailers. See my #116319 (comment) above.

TimeSpan drainTime = settings._maxResponseDrainTime;
int remaining = settings._maxResponseDrainSize - bytesRead;
Debug.Assert(remaining >= 0);
if (drainTime == TimeSpan.Zero || remaining <= 0)
{
return;
}

_responseDrainTask = DrainResponseAsync();

async Task DrainResponseAsync()
{
using CancellationTokenSource cts = new CancellationTokenSource(settings._maxResponseDrainTime);

try
{
// If there is more data than MaxResponseDrainSize, we will silently stop draining and let Dispose(Async) abort the reads.
while (remaining > 0)
{
_recvBuffer.EnsureAvailableSpace(1);
Memory<byte> buffer = remaining >= _recvBuffer.AvailableMemory.Length ? _recvBuffer.AvailableMemory : _recvBuffer.AvailableMemory.Slice(0, remaining);
int bytesRead = await _stream.ReadAsync(buffer, cts.Token).ConfigureAwait(false);
if (bytesRead == 0)
{
// Reached EOS.
return;
}
remaining -= bytesRead;
_recvBuffer.Commit(bytesRead);
_recvBuffer.Discard(bytesRead);
}
}
catch (Exception ex)
{
// Eat exceptions and stop draining to unblock QuicStream disposal waiting for response drain.
if (NetEventSource.Log.IsEnabled())
{
string message = ex is OperationCanceledException oce && oce.CancellationToken == cts.Token ? "Response drain timed out." : $"Response drain failed with exception: {ex}";
Trace(message);
}

return;
}
}
}

private async ValueTask<bool> ReadNextDataFrameAsync(HttpResponseMessage response, CancellationToken cancellationToken)
{
if (_responseDataPayloadRemaining == -1)
Expand Down Expand Up @@ -1365,11 +1464,8 @@ private async ValueTask<bool> ReadNextDataFrameAsync(HttpResponseMessage respons
_trailingHeaders = new List<(HeaderDescriptor name, string value)>();
await ReadHeadersAsync(payloadLength, cancellationToken).ConfigureAwait(false);

// There may be more frames after this one, but they would all be unknown extension
// frames that we are allowed to skip. Just close the stream early.

// Note: if a server sends additional HEADERS or DATA frames at this point, it
// would be a connection error -- not draining the stream means we won't catch this.
// We do not expect more DATA frames after the trailers, but we haven't reached EOS yet.
await CheckForEndOfStreamAsync(cancellationToken).ConfigureAwait(false);
goto case null;
case null:
// End of stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,143 @@ protected override async Task AcceptConnectionAndSendResponseAsync(
await stream.SendResponseHeadersAsync(statusCode: null, headers: trailers);
stream.Stream.CompleteWrites();
}

// This is a regression test for https://github.com/dotnet/runtime/issues/60118.
[Theory]
[InlineData(false, false, HttpCompletionOption.ResponseContentRead)]
[InlineData(false, true, HttpCompletionOption.ResponseContentRead)]
[InlineData(false, false, HttpCompletionOption.ResponseHeadersRead)]
[InlineData(true, false, HttpCompletionOption.ResponseContentRead)]
[InlineData(true, true, HttpCompletionOption.ResponseContentRead)]
public async Task GetAsync_TrailersWithoutServerStreamClosure_Success(bool sendBytesAfterTrailers, bool emptyResponse, HttpCompletionOption httpCompletionOption)
{
SemaphoreSlim serverCompleted = new SemaphoreSlim(0);

await LoopbackServerFactory.CreateClientAndServerAsync(async uri =>
{
HttpClientHandler handler = CreateHttpClientHandler();

// Avoid drain timeout if CI is slow.
GetUnderlyingSocketsHttpHandler(handler).ResponseDrainTimeout = TimeSpan.FromSeconds(10);
using HttpClient client = CreateHttpClient(handler);

using (HttpResponseMessage response = await client.GetAsync(uri, httpCompletionOption))
{
if (httpCompletionOption == HttpCompletionOption.ResponseHeadersRead)
{
using Stream stream = await response.Content.ReadAsStreamAsync();
byte[] buffer = new byte[512];
// Consume the stream
while ((await stream.ReadAsync(buffer)) > 0) ;
}

Assert.Equal(TrailingHeaders.Count, response.TrailingHeaders.Count());
}

await serverCompleted.WaitAsync();
},
async server =>
{
try
{
await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
_ = await stream.ReadRequestDataAsync();

HttpHeaderData[] headers = emptyResponse ? [new HttpHeaderData("Content-Length", "0")] : null;

await stream.SendResponseHeadersAsync(statusCode: HttpStatusCode.OK, headers);
if (!emptyResponse)
{
await stream.SendResponseBodyAsync(new byte[16384], isFinal: false);
}

await stream.SendResponseHeadersAsync(statusCode: null, headers: TrailingHeaders);
if (sendBytesAfterTrailers)
{
// https://datatracker.ietf.org/doc/html/rfc9114#section-7.2.8
// Frame types of the format 0x1f * N + 0x21 for non - negative integer values of N are reserved to exercise the requirement that unknown types be ignored.
await stream.SendFrameAsync(0x1f * 7 + 0x21, new byte[16384]);
}

// Small delay to make sure we do test if the client is waiting for EOS.
await Task.Delay(15);

await stream.DisposeAsync();
await stream.Stream.WritesClosed;
}
finally
{
serverCompleted.Release();
}
}).WaitAsync(TimeSpan.FromSeconds(30));
}

[Theory]
[InlineData(0)] // MaxResponseDrainSize = 0
[InlineData(1)] // ResponseDrainTimeout = TimeSpan.Zero
public async Task GetAsync_TrailersWithoutServerStreamClosure_ResponseDrainDisabled_ShutsDownClientReads(int drainDisableMode)
{
SemaphoreSlim allDataSent = new SemaphoreSlim(0);
SemaphoreSlim responseConsumed = new SemaphoreSlim(0);
SemaphoreSlim serverCompleted = new SemaphoreSlim(0);

await LoopbackServerFactory.CreateClientAndServerAsync(async uri =>
{
HttpClientHandler handler = CreateHttpClientHandler();

if (drainDisableMode == 0)
{
GetUnderlyingSocketsHttpHandler(handler).MaxResponseDrainSize = 0;
}
else
{
GetUnderlyingSocketsHttpHandler(handler).ResponseDrainTimeout = TimeSpan.Zero;
}

using HttpClient client = CreateHttpClient(handler);
using (HttpResponseMessage response = await client.GetAsync(uri, HttpCompletionOption.ResponseHeadersRead))
{
using Stream stream = await response.Content.ReadAsStreamAsync();
byte[] buffer = new byte[512];
// Consume the stream
while ((await stream.ReadAsync(buffer)) > 0) ;
Assert.Equal(TrailingHeaders.Count, response.TrailingHeaders.Count());

// Defer stream disposal until server finishes sending.
await allDataSent.WaitAsync();
}

responseConsumed.Release();
await serverCompleted.WaitAsync();
},
async server =>
{
try
{
await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
await using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
_ = await stream.ReadRequestDataAsync();
await stream.SendResponseHeadersAsync(statusCode: HttpStatusCode.OK);
await stream.SendResponseBodyAsync(new byte[4096], isFinal: false);
await stream.SendResponseHeadersAsync(statusCode: null, headers: TrailingHeaders);

// https://datatracker.ietf.org/doc/html/rfc9114#section-7.2.8
// Frame types of the format 0x1f * N + 0x21 for non - negative integer values of N are reserved to exercise the requirement that unknown types be ignored.
await stream.SendFrameAsync(0x1f * 7 + 0x21, new byte[1024]);

allDataSent.Release();
await responseConsumed.WaitAsync();

await stream.DisposeAsync();
await Assert.ThrowsAsync<QuicException>(() => stream.Stream.WritesClosed);
}
finally
{
serverCompleted.Release();
}
}).WaitAsync(TimeSpan.FromSeconds(30));
}
}

public sealed class SocketsHttpHandler_HttpClientHandlerTest : HttpClientHandlerTest
Expand Down
Loading