Skip to content

Commit

Permalink
ratelimit: Add ratelimit custom response headers (envoyproxy#4015)
Browse files Browse the repository at this point in the history
  - Ability to add custom response headers from ratelimit
    service/filter
  - For both (LimitStatus::OK and LimitStatus::OverLimit) custom
    headers are added if RLS service sends headers
  - For LimitStatus:OK, we temporarily store the headers and add
    them to the response (via Filter::encodeHeaders())

*Risk Level*: Low

*Testing*: unit and integration tests added. Verified with modified
 github.com/lyft/ratelimit service. Passes "bazel test //test/..." in
 Linux

Signed-off-by: Suresh Kumar <suresh@freshdesk.com>
  • Loading branch information
surki authored and Matt Klein committed Aug 16, 2018
1 parent 3062874 commit 71152b7
Show file tree
Hide file tree
Showing 19 changed files with 370 additions and 38 deletions.
2 changes: 2 additions & 0 deletions api/envoy/service/ratelimit/v2/BUILD
Expand Up @@ -7,6 +7,7 @@ api_proto_library_internal(
srcs = ["rls.proto"],
has_services = 1,
deps = [
"//envoy/api/v2/core:base",
"//envoy/api/v2/core:grpc_service",
"//envoy/api/v2/ratelimit",
],
Expand All @@ -16,6 +17,7 @@ api_go_grpc_library(
name = "rls",
proto = ":rls",
deps = [
"//envoy/api/v2/core:base_go_proto",
"//envoy/api/v2/core:grpc_service_go_proto",
"//envoy/api/v2/ratelimit:ratelimit_go_proto",
],
Expand Down
3 changes: 3 additions & 0 deletions api/envoy/service/ratelimit/v2/rls.proto
Expand Up @@ -3,6 +3,7 @@ syntax = "proto3";
package envoy.service.ratelimit.v2;
option go_package = "v2";

import "envoy/api/v2/core/base.proto";
import "envoy/api/v2/ratelimit/ratelimit.proto";

import "validate/validate.proto";
Expand Down Expand Up @@ -75,4 +76,6 @@ message RateLimitResponse {
// in the RateLimitRequest. This can be used by the caller to determine which individual
// descriptors failed and/or what the currently configured limits are for all of them.
repeated DescriptorStatus statuses = 2;
// A list of headers to add to the response
repeated envoy.api.v2.core.HeaderValue headers = 3;
}
5 changes: 3 additions & 2 deletions include/envoy/ratelimit/ratelimit.h
Expand Up @@ -33,9 +33,10 @@ class RequestCallbacks {
virtual ~RequestCallbacks() {}

/**
* Called when a limit request is complete. The resulting status is supplied.
* Called when a limit request is complete. The resulting status and
* response headers are supplied.
*/
virtual void complete(LimitStatus status) PURE;
virtual void complete(LimitStatus status, Http::HeaderMapPtr&& headers) PURE;
};

/**
Expand Down
14 changes: 14 additions & 0 deletions source/common/http/header_utility.cc
Expand Up @@ -2,6 +2,7 @@

#include "common/common/utility.h"
#include "common/config/rds_json.h"
#include "common/http/header_map_impl.h"
#include "common/protobuf/utility.h"

#include "absl/strings/match.h"
Expand Down Expand Up @@ -115,5 +116,18 @@ bool HeaderUtility::matchHeaders(const Http::HeaderMap& request_headers,
return match != header_data.invert_match_;
}

void HeaderUtility::addHeaders(Http::HeaderMap& headers, const Http::HeaderMap& headers_to_add) {
headers_to_add.iterate(
[](const Http::HeaderEntry& header, void* context) -> Http::HeaderMap::Iterate {
Http::HeaderString k;
k.setCopy(header.key().c_str(), header.key().size());
Http::HeaderString v;
v.setCopy(header.value().c_str(), header.value().size());
static_cast<Http::HeaderMapImpl*>(context)->addViaMove(std::move(k), std::move(v));
return Http::HeaderMap::Iterate::Continue;
},
&headers);
}

} // namespace Http
} // namespace Envoy
7 changes: 7 additions & 0 deletions source/common/http/header_utility.h
Expand Up @@ -44,6 +44,13 @@ class HeaderUtility {
const std::vector<HeaderData>& config_headers);

static bool matchHeaders(const Http::HeaderMap& request_headers, const HeaderData& config_header);

/**
* Add headers from one HeaderMap to another
* @param headers target where headers will be added
* @param headers_to_add supplies the headers to be added
*/
static void addHeaders(Http::HeaderMap& headers, const Http::HeaderMap& headers_to_add);
};
} // namespace Http
} // namespace Envoy
14 changes: 12 additions & 2 deletions source/common/ratelimit/ratelimit_impl.cc
Expand Up @@ -9,6 +9,7 @@
#include "envoy/stats/scope.h"

