Skip to content

Commit

Permalink
limit only on request failure
Browse files Browse the repository at this point in the history
  • Loading branch information
WnP committed Dec 30, 2019
1 parent b7eca7c commit afd31ad
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 6 deletions.
36 changes: 33 additions & 3 deletions api/http/security/rate_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,28 @@ import (

"github.com/g07cha/defender"
httperror "github.com/portainer/libhttp/error"
"github.com/portainer/portainer/api"
portainer "github.com/portainer/portainer/api"
)

// RatedResponseWriter implement ResponseWriter interface but also expose HTTP status code
type RatedResponseWriter struct {
http.ResponseWriter
StatusCode int
}

// NewRatedResponseWriter build a RatedResponseWriter from a ResponseWriter
func NewRatedResponseWriter(w http.ResponseWriter) *RatedResponseWriter {
// WriteHeader(int) is not called if our response implicitly returns 200 OK, so
// we default to that status code.
return &RatedResponseWriter{w, http.StatusOK}
}

// WriteHeader overload ResponseWriter method in order to store the status code
func (rrw *RatedResponseWriter) WriteHeader(code int) {
rrw.StatusCode = code
rrw.ResponseWriter.WriteHeader(code)
}

// RateLimiter represents an entity that manages request rate limiting
type RateLimiter struct {
*defender.Defender
Expand All @@ -25,15 +44,26 @@ func NewRateLimiter(maxRequests int, duration time.Duration, banDuration time.Du
}
}

// IsBanned return true if given IP is banned
func (limiter *RateLimiter) IsBanned(ip interface{}) bool {
c, ok := limiter.Client(ip)
return ok && c.Banned() || ok
}

// LimitAccess wraps current request with check if remote address does not goes above the defined limits
func (limiter *RateLimiter) LimitAccess(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := StripAddrPort(r.RemoteAddr)
if banned := limiter.Inc(ip); banned == true {
if limiter.IsBanned(ip) {
httperror.WriteError(w, http.StatusForbidden, "Access denied", portainer.ErrResourceAccessDenied)
return
}
rrw := NewRatedResponseWriter(w)
next.ServeHTTP(rrw, r)
if rrw.StatusCode >= 400 && limiter.Inc(ip) {
httperror.WriteError(w, http.StatusForbidden, "Access denied", portainer.ErrResourceAccessDenied)
return
}
next.ServeHTTP(w, r)
})
}

Expand Down
43 changes: 40 additions & 3 deletions api/http/security/rate_limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ import (
"time"
)

func TestLimitAccess(t *testing.T) {
func TestLimitAccessOnRegularQuery(t *testing.T) {
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

t.Run("Request below the limit", func(t *testing.T) {
t.Run("Request below the limit and OK", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
rateLimiter := NewRateLimiter(10, 1*time.Second, 1*time.Hour)
Expand All @@ -26,7 +26,7 @@ func TestLimitAccess(t *testing.T) {
}
})

t.Run("Request above the limit", func(t *testing.T) {
t.Run("Request above the limit but still OK", func(t *testing.T) {
rateLimiter := NewRateLimiter(1, 1*time.Second, 1*time.Hour)
handler := rateLimiter.LimitAccess(testHandler)

Expand All @@ -38,6 +38,43 @@ func TestLimitAccess(t *testing.T) {
t.Fatal(err)
}

if status := resp.StatusCode; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusForbidden)
}
})
}

func TestLimitAccessFailureQuery(t *testing.T) {
testFailHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
})
t.Run("Request below the limit but not OK", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
rateLimiter := NewRateLimiter(10, 1*time.Second, 1*time.Hour)
handler := rateLimiter.LimitAccess(testFailHandler)

handler.ServeHTTP(rr, req)

if status := rr.Code; status != http.StatusBadRequest {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
})

t.Run("Request above the limit but not OK", func(t *testing.T) {
rateLimiter := NewRateLimiter(1, 1*time.Second, 1*time.Hour)
handler := rateLimiter.LimitAccess(testFailHandler)

ts := httptest.NewServer(handler)
defer ts.Close()
http.Get(ts.URL)
resp, err := http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}

if status := resp.StatusCode; status != http.StatusForbidden {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusForbidden)
Expand Down

0 comments on commit afd31ad

Please sign in to comment.