forked from Versent/saml2aws
-
Notifications
You must be signed in to change notification settings - Fork 0
/
http.go
144 lines (114 loc) · 3.33 KB
/
http.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package provider
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"runtime"
"time"
"github.com/sirupsen/logrus"
"github.com/versent/saml2aws/pkg/cookiejar"
"github.com/versent/saml2aws/pkg/dump"
"github.com/briandowns/spinner"
"github.com/pkg/errors"
"golang.org/x/net/publicsuffix"
)
// HTTPClient saml2aws http client which extends the existing client
type HTTPClient struct {
http.Client
CheckResponseStatus func(*http.Request, *http.Response) error
}
// NewDefaultTransport configure a transport with the TLS skip verify option
func NewDefaultTransport(skipVerify bool) *http.Transport {
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: skipVerify},
}
}
// NewHTTPClient configure the default http client used by the providers
func NewHTTPClient(tr http.RoundTripper) (*HTTPClient, error) {
options := &cookiejar.Options{
PublicSuffixList: publicsuffix.List,
}
jar, err := cookiejar.New(options)
if err != nil {
return nil, err
}
client := http.Client{Transport: tr, Jar: jar}
return &HTTPClient{client, nil}, nil
}
// Do do the request
func (hc *HTTPClient) Do(req *http.Request) (*http.Response, error) {
cs := spinner.CharSets[14]
// use a NON unicode spinner for windows
if runtime.GOOS == "windows" {
cs = spinner.CharSets[26]
}
if logrus.GetLevel() != logrus.DebugLevel {
s := spinner.New(cs, 100*time.Millisecond)
defer func() {
s.Stop()
}()
s.Start()
}
req.Header.Set("User-Agent", fmt.Sprintf("saml2aws/1.0 (%s %s) Versent", runtime.GOOS, runtime.GOARCH))
hc.logHTTPRequest(req)
resp, err := hc.Client.Do(req)
if err != nil {
return resp, err
}
// if a response check has been configured
if hc.CheckResponseStatus != nil {
err = hc.CheckResponseStatus(req, resp)
if err != nil {
return resp, err
}
}
hc.logHTTPResponse(resp)
return resp, err
}
// DisableFollowRedirect disable redirects
func (hc *HTTPClient) DisableFollowRedirect() {
hc.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
}
// EnableFollowRedirect enable redirects
func (hc *HTTPClient) EnableFollowRedirect() {
hc.CheckRedirect = nil
}
// SuccessOrRedirectResponseValidator this validates the response code is within range of 200 - 399
func SuccessOrRedirectResponseValidator(req *http.Request, resp *http.Response) error {
if resp.StatusCode >= 200 && resp.StatusCode < 400 {
return nil
}
return errors.Errorf("request for url: %s failed status: %s", req.URL.String(), resp.Status)
}
func (hc *HTTPClient) logHTTPRequest(req *http.Request) {
if dump.ContentEnable() {
fmt.Println(dump.RequestString(req))
return
}
logrus.WithField("http", "client").WithFields(logrus.Fields{
"URL": req.URL.String(),
"method": req.Method,
}).Debug("HTTP Req")
}
func (hc *HTTPClient) logHTTPResponse(resp *http.Response) {
if dump.ContentEnable() {
fmt.Println(dump.ResponseString(resp))
return
}
logrus.WithField("http", "client").WithFields(logrus.Fields{
"Status": resp.Status,
}).Debug("HTTP Res")
}