#include "common/common/assert.h"
#include "common/http/header_map_impl.h"
#include "common/http/headers.h"

namespace Envoy {
Expand Down Expand Up @@ -67,14 +68,23 @@ void GrpcClientImpl::onSuccess(
span.setTag(Constants::get().TraceStatus, Constants::get().TraceOk);
}

callbacks_->complete(status);
if (response->headers_size()) {
Http::HeaderMapPtr headers = std::make_unique<Http::HeaderMapImpl>();
for (const auto& h : response->headers()) {
headers->addCopy(Http::LowerCaseString(h.key()), h.value());
}
callbacks_->complete(status, std::move(headers));
} else {
callbacks_->complete(status, nullptr);
}

callbacks_ = nullptr;
}

void GrpcClientImpl::onFailure(Grpc::Status::GrpcStatus status, const std::string&,
Tracing::Span&) {
ASSERT(status != Grpc::Status::GrpcStatus::Ok);
callbacks_->complete(LimitStatus::Error);
callbacks_->complete(LimitStatus::Error, nullptr);
callbacks_ = nullptr;
}

Expand Down
2 changes: 1 addition & 1 deletion source/common/ratelimit/ratelimit_impl.h
Expand Up @@ -87,7 +87,7 @@ class NullClientImpl : public Client {
void cancel() override {}
void limit(RequestCallbacks& callbacks, const std::string&, const std::vector<Descriptor>&,
Tracing::Span&) override {
callbacks.complete(LimitStatus::OK);
callbacks.complete(LimitStatus::OK, nullptr);
}
};

