Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid retries when any data was written to the backend #3285

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
78 changes: 39 additions & 39 deletions middlewares/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package middlewares

import (
"bufio"
"context"
"io/ioutil"
"net"
"net/http"
"net/http/httptrace"

"github.com/containous/traefik/log"
)
Expand Down Expand Up @@ -40,11 +40,24 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) {

attempts := 1
for {
netErrorOccurred := false
// We pass in a pointer to netErrorOccurred so that we can set it to true on network errors
// when proxying the HTTP requests to the backends. This happens in the custom RecordingErrorHandler.
newCtx := context.WithValue(r.Context(), defaultNetErrCtxKey, &netErrorOccurred)
retryResponseWriter := newRetryResponseWriter(rw, attempts >= retry.attempts, &netErrorOccurred)
attemptsExhausted := attempts >= retry.attempts
// Websocket requests can't be retried at this point in time.
// This is due to the fact that gorilla/websocket doesn't use the request
// context and so we don't get httptrace information.
// Websocket clients should however retry on their own anyway.
shouldRetry := !attemptsExhausted && !isWebsocketRequest(r)
retryResponseWriter := newRetryResponseWriter(rw, shouldRetry)

// Disable retries when the backend already received request data
trace := &httptrace.ClientTrace{
WroteHeaders: func() {
retryResponseWriter.DisableRetries()
},
WroteRequest: func(httptrace.WroteRequestInfo) {
retryResponseWriter.DisableRetries()
},
}
newCtx := httptrace.WithClientTrace(r.Context(), trace)

retry.next.ServeHTTP(retryResponseWriter, r.WithContext(newCtx))
if !retryResponseWriter.ShouldRetry() {
Expand All @@ -57,31 +70,6 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}
}

// netErrorCtxKey is a custom type that is used as key for the context.
type netErrorCtxKey string

// defaultNetErrCtxKey is the actual key which value is used to record network errors.
var defaultNetErrCtxKey netErrorCtxKey = "NetErrCtxKey"

// NetErrorRecorder is an interface to record net errors.
type NetErrorRecorder interface {
// Record can be used to signal the retry middleware that an network error happened
// and therefore the request should be retried.
Record(ctx context.Context)
}

// DefaultNetErrorRecorder is the default NetErrorRecorder implementation.
type DefaultNetErrorRecorder struct{}

// Record is recording network errors by setting the context value for the defaultNetErrCtxKey to true.
func (DefaultNetErrorRecorder) Record(ctx context.Context) {
val := ctx.Value(defaultNetErrCtxKey)

if netErrorOccurred, isBoolPointer := val.(*bool); isBoolPointer {
*netErrorOccurred = true
}
}

// RetryListener is used to inform about retry attempts.
type RetryListener interface {
// Retried will be called when a retry happens, with the request attempt passed to it.
Expand All @@ -104,13 +92,13 @@ type retryResponseWriter interface {
http.ResponseWriter
http.Flusher
ShouldRetry() bool
DisableRetries()
}

func newRetryResponseWriter(rw http.ResponseWriter, attemptsExhausted bool, netErrorOccured *bool) retryResponseWriter {
func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryResponseWriter {
responseWriter := &retryResponseWriterWithoutCloseNotify{
responseWriter: rw,
attemptsExhausted: attemptsExhausted,
netErrorOccured: netErrorOccured,
responseWriter: rw,
shouldRetry: shouldRetry,
}
if _, ok := rw.(http.CloseNotifier); ok {
return &retryResponseWriterWithCloseNotify{responseWriter}
Expand All @@ -119,13 +107,16 @@ func newRetryResponseWriter(rw http.ResponseWriter, attemptsExhausted bool, netE
}

type retryResponseWriterWithoutCloseNotify struct {
responseWriter http.ResponseWriter
attemptsExhausted bool
netErrorOccured *bool
responseWriter http.ResponseWriter
shouldRetry bool
}

func (rr *retryResponseWriterWithoutCloseNotify) ShouldRetry() bool {
return *rr.netErrorOccured && !rr.attemptsExhausted
return rr.shouldRetry
}

func (rr *retryResponseWriterWithoutCloseNotify) DisableRetries() {
rr.shouldRetry = false
}

func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header {
Expand All @@ -143,6 +134,15 @@ func (rr *retryResponseWriterWithoutCloseNotify) Write(buf []byte) (int, error)
}

func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) {
if rr.ShouldRetry() && code == http.StatusServiceUnavailable {
// We get a 503 HTTP Status Code when there is no backend server in the pool
// to which the request could be sent. Also, note that rr.ShouldRetry()
// will never return true in case there was a connetion established to
// the backend server and so we can be sure that the 503 was produced
// inside Traefik already and we don't have to retry in this cases.
rr.DisableRetries()
}

if rr.ShouldRetry() {
return
}
Expand Down
190 changes: 116 additions & 74 deletions middlewares/retry_test.go
Original file line number Diff line number Diff line change
@@ -1,91 +1,155 @@
package middlewares

import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"

"github.com/containous/traefik/testhelpers"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/roundrobin"
)

func TestRetry(t *testing.T) {
testCases := []struct {
failAtCalls []int
attempts int
responseStatus int
listener *countingRetryListener
retriedCount int
desc string
maxRequestAttempts int
wantRetryAttempts int
wantResponseStatus int
amountFaultyEndpoints int
isWebsocketHandshakeRequest bool
}{
{
failAtCalls: []int{1, 2},
attempts: 3,
responseStatus: http.StatusOK,
listener: &countingRetryListener{},
retriedCount: 2,
desc: "no retry on success",
maxRequestAttempts: 1,
wantRetryAttempts: 0,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 0,
},
{
desc: "no retry when max request attempts is one",
maxRequestAttempts: 1,
wantRetryAttempts: 0,
wantResponseStatus: http.StatusInternalServerError,
amountFaultyEndpoints: 1,
},
{
desc: "one retry when one server is faulty",
maxRequestAttempts: 2,
wantRetryAttempts: 1,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 1,
},
{
desc: "two retries when two servers are faulty",
maxRequestAttempts: 3,
wantRetryAttempts: 2,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 2,
},
{
desc: "max attempts exhausted delivers the 5xx response",
maxRequestAttempts: 3,
wantRetryAttempts: 2,
wantResponseStatus: http.StatusInternalServerError,
amountFaultyEndpoints: 3,
},
{
failAtCalls: []int{1, 2},
attempts: 2,
responseStatus: http.StatusBadGateway,
listener: &countingRetryListener{},
retriedCount: 1,
desc: "websocket request should not be retried",
maxRequestAttempts: 3,
wantRetryAttempts: 0,
wantResponseStatus: http.StatusBadGateway,
amountFaultyEndpoints: 1,
isWebsocketHandshakeRequest: true,
},
}

backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
rw.Write([]byte("OK"))
}))

forwarder, err := forward.New()
if err != nil {
t.Fatalf("Error creating forwarder: %s", err)
}

for _, tc := range testCases {
// bind tc locally
tc := tc
tcName := fmt.Sprintf("FailAtCalls(%v) RetryAttempts(%v)", tc.failAtCalls, tc.attempts)

t.Run(tcName, func(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()

var httpHandler http.Handler = &networkFailingHTTPHandler{failAtCalls: tc.failAtCalls, netErrorRecorder: &DefaultNetErrorRecorder{}}
httpHandler = NewRetry(tc.attempts, httpHandler, tc.listener)
loadBalancer, err := roundrobin.New(forwarder)
if err != nil {
t.Fatalf("Error creating load balancer: %s", err)
}

basePort := 33444
for i := 0; i < tc.amountFaultyEndpoints; i++ {
// 192.0.2.0 is a non-routable IP for testing purposes.
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
// We only use the port specification here because the URL is used as identifier
// in the load balancer and using the exact same URL would not add a new server.
loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
}

// add the functioning server to the end of the load balancer list
loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))

retryListener := &countingRetryListener{}
retry := NewRetry(tc.maxRequestAttempts, loadBalancer, retryListener)

recorder := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, "http://localhost:3000/ok", ioutil.NopCloser(nil))
if err != nil {
t.Fatalf("could not create request: %+v", err)
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)

if tc.isWebsocketHandshakeRequest {
req.Header.Add("Connection", "Upgrade")
req.Header.Add("Upgrade", "websocket")
}

httpHandler.ServeHTTP(recorder, req)
retry.ServeHTTP(recorder, req)

if tc.responseStatus != recorder.Code {
t.Errorf("wrong status code %d, want %d", recorder.Code, tc.responseStatus)
if tc.wantResponseStatus != recorder.Code {
t.Errorf("got status code %d, want %d", recorder.Code, tc.wantResponseStatus)
}
if tc.retriedCount != tc.listener.timesCalled {
t.Errorf("RetryListener called %d times, want %d times", tc.listener.timesCalled, tc.retriedCount)
if tc.wantRetryAttempts != retryListener.timesCalled {
t.Errorf("retry listener called %d time(s), want %d time(s)", retryListener.timesCalled, tc.wantRetryAttempts)
}
})
}
}

