Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CacheControl and refactoring response writers #2053

Merged
merged 1 commit into from
Apr 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
53 changes: 49 additions & 4 deletions src/RestSharp/Extensions/HttpResponseExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,60 @@
// limitations under the License.
//

using System.Text;

namespace RestSharp.Extensions;

public static class HttpResponseExtensions {
internal static Exception? MaybeException(this HttpResponseMessage httpResponse)
static class HttpResponseExtensions {
public static Exception? MaybeException(this HttpResponseMessage httpResponse)
=> httpResponse.IsSuccessStatusCode
? null
#if NETSTANDARD || NETFRAMEWORK
#if NET
: new HttpRequestException($"Request failed with status code {httpResponse.StatusCode}", null, httpResponse.StatusCode);
#else
: new HttpRequestException($"Request failed with status code {httpResponse.StatusCode}");
#endif

public static string GetResponseString(this HttpResponseMessage response, byte[] bytes, Encoding clientEncoding) {
var encodingString = response.Content.Headers.ContentType?.CharSet;
var encoding = encodingString != null ? TryGetEncoding(encodingString) : clientEncoding;

using var reader = new StreamReader(new MemoryStream(bytes), encoding);
return reader.ReadToEnd();

Encoding TryGetEncoding(string es) {
try {
return Encoding.GetEncoding(es);
}
catch {
return Encoding.Default;
}
}
}

public static Task<Stream?> ReadResponseStream(
this HttpResponseMessage httpResponse,
Func<Stream, Stream?>? writer,
CancellationToken cancellationToken = default
) {
var readTask = writer == null ? ReadResponse() : ReadAndConvertResponse();
return readTask;

Task<Stream?> ReadResponse() {
#if NET
return httpResponse.Content.ReadAsStreamAsync(cancellationToken)!;
# else
return httpResponse.Content.ReadAsStreamAsync();
#endif
}

async Task<Stream?> ReadAndConvertResponse() {
#if NET
await using var original = await ReadResponse().ConfigureAwait(false);
#else
: new HttpRequestException($"Request failed with status code {httpResponse.StatusCode}", null, httpResponse.StatusCode);
using var original = await ReadResponse().ConfigureAwait(false);
#endif
return writer!(original!);
}
}
}
6 changes: 3 additions & 3 deletions src/RestSharp/Extensions/StreamExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ static class StreamExtensions {
using var ms = new MemoryStream();

int read;
#if NETSTANDARD || NETFRAMEWORK
while ((read = await input.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false)) > 0)
#else
#if NET
while ((read = await input.ReadAsync(buffer, cancellationToken).ConfigureAwait(false)) > 0)
#else
while ((read = await input.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false)) > 0)
#endif
ms.Write(buffer, 0, read);

Expand Down
2 changes: 1 addition & 1 deletion src/RestSharp/Properties/IsExternalInit.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#if NETSTANDARD || NETFRAMEWORK
#if !NET
using System.ComponentModel;

// ReSharper disable once CheckNamespace
Expand Down
14 changes: 10 additions & 4 deletions src/RestSharp/Request/RestRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

using System.Net;
using System.Net.Http.Headers;
using RestSharp.Authenticators;
using RestSharp.Extensions;