Expand Down
2 changes: 1 addition & 1 deletion source/extensions/filters/http/ratelimit/config.cc
Expand Up @@ -26,7 +26,7 @@ Http::FilterFactoryCb RateLimitFilterConfig::createFilterFactoryFromProtoTyped(
const uint32_t timeout_ms = PROTOBUF_GET_MS_OR_DEFAULT(proto_config, timeout, 20);
return
[filter_config, timeout_ms, &context](Http::FilterChainFactoryCallbacks& callbacks) -> void {
callbacks.addStreamDecoderFilter(std::make_shared<Filter>(
callbacks.addStreamFilter(std::make_shared<Filter>(
filter_config, context.rateLimitClient(std::chrono::milliseconds(timeout_ms))));
};
}
Expand Down
33 changes: 31 additions & 2 deletions source/extensions/filters/http/ratelimit/ratelimit.cc
Expand Up @@ -9,6 +9,7 @@
#include "common/common/enum_to_int.h"
#include "common/common/fmt.h"
#include "common/http/codes.h"
#include "common/http/header_utility.h"
#include "common/router/config_impl.h"

namespace Envoy {
Expand Down Expand Up @@ -87,15 +88,35 @@ void Filter::setDecoderFilterCallbacks(Http::StreamDecoderFilterCallbacks& callb
callbacks_ = &callbacks;
}

Http::FilterHeadersStatus Filter::encode100ContinueHeaders(Http::HeaderMap&) {
return Http::FilterHeadersStatus::Continue;
}

Http::FilterHeadersStatus Filter::encodeHeaders(Http::HeaderMap& headers, bool) {
addHeaders(headers);
return Http::FilterHeadersStatus::Continue;
}

Http::FilterDataStatus Filter::encodeData(Buffer::Instance&, bool) {
return Http::FilterDataStatus::Continue;
}

Http::FilterTrailersStatus Filter::encodeTrailers(Http::HeaderMap&) {
return Http::FilterTrailersStatus::Continue;
}

void Filter::setEncoderFilterCallbacks(Http::StreamEncoderFilterCallbacks&) {}

void Filter::onDestroy() {
if (state_ == State::Calling) {
state_ = State::Complete;
client_->cancel();
}
}

void Filter::complete(RateLimit::LimitStatus status) {
void Filter::complete(RateLimit::LimitStatus status, Http::HeaderMapPtr&& headers) {
state_ = State::Complete;
headers_to_add_ = std::move(headers);

switch (status) {
case RateLimit::LimitStatus::OK:
Expand Down Expand Up @@ -123,7 +144,8 @@ void Filter::complete(RateLimit::LimitStatus status) {
if (status == RateLimit::LimitStatus::OverLimit &&
config_->runtime().snapshot().featureEnabled("ratelimit.http_filter_enforcing", 100)) {
state_ = State::Responded;
callbacks_->sendLocalReply(Http::Code::TooManyRequests, "", nullptr);
callbacks_->sendLocalReply(Http::Code::TooManyRequests, "",
[this](Http::HeaderMap& headers) { addHeaders(headers); });
callbacks_->requestInfo().setResponseFlag(RequestInfo::ResponseFlag::RateLimited);
} else if (!initiating_call_) {
callbacks_->continueDecoding();
Expand All @@ -147,6 +169,13 @@ void Filter::populateRateLimitDescriptors(const Router::RateLimitPolicy& rate_li
}
}

void Filter::addHeaders(Http::HeaderMap& headers) {
if (headers_to_add_) {
Http::HeaderUtility::addHeaders(headers, *headers_to_add_);
headers_to_add_ = nullptr;
}
}

} // namespace RateLimitFilter
} // namespace HttpFilters
} // namespace Extensions
Expand Down
13 changes: 11 additions & 2 deletions source/extensions/filters/http/ratelimit/ratelimit.h
Expand Up @@ -74,7 +74,7 @@ typedef std::shared_ptr<FilterConfig> FilterConfigSharedPtr;
* HTTP rate limit filter. Depending on the route configuration, this filter calls the global
* rate limiting service before allowing further filter iteration.
*/
class Filter : public Http::StreamDecoderFilter, public RateLimit::RequestCallbacks {
class Filter : public Http::StreamFilter, public RateLimit::RequestCallbacks {
public:
Filter(FilterConfigSharedPtr config, RateLimit::ClientPtr&& client)
: config_(config), client_(std::move(client)) {}
Expand All @@ -88,15 +88,23 @@ class Filter : public Http::StreamDecoderFilter, public RateLimit::RequestCallba
Http::FilterTrailersStatus decodeTrailers(Http::HeaderMap& trailers) override;
void setDecoderFilterCallbacks(Http::StreamDecoderFilterCallbacks& callbacks) override;

// Http::StreamEncoderFilter
Http::FilterHeadersStatus encode100ContinueHeaders(Http::HeaderMap& headers) override;
Http::FilterHeadersStatus encodeHeaders(Http::HeaderMap& headers, bool end_stream) override;
Http::FilterDataStatus encodeData(Buffer::Instance& data, bool end_stream) override;
Http::FilterTrailersStatus encodeTrailers(Http::HeaderMap& trailers) override;
void setEncoderFilterCallbacks(Http::StreamEncoderFilterCallbacks& callbacks) override;

// RateLimit::RequestCallbacks
void complete(RateLimit::LimitStatus status) override;
void complete(RateLimit::LimitStatus status, Http::HeaderMapPtr&& headers) override;

private:
void initiateCall(const Http::HeaderMap& headers);
void populateRateLimitDescriptors(const Router::RateLimitPolicy& rate_limit_policy,
std::vector<RateLimit::Descriptor>& descriptors,
const Router::RouteEntry* route_entry,
const Http::HeaderMap& headers) const;
void addHeaders(Http::HeaderMap& headers);

enum class State { NotStarted, Calling, Complete, Responded };

Expand All @@ -106,6 +114,7 @@ class Filter : public Http::StreamDecoderFilter, public RateLimit::RequestCallba
State state_{State::NotStarted};
Upstream::ClusterInfoConstSharedPtr cluster_;
bool initiating_call_{};
Http::HeaderMapPtr headers_to_add_;
};

} // namespace RateLimitFilter
Expand Down
2 changes: 1 addition & 1 deletion source/extensions/filters/network/ratelimit/ratelimit.cc
Expand Up @@ -69,7 +69,7 @@ void Filter::onEvent(Network::ConnectionEvent event) {
}
}

void Filter::complete(RateLimit::LimitStatus status) {
void Filter::complete(RateLimit::LimitStatus status, Http::HeaderMapPtr&&) {
status_ = Status::Complete;
config_->stats().active_.dec();

Expand Down
4 changes: 2 additions & 2 deletions source/extensions/filters/network/ratelimit/ratelimit.h
Expand Up @@ -88,7 +88,7 @@ class Filter : public Network::ReadFilter,
void onBelowWriteBufferLowWatermark() override {}

// RateLimit::RequestCallbacks
void complete(RateLimit::LimitStatus status) override;
void complete(RateLimit::LimitStatus status, Http::HeaderMapPtr&& headers) override;

private:
enum class Status { NotStarted, Calling, Complete };
Expand All @@ -99,7 +99,7 @@ class Filter : public Network::ReadFilter,
Status status_{Status::NotStarted};
bool calling_limit_{};
};
}
} // namespace RateLimitFilter
} // namespace NetworkFilters
} // namespace Extensions
} // namespace Envoy
16 changes: 16 additions & 0 deletions test/common/http/header_utility_test.cc
Expand Up @@ -403,5 +403,21 @@ invert_match: true
EXPECT_FALSE(HeaderUtility::matchHeaders(unmatching_headers, header_data));
}

