Skip to content

Commit

Permalink
feat: add timeout middleware (#1529)
Browse files Browse the repository at this point in the history
A new middleware is introduced that enforces a strict timeout by using
`context.WithTimeout()`. When the timeout is reached, a 504 JSON error
with the `request_timeout` error code is sent. Anything that depends on
the context is cancelled.

---------

Co-authored-by: Kang Ming <kang.ming1996@gmail.com>
  • Loading branch information
J0 and kangmingtay committed Apr 25, 2024
1 parent bd8b5c4 commit f96ff31
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 13 deletions.
5 changes: 5 additions & 0 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
logger := observability.NewStructuredLogger(logrus.StandardLogger(), globalConfig)

r := newRouter()

if globalConfig.API.MaxRequestDuration > 0 {
r.UseBypass(api.timeoutMiddleware(globalConfig.API.MaxRequestDuration))
}

r.Use(addRequestID(globalConfig))

// request tracing should be added only when tracing or metrics is enabled
Expand Down
1 change: 1 addition & 0 deletions internal/api/errorcodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,5 @@ const (
ErrorCodeHookTimeoutAfterRetry ErrorCode = "hook_timeout_after_retry"
ErrorCodeHookPayloadOverSizeLimit ErrorCode = "hook_payload_over_size_limit"
ErrorCodeHookPayloadUnknownSize ErrorCode = "hook_payload_unknown_size"
ErrorCodeRequestTimeout ErrorCode = "request_timeout"
)
16 changes: 8 additions & 8 deletions internal/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
output.Message = e.Message
output.Payload.Reasons = e.Reasons

if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil {
if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil && jsonErr != context.DeadlineExceeded {
HandleResponseError(jsonErr, w, r)
}

Expand All @@ -224,7 +224,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
output.Message = e.Message
output.Payload.Reasons = e.Reasons

if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil {
if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil && jsonErr != context.DeadlineExceeded {
HandleResponseError(jsonErr, w, r)
}
}
Expand Down Expand Up @@ -252,7 +252,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
}
}

if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil {
if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded {
HandleResponseError(jsonErr, w, r)
}
} else {
Expand All @@ -266,20 +266,20 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {

// Provide better error messages for certain user-triggered Postgres errors.
if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil {
if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil {
if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil && jsonErr != context.DeadlineExceeded {
HandleResponseError(jsonErr, w, r)
}
return
}

if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil {
if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil && jsonErr != context.DeadlineExceeded {
HandleResponseError(jsonErr, w, r)
}
}

case *OAuthError:
log.WithError(e.Cause()).Info(e.Error())
if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil {
if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil && jsonErr != context.DeadlineExceeded {
HandleResponseError(jsonErr, w, r)
}

Expand All @@ -295,7 +295,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
Message: "Unexpected failure, please check server logs for more information",
}

if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil {
if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded {
HandleResponseError(jsonErr, w, r)
}
} else {
Expand All @@ -305,7 +305,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) {
Message: "Unexpected failure, please check server logs for more information",
}

if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil {
if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil && jsonErr != context.DeadlineExceeded {
HandleResponseError(jsonErr, w, r)
}
}
Expand Down
95 changes: 95 additions & 0 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/supabase/auth/internal/models"
Expand Down Expand Up @@ -260,3 +262,96 @@ func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.H
})
}
}

// timeoutResponseWriter is a http.ResponseWriter that prevents subsequent
// writes after the context contained in it has exceeded the deadline. If a
// partial write occurs before the deadline is exceeded, but the writing is not
// complete it will allow further writes.
type timeoutResponseWriter struct {
ctx context.Context
w http.ResponseWriter
wrote int32
mu sync.Mutex
}

func (t *timeoutResponseWriter) Header() http.Header {
t.mu.Lock()
defer t.mu.Unlock()
return t.w.Header()
}

func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.ctx.Err() == context.DeadlineExceeded {
if atomic.LoadInt32(&t.wrote) == 0 {
return 0, context.DeadlineExceeded
}

// writing started before the deadline exceeded, but the
// deadline came in the middle, so letting the writes go
// through
}

t.wrote = 1

return t.w.Write(bytes)
}

func (t *timeoutResponseWriter) WriteHeader(statusCode int) {
t.mu.Lock()
defer t.mu.Unlock()
if t.ctx.Err() == context.DeadlineExceeded {
if atomic.LoadInt32(&t.wrote) == 0 {
return
}

// writing started before the deadline exceeded, but the
// deadline came in the middle, so letting the writes go
// through
}

t.wrote = 1

t.w.WriteHeader(statusCode)
}

func (a *API) timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()

timeoutWriter := &timeoutResponseWriter{
w: w,
ctx: ctx,
}

go func() {
<-ctx.Done()

err := ctx.Err()

if err == context.DeadlineExceeded {
timeoutWriter.mu.Lock()
defer timeoutWriter.mu.Unlock()
if timeoutWriter.wrote == 0 {
// writer wasn't written to, so we're sending the error payload

httpError := &HTTPError{
HTTPStatus: http.StatusGatewayTimeout,
ErrorCode: ErrorCodeRequestTimeout,
Message: "Processing this request timed out, please retry after a moment.",
}

httpError = httpError.WithInternalError(err)

HandleResponseError(httpError, w, r)
}
}
}()

next.ServeHTTP(timeoutWriter, r.WithContext(ctx))
})
}
}
23 changes: 23 additions & 0 deletions internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http/httptest"
"net/url"
"testing"
"time"

jwt "github.com/golang-jwt/jwt"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -312,3 +313,25 @@ func TestFunctionHooksUnmarshalJSON(t *testing.T) {
})
}
}

func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() {
ts.Config.API.MaxRequestDuration = 5 * time.Microsecond
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
w := httptest.NewRecorder()

timeoutHandler := ts.API.timeoutMiddleware(ts.Config.API.MaxRequestDuration)

slowHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Sleep for 1 second to simulate a slow handler which should trigger the timeout
time.Sleep(1 * time.Second)
ts.API.handler.ServeHTTP(w, r)
})
timeoutHandler(slowHandler).ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusGatewayTimeout, w.Code)

var data map[string]interface{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
require.Equal(ts.T(), ErrorCodeRequestTimeout, data["error_code"])
require.Equal(ts.T(), float64(504), data["code"])
require.NotNil(ts.T(), data["msg"])
}
11 changes: 6 additions & 5 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,12 @@ type MFAConfiguration struct {
}

type APIConfiguration struct {
Host string
Port string `envconfig:"PORT" default:"8081"`
Endpoint string
RequestIDHeader string `envconfig:"REQUEST_ID_HEADER"`
ExternalURL string `json:"external_url" envconfig:"API_EXTERNAL_URL" required:"true"`
Host string
Port string `envconfig:"PORT" default:"8081"`
Endpoint string
RequestIDHeader string `envconfig:"REQUEST_ID_HEADER"`
ExternalURL string `json:"external_url" envconfig:"API_EXTERNAL_URL" required:"true"`
MaxRequestDuration time.Duration `json:"max_request_duration" split_words:"true" default:"10s"`
}

func (a *APIConfiguration) Validate() error {
Expand Down

0 comments on commit f96ff31

Please sign in to comment.