-
Notifications
You must be signed in to change notification settings - Fork 0
/
response_reverse_proxy.go
108 lines (90 loc) · 2.79 KB
/
response_reverse_proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
package http
import (
"context"
"errors"
"fmt"
"github.com/topvisor/go-prifma/pkg/prifma"
"github.com/topvisor/go-prifma/pkg/utils"
"net"
"net/http"
"net/http/httptrace"
"net/http/httputil"
)
type ResponseReverseProxy struct {
RoundTrippers RoundTrippersMap
ResponseCode int
Error error
LAddr net.Addr
RAddr net.Addr
}
func NewResponseReverseProxy(roundTrippers RoundTrippersMap) *ResponseReverseProxy {
return &ResponseReverseProxy{
ResponseCode: http.StatusInternalServerError,
Error: errors.New(http.StatusText(http.StatusInternalServerError)),
RoundTrippers: roundTrippers,
}
}
func (t *ResponseReverseProxy) Write(rw http.ResponseWriter, result prifma.HandleRequestResult) error {
reverseProxy := &httputil.ReverseProxy{
Director: utils.RemoveProxyHeaders,
Transport: t.RoundTrippers.Get(result),
FlushInterval: -1,
ModifyResponse: t.SaveResponse,
ErrorHandler: t.ErrorHandler,
}
req := result.GetRequest().WithContext(
httptrace.WithClientTrace(
result.GetRequest().Context(),
&httptrace.ClientTrace{
GotConn: func(info httptrace.GotConnInfo) {
t.LAddr = info.Conn.LocalAddr()
t.RAddr = info.Conn.RemoteAddr()
},
},
),
)
req.RemoteAddr = ""
reverseProxy.ServeHTTP(rw, req)
return t.Error
}
func (t *ResponseReverseProxy) GetCode() int {
return t.ResponseCode
}
func (t *ResponseReverseProxy) GetLAddr() net.Addr {
return t.LAddr
}
func (t *ResponseReverseProxy) GetRAddr() net.Addr {
return t.RAddr
}
func (t *ResponseReverseProxy) SaveResponse(resp *http.Response) error {
t.ResponseCode = resp.StatusCode
t.Error = nil
return nil
}
func (t *ResponseReverseProxy) ErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
t.Error = nil
switch err {
case context.DeadlineExceeded:
rw.Header().Add("X-Prifma-Error", http.StatusText(http.StatusGatewayTimeout))
http.Error(rw, http.StatusText(http.StatusGatewayTimeout), http.StatusGatewayTimeout)
t.ResponseCode = http.StatusGatewayTimeout
t.Error = fmt.Errorf("%d, %s", http.StatusGatewayTimeout, http.StatusText(http.StatusGatewayTimeout))
case context.Canceled:
rw.Header().Add("X-Prifma-Error", prifma.StatusTextClientClosedRequest)
http.Error(rw, prifma.StatusTextClientClosedRequest, prifma.StatusClientClosedRequest)
t.ResponseCode = prifma.StatusClientClosedRequest
default:
switch err := err.(type) {
case *net.OpError:
rw.Header().Add("X-Prifma-Error", err.Error())
if err.Op == "dial" {
http.Error(rw, err.Error(), http.StatusBadGateway)
t.ResponseCode = http.StatusBadGateway
} else {
http.Error(rw, err.Error(), http.StatusInternalServerError)
t.ResponseCode = http.StatusInternalServerError
t.Error = fmt.Errorf("%d, %s", http.StatusInternalServerError, err.Error())
}
}
}
}