TEST(HeaderAddTest, HeaderAdd) {
TestHeaderMapImpl headers{{"myheader1", "123value"}};
TestHeaderMapImpl headers_to_add{{"myheader2", "456value"}};

HeaderUtility::addHeaders(headers, headers_to_add);

headers_to_add.iterate(
[](const Http::HeaderEntry& entry, void* context) -> Http::HeaderMap::Iterate {
TestHeaderMapImpl* headers = static_cast<TestHeaderMapImpl*>(context);
Http::LowerCaseString lower_key{entry.key().c_str()};
EXPECT_STREQ(entry.value().c_str(), headers->get(lower_key)->value().c_str());
return Http::HeaderMap::Iterate::Continue;
},
&headers);
}

} // namespace Http
} // namespace Envoy
2 changes: 1 addition & 1 deletion test/common/network/filter_manager_impl_test.cc
Expand Up @@ -206,7 +206,7 @@ TEST_F(NetworkFilterManagerTest, RateLimitAndTcpProxy) {
EXPECT_CALL(factory_context.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _))
.WillOnce(Return(&conn_pool));

request_callbacks->complete(RateLimit::LimitStatus::OK);
request_callbacks->complete(RateLimit::LimitStatus::OK, nullptr);

conn_pool.poolReady(upstream_connection);

