From 71152b710e3543732464fca57c8f07b7395de68d Mon Sep 17 00:00:00 2001 From: Suresh Kumar Date: Thu, 16 Aug 2018 20:16:09 +0530 Subject: [PATCH] ratelimit: Add ratelimit custom response headers (#4015) - Ability to add custom response headers from ratelimit service/filter - For both (LimitStatus::OK and LimitStatus::OverLimit) custom headers are added if RLS service sends headers - For LimitStatus:OK, we temporarily store the headers and add them to the response (via Filter::encodeHeaders()) *Risk Level*: Low *Testing*: unit and integration tests added. Verified with modified github.com/lyft/ratelimit service. Passes "bazel test //test/..." in Linux Signed-off-by: Suresh Kumar --- api/envoy/service/ratelimit/v2/BUILD | 2 + api/envoy/service/ratelimit/v2/rls.proto | 3 + include/envoy/ratelimit/ratelimit.h | 5 +- source/common/http/header_utility.cc | 14 ++ source/common/http/header_utility.h | 7 + source/common/ratelimit/ratelimit_impl.cc | 14 +- source/common/ratelimit/ratelimit_impl.h | 2 +- .../filters/http/ratelimit/config.cc | 2 +- .../filters/http/ratelimit/ratelimit.cc | 33 ++- .../filters/http/ratelimit/ratelimit.h | 13 +- .../filters/network/ratelimit/ratelimit.cc | 2 +- .../filters/network/ratelimit/ratelimit.h | 4 +- test/common/http/header_utility_test.cc | 16 ++ .../network/filter_manager_impl_test.cc | 2 +- test/common/ratelimit/ratelimit_impl_test.cc | 14 +- .../filters/http/ratelimit/config_test.cc | 6 +- .../filters/http/ratelimit/ratelimit_test.cc | 188 +++++++++++++++++- .../network/ratelimit/ratelimit_test.cc | 10 +- .../integration/ratelimit_integration_test.cc | 71 ++++++- 19 files changed, 370 insertions(+), 38 deletions(-) diff --git a/api/envoy/service/ratelimit/v2/BUILD b/api/envoy/service/ratelimit/v2/BUILD index 4ee72b651888..d0e114ebdbec 100644 --- a/api/envoy/service/ratelimit/v2/BUILD +++ b/api/envoy/service/ratelimit/v2/BUILD @@ -7,6 +7,7 @@ api_proto_library_internal( srcs = ["rls.proto"], has_services = 1, deps = [ + "//envoy/api/v2/core:base", "//envoy/api/v2/core:grpc_service", "//envoy/api/v2/ratelimit", ], @@ -16,6 +17,7 @@ api_go_grpc_library( name = "rls", proto = ":rls", deps = [ + "//envoy/api/v2/core:base_go_proto", "//envoy/api/v2/core:grpc_service_go_proto", "//envoy/api/v2/ratelimit:ratelimit_go_proto", ], diff --git a/api/envoy/service/ratelimit/v2/rls.proto b/api/envoy/service/ratelimit/v2/rls.proto index c1a416fcc842..ebaf54358083 100644 --- a/api/envoy/service/ratelimit/v2/rls.proto +++ b/api/envoy/service/ratelimit/v2/rls.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package envoy.service.ratelimit.v2; option go_package = "v2"; +import "envoy/api/v2/core/base.proto"; import "envoy/api/v2/ratelimit/ratelimit.proto"; import "validate/validate.proto"; @@ -75,4 +76,6 @@ message RateLimitResponse { // in the RateLimitRequest. This can be used by the caller to determine which individual // descriptors failed and/or what the currently configured limits are for all of them. repeated DescriptorStatus statuses = 2; + // A list of headers to add to the response + repeated envoy.api.v2.core.HeaderValue headers = 3; } diff --git a/include/envoy/ratelimit/ratelimit.h b/include/envoy/ratelimit/ratelimit.h index 228aa696dc18..bec66abcb473 100644 --- a/include/envoy/ratelimit/ratelimit.h +++ b/include/envoy/ratelimit/ratelimit.h @@ -33,9 +33,10 @@ class RequestCallbacks { virtual ~RequestCallbacks() {} /** - * Called when a limit request is complete. The resulting status is supplied. + * Called when a limit request is complete. The resulting status and + * response headers are supplied. */ - virtual void complete(LimitStatus status) PURE; + virtual void complete(LimitStatus status, Http::HeaderMapPtr&& headers) PURE; }; /** diff --git a/source/common/http/header_utility.cc b/source/common/http/header_utility.cc index 3618a5a08f48..b579ba10b2bd 100644 --- a/source/common/http/header_utility.cc +++ b/source/common/http/header_utility.cc @@ -2,6 +2,7 @@ #include "common/common/utility.h" #include "common/config/rds_json.h" +#include "common/http/header_map_impl.h" #include "common/protobuf/utility.h" #include "absl/strings/match.h" @@ -115,5 +116,18 @@ bool HeaderUtility::matchHeaders(const Http::HeaderMap& request_headers, return match != header_data.invert_match_; } +void HeaderUtility::addHeaders(Http::HeaderMap& headers, const Http::HeaderMap& headers_to_add) { + headers_to_add.iterate( + [](const Http::HeaderEntry& header, void* context) -> Http::HeaderMap::Iterate { + Http::HeaderString k; + k.setCopy(header.key().c_str(), header.key().size()); + Http::HeaderString v; + v.setCopy(header.value().c_str(), header.value().size()); + static_cast(context)->addViaMove(std::move(k), std::move(v)); + return Http::HeaderMap::Iterate::Continue; + }, + &headers); +} + } // namespace Http } // namespace Envoy diff --git a/source/common/http/header_utility.h b/source/common/http/header_utility.h index d159ac7b1941..ec061ee4fa76 100644 --- a/source/common/http/header_utility.h +++ b/source/common/http/header_utility.h @@ -44,6 +44,13 @@ class HeaderUtility { const std::vector& config_headers); static bool matchHeaders(const Http::HeaderMap& request_headers, const HeaderData& config_header); + + /** + * Add headers from one HeaderMap to another + * @param headers target where headers will be added + * @param headers_to_add supplies the headers to be added + */ + static void addHeaders(Http::HeaderMap& headers, const Http::HeaderMap& headers_to_add); }; } // namespace Http } // namespace Envoy diff --git a/source/common/ratelimit/ratelimit_impl.cc b/source/common/ratelimit/ratelimit_impl.cc index bddfd010c973..9403838bfdfb 100644 --- a/source/common/ratelimit/ratelimit_impl.cc +++ b/source/common/ratelimit/ratelimit_impl.cc @@ -9,6 +9,7 @@ #include "envoy/stats/scope.h" #include "common/common/assert.h" +#include "common/http/header_map_impl.h" #include "common/http/headers.h" namespace Envoy { @@ -67,14 +68,23 @@ void GrpcClientImpl::onSuccess( span.setTag(Constants::get().TraceStatus, Constants::get().TraceOk); } - callbacks_->complete(status); + if (response->headers_size()) { + Http::HeaderMapPtr headers = std::make_unique(); + for (const auto& h : response->headers()) { + headers->addCopy(Http::LowerCaseString(h.key()), h.value()); + } + callbacks_->complete(status, std::move(headers)); + } else { + callbacks_->complete(status, nullptr); + } + callbacks_ = nullptr; } void GrpcClientImpl::onFailure(Grpc::Status::GrpcStatus status, const std::string&, Tracing::Span&) { ASSERT(status != Grpc::Status::GrpcStatus::Ok); - callbacks_->complete(LimitStatus::Error); + callbacks_->complete(LimitStatus::Error, nullptr); callbacks_ = nullptr; } diff --git a/source/common/ratelimit/ratelimit_impl.h b/source/common/ratelimit/ratelimit_impl.h index 867b56db6d68..576b327746f8 100644 --- a/source/common/ratelimit/ratelimit_impl.h +++ b/source/common/ratelimit/ratelimit_impl.h @@ -87,7 +87,7 @@ class NullClientImpl : public Client { void cancel() override {} void limit(RequestCallbacks& callbacks, const std::string&, const std::vector&, Tracing::Span&) override { - callbacks.complete(LimitStatus::OK); + callbacks.complete(LimitStatus::OK, nullptr); } }; diff --git a/source/extensions/filters/http/ratelimit/config.cc b/source/extensions/filters/http/ratelimit/config.cc index 7baadccac5d6..e094fd4c5e3b 100644 --- a/source/extensions/filters/http/ratelimit/config.cc +++ b/source/extensions/filters/http/ratelimit/config.cc @@ -26,7 +26,7 @@ Http::FilterFactoryCb RateLimitFilterConfig::createFilterFactoryFromProtoTyped( const uint32_t timeout_ms = PROTOBUF_GET_MS_OR_DEFAULT(proto_config, timeout, 20); return [filter_config, timeout_ms, &context](Http::FilterChainFactoryCallbacks& callbacks) -> void { - callbacks.addStreamDecoderFilter(std::make_shared( + callbacks.addStreamFilter(std::make_shared( filter_config, context.rateLimitClient(std::chrono::milliseconds(timeout_ms)))); }; } diff --git a/source/extensions/filters/http/ratelimit/ratelimit.cc b/source/extensions/filters/http/ratelimit/ratelimit.cc index d63f0ca6cdd3..956bc56d210a 100644 --- a/source/extensions/filters/http/ratelimit/ratelimit.cc +++ b/source/extensions/filters/http/ratelimit/ratelimit.cc @@ -9,6 +9,7 @@ #include "common/common/enum_to_int.h" #include "common/common/fmt.h" #include "common/http/codes.h" +#include "common/http/header_utility.h" #include "common/router/config_impl.h" namespace Envoy { @@ -87,6 +88,25 @@ void Filter::setDecoderFilterCallbacks(Http::StreamDecoderFilterCallbacks& callb callbacks_ = &callbacks; } +Http::FilterHeadersStatus Filter::encode100ContinueHeaders(Http::HeaderMap&) { + return Http::FilterHeadersStatus::Continue; +} + +Http::FilterHeadersStatus Filter::encodeHeaders(Http::HeaderMap& headers, bool) { + addHeaders(headers); + return Http::FilterHeadersStatus::Continue; +} + +Http::FilterDataStatus Filter::encodeData(Buffer::Instance&, bool) { + return Http::FilterDataStatus::Continue; +} + +Http::FilterTrailersStatus Filter::encodeTrailers(Http::HeaderMap&) { + return Http::FilterTrailersStatus::Continue; +} + +void Filter::setEncoderFilterCallbacks(Http::StreamEncoderFilterCallbacks&) {} + void Filter::onDestroy() { if (state_ == State::Calling) { state_ = State::Complete; @@ -94,8 +114,9 @@ void Filter::onDestroy() { } } -void Filter::complete(RateLimit::LimitStatus status) { +void Filter::complete(RateLimit::LimitStatus status, Http::HeaderMapPtr&& headers) { state_ = State::Complete; + headers_to_add_ = std::move(headers); switch (status) { case RateLimit::LimitStatus::OK: @@ -123,7 +144,8 @@ void Filter::complete(RateLimit::LimitStatus status) { if (status == RateLimit::LimitStatus::OverLimit && config_->runtime().snapshot().featureEnabled("ratelimit.http_filter_enforcing", 100)) { state_ = State::Responded; - callbacks_->sendLocalReply(Http::Code::TooManyRequests, "", nullptr); + callbacks_->sendLocalReply(Http::Code::TooManyRequests, "", + [this](Http::HeaderMap& headers) { addHeaders(headers); }); callbacks_->requestInfo().setResponseFlag(RequestInfo::ResponseFlag::RateLimited); } else if (!initiating_call_) { callbacks_->continueDecoding(); @@ -147,6 +169,13 @@ void Filter::populateRateLimitDescriptors(const Router::RateLimitPolicy& rate_li } } +void Filter::addHeaders(Http::HeaderMap& headers) { + if (headers_to_add_) { + Http::HeaderUtility::addHeaders(headers, *headers_to_add_); + headers_to_add_ = nullptr; + } +} + } // namespace RateLimitFilter } // namespace HttpFilters } // namespace Extensions diff --git a/source/extensions/filters/http/ratelimit/ratelimit.h b/source/extensions/filters/http/ratelimit/ratelimit.h index 001358b70ad8..5d46592fd92f 100644 --- a/source/extensions/filters/http/ratelimit/ratelimit.h +++ b/source/extensions/filters/http/ratelimit/ratelimit.h @@ -74,7 +74,7 @@ typedef std::shared_ptr FilterConfigSharedPtr; * HTTP rate limit filter. Depending on the route configuration, this filter calls the global * rate limiting service before allowing further filter iteration. */ -class Filter : public Http::StreamDecoderFilter, public RateLimit::RequestCallbacks { +class Filter : public Http::StreamFilter, public RateLimit::RequestCallbacks { public: Filter(FilterConfigSharedPtr config, RateLimit::ClientPtr&& client) : config_(config), client_(std::move(client)) {} @@ -88,8 +88,15 @@ class Filter : public Http::StreamDecoderFilter, public RateLimit::RequestCallba Http::FilterTrailersStatus decodeTrailers(Http::HeaderMap& trailers) override; void setDecoderFilterCallbacks(Http::StreamDecoderFilterCallbacks& callbacks) override; + // Http::StreamEncoderFilter + Http::FilterHeadersStatus encode100ContinueHeaders(Http::HeaderMap& headers) override; + Http::FilterHeadersStatus encodeHeaders(Http::HeaderMap& headers, bool end_stream) override; + Http::FilterDataStatus encodeData(Buffer::Instance& data, bool end_stream) override; + Http::FilterTrailersStatus encodeTrailers(Http::HeaderMap& trailers) override; + void setEncoderFilterCallbacks(Http::StreamEncoderFilterCallbacks& callbacks) override; + // RateLimit::RequestCallbacks - void complete(RateLimit::LimitStatus status) override; + void complete(RateLimit::LimitStatus status, Http::HeaderMapPtr&& headers) override; private: void initiateCall(const Http::HeaderMap& headers); @@ -97,6 +104,7 @@ class Filter : public Http::StreamDecoderFilter, public RateLimit::RequestCallba std::vector& descriptors, const Router::RouteEntry* route_entry, const Http::HeaderMap& headers) const; + void addHeaders(Http::HeaderMap& headers); enum class State { NotStarted, Calling, Complete, Responded }; @@ -106,6 +114,7 @@ class Filter : public Http::StreamDecoderFilter, public RateLimit::RequestCallba State state_{State::NotStarted}; Upstream::ClusterInfoConstSharedPtr cluster_; bool initiating_call_{}; + Http::HeaderMapPtr headers_to_add_; }; } // namespace RateLimitFilter diff --git a/source/extensions/filters/network/ratelimit/ratelimit.cc b/source/extensions/filters/network/ratelimit/ratelimit.cc index 9ed1c54ddaef..dc33acdff773 100644 --- a/source/extensions/filters/network/ratelimit/ratelimit.cc +++ b/source/extensions/filters/network/ratelimit/ratelimit.cc @@ -69,7 +69,7 @@ void Filter::onEvent(Network::ConnectionEvent event) { } } -void Filter::complete(RateLimit::LimitStatus status) { +void Filter::complete(RateLimit::LimitStatus status, Http::HeaderMapPtr&&) { status_ = Status::Complete; config_->stats().active_.dec(); diff --git a/source/extensions/filters/network/ratelimit/ratelimit.h b/source/extensions/filters/network/ratelimit/ratelimit.h index 3e186b6da1ae..9ad7718a8185 100644 --- a/source/extensions/filters/network/ratelimit/ratelimit.h +++ b/source/extensions/filters/network/ratelimit/ratelimit.h @@ -88,7 +88,7 @@ class Filter : public Network::ReadFilter, void onBelowWriteBufferLowWatermark() override {} // RateLimit::RequestCallbacks - void complete(RateLimit::LimitStatus status) override; + void complete(RateLimit::LimitStatus status, Http::HeaderMapPtr&& headers) override; private: enum class Status { NotStarted, Calling, Complete }; @@ -99,7 +99,7 @@ class Filter : public Network::ReadFilter, Status status_{Status::NotStarted}; bool calling_limit_{}; }; -} +} // namespace RateLimitFilter } // namespace NetworkFilters } // namespace Extensions } // namespace Envoy diff --git a/test/common/http/header_utility_test.cc b/test/common/http/header_utility_test.cc index b94280d2d41a..0d8c82d13424 100644 --- a/test/common/http/header_utility_test.cc +++ b/test/common/http/header_utility_test.cc @@ -403,5 +403,21 @@ invert_match: true EXPECT_FALSE(HeaderUtility::matchHeaders(unmatching_headers, header_data)); } +TEST(HeaderAddTest, HeaderAdd) { + TestHeaderMapImpl headers{{"myheader1", "123value"}}; + TestHeaderMapImpl headers_to_add{{"myheader2", "456value"}}; + + HeaderUtility::addHeaders(headers, headers_to_add); + + headers_to_add.iterate( + [](const Http::HeaderEntry& entry, void* context) -> Http::HeaderMap::Iterate { + TestHeaderMapImpl* headers = static_cast(context); + Http::LowerCaseString lower_key{entry.key().c_str()}; + EXPECT_STREQ(entry.value().c_str(), headers->get(lower_key)->value().c_str()); + return Http::HeaderMap::Iterate::Continue; + }, + &headers); +} + } // namespace Http } // namespace Envoy diff --git a/test/common/network/filter_manager_impl_test.cc b/test/common/network/filter_manager_impl_test.cc index 9bcb366c4c42..7da939a2413d 100644 --- a/test/common/network/filter_manager_impl_test.cc +++ b/test/common/network/filter_manager_impl_test.cc @@ -206,7 +206,7 @@ TEST_F(NetworkFilterManagerTest, RateLimitAndTcpProxy) { EXPECT_CALL(factory_context.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) .WillOnce(Return(&conn_pool)); - request_callbacks->complete(RateLimit::LimitStatus::OK); + request_callbacks->complete(RateLimit::LimitStatus::OK, nullptr); conn_pool.poolReady(upstream_connection); diff --git a/test/common/ratelimit/ratelimit_impl_test.cc b/test/common/ratelimit/ratelimit_impl_test.cc index 83fed1adba9a..02e09fb4fccc 100644 --- a/test/common/ratelimit/ratelimit_impl_test.cc +++ b/test/common/ratelimit/ratelimit_impl_test.cc @@ -29,7 +29,11 @@ namespace RateLimit { class MockRequestCallbacks : public RequestCallbacks { public: - MOCK_METHOD1(complete, void(LimitStatus status)); + void complete(LimitStatus status, Http::HeaderMapPtr&& headers) { + complete_(status, headers.get()); + } + + MOCK_METHOD2(complete_, void(LimitStatus status, const Http::HeaderMap* headers)); }; // TODO(junr03): legacy rate limit is deprecated. Remove the boolean parameter after 1.8.0. @@ -91,7 +95,7 @@ TEST_P(RateLimitGrpcClientTest, Basic) { response.reset(new envoy::service::ratelimit::v2::RateLimitResponse()); response->set_overall_code(envoy::service::ratelimit::v2::RateLimitResponse_Code_OVER_LIMIT); EXPECT_CALL(span_, setTag("ratelimit_status", "over_limit")); - EXPECT_CALL(request_callbacks_, complete(LimitStatus::OverLimit)); + EXPECT_CALL(request_callbacks_, complete_(LimitStatus::OverLimit, _)); client_->onSuccess(std::move(response), span_); } @@ -110,7 +114,7 @@ TEST_P(RateLimitGrpcClientTest, Basic) { response.reset(new envoy::service::ratelimit::v2::RateLimitResponse()); response->set_overall_code(envoy::service::ratelimit::v2::RateLimitResponse_Code_OK); EXPECT_CALL(span_, setTag("ratelimit_status", "ok")); - EXPECT_CALL(request_callbacks_, complete(LimitStatus::OK)); + EXPECT_CALL(request_callbacks_, complete_(LimitStatus::OK, _)); client_->onSuccess(std::move(response), span_); } @@ -127,7 +131,7 @@ TEST_P(RateLimitGrpcClientTest, Basic) { Tracing::NullSpan::instance()); response.reset(new envoy::service::ratelimit::v2::RateLimitResponse()); - EXPECT_CALL(request_callbacks_, complete(LimitStatus::Error)); + EXPECT_CALL(request_callbacks_, complete_(LimitStatus::Error, _)); client_->onFailure(Grpc::Status::Unknown, "", span_); } } @@ -179,7 +183,7 @@ TEST(RateLimitNullFactoryTest, Basic) { NullFactoryImpl factory; ClientPtr client = factory.create(absl::optional()); MockRequestCallbacks request_callbacks; - EXPECT_CALL(request_callbacks, complete(LimitStatus::OK)); + EXPECT_CALL(request_callbacks, complete_(LimitStatus::OK, _)); client->limit(request_callbacks, "foo", {{{{"foo", "bar"}}}}, Tracing::NullSpan::instance()); client->cancel(); } diff --git a/test/extensions/filters/http/ratelimit/config_test.cc b/test/extensions/filters/http/ratelimit/config_test.cc index 6da40aab588a..21237385e2a9 100644 --- a/test/extensions/filters/http/ratelimit/config_test.cc +++ b/test/extensions/filters/http/ratelimit/config_test.cc @@ -36,7 +36,7 @@ TEST(RateLimitFilterConfigTest, RateLimitFilterCorrectJson) { RateLimitFilterConfig factory; Http::FilterFactoryCb cb = factory.createFilterFactory(*json_config, "stats", context); Http::MockFilterChainFactoryCallbacks filter_callback; - EXPECT_CALL(filter_callback, addStreamDecoderFilter(_)); + EXPECT_CALL(filter_callback, addStreamFilter(_)); cb(filter_callback); } @@ -56,7 +56,7 @@ TEST(RateLimitFilterConfigTest, RateLimitFilterCorrectProto) { RateLimitFilterConfig factory; Http::FilterFactoryCb cb = factory.createFilterFactoryFromProto(proto_config, "stats", context); Http::MockFilterChainFactoryCallbacks filter_callback; - EXPECT_CALL(filter_callback, addStreamDecoderFilter(_)); + EXPECT_CALL(filter_callback, addStreamFilter(_)); cb(filter_callback); } @@ -79,7 +79,7 @@ TEST(RateLimitFilterConfigTest, RateLimitFilterEmptyProto) { Http::FilterFactoryCb cb = factory.createFilterFactory(*json_config, "stats", context); Http::MockFilterChainFactoryCallbacks filter_callback; - EXPECT_CALL(filter_callback, addStreamDecoderFilter(_)); + EXPECT_CALL(filter_callback, addStreamFilter(_)); cb(filter_callback); } diff --git a/test/extensions/filters/http/ratelimit/ratelimit_test.cc b/test/extensions/filters/http/ratelimit/ratelimit_test.cc index cbcec14f2738..ba3ffd6e16f4 100644 --- a/test/extensions/filters/http/ratelimit/ratelimit_test.cc +++ b/test/extensions/filters/http/ratelimit/ratelimit_test.cc @@ -76,7 +76,9 @@ class HttpRateLimitFilterTest : public testing::Test { NiceMock filter_callbacks_; RateLimit::RequestCallbacks* request_callbacks_{}; Http::TestHeaderMapImpl request_headers_; + Http::TestHeaderMapImpl response_headers_; Buffer::OwnedImpl data_; + Buffer::OwnedImpl response_data_; Stats::IsolatedStoreImpl stats_store_; NiceMock runtime_; NiceMock cm_; @@ -108,6 +110,11 @@ TEST_F(HttpRateLimitFilterTest, NoRoute) { EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers_, false)); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); } TEST_F(HttpRateLimitFilterTest, NoCluster) { @@ -118,6 +125,11 @@ TEST_F(HttpRateLimitFilterTest, NoCluster) { EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers_, false)); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); } TEST_F(HttpRateLimitFilterTest, NoApplicableRateLimit) { @@ -128,6 +140,11 @@ TEST_F(HttpRateLimitFilterTest, NoApplicableRateLimit) { EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers_, false)); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); } TEST_F(HttpRateLimitFilterTest, NoDescriptor) { @@ -139,6 +156,11 @@ TEST_F(HttpRateLimitFilterTest, NoDescriptor) { EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers_, false)); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); } TEST_F(HttpRateLimitFilterTest, RuntimeDisabled) { @@ -149,6 +171,11 @@ TEST_F(HttpRateLimitFilterTest, RuntimeDisabled) { EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers_, false)); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); } TEST_F(HttpRateLimitFilterTest, OkResponse) { @@ -178,12 +205,69 @@ TEST_F(HttpRateLimitFilterTest, OkResponse) { filter_->decodeHeaders(request_headers_, false)); EXPECT_EQ(Http::FilterDataStatus::StopIterationAndWatermark, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::StopIteration, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); EXPECT_CALL(filter_callbacks_, continueDecoding()); EXPECT_CALL(filter_callbacks_.request_info_, setResponseFlag(RequestInfo::ResponseFlag::RateLimited)) .Times(0); - request_callbacks_->complete(RateLimit::LimitStatus::OK); + request_callbacks_->complete(RateLimit::LimitStatus::OK, nullptr); + + EXPECT_EQ(1U, + cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ratelimit.ok").value()); +} + +TEST_F(HttpRateLimitFilterTest, OkResponseWithHeaders) { + SetUpTest(filter_config_); + InSequence s; + + EXPECT_CALL(filter_callbacks_.route_->route_entry_.rate_limit_policy_, getApplicableRateLimit(0)) + .Times(1); + + EXPECT_CALL(route_rate_limit_, populateDescriptors(_, _, _, _, _)) + .WillOnce(SetArgReferee<1>(descriptor_)); + + EXPECT_CALL(filter_callbacks_.route_->route_entry_.virtual_host_.rate_limit_policy_, + getApplicableRateLimit(0)) + .Times(1); + + EXPECT_CALL(*client_, limit(_, "foo", + testing::ContainerEq(std::vector{ + {{{"descriptor_key", "descriptor_value"}}}}), + _)) + .WillOnce(WithArgs<0>(Invoke([&](RateLimit::RequestCallbacks& callbacks) -> void { + request_callbacks_ = &callbacks; + }))); + + request_headers_.addCopy(Http::Headers::get().RequestId, "requestid"); + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, + filter_->decodeHeaders(request_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::StopIterationAndWatermark, filter_->decodeData(data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::StopIteration, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); + + EXPECT_CALL(filter_callbacks_, continueDecoding()); + EXPECT_CALL(filter_callbacks_.request_info_, + setResponseFlag(RequestInfo::ResponseFlag::RateLimited)) + .Times(0); + + Http::HeaderMapPtr rl_headers{ + new Http::TestHeaderMapImpl{{"x-ratelimit-limit", "1000"}, {"x-ratelimit-remaining", "500"}}}; + request_callbacks_->complete(RateLimit::LimitStatus::OK, + Http::HeaderMapPtr{new Http::TestHeaderMapImpl(*rl_headers)}); + + Http::TestHeaderMapImpl expected_headers(*rl_headers); + Http::TestHeaderMapImpl response_headers; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers, false)); + EXPECT_EQ(true, (expected_headers == response_headers)); EXPECT_EQ(1U, cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ratelimit.ok").value()); @@ -201,13 +285,18 @@ TEST_F(HttpRateLimitFilterTest, ImmediateOkResponse) { {{{"descriptor_key", "descriptor_value"}}}}), _)) .WillOnce(WithArgs<0>(Invoke([&](RateLimit::RequestCallbacks& callbacks) -> void { - callbacks.complete(RateLimit::LimitStatus::OK); + callbacks.complete(RateLimit::LimitStatus::OK, nullptr); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers_, false)); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); EXPECT_EQ(1U, cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ratelimit.ok").value()); @@ -228,7 +317,7 @@ TEST_F(HttpRateLimitFilterTest, ErrorResponse) { filter_->decodeHeaders(request_headers_, false)); EXPECT_CALL(filter_callbacks_, continueDecoding()); - request_callbacks_->complete(RateLimit::LimitStatus::Error); + request_callbacks_->complete(RateLimit::LimitStatus::Error, nullptr); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_)); @@ -261,7 +350,50 @@ TEST_F(HttpRateLimitFilterTest, LimitResponse) { EXPECT_CALL(filter_callbacks_.request_info_, setResponseFlag(RequestInfo::ResponseFlag::RateLimited)); - request_callbacks_->complete(RateLimit::LimitStatus::OverLimit); + request_callbacks_->complete(RateLimit::LimitStatus::OverLimit, nullptr); + + EXPECT_EQ(1U, + cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ratelimit.over_limit") + .value()); + EXPECT_EQ( + 1U, + cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("upstream_rq_4xx").value()); + EXPECT_EQ( + 1U, + cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("upstream_rq_429").value()); +} + +TEST_F(HttpRateLimitFilterTest, LimitResponseWithHeaders) { + SetUpTest(filter_config_); + InSequence s; + + EXPECT_CALL(route_rate_limit_, populateDescriptors(_, _, _, _, _)) + .WillOnce(SetArgReferee<1>(descriptor_)); + EXPECT_CALL(*client_, limit(_, _, _, _)) + .WillOnce(WithArgs<0>(Invoke([&](RateLimit::RequestCallbacks& callbacks) -> void { + request_callbacks_ = &callbacks; + }))); + + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, + filter_->decodeHeaders(request_headers_, false)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); + + Http::HeaderMapPtr rl_headers{new Http::TestHeaderMapImpl{ + {"x-ratelimit-limit", "1000"}, {"x-ratelimit-remaining", "0"}, {"retry-after", "33"}}}; + Http::TestHeaderMapImpl expected_headers(*rl_headers); + expected_headers.addCopy(":status", "429"); + + EXPECT_CALL(filter_callbacks_, encodeHeaders_(HeaderMapEqualRef(&expected_headers), true)); + EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); + EXPECT_CALL(filter_callbacks_.request_info_, + setResponseFlag(RequestInfo::ResponseFlag::RateLimited)); + + Http::HeaderMapPtr h{new Http::TestHeaderMapImpl(*rl_headers)}; + request_callbacks_->complete(RateLimit::LimitStatus::OverLimit, std::move(h)); EXPECT_EQ(1U, cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ratelimit.over_limit") @@ -291,10 +423,15 @@ TEST_F(HttpRateLimitFilterTest, LimitResponseRuntimeDisabled) { EXPECT_CALL(runtime_.snapshot_, featureEnabled("ratelimit.http_filter_enforcing", 100)) .WillOnce(Return(false)); EXPECT_CALL(filter_callbacks_, continueDecoding()); - request_callbacks_->complete(RateLimit::LimitStatus::OverLimit); + request_callbacks_->complete(RateLimit::LimitStatus::OverLimit, nullptr); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); EXPECT_EQ(1U, cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ratelimit.over_limit") @@ -338,6 +475,11 @@ TEST_F(HttpRateLimitFilterTest, RouteRateLimitDisabledForRouteKey) { EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers_, false)); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); } TEST_F(HttpRateLimitFilterTest, VirtualHostRateLimitDisabledForRouteKey) { @@ -353,6 +495,11 @@ TEST_F(HttpRateLimitFilterTest, VirtualHostRateLimitDisabledForRouteKey) { EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers_, false)); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); } TEST_F(HttpRateLimitFilterTest, IncorrectRequestType) { @@ -370,6 +517,11 @@ TEST_F(HttpRateLimitFilterTest, IncorrectRequestType) { EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers_, false)); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); std::string external_filter_config = R"EOF( { @@ -386,6 +538,11 @@ TEST_F(HttpRateLimitFilterTest, IncorrectRequestType) { EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, false)); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); } TEST_F(HttpRateLimitFilterTest, InternalRequestType) { @@ -414,13 +571,18 @@ TEST_F(HttpRateLimitFilterTest, InternalRequestType) { {{{"descriptor_key", "descriptor_value"}}}}), _)) .WillOnce(WithArgs<0>(Invoke([&](RateLimit::RequestCallbacks& callbacks) -> void { - callbacks.complete(RateLimit::LimitStatus::OK); + callbacks.complete(RateLimit::LimitStatus::OK, nullptr); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, false)); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); EXPECT_EQ(1U, cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ratelimit.ok").value()); @@ -453,13 +615,18 @@ TEST_F(HttpRateLimitFilterTest, ExternalRequestType) { {{{"descriptor_key", "descriptor_value"}}}}), _)) .WillOnce(WithArgs<0>(Invoke([&](RateLimit::RequestCallbacks& callbacks) -> void { - callbacks.complete(RateLimit::LimitStatus::OK); + callbacks.complete(RateLimit::LimitStatus::OK, nullptr); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, false)); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); EXPECT_EQ(1U, cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ratelimit.ok").value()); @@ -489,13 +656,18 @@ TEST_F(HttpRateLimitFilterTest, ExcludeVirtualHost) { {{{"descriptor_key", "descriptor_value"}}}}), _)) .WillOnce(WithArgs<0>(Invoke([&](RateLimit::RequestCallbacks& callbacks) -> void { - callbacks.complete(RateLimit::LimitStatus::OK); + callbacks.complete(RateLimit::LimitStatus::OK, nullptr); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers_, false)); EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_->encode100ContinueHeaders(response_headers_)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(response_data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->encodeTrailers(response_headers_)); EXPECT_EQ(1U, cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ratelimit.ok").value()); diff --git a/test/extensions/filters/network/ratelimit/ratelimit_test.cc b/test/extensions/filters/network/ratelimit/ratelimit_test.cc index ceb163f0ccff..dfd3a306b8b2 100644 --- a/test/extensions/filters/network/ratelimit/ratelimit_test.cc +++ b/test/extensions/filters/network/ratelimit/ratelimit_test.cc @@ -112,7 +112,7 @@ TEST_F(RateLimitFilterTest, OK) { EXPECT_EQ(Network::FilterStatus::StopIteration, filter_->onData(data, false)); EXPECT_CALL(filter_callbacks_, continueReading()); - request_callbacks_->complete(RateLimit::LimitStatus::OK); + request_callbacks_->complete(RateLimit::LimitStatus::OK, nullptr); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onData(data, false)); @@ -137,7 +137,7 @@ TEST_F(RateLimitFilterTest, OverLimit) { EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); EXPECT_CALL(*client_, cancel()).Times(0); - request_callbacks_->complete(RateLimit::LimitStatus::OverLimit); + request_callbacks_->complete(RateLimit::LimitStatus::OverLimit, nullptr); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onData(data, false)); @@ -163,7 +163,7 @@ TEST_F(RateLimitFilterTest, OverLimitNotEnforcing) { EXPECT_CALL(filter_callbacks_.connection_, close(_)).Times(0); EXPECT_CALL(*client_, cancel()).Times(0); EXPECT_CALL(filter_callbacks_, continueReading()); - request_callbacks_->complete(RateLimit::LimitStatus::OverLimit); + request_callbacks_->complete(RateLimit::LimitStatus::OverLimit, nullptr); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onData(data, false)); @@ -185,7 +185,7 @@ TEST_F(RateLimitFilterTest, Error) { EXPECT_EQ(Network::FilterStatus::StopIteration, filter_->onData(data, false)); EXPECT_CALL(filter_callbacks_, continueReading()); - request_callbacks_->complete(RateLimit::LimitStatus::Error); + request_callbacks_->complete(RateLimit::LimitStatus::Error, nullptr); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onData(data, false)); @@ -220,7 +220,7 @@ TEST_F(RateLimitFilterTest, ImmediateOK) { EXPECT_CALL(filter_callbacks_, continueReading()).Times(0); EXPECT_CALL(*client_, limit(_, "foo", _, _)) .WillOnce(WithArgs<0>(Invoke([&](RateLimit::RequestCallbacks& callbacks) -> void { - callbacks.complete(RateLimit::LimitStatus::OK); + callbacks.complete(RateLimit::LimitStatus::OK, nullptr); }))); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onNewConnection()); diff --git a/test/integration/ratelimit_integration_test.cc b/test/integration/ratelimit_integration_test.cc index c45084e716fb..768510c3575e 100644 --- a/test/integration/ratelimit_integration_test.cc +++ b/test/integration/ratelimit_integration_test.cc @@ -130,10 +130,23 @@ class RatelimitIntegrationTest : public HttpIntegrationTest, response_->headers().Status()->value().c_str()); } - void sendRateLimitResponse(envoy::service::ratelimit::v2::RateLimitResponse_Code code) { + void sendRateLimitResponse(envoy::service::ratelimit::v2::RateLimitResponse_Code code, + const Http::HeaderMapImpl& headers) { ratelimit_request_->startGrpcStream(); envoy::service::ratelimit::v2::RateLimitResponse response_msg; response_msg.set_overall_code(code); + + headers.iterate( + [](const Http::HeaderEntry& h, void* context) -> Http::HeaderMap::Iterate { + auto header = static_cast(context) + ->mutable_headers() + ->Add(); + header->set_key(h.key().c_str()); + header->set_value(h.value().c_str()); + return Http::HeaderMap::Iterate::Continue; + }, + &response_msg); + ratelimit_request_->sendGrpcMessage(response_msg); ratelimit_request_->finishGrpcStream(Grpc::Status::Ok); } @@ -165,7 +178,8 @@ INSTANTIATE_TEST_CASE_P(IpVersionsClientType, RatelimitIntegrationTest, TEST_P(RatelimitIntegrationTest, Ok) { initiateClientConnection(); waitForRatelimitRequest(); - sendRateLimitResponse(envoy::service::ratelimit::v2::RateLimitResponse_Code_OK); + sendRateLimitResponse(envoy::service::ratelimit::v2::RateLimitResponse_Code_OK, + Http::HeaderMapImpl{}); waitForSuccessfulUpstreamResponse(); cleanup(); @@ -174,10 +188,36 @@ TEST_P(RatelimitIntegrationTest, Ok) { EXPECT_EQ(nullptr, test_server_->counter("cluster.cluster_0.ratelimit.error")); } +TEST_P(RatelimitIntegrationTest, OkWithHeaders) { + initiateClientConnection(); + waitForRatelimitRequest(); + Http::TestHeaderMapImpl ratelimit_headers{{"x-ratelimit-limit", "1000"}, + {"x-ratelimit-remaining", "500"}}; + sendRateLimitResponse(envoy::service::ratelimit::v2::RateLimitResponse_Code_OK, + ratelimit_headers); + waitForSuccessfulUpstreamResponse(); + + ratelimit_headers.iterate( + [](const Http::HeaderEntry& entry, void* context) -> Http::HeaderMap::Iterate { + IntegrationStreamDecoder* response = static_cast(context); + Http::LowerCaseString lower_key{entry.key().c_str()}; + EXPECT_STREQ(entry.value().c_str(), response->headers().get(lower_key)->value().c_str()); + return Http::HeaderMap::Iterate::Continue; + }, + response_.get()); + + cleanup(); + + EXPECT_EQ(1, test_server_->counter("cluster.cluster_0.ratelimit.ok")->value()); + EXPECT_EQ(nullptr, test_server_->counter("cluster.cluster_0.ratelimit.over_limit")); + EXPECT_EQ(nullptr, test_server_->counter("cluster.cluster_0.ratelimit.error")); +} + TEST_P(RatelimitIntegrationTest, OverLimit) { initiateClientConnection(); waitForRatelimitRequest(); - sendRateLimitResponse(envoy::service::ratelimit::v2::RateLimitResponse_Code_OVER_LIMIT); + sendRateLimitResponse(envoy::service::ratelimit::v2::RateLimitResponse_Code_OVER_LIMIT, + Http::HeaderMapImpl{}); waitForFailedUpstreamResponse(429); cleanup(); @@ -186,6 +226,31 @@ TEST_P(RatelimitIntegrationTest, OverLimit) { EXPECT_EQ(nullptr, test_server_->counter("cluster.cluster_0.ratelimit.error")); } +TEST_P(RatelimitIntegrationTest, OverLimitWithHeaders) { + initiateClientConnection(); + waitForRatelimitRequest(); + Http::TestHeaderMapImpl ratelimit_headers{ + {"x-ratelimit-limit", "1000"}, {"x-ratelimit-remaining", "0"}, {"retry-after", "33"}}; + sendRateLimitResponse(envoy::service::ratelimit::v2::RateLimitResponse_Code_OVER_LIMIT, + ratelimit_headers); + waitForFailedUpstreamResponse(429); + + ratelimit_headers.iterate( + [](const Http::HeaderEntry& entry, void* context) -> Http::HeaderMap::Iterate { + IntegrationStreamDecoder* response = static_cast(context); + Http::LowerCaseString lower_key{entry.key().c_str()}; + EXPECT_STREQ(entry.value().c_str(), response->headers().get(lower_key)->value().c_str()); + return Http::HeaderMap::Iterate::Continue; + }, + response_.get()); + + cleanup(); + + EXPECT_EQ(nullptr, test_server_->counter("cluster.cluster_0.ratelimit.ok")); + EXPECT_EQ(1, test_server_->counter("cluster.cluster_0.ratelimit.over_limit")->value()); + EXPECT_EQ(nullptr, test_server_->counter("cluster.cluster_0.ratelimit.error")); +} + TEST_P(RatelimitIntegrationTest, Error) { initiateClientConnection(); waitForRatelimitRequest();