Expand All @@ -25,8 +26,8 @@ namespace RestSharp;
/// Container for data used to make requests
/// </summary>
public class RestRequest {
readonly Func<HttpResponseMessage, RestRequest, RestResponse>? _advancedResponseHandler;
readonly Func<Stream, Stream?>? _responseWriter;
Func<HttpResponseMessage, RestRequest, RestResponse>? _advancedResponseHandler;
Func<Stream, Stream?>? _responseWriter;

/// <summary>
/// Default constructor
Expand Down Expand Up @@ -186,12 +187,17 @@ public RestRequest(Uri resource, Method method = Method.Get)
/// </summary>
public HttpCompletionOption CompletionOption { get; set; } = HttpCompletionOption.ResponseContentRead;

/// <summary>
/// Cache policy to be used for requests using <seealso cref="CacheControlHeaderValue"/>
/// </summary>
public CacheControlHeaderValue? CachePolicy { get; set; }

/// <summary>
/// Set this to write response to Stream rather than reading into memory.
/// </summary>
public Func<Stream, Stream?>? ResponseWriter {
get => _responseWriter;
init {
set {
if (AdvancedResponseWriter != null)
throw new ArgumentException(
"AdvancedResponseWriter is not null. Only one response writer can be used."
Expand All @@ -206,7 +212,7 @@ public RestRequest(Uri resource, Method method = Method.Get)
/// </summary>
public Func<HttpResponseMessage, RestRequest, RestResponse>? AdvancedResponseWriter {
get => _advancedResponseHandler;
init {
set {
if (ResponseWriter != null) throw new ArgumentException("ResponseWriter is not null. Only one response writer can be used.");

_advancedResponseHandler = value;
Expand Down
45 changes: 0 additions & 45 deletions src/RestSharp/Response/ResponseHandling.cs

This file was deleted.

20 changes: 4 additions & 16 deletions src/RestSharp/Response/RestResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,13 @@ CancellationToken cancellationToken
return request.AdvancedResponseWriter?.Invoke(httpResponse, request) ?? await GetDefaultResponse().ConfigureAwait(false);

async Task<RestResponse> GetDefaultResponse() {
var readTask = request.ResponseWriter == null ? ReadResponse() : ReadAndConvertResponse();
#if NETSTANDARD || NETFRAMEWORK
using var stream = await readTask.ConfigureAwait(false);
#if NET
await using var stream = await httpResponse.ReadResponseStream(request.ResponseWriter, cancellationToken).ConfigureAwait(false);
#else
await using var stream = await readTask.ConfigureAwait(false);
using var stream = await httpResponse.ReadResponseStream(request.ResponseWriter, cancellationToken).ConfigureAwait(false);
#endif

var bytes = request.ResponseWriter != null || stream == null ? null : await stream.ReadAsBytes(cancellationToken).ConfigureAwait(false);
var bytes = stream == null ? null : await stream.ReadAsBytes(cancellationToken).ConfigureAwait(false);
var content = bytes == null ? null : httpResponse.GetResponseString(bytes, encoding);

return new RestResponse(request) {
Expand All @@ -101,17 +100,6 @@ CancellationToken cancellationToken
Cookies = cookieCollection,
RootElement = request.RootElement
};

Task<Stream?> ReadResponse() => httpResponse.ReadResponse(cancellationToken);

async Task<Stream?> ReadAndConvertResponse() {
#if NETSTANDARD || NETFRAMEWORK
using var original = await ReadResponse().ConfigureAwait(false);
#else
await using var original = await ReadResponse().ConfigureAwait(false);
#endif
return request.ResponseWriter!(original!);
}
}
}

Expand Down
14 changes: 3 additions & 11 deletions src/RestSharp/RestClient.Async.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

using System.Net;
using System.Net.Http.Headers;
using RestSharp.Extensions;

namespace RestSharp;
Expand Down Expand Up @@ -52,16 +53,7 @@ public partial class RestClient {

if (response.ResponseMessage == null) return null;

if (request.ResponseWriter != null) {
#if NETSTANDARD || NETFRAMEWORK
using var stream = await response.ResponseMessage.ReadResponse(cancellationToken).ConfigureAwait(false);
#else
await using var stream = await response.ResponseMessage.ReadResponse(cancellationToken).ConfigureAwait(false);
#endif
return request.ResponseWriter(stream!);
}

return await response.ResponseMessage.ReadResponse(cancellationToken).ConfigureAwait(false);
return await response.ResponseMessage.ReadResponseStream(request.ResponseWriter, cancellationToken).ConfigureAwait(false);
}

static RestResponse GetErrorResponse(RestRequest request, Exception exception, CancellationToken timeoutToken) {
Expand Down Expand Up @@ -95,7 +87,7 @@ public partial class RestClient {
var url = this.BuildUri(request);
var message = new HttpRequestMessage(httpMethod, url) { Content = requestContent.BuildContent() };
message.Headers.Host = Options.BaseHost;
message.Headers.CacheControl = Options.CachePolicy;
message.Headers.CacheControl = request.CachePolicy ?? Options.CachePolicy;

using var timeoutCts = new CancellationTokenSource(request.Timeout > 0 ? request.Timeout : int.MaxValue);
using var cts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken);
Expand Down
6 changes: 3 additions & 3 deletions src/RestSharp/RestClient.Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,10 @@ public static RestResponse Post(this IRestClient client, RestRequest request)
/// <returns>The downloaded file.</returns>
[PublicAPI]
public static async Task<byte[]?> DownloadDataAsync(this IRestClient client, RestRequest request, CancellationToken cancellationToken = default) {
#if NETSTANDARD || NETFRAMEWORK
using var stream = await client.DownloadStreamAsync(request, cancellationToken).ConfigureAwait(false);
#else
#if NET
await using var stream = await client.DownloadStreamAsync(request, cancellationToken).ConfigureAwait(false);
#else
using var stream = await client.DownloadStreamAsync(request, cancellationToken).ConfigureAwait(false);
#endif
return stream == null ? null : await stream.ReadAsBytes(cancellationToken).ConfigureAwait(false);
}
Expand Down
31 changes: 17 additions & 14 deletions test/RestSharp.Tests.Integrated/DownloadFileTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ public sealed class DownloadFileTests : IDisposable {
public async Task AdvancedResponseWriter_without_ResponseWriter_reads_stream() {
var tag = string.Empty;

var rr = new RestRequest("Assets/Koala.jpg") {
AdvancedResponseWriter = (response, request) => {
var buf = new byte[16];
response.Content.ReadAsStream().Read(buf, 0, buf.Length);
tag = Encoding.ASCII.GetString(buf, 6, 4);
return new RestResponse(request);
}
// ReSharper disable once UseObjectOrCollectionInitializer
var rr = new RestRequest("Assets/Koala.jpg");

rr.AdvancedResponseWriter = (response, request) => {
var buf = new byte[16];
// ReSharper disable once MustUseReturnValue
response.Content.ReadAsStream().Read(buf, 0, buf.Length);
tag = Encoding.ASCII.GetString(buf, 6, 4);
return new RestResponse(request);
};

await _client.ExecuteAsync(rr);
Expand All @@ -50,7 +52,7 @@ public sealed class DownloadFileTests : IDisposable {
[Fact]
public async Task Handles_File_Download_Failure() {
var request = new RestRequest("Assets/Koala1.jpg");
var task = () => _client.DownloadDataAsync(request);
var task = () => _client.DownloadDataAsync(request);
await task.Should().ThrowAsync<HttpRequestException>().WithMessage("Request failed with status code NotFound");
}

Expand All @@ -67,13 +69,14 @@ public sealed class DownloadFileTests : IDisposable {
public async Task Writes_Response_To_Stream() {
var tempFile = Path.GetTempFileName();

var request = new RestRequest("Assets/Koala.jpg") {
ResponseWriter = responseStream => {
using var writer = File.OpenWrite(tempFile);
// ReSharper disable once UseObjectOrCollectionInitializer
var request = new RestRequest("Assets/Koala.jpg");

responseStream.CopyTo(writer);
return null;
}
request.ResponseWriter = responseStream => {
using var writer = File.OpenWrite(tempFile);

responseStream.CopyTo(writer);
return null;
};
var response = await _client.DownloadDataAsync(request);

Expand Down
46 changes: 46 additions & 0 deletions test/RestSharp.Tests.Integrated/RedirectTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) .NET Foundation and Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

using System.Net;
using RestSharp.Tests.Integrated.Server;

namespace RestSharp.Tests.Integrated;

[Collection(nameof(TestServerCollection))]
public class RedirectTests {
readonly RestClient _client;

public RedirectTests(TestServerFixture fixture, ITestOutputHelper output) {
var options = new RestClientOptions(fixture.Server.Url) {
FollowRedirects = true
};
_client = new RestClient(options);
}

[Fact]
public async Task Can_Perform_GET_Async_With_Redirect() {
const string val = "Works!";

var request = new RestRequest("redirect");

var response = await _client.ExecuteAsync<Response>(request);
response.StatusCode.Should().Be(HttpStatusCode.OK);
response.Data!.Message.Should().Be(val);
}

class Response {
public string? Message { get; set; }
}
}
1 change: 1 addition & 0 deletions test/RestSharp.Tests.Integrated/Server/TestServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public sealed class HttpServer {
// Cookies
_app.MapGet("get-cookies", CookieHandlers.HandleCookies);
_app.MapGet("set-cookies", CookieHandlers.HandleSetCookies);
_app.MapGet("redirect", () => Results.Redirect("/success", false, true));

// PUT
_app.MapPut(
Expand Down
16 changes: 16 additions & 0 deletions test/RestSharp.Tests/OptionsTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
namespace RestSharp.Tests;

public class OptionsTests {
[Fact]
public void Ensure_follow_redirect() {
var value = false;
var options = new RestClientOptions { FollowRedirects = true, ConfigureMessageHandler = Configure};
var _ = new RestClient(options);
value.Should().BeTrue();

HttpMessageHandler Configure(HttpMessageHandler handler) {
value = ((handler as HttpClientHandler)!).AllowAutoRedirect;
return handler;
}
}
}