Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new timeout writer implementation #1584

Merged
merged 4 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
r.UseBypass(recoverer)

if globalConfig.API.MaxRequestDuration > 0 {
r.UseBypass(api.timeoutMiddleware(globalConfig.API.MaxRequestDuration))
r.UseBypass(timeoutMiddleware(globalConfig.API.MaxRequestDuration))
}

// request tracing should be added only when tracing or metrics is enabled
Expand Down
138 changes: 83 additions & 55 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package api

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/security"
Expand Down Expand Up @@ -263,95 +264,122 @@ func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.H
}
}

// timeoutResponseWriter is a http.ResponseWriter that prevents subsequent
// writes after the context contained in it has exceeded the deadline. If a
// partial write occurs before the deadline is exceeded, but the writing is not
// complete it will allow further writes.
// timeoutResponseWriter is a http.ResponseWriter that queues up a response
// body to be sent if the serving completes before the context has exceeded its
// deadline.
type timeoutResponseWriter struct {
ctx context.Context
w http.ResponseWriter
wrote int32
mu sync.Mutex
sync.Mutex

header http.Header
wroteHeader bool
snapHeader http.Header // snapshot of the header at the time WriteHeader was called
statusCode int
buf bytes.Buffer
}

func (t *timeoutResponseWriter) Header() http.Header {
t.mu.Lock()
defer t.mu.Unlock()
return t.w.Header()
t.Lock()
defer t.Unlock()

return t.header
}

func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.ctx.Err() == context.DeadlineExceeded {
if atomic.LoadInt32(&t.wrote) == 0 {
return 0, context.DeadlineExceeded
}
t.Lock()
defer t.Unlock()

// writing started before the deadline exceeded, but the
// deadline came in the middle, so letting the writes go
// through
if !t.wroteHeader {
t.WriteHeader(http.StatusOK)
}

t.wrote = 1

return t.w.Write(bytes)
return t.buf.Write(bytes)
}

func (t *timeoutResponseWriter) WriteHeader(statusCode int) {
t.mu.Lock()
defer t.mu.Unlock()
if t.ctx.Err() == context.DeadlineExceeded {
if atomic.LoadInt32(&t.wrote) == 0 {
return
}
t.Lock()
defer t.Unlock()

if t.wroteHeader {
// ignore multiple calls to WriteHeader
// once WriteHeader has been called once, a snapshot of the header map is taken
// and saved in snapHeader to be used in finallyWrite
return
}
t.statusCode = statusCode
t.wroteHeader = true
t.snapHeader = t.header.Clone()
}

func (t *timeoutResponseWriter) finallyWrite(w http.ResponseWriter) {
t.Lock()
defer t.Unlock()

// writing started before the deadline exceeded, but the
// deadline came in the middle, so letting the writes go
// through
dst := w.Header()
for k, vv := range t.snapHeader {
dst[k] = vv
}

t.wrote = 1
if !t.wroteHeader {
t.statusCode = http.StatusOK
}

t.w.WriteHeader(statusCode)
w.WriteHeader(t.statusCode)
if _, err := w.Write(t.buf.Bytes()); err != nil {
logrus.WithError(err).Warn("Write failed")
}
}

func (a *API) timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
func timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
J0 marked this conversation as resolved.
Show resolved Hide resolved
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()

timeoutWriter := &timeoutResponseWriter{
w: w,
ctx: ctx,
header: make(http.Header),
}

panicChan := make(chan any, 1)
serverDone := make(chan struct{})
go func() {
<-ctx.Done()
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()

next.ServeHTTP(timeoutWriter, r.WithContext(ctx))
close(serverDone)
}()

select {
case p := <-panicChan:
panic(p)

case <-serverDone:
timeoutWriter.finallyWrite(w)

case <-ctx.Done():
err := ctx.Err()

if err == context.DeadlineExceeded {
timeoutWriter.mu.Lock()
defer timeoutWriter.mu.Unlock()
if timeoutWriter.wrote == 0 {
// writer wasn't written to, so we're sending the error payload

httpError := &HTTPError{
HTTPStatus: http.StatusGatewayTimeout,
ErrorCode: ErrorCodeRequestTimeout,
Message: "Processing this request timed out, please retry after a moment.",
}
httpError := &HTTPError{
HTTPStatus: http.StatusGatewayTimeout,
ErrorCode: ErrorCodeRequestTimeout,
Message: "Processing this request timed out, please retry after a moment.",
}

httpError = httpError.WithInternalError(err)
httpError = httpError.WithInternalError(err)

HandleResponseError(httpError, w, r)
}
}
}()
HandleResponseError(httpError, w, r)
} else {
// unrecognized context error, so we should wait for the server to finish
// and write out the response
<-serverDone

next.ServeHTTP(timeoutWriter, r.WithContext(ctx))
timeoutWriter.finallyWrite(w)
}
}
})
}
}
23 changes: 22 additions & 1 deletion internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
w := httptest.NewRecorder()

