-
Notifications
You must be signed in to change notification settings - Fork 163
/
timeout.go
54 lines (45 loc) · 1.16 KB
/
timeout.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
// Copyright (C) 2017 ScyllaDB
package scyllaclient
import (
"context"
"io"
"net/http"
"time"
"github.com/pkg/errors"
"github.com/scylladb/scylla-operator/pkg/util/httpx"
)
var (
// ErrTimeout is returned when request times out.
ErrTimeout = errors.New("timeout")
)
// body defers context cancellation until response body is closed.
type body struct {
io.ReadCloser
cancel context.CancelFunc
}
func (b body) Close() error {
defer b.cancel()
return b.ReadCloser.Close()
}
// timeout sets request context timeout for individual requests.
func timeout(next http.RoundTripper, timeout time.Duration) http.RoundTripper {
return httpx.RoundTripperFunc(func(req *http.Request) (resp *http.Response, err error) {
d, ok := hasCustomTimeout(req.Context())
if !ok {
d = timeout
}
ctx, cancel := context.WithTimeout(req.Context(), d)
defer func() {
if resp != nil {
resp.Body = body{
ReadCloser: resp.Body,
cancel: cancel,
}
}
if errors.Cause(err) == context.DeadlineExceeded && ctx.Err() == context.DeadlineExceeded {
err = errors.Wrapf(ErrTimeout, "after %s", d)
}
}()
return next.RoundTrip(req.WithContext(ctx))
})
}