forked from polyverse-security/vulcand
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fwd.go
140 lines (121 loc) · 3.29 KB
/
fwd.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// package forwarder implements http handler that forwards requests to remote server
// and serves back the response
package forward
import (
"io"
"net/http"
"net/url"
"os"
"strconv"
"time"
"github.com/mailgun/vulcand/Godeps/_workspace/src/github.com/mailgun/oxy/utils"
)
// ReqRewriter can alter request headers and body
type ReqRewriter interface {
Rewrite(r *http.Request)
}
type optSetter func(f *Forwarder) error
func RoundTripper(r http.RoundTripper) optSetter {
return func(f *Forwarder) error {
f.roundTripper = r
return nil
}
}
func Rewriter(r ReqRewriter) optSetter {
return func(f *Forwarder) error {
f.rewriter = r
return nil
}
}
// ErrorHandler is a functional argument that sets error handler of the server
func ErrorHandler(h utils.ErrorHandler) optSetter {
return func(f *Forwarder) error {
f.errHandler = h
return nil
}
}
func Logger(l utils.Logger) optSetter {
return func(f *Forwarder) error {
f.log = l
return nil
}
}
type Forwarder struct {
errHandler utils.ErrorHandler
roundTripper http.RoundTripper
rewriter ReqRewriter
log utils.Logger
}
func New(setters ...optSetter) (*Forwarder, error) {
f := &Forwarder{}
for _, s := range setters {
if err := s(f); err != nil {
return nil, err
}
}
if f.roundTripper == nil {
f.roundTripper = http.DefaultTransport
}
if f.rewriter == nil {
h, err := os.Hostname()
if err != nil {
h = "localhost"
}
f.rewriter = &HeaderRewriter{TrustForwardHeader: true, Hostname: h}
}
if f.log == nil {
f.log = utils.NullLogger
}
if f.errHandler == nil {
f.errHandler = utils.DefaultHandler
}
return f, nil
}
func (f *Forwarder) ServeHTTP(w http.ResponseWriter, req *http.Request) {
start := time.Now().UTC()
response, err := f.roundTripper.RoundTrip(f.copyRequest(req, req.URL))
if err != nil {
f.log.Errorf("Error forwarding to %v, err: %v", req.URL, err)
f.errHandler.ServeHTTP(w, req, err)
return
}
if req.TLS != nil {
f.log.Infof("Round trip: %v, code: %v, duration: %v tls:version: %x, tls:resume:%t, tls:csuite:%x, tls:server:%v",
req.URL, response.StatusCode, time.Now().UTC().Sub(start),
req.TLS.Version,
req.TLS.DidResume,
req.TLS.CipherSuite,
req.TLS.ServerName)
} else {
f.log.Infof("Round trip: %v, code: %v, duration: %v",
req.URL, response.StatusCode, time.Now().UTC().Sub(start))
}
utils.CopyHeaders(w.Header(), response.Header)
w.WriteHeader(response.StatusCode)
written, _ := io.Copy(w, response.Body)
if written != 0 {
w.Header().Set(ContentLength, strconv.FormatInt(written, 10))
}
response.Body.Close()
}
func (f *Forwarder) copyRequest(req *http.Request, u *url.URL) *http.Request {
outReq := new(http.Request)
*outReq = *req // includes shallow copies of maps, but we handle this below
outReq.URL = utils.CopyURL(req.URL)
outReq.URL.Scheme = u.Scheme
outReq.URL.Host = u.Host
outReq.URL.Opaque = req.RequestURI
// raw query is already included in RequestURI, so ignore it to avoid dupes
outReq.URL.RawQuery = ""
outReq.Proto = "HTTP/1.1"
outReq.ProtoMajor = 1
outReq.ProtoMinor = 1
// Overwrite close flag so we can keep persistent connection for the backend servers
outReq.Close = false
outReq.Header = make(http.Header)
utils.CopyHeaders(outReq.Header, req.Header)
if f.rewriter != nil {
f.rewriter.Rewrite(outReq)
}
return outReq
}