-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
allows authorized clients to bypass rate limiter #599
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,10 +4,12 @@ import ( | |
"fmt" | ||
"net" | ||
"net/http" | ||
"strconv" | ||
"strings" | ||
"time" | ||
|
||
"github.com/gorilla/mux" | ||
"github.com/sethvargo/go-limiter" | ||
"github.com/sethvargo/go-limiter/httplimit" | ||
"github.com/sethvargo/go-limiter/memorystore" | ||
) | ||
|
@@ -22,6 +24,7 @@ type RateLimiterConfig struct { | |
type RateLimiterRouteConfig struct { | ||
MaxRPI uint64 | ||
Interval time.Duration | ||
APIKey string | ||
} | ||
|
||
// RateLimitController creates a new middleware to rate limit requests. | ||
|
@@ -47,19 +50,20 @@ func RateLimitController(cfg RateLimiterConfig) (mux.MiddlewareFunc, error) { | |
}, nil | ||
} | ||
|
||
func createRateLimiter(cfg RateLimiterRouteConfig, kf httplimit.KeyFunc) (*httplimit.Middleware, error) { | ||
func createRateLimiter(cfg RateLimiterRouteConfig, kf httplimit.KeyFunc) (*middleware, error) { | ||
defaultStore, err := memorystore.New(&memorystore.Config{ | ||
Tokens: cfg.MaxRPI, | ||
Interval: cfg.Interval, | ||
}) | ||
if err != nil { | ||
return nil, fmt.Errorf("creating default memory: %s", err) | ||
} | ||
m, err := httplimit.NewMiddleware(defaultStore, kf) | ||
if err != nil { | ||
return nil, fmt.Errorf("creating default httplimiter: %s", err) | ||
} | ||
return m, nil | ||
|
||
return &middleware{ | ||
store: defaultStore, | ||
keyFunc: kf, | ||
apiKey: cfg.APIKey, | ||
}, nil | ||
} | ||
|
||
func extractClientIP(r *http.Request) (string, error) { | ||
|
@@ -77,3 +81,62 @@ func extractClientIP(r *http.Request) (string, error) { | |
} | ||
return ip, nil | ||
} | ||
|
||
type middleware struct { | ||
store limiter.Store | ||
keyFunc httplimit.KeyFunc | ||
|
||
// clients with key are not affected by rate limiter | ||
apiKey string | ||
} | ||
|
||
// Handle returns the HTTP handler as a middleware. This handler calls Take() on | ||
// the store and sets the common rate limiting headers. If the take is | ||
// successful, the remaining middleware is called. If take is unsuccessful, the | ||
// middleware chain is halted and the function renders a 429 to the caller with | ||
// metadata about when it's safe to retry. | ||
func (m *middleware) Handle(next http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
ctx := r.Context() | ||
|
||
// Call the key function - if this fails, it's an internal server error. | ||
key, err := m.keyFunc(r) | ||
if err != nil { | ||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) | ||
return | ||
} | ||
|
||
// skip rate limiting checks if secret key is provided | ||
if key := r.Header.Get("Secret-Key"); key != "" && m.apiKey != "" { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we name the header "Api-Key" instead of "Secret-Key" to make it consistent with the code? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm maybe i get your point now. this is the special "API key" that can bypass the checks. so it makes sense. |
||
if strings.EqualFold(key, m.apiKey) { | ||
next.ServeHTTP(w, r) | ||
return | ||
} | ||
} | ||
|
||
// Take from the store. | ||
limit, remaining, reset, ok, err := m.store.Take(ctx, key) | ||
if err != nil { | ||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) | ||
return | ||
} | ||
|
||
resetTime := time.Unix(0, int64(reset)).UTC().Format(time.RFC1123) | ||
|
||
// Set headers (we do this regardless of whether the request is permitted). | ||
w.Header().Set("X-RateLimit-Limit", strconv.FormatUint(limit, 10)) | ||
w.Header().Set("X-RateLimit-Remaining", strconv.FormatUint(remaining, 10)) | ||
w.Header().Set("X-RateLimit-Reset", resetTime) | ||
|
||
// Fail if there were no tokens remaining. | ||
if !ok { | ||
w.Header().Set("Retry-After", resetTime) | ||
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) | ||
return | ||
} | ||
|
||
// If we got this far, we're allowed to continue, so call the next middleware | ||
// in the stack to continue processing. | ||
next.ServeHTTP(w, r) | ||
}) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ func TestLimit1IP(t *testing.T) { | |
callRPS int | ||
limitRPS int | ||
forwardedFor bool | ||
allow bool | ||
} | ||
|
||
tests := []testCase{ | ||
|
@@ -28,6 +29,9 @@ func TestLimit1IP(t *testing.T) { | |
|
||
{name: "success", callRPS: 100, limitRPS: 500, forwardedFor: false}, | ||
{name: "block-me", callRPS: 1000, limitRPS: 500, forwardedFor: false}, | ||
|
||
{name: "allow-me", callRPS: 1000, limitRPS: 500, forwardedFor: false, allow: true}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. adds two more test cases where rps is greater than limit but it never gets 429 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks good |
||
{name: "forwarded-allow-me", callRPS: 1000, limitRPS: 500, forwardedFor: true, allow: true}, | ||
} | ||
|
||
for _, tc := range tests { | ||
|
@@ -41,28 +45,35 @@ func TestLimit1IP(t *testing.T) { | |
Interval: time.Second, | ||
}, | ||
} | ||
rlcm, err := RateLimitController(cfg) | ||
require.NoError(t, err) | ||
rlc := rlcm(dummyHandler{}) | ||
|
||
ctx := context.Background() | ||
r, err := http.NewRequestWithContext(ctx, "", "", nil) | ||
require.NoError(t, err) | ||
|
||
ip := uuid.NewString() | ||
if tc.forwardedFor { | ||
r.Header.Set("X-Forwarded-For", uuid.NewString()) | ||
r.Header.Set("X-Forwarded-For", ip) | ||
} else { | ||
r.RemoteAddr = uuid.NewString() + ":1234" | ||
r.RemoteAddr = ip + ":1234" | ||
} | ||
|
||
if tc.allow { | ||
r.Header.Set("Secret-Key", "MYSECRETKEY") | ||
cfg.Default.APIKey = "MYSECRETKEY" | ||
} | ||
|
||
rlcm, err := RateLimitController(cfg) | ||
require.NoError(t, err) | ||
rlc := rlcm(dummyHandler{}) | ||
|
||
res := httptest.NewRecorder() | ||
|
||
// Verify that after some seconds making requests with the configured | ||
// callRPS with the limitRPS, we are getting the expected output: | ||
// - If callRPS < limitRPS, we never get a 429. | ||
// - If callRPS > limitRPS, we eventually should see a 429. | ||
assertFunc := require.Eventually | ||
if tc.callRPS < tc.limitRPS { | ||
if tc.callRPS < tc.limitRPS || tc.allow { | ||
assertFunc = require.Never | ||
} | ||
assertFunc(t, func() bool { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implementation borrowed from https://github.com/sethvargo/go-limiter/blob/main/httplimit/middleware.go with an extra tweak