func TestDefaultNetErrorRecorderSuccess(t *testing.T) {
boolNetErrorOccurred := false
recorder := DefaultNetErrorRecorder{}
recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &boolNetErrorOccurred))
if !boolNetErrorOccurred {
t.Errorf("got %v after recording net error, wanted %v", boolNetErrorOccurred, true)
func TestRetryEmptyServerList(t *testing.T) {
forwarder, err := forward.New()
if err != nil {
t.Fatalf("Error creating forwarder: %s", err)
}
}

func TestDefaultNetErrorRecorderInvalidValueType(t *testing.T) {
stringNetErrorOccured := "nonsense"
recorder := DefaultNetErrorRecorder{}
recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &stringNetErrorOccured))
if stringNetErrorOccured != "nonsense" {
t.Errorf("got %v after recording net error, wanted %v", stringNetErrorOccured, "nonsense")
loadBalancer, err := roundrobin.New(forwarder)
if err != nil {
t.Fatalf("Error creating load balancer: %s", err)
}
}

func TestDefaultNetErrorRecorderNilValue(t *testing.T) {
nilNetErrorOccured := interface{}(nil)
recorder := DefaultNetErrorRecorder{}
recorder.Record(context.WithValue(context.Background(), defaultNetErrCtxKey, &nilNetErrorOccured))
if nilNetErrorOccured != interface{}(nil) {
t.Errorf("got %v after recording net error, wanted %v", nilNetErrorOccured, interface{}(nil))
// The EmptyBackendHandler middleware ensures that there is a 503
// response status set when there is no backend server in the pool.
next := NewEmptyBackendHandler(loadBalancer)

retryListener := &countingRetryListener{}
retry := NewRetry(3, next, retryListener)

recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)

retry.ServeHTTP(recorder, req)

const wantResponseStatus = http.StatusServiceUnavailable
if wantResponseStatus != recorder.Code {
t.Errorf("got status code %d, want %d", recorder.Code, wantResponseStatus)
}
const wantRetryAttempts = 0
if wantRetryAttempts != retryListener.timesCalled {
t.Errorf("retry listener called %d time(s), want %d time(s)", retryListener.timesCalled, wantRetryAttempts)
}
}

Expand All @@ -99,33 +163,11 @@ func TestRetryListeners(t *testing.T) {
for _, retryListener := range retryListeners {
listener := retryListener.(*countingRetryListener)
if listener.timesCalled != 2 {
t.Errorf("retry listener was called %d times, want %d", listener.timesCalled, 2)
t.Errorf("retry listener was called %d time(s), want %d time(s)", listener.timesCalled, 2)
}
}
}

// networkFailingHTTPHandler is an http.Handler implementation you can use to test retries.
type networkFailingHTTPHandler struct {
netErrorRecorder NetErrorRecorder
failAtCalls []int
callNumber int
}

func (handler *networkFailingHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handler.callNumber++

for _, failAtCall := range handler.failAtCalls {
if handler.callNumber == failAtCall {
handler.netErrorRecorder.Record(r.Context())

w.WriteHeader(http.StatusBadGateway)
return
}
}

w.WriteHeader(http.StatusOK)
}

// countingRetryListener is a RetryListener implementation to count the times the Retried fn is called.
type countingRetryListener struct {
timesCalled int
Expand Down