timeoutHandler := ts.API.timeoutMiddleware(ts.Config.API.MaxRequestDuration)
timeoutHandler := timeoutMiddleware(ts.Config.API.MaxRequestDuration)

slowHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Sleep for 1 second to simulate a slow handler which should trigger the timeout
Expand All @@ -335,3 +335,24 @@ func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() {
require.Equal(ts.T(), float64(504), data["code"])
require.NotNil(ts.T(), data["msg"])
}

func TestTimeoutResponseWriter(t *testing.T) {
// timeoutResponseWriter should exhitbit a similar behavior as http.ResponseWriter
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
w1 := httptest.NewRecorder()
w2 := httptest.NewRecorder()

timeoutHandler := timeoutMiddleware(time.Second * 10)

redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// tries to redirect twice
http.Redirect(w, r, "http://localhost:3001/#message=first_message", http.StatusSeeOther)

// overwrites the first
http.Redirect(w, r, "http://localhost:3001/second", http.StatusSeeOther)
})
timeoutHandler(redirectHandler).ServeHTTP(w1, req)
redirectHandler.ServeHTTP(w2, req)

require.Equal(t, w1.Result(), w2.Result())
}
21 changes: 12 additions & 9 deletions internal/api/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
err error
token *AccessTokenResponse
authCode string
rurl string
)

grantParams.FillGrantParams(r)
Expand All @@ -138,6 +139,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
return err
}
}

err = db.Transaction(func(tx *storage.Connection) error {
var terr error
user, terr = a.verifyTokenHash(tx, params)
Expand All @@ -152,12 +154,11 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
case mail.EmailChangeVerification:
user, terr = a.emailChangeVerify(r, tx, params, user)
if user == nil && terr == nil {
// when double confirmation is required
rurl, err := a.prepRedirectURL(singleConfirmationAccepted, params.RedirectTo, flowType)
if err != nil {
return err
// only one OTP is confirmed at this point, so we return early and ask the user to confirm the second OTP
rurl, terr = a.prepRedirectURL(singleConfirmationAccepted, params.RedirectTo, flowType)
if terr != nil {
return terr
kangmingtay marked this conversation as resolved.
Show resolved Hide resolved
}
http.Redirect(w, r, rurl, http.StatusSeeOther)
return nil
}
default:
Expand Down Expand Up @@ -198,15 +199,17 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
if err != nil {
var herr *HTTPError
if errors.As(err, &herr) {
rurl, err := a.prepErrorRedirectURL(herr, r, params.RedirectTo, flowType)
rurl, err = a.prepErrorRedirectURL(herr, r, params.RedirectTo, flowType)
if err != nil {
return err
}
http.Redirect(w, r, rurl, http.StatusSeeOther)
return nil
}
}
rurl := params.RedirectTo
if rurl != "" {
http.Redirect(w, r, rurl, http.StatusSeeOther)
return nil
}
rurl = params.RedirectTo
if isImplicitFlow(flowType) && token != nil {
q := url.Values{}
q.Set("type", params.Type)
Expand Down
Loading
Loading