-
Notifications
You must be signed in to change notification settings - Fork 163
/
hostpool.go
66 lines (52 loc) · 1.39 KB
/
hostpool.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// Copyright (C) 2017 ScyllaDB
package scyllaclient
import (
"net"
"net/http"
"github.com/hailocab/go-hostpool"
"github.com/pkg/errors"
"github.com/scylladb/scylla-operator/pkg/util/httpx"
)
var errPoolServerError = errors.New("server error")
// hostPool sets request host from a pool.
func hostPool(next http.RoundTripper, pool hostpool.HostPool, port string) http.RoundTripper {
return httpx.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
ctx := req.Context()
var (
h string
hpr hostpool.HostPoolResponse
)
// Get host from context
h, ok := ctx.Value(ctxHost).(string)
// Get host from pool
if !ok {
hpr = pool.Get()
h = hpr.Host()
}
// Clone request
r := httpx.CloneRequest(req)
// Set host and port
hp := net.JoinHostPort(h, port)
r.Host = hp
r.URL.Host = hp
// RoundTrip shall not modify requests, here we modify it to fix error
// messages see https://github.com/scylladb/mermaid/pkg/issues/266.
// This is legit because we own the whole process. The modified request
// is not being sent.
req.Host = h
req.URL.Host = h
resp, err := next.RoundTrip(r)
// Mark response
if hpr != nil {
switch {
case err != nil:
hpr.Mark(err)
case resp.StatusCode == 401 || resp.StatusCode == 403 || resp.StatusCode >= 500:
hpr.Mark(errPoolServerError)
default:
hpr.Mark(nil)
}
}
return resp, err
})
}