-
Notifications
You must be signed in to change notification settings - Fork 28
/
middleware.go
123 lines (107 loc) · 4.47 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// Copyright (c) 2019 Palantir Technologies. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ratelimit
import (
"context"
"net/http"
"sync/atomic"
"github.com/palantir/witchcraft-go-logging/wlog/svclog/svc1log"
"github.com/palantir/witchcraft-go-server/conjure/witchcraft/api/health"
"github.com/palantir/witchcraft-go-server/status/reporter"
"github.com/palantir/witchcraft-go-server/witchcraft/refreshable"
"github.com/palantir/witchcraft-go-server/wrouter"
)
type MatchFunc func(req *http.Request, vals wrouter.RequestVals) bool
var MatchReadOnly MatchFunc = func(req *http.Request, vals wrouter.RequestVals) bool {
return req.Method == http.MethodGet || req.Method == http.MethodHead || req.Method == http.MethodOptions
}
var MatchMutating MatchFunc = func(req *http.Request, vals wrouter.RequestVals) bool {
return req.Method == http.MethodPost || req.Method == http.MethodPut || req.Method == http.MethodDelete || req.Method == http.MethodPatch
}
// NewInFlightRequestLimitMiddleware returns a middleware which counts and limits the number of
// inflight requests that match the provided MatchFunc filter. If MatchFunc is nil, it will
// match all requests. When the number of active matched requests exceeds the limit, the middleware
// returns StatusTooManyRequests (429).
//
// If healthcheck is non-nil, it will be set to REPAIRING when the middleware is throttling
// and HEALTHY when the current counter falls below the limit. It is initialized to HEALTHY.
//
// If limit is ever negative it will be treated as a 0, i.e. all requests will be throttled.
//
// TODO: We should set the Retry-After header based on how many requests we're rejecting.
// Maybe enqueue requests in a channel for a few seconds in case other requests return quickly?
func NewInFlightRequestLimitMiddleware(limit refreshable.Int, matches MatchFunc, healthcheck reporter.HealthComponent) wrouter.RouteHandlerMiddleware {
l := &limiter{
Limit: limit,
Matches: matches,
Health: healthcheck,
}
if healthcheck != nil {
healthcheck.Healthy()
}
return l.ServeHTTP
}
type limiter struct {
Limit refreshable.Int
Matches MatchFunc
Health reporter.HealthComponent
current int64
}
const inFlightThrottledMessage = "Throttling due to too many in-flight requests"
func (l *limiter) ServeHTTP(rw http.ResponseWriter, req *http.Request, reqVals wrouter.RequestVals, next wrouter.RouteRequestHandler) {
if l.Matches == nil || l.Matches(req, reqVals) {
throttled := l.increment(req.Context())
defer l.decrement(req.Context())
if throttled {
// Return early, triggering failover or exponential backoff.
http.Error(rw, inFlightThrottledMessage, http.StatusTooManyRequests)
return
}
}
next(rw, req, reqVals)
}
// increment adds 1 to the current counter. If the new value is over the Limit,
// increment returns 'true' to indicate the request should be rejected/throttled and
// l.Health is set to REPAIRING (if not already in that state).
func (l *limiter) increment(ctx context.Context) (throttled bool) {
current := atomic.AddInt64(&l.current, 1)
limit := l.limit()
if current <= limit {
return false
}
if l.Health != nil && l.Health.Status() != health.HealthStateRepairing {
msg := inFlightThrottledMessage
l.Health.SetHealth(health.HealthStateRepairing, &msg, nil)
}
svc1log.FromContext(ctx).Warn(inFlightThrottledMessage,
svc1log.SafeParam("current", current),
svc1log.SafeParam("limit", limit))
return true
}
// increment subtracts 1 from the current counter. If the new value is under the Limit,
// l.Health is set to HEALTHY (if not already in that state).
func (l *limiter) decrement(ctx context.Context) {
current := atomic.AddInt64(&l.current, -1)
if current < l.limit() && l.Health != nil && l.Health.Status() != health.HealthStateHealthy {
l.Health.Healthy()
}
}
// limit returns the current value of l.limit, or zero if the limit is negative.
func (l *limiter) limit() int64 {
current := l.Limit.CurrentInt()
if current < 0 {
current = 0
}
return int64(current)
}