-
Notifications
You must be signed in to change notification settings - Fork 3
/
custom_transport.go
149 lines (126 loc) · 3.75 KB
/
custom_transport.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package datadog
// Credit to https://github.com/turbot/steampipe-plugin-datadog/blob/add-tables/internal/transport/custom_transport.go
import (
"bytes"
"context"
"io"
"math"
"net/http"
"strconv"
"time"
)
var (
defaultHTTPRetryDuration = 5 * time.Second
defaultHTTPRetryTimeout = 60 * time.Second
rateLimitResetHeader = "X-Ratelimit-Reset"
)
// CustomTransport holds DefaultTransport configuration and is used to for custom http error handling
type CustomTransport struct {
defaultTransport http.RoundTripper
httpRetryDuration time.Duration
httpRetryTimeout time.Duration
}
// CustomTransportOptions Set options for CustomTransport
type CustomTransportOptions struct {
Timeout *time.Duration
}
// RoundTrip method used to retry http errors
func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
var ccancel context.CancelFunc
ctx := req.Context()
if _, set := ctx.Deadline(); !set {
ctx, ccancel = context.WithTimeout(ctx, t.httpRetryTimeout)
defer ccancel()
}
retryCount := 0
for {
newRequest := t.copyRequest(req)
resp, respErr := t.defaultTransport.RoundTrip(newRequest)
// Close the body so connection can be re-used
if resp != nil {
localVarBody, _ := io.ReadAll(resp.Body)
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewBuffer(localVarBody))
}
if respErr != nil {
return resp, respErr
}
// Check if request should be retried and get retry time
retryDuration, retry := t.retryRequest(resp)
if !retry {
return resp, respErr
}
// Calculate retryDuration if nil
if retryDuration == nil {
newRetryDurationVal := time.Duration(retryCount) * t.httpRetryDuration
retryDuration = &newRetryDurationVal
}
select {
case <-ctx.Done():
return resp, respErr
case <-time.After(*retryDuration):
retryCount++
continue
}
}
}
func (t *CustomTransport) copyRequest(r *http.Request) *http.Request {
newRequest := *r
if r.Body == nil || r.Body == http.NoBody {
return &newRequest
}
body, _ := r.GetBody()
newRequest.Body = body
return &newRequest
}
func (t *CustomTransport) retryRequest(response *http.Response) (*time.Duration, bool) {
if v := response.Header.Get(rateLimitResetHeader); v != "" && response.StatusCode == 429 {
vInt, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return nil, true
}
retryDuration := time.Duration(vInt) * time.Second
return &retryDuration, true
}
if response.StatusCode >= 500 {
return nil, true
}
return nil, false
}
// NewCustomTransport returns new CustomTransport struct
func NewCustomTransport(t http.RoundTripper, opt CustomTransportOptions) *CustomTransport {
// Use default transport if one provided is nil
if t == nil {
t = http.DefaultTransport
}
ct := CustomTransport{
defaultTransport: t,
httpRetryDuration: defaultHTTPRetryDuration,
}
if opt.Timeout != nil {
ct.httpRetryTimeout = *opt.Timeout
} else {
ct.httpRetryTimeout = defaultHTTPRetryTimeout
}
return &ct
}
// It also tries to parse Retry-After response header when a http.StatusTooManyRequests
// (HTTP Code 429) is found in the resp parameter. Hence it will return the number of
// seconds the server states it may be ready to process more requests from this client.
func (t *CustomTransport) DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration {
if resp != nil {
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable {
if s, ok := resp.Header["Retry-After"]; ok {
if sleep, err := strconv.ParseInt(s[0], 10, 64); err == nil {
return time.Second * time.Duration(sleep)
}
}
}
}
mult := math.Pow(2, float64(attemptNum)) * float64(min)
sleep := time.Duration(mult)
if float64(sleep) != mult || sleep > max {
sleep = max
}
return sleep
}