-
Notifications
You must be signed in to change notification settings - Fork 283
/
client.go
154 lines (139 loc) · 4.19 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
package httputil
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/requestid"
"github.com/pomerium/pomerium/internal/tripper"
)
// ErrTokenRevoked signifies a token revocation or expiration error
var ErrTokenRevoked = errors.New("token expired or revoked")
type loggingRoundTripper struct {
base http.RoundTripper
customize []func(event *zerolog.Event) *zerolog.Event
}
func (l loggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
start := time.Now()
res, err := l.base.RoundTrip(req)
statusCode := http.StatusInternalServerError
if res != nil {
statusCode = res.StatusCode
}
evt := log.Debug(req.Context()).
Str("method", req.Method).
Str("authority", req.URL.Host).
Str("path", req.URL.Path).
Dur("duration", time.Since(start)).
Int("response-code", statusCode)
for _, f := range l.customize {
f(evt)
}
evt.Msg("outbound http-request")
return res, err
}
// NewLoggingRoundTripper creates a http.RoundTripper that will log requests.
func NewLoggingRoundTripper(base http.RoundTripper, customize ...func(event *zerolog.Event) *zerolog.Event) http.RoundTripper {
if base == nil {
base = http.DefaultTransport
}
return loggingRoundTripper{base: base, customize: customize}
}
// NewLoggingClient creates a new http.Client that will log requests.
func NewLoggingClient(base *http.Client, name string, customize ...func(event *zerolog.Event) *zerolog.Event) *http.Client {
if base == nil {
base = http.DefaultClient
}
newClient := new(http.Client)
*newClient = *base
newClient.Transport = tripper.NewChain(metrics.HTTPMetricsRoundTripper(func() string {
return ""
}, name)).Then(NewLoggingRoundTripper(newClient.Transport, customize...))
return newClient
}
type httpClient struct {
*http.Client
requestIDTripper http.RoundTripper
}
func (c *httpClient) Do(req *http.Request) (*http.Response, error) {
tripperChain := tripper.NewChain(metrics.HTTPMetricsRoundTripper(func() string {
return ""
}, "idp_http_client"))
c.Client.Transport = tripperChain.Then(c.requestIDTripper)
return c.Client.Do(req)
}
// getDefaultClient returns an HTTP client that avoids leaks by setting an upper limit for timeouts.
func getDefaultClient() *httpClient {
return &httpClient{
&http.Client{Timeout: 1 * time.Minute},
requestid.NewRoundTripper(http.DefaultTransport),
}
}
// Do provides a simple helper interface to make HTTP requests
func Do(ctx context.Context, method, endpoint, userAgent string, headers map[string]string, params url.Values, response interface{}) error {
var body io.Reader
switch method {
case http.MethodPost:
body = bytes.NewBufferString(params.Encode())
case http.MethodGet:
// error checking skipped because we are just parsing in
// order to make a copy of an existing URL
if params != nil {
u, _ := url.Parse(endpoint)
u.RawQuery = params.Encode()
endpoint = u.String()
}
default:
return fmt.Errorf(http.StatusText(http.StatusBadRequest))
}
req, err := http.NewRequestWithContext(ctx, method, endpoint, body)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("User-Agent", userAgent)
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := getDefaultClient().Do(req)
if err != nil {
return err
}
var respBody []byte
respBody, err = io.ReadAll(resp.Body)
defer resp.Body.Close()
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
switch resp.StatusCode {
case http.StatusBadRequest:
var response struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
e := json.Unmarshal(respBody, &response)
if e == nil && response.ErrorDescription == "Token expired or revoked" {
return ErrTokenRevoked
}
return fmt.Errorf(http.StatusText(http.StatusBadRequest))
default:
return fmt.Errorf(http.StatusText(resp.StatusCode))
}
}
if response != nil {
err := json.Unmarshal(respBody, &response)
if err != nil {
return err
}
}
return nil
}