/
SocketProxy.go
125 lines (106 loc) · 3.03 KB
/
SocketProxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// +build linux
package transproxy
import (
"fmt"
"io"
"net"
"syscall"
"time"
"github.com/alexcesaro/log"
"github.com/xiqingping/transproxy/proxy"
)
type sockaddr struct {
family uint16
data [14]byte
}
const SO_ORIGINAL_DST = 80
func getOriginalDst(tcpConn *net.TCPConn) (addr net.TCPAddr, newTCPConn *net.TCPConn, err error) {
newTCPConn = tcpConn
// net.TCPConn.File() will cause the receiver's (clientConn) socket to be placed in blocking mode.
// The workaround is to take the File returned by .File(), do getsockopt() to get the original
// destination, then create a new *net.TCPConn by calling net.Conn.FileConn(). The new TCPConn
// will be in non-blocking mode. What a pain.
connFile, err := tcpConn.File()
if err != nil {
return
} else {
tcpConn.Close()
}
// Get original destination
// this is the only syscall in the Golang libs that I can find that returns 16 bytes
// Example result: &{Multiaddr:[2 0 31 144 206 190 36 45 0 0 0 0 0 0 0 0] Interface:0}
// port starts at the 3rd byte and is 2 bytes long (31 144 = port 8080)
// IPv4 address starts at the 5th byte, 4 bytes long (206 190 36 45)
mreq, err := syscall.GetsockoptIPv6Mreq(int(connFile.Fd()), syscall.IPPROTO_IP, SO_ORIGINAL_DST)
if err != nil {
return
}
newConn, err := net.FileConn(connFile)
if err != nil {
return
}
newTCPConn = newConn.(*net.TCPConn)
connFile.Close()
addr.IP = mreq.Multiaddr[4:8]
addr.Port = int(mreq.Multiaddr[2])<<8 + int(mreq.Multiaddr[3])
return
}
type SocketProxy struct {
conn *net.TCPConn
dest net.TCPAddr
logger log.Logger
bl *BlackList
proxyConn net.Conn
proxyDial proxy.Dialer
}
func NewSocketProxy(conn *net.TCPConn, bl *BlackList, proxyDial proxy.Dialer, logger log.Logger) (sp *SocketProxy, err error) {
dest, conn, err := getOriginalDst(conn)
if err != nil {
conn.Close()
return
}
sp = &SocketProxy{
conn: conn,
dest: dest,
logger: logger,
bl: bl,
proxyDial: proxyDial,
}
return
}
func (sp *SocketProxy) String() string {
return fmt.Sprintf("[%v->%v->%v]", sp.conn.RemoteAddr(), sp.conn.LocalAddr(), sp.dest.String())
}
func copyAndCloseReader(rc io.ReadCloser, w io.Writer) {
io.Copy(w, rc)
rc.Close()
}
func (sp *SocketProxy) Run() {
var err error
needProxy := false
if sp.bl.Contains(sp.dest.IP) {
sp.logger.Debug("Dest in black list")
needProxy = true
} else {
sp.proxyConn, err = net.DialTimeout("tcp4", sp.dest.String(), time.Second*10)
if err != nil {
sp.logger.Warning(sp, "Direct dial to", sp.dest, "error:", err, ", try proxy")
needProxy = true
}
}
if needProxy {
sp.proxyConn, err = sp.proxyDial.Dial("tcp", sp.dest.String())
if err != nil {
sp.logger.Warning(sp, "Proxy dial to", sp.dest, "error:", err)
sp.conn.Close()
sp.logger.Info(sp, "Finished")
return
}
sp.logger.Debug("Add", sp.dest.IP, "to black list")
sp.bl.Add(sp.dest.IP)
}
sp.logger.Info(sp, "Start proxy ...")
go copyAndCloseReader(sp.proxyConn, sp.conn)
copyAndCloseReader(sp.conn, sp.proxyConn)
sp.logger.Info(sp, "Finished")
}