Expand Down
14 changes: 9 additions & 5 deletions test/common/ratelimit/ratelimit_impl_test.cc
Expand Up @@ -29,7 +29,11 @@ namespace RateLimit {

class MockRequestCallbacks : public RequestCallbacks {
public:
MOCK_METHOD1(complete, void(LimitStatus status));
void complete(LimitStatus status, Http::HeaderMapPtr&& headers) {
complete_(status, headers.get());
}

MOCK_METHOD2(complete_, void(LimitStatus status, const Http::HeaderMap* headers));
};

// TODO(junr03): legacy rate limit is deprecated. Remove the boolean parameter after 1.8.0.
Expand Down Expand Up @@ -91,7 +95,7 @@ TEST_P(RateLimitGrpcClientTest, Basic) {
response.reset(new envoy::service::ratelimit::v2::RateLimitResponse());
response->set_overall_code(envoy::service::ratelimit::v2::RateLimitResponse_Code_OVER_LIMIT);
EXPECT_CALL(span_, setTag("ratelimit_status", "over_limit"));
EXPECT_CALL(request_callbacks_, complete(LimitStatus::OverLimit));
EXPECT_CALL(request_callbacks_, complete_(LimitStatus::OverLimit, _));
client_->onSuccess(std::move(response), span_);
}

Expand All @@ -110,7 +114,7 @@ TEST_P(RateLimitGrpcClientTest, Basic) {
response.reset(new envoy::service::ratelimit::v2::RateLimitResponse());
response->set_overall_code(envoy::service::ratelimit::v2::RateLimitResponse_Code_OK);
EXPECT_CALL(span_, setTag("ratelimit_status", "ok"));
EXPECT_CALL(request_callbacks_, complete(LimitStatus::OK));
EXPECT_CALL(request_callbacks_, complete_(LimitStatus::OK, _));
client_->onSuccess(std::move(response), span_);
}

Expand All @@ -127,7 +131,7 @@ TEST_P(RateLimitGrpcClientTest, Basic) {
Tracing::NullSpan::instance());

response.reset(new envoy::service::ratelimit::v2::RateLimitResponse());
EXPECT_CALL(request_callbacks_, complete(LimitStatus::Error));
EXPECT_CALL(request_callbacks_, complete_(LimitStatus::Error, _));
client_->onFailure(Grpc::Status::Unknown, "", span_);
}
}
Expand Down Expand Up @@ -179,7 +183,7 @@ TEST(RateLimitNullFactoryTest, Basic) {
NullFactoryImpl factory;
ClientPtr client = factory.create(absl::optional<std::chrono::milliseconds>());
MockRequestCallbacks request_callbacks;
EXPECT_CALL(request_callbacks, complete(LimitStatus::OK));
EXPECT_CALL(request_callbacks, complete_(LimitStatus::OK, _));
client->limit(request_callbacks, "foo", {{{{"foo", "bar"}}}}, Tracing::NullSpan::instance());
client->cancel();
}
Expand Down
6 changes: 3 additions & 3 deletions test/extensions/filters/http/ratelimit/config_test.cc
Expand Up @@ -36,7 +36,7 @@ TEST(RateLimitFilterConfigTest, RateLimitFilterCorrectJson) {
RateLimitFilterConfig factory;
Http::FilterFactoryCb cb = factory.createFilterFactory(*json_config, "stats", context);
Http::MockFilterChainFactoryCallbacks filter_callback;
EXPECT_CALL(filter_callback, addStreamDecoderFilter(_));
EXPECT_CALL(filter_callback, addStreamFilter(_));
cb(filter_callback);
}

Expand All @@ -56,7 +56,7 @@ TEST(RateLimitFilterConfigTest, RateLimitFilterCorrectProto) {
RateLimitFilterConfig factory;
Http::FilterFactoryCb cb = factory.createFilterFactoryFromProto(proto_config, "stats", context);
Http::MockFilterChainFactoryCallbacks filter_callback;
EXPECT_CALL(filter_callback, addStreamDecoderFilter(_));
EXPECT_CALL(filter_callback, addStreamFilter(_));
cb(filter_callback);
}

Expand All @@ -79,7 +79,7 @@ TEST(RateLimitFilterConfigTest, RateLimitFilterEmptyProto) {

Http::FilterFactoryCb cb = factory.createFilterFactory(*json_config, "stats", context);
Http::MockFilterChainFactoryCallbacks filter_callback;
EXPECT_CALL(filter_callback, addStreamDecoderFilter(_));
EXPECT_CALL(filter_callback, addStreamFilter(_));
cb(filter_callback);
}

Expand Down

0 comments on commit 71152b7

Please sign in to comment.