forked from cloudflare/cloudflared
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrequest_handler.go
152 lines (131 loc) · 4.22 KB
/
request_handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package socks
import (
"fmt"
"io"
"net"
"strings"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/ipaccess"
)
// RequestHandler is the functions needed to handle a SOCKS5 command
type RequestHandler interface {
Handle(*Request, io.ReadWriter) error
}
// StandardRequestHandler implements the base socks5 command processing
type StandardRequestHandler struct {
dialer Dialer
accessPolicy *ipaccess.Policy
}
// NewRequestHandler creates a standard SOCKS5 request handler
// This handles the SOCKS5 commands and proxies them to their destination
func NewRequestHandler(dialer Dialer, accessPolicy *ipaccess.Policy) RequestHandler {
return &StandardRequestHandler{
dialer: dialer,
accessPolicy: accessPolicy,
}
}
// Handle processes and responds to socks5 commands
func (h *StandardRequestHandler) Handle(req *Request, conn io.ReadWriter) error {
switch req.Command {
case connectCommand:
return h.handleConnect(conn, req)
case bindCommand:
return h.handleBind(conn, req)
case associateCommand:
return h.handleAssociate(conn, req)
default:
if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Unsupported command: %v", req.Command)
}
}
// handleConnect is used to handle a connect command
func (h *StandardRequestHandler) handleConnect(conn io.ReadWriter, req *Request) error {
if h.accessPolicy != nil {
if req.DestAddr.IP == nil {
addr, err := net.ResolveIPAddr("ip", req.DestAddr.FQDN)
if err != nil {
_ = sendReply(conn, ruleFailure, req.DestAddr)
return fmt.Errorf("unable to resolve host to confirm access")
}
req.DestAddr.IP = addr.IP
}
if allowed, rule := h.accessPolicy.Allowed(req.DestAddr.IP, req.DestAddr.Port); !allowed {
_ = sendReply(conn, ruleFailure, req.DestAddr)
if rule != nil {
return fmt.Errorf("Connect to %v denied due to iprule: %s", req.DestAddr, rule.String())
}
return fmt.Errorf("Connect to %v denied", req.DestAddr)
}
}
target, localAddr, err := h.dialer.Dial(req.DestAddr.Address())
if err != nil {
msg := err.Error()
resp := hostUnreachable
if strings.Contains(msg, "refused") {
resp = connectionRefused
} else if strings.Contains(msg, "network is unreachable") {
resp = networkUnreachable
}
if err := sendReply(conn, resp, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err)
}
defer target.Close()
// Send success
if err := sendReply(conn, successReply, localAddr); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
// Start proxying
proxyDone := make(chan error, 2)
go func() {
_, e := io.Copy(target, req.bufConn)
proxyDone <- e
}()
go func() {
_, e := io.Copy(conn, target)
proxyDone <- e
}()
// Wait for both
for i := 0; i < 2; i++ {
e := <-proxyDone
if e != nil {
return e
}
}
return nil
}
// handleBind is used to handle a bind command
// TODO: Support bind command
func (h *StandardRequestHandler) handleBind(conn io.ReadWriter, req *Request) error {
if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return nil
}
// handleAssociate is used to handle a connect command
// TODO: Support associate command
func (h *StandardRequestHandler) handleAssociate(conn io.ReadWriter, req *Request) error {
if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return nil
}
func StreamHandler(tunnelConn io.ReadWriter, originConn net.Conn, log *zerolog.Logger) {
dialer := NewConnDialer(originConn)
requestHandler := NewRequestHandler(dialer, nil)
socksServer := NewConnectionHandler(requestHandler)
if err := socksServer.Serve(tunnelConn); err != nil {
log.Debug().Err(err).Msg("Socks stream handler error")
}
}
func StreamNetHandler(tunnelConn io.ReadWriter, accessPolicy *ipaccess.Policy, log *zerolog.Logger) {
dialer := NewNetDialer()
requestHandler := NewRequestHandler(dialer, accessPolicy)
socksServer := NewConnectionHandler(requestHandler)
if err := socksServer.Serve(tunnelConn); err != nil {
log.Debug().Err(err).Msg("Socks stream handler error")
}
}