From 65e9e02a98c96e53b8c08d7ee403fbadd9574611 Mon Sep 17 00:00:00 2001 From: Daniel Ottiger Date: Wed, 26 May 2021 13:24:37 +0200 Subject: [PATCH] content provider: add result callback to content provider to learn if the server did successfully complete a client-request or not. --- httplib.h | 21 ++++++++++++++++----- test/test.cc | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/httplib.h b/httplib.h index a8d273a7ff..e8a279a31f 100644 --- a/httplib.h +++ b/httplib.h @@ -446,11 +446,13 @@ struct Response { void set_content_provider( size_t length, const char *content_type, ContentProvider provider, - const std::function &resource_releaser = nullptr); + const std::function &resource_releaser = nullptr, + const std::function& result_callback = nullptr); void set_content_provider( const char *content_type, ContentProviderWithoutLength provider, - const std::function &resource_releaser = nullptr); + const std::function &resource_releaser = nullptr, + const std::function &result_callback = nullptr); void set_chunked_content_provider( const char *content_type, ContentProviderWithoutLength provider, @@ -471,6 +473,7 @@ struct Response { size_t content_length_ = 0; ContentProvider content_provider_; std::function content_provider_resource_releaser_; + std::function content_provider_result_callback_; bool is_chunked_content_provider_ = false; }; @@ -4033,23 +4036,27 @@ inline void Response::set_content(const std::string &s, inline void Response::set_content_provider(size_t in_length, const char *content_type, ContentProvider provider, - const std::function &resource_releaser) { + const std::function &resource_releaser, + const std::function& result_callback) { assert(in_length > 0); set_header("Content-Type", content_type); content_length_ = in_length; content_provider_ = std::move(provider); content_provider_resource_releaser_ = resource_releaser; + content_provider_result_callback_ = result_callback; is_chunked_content_provider_ = false; } inline void Response::set_content_provider(const char *content_type, ContentProviderWithoutLength provider, - const std::function &resource_releaser) { + const std::function &resource_releaser, + const std::function &result_callback) { set_header("Content-Type", content_type); content_length_ = 0; content_provider_ = detail::ContentProviderAdapter(std::move(provider)); content_provider_resource_releaser_ = resource_releaser; + content_provider_result_callback_ = result_callback; is_chunked_content_provider_ = false; } @@ -5232,7 +5239,11 @@ Server::process_request(Stream &strm, bool close_connection, if (routed) { if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; } - return write_response_with_content(strm, close_connection, req, res); + bool success = write_response_with_content(strm, close_connection, req, res); + if (res.content_provider_result_callback_) { + res.content_provider_result_callback_(success); + } + return success; } else { if (res.status == -1) { res.status = 404; } return write_response(strm, close_connection, req, res); diff --git a/test/test.cc b/test/test.cc index 2492de3fb8..41dfc7abdd 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1389,6 +1389,27 @@ class ServerTest : public ::testing::Test { return true; }); }) + .Get("/streamed-result-success", + [&](const Request & /*req*/, Response &res) { + res.set_content_provider( + 6, "text/plain", + [](size_t offset, size_t /*length*/, DataSink &sink) { + sink.os << (offset < 3 ? "a" : "b"); + return true; + }, + []() {}, + [](bool success) { EXPECT_TRUE(success); } + ); + }) + .Get("/streamed-result-failure", + [&](const Request & /*req*/, Response &res) { + res.set_content_provider( + 6, "text/plain", + [](size_t offset, size_t /*length*/, DataSink &sink) { + return false; + }, + []() {}, [](bool success) { EXPECT_FALSE(success); }); + }) .Get("/streamed-with-range", [&](const Request & /*req*/, Response &res) { auto data = new std::string("abcdefg"); @@ -2324,6 +2345,19 @@ TEST_F(ServerTest, GetStreamed) { EXPECT_EQ(std::string("aaabbb"), res->body); } +TEST_F(ServerTest, GetStreamedResultSuccess) { + auto res = cli_.Get("/streamed-result-success"); + ASSERT_TRUE(res); + EXPECT_EQ(200, res->status); + EXPECT_EQ("6", res->get_header_value("Content-Length")); + EXPECT_EQ(std::string("aaabbb"), res->body); +} + +TEST_F(ServerTest, GetStreamedResultFailure) { + auto res = cli_.Get("/streamed-result-failure"); + ASSERT_FALSE(res); +} + TEST_F(ServerTest, GetStreamedWithRange1) { auto res = cli_.Get("/streamed-with-range", {{make_range_header({{3, 5}})}}); ASSERT_TRUE(res);