-
Notifications
You must be signed in to change notification settings - Fork 43
/
mitm_transparent_ssl.go
415 lines (362 loc) · 12.4 KB
/
mitm_transparent_ssl.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
package crep
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"net"
"sync"
"time"
"github.com/yaklang/yaklang/common/log"
"github.com/yaklang/yaklang/common/netx"
"github.com/yaklang/yaklang/common/utils"
)
var (
setHijackRequestLock = new(sync.Mutex)
fallbackHttpFrame = []byte(`HTTP/1.1 200 OK
Content-Length: 11
origin_fail
`)
)
func (m *MITMServer) ServeTransparentTLS(ctx context.Context, addr string) error {
if m.mitmConfig == nil {
return utils.Errorf("mitm config empty")
}
l, err := net.Listen("tcp", addr)
if err != nil {
return utils.Errorf("listen tcp://%v failed; %s", addr, err)
}
addrVerbose := fmt.Sprintf("tcp://%v", addr)
go func() {
log.Infof("start to server transparent mitm server: tcp://%v", addr)
for {
conn, err := l.Accept()
if err != nil {
log.Errorf("%v accept new conn failed: %s", addrVerbose, err)
return
}
log.Infof("recv tcp conn from %v", conn.RemoteAddr().String())
_ = conn
go func() {
defer conn.Close()
log.Infof("start check tls/http connection... for %s", conn.RemoteAddr().String())
err := m.handleHTTPS(ctx, conn, addr)
if err != nil {
log.Errorf("handle conn from [%s] failed: %s", conn.RemoteAddr().String(), err)
return
}
}()
}
}()
select {
case <-ctx.Done():
return l.Close()
}
}
func (m *MITMServer) handleHTTPS(ctx context.Context, conn net.Conn, origin string) error {
pc := utils.NewPeekableNetConn(conn)
raw, err := pc.Peek(1)
if err != nil {
return utils.Errorf("peek [%s] failed: %s", conn.RemoteAddr(), err)
}
log.Infof("peek one char[%#v] for %v", raw, conn.RemoteAddr().String())
isHttps := utils.NewAtomicBool()
var httpConn net.Conn
var sni string
switch raw[0] {
case 0x16: // https
log.Infof("serving https for: %s", conn.RemoteAddr().String())
tconn := tls.Server(pc, m.mitmConfig.TLS())
err := tconn.Handshake()
if err != nil {
return utils.Errorf("tls handshake failed: %s", err)
}
log.Infof("conn: %s handshake finished", conn.RemoteAddr().String())
httpConn = tconn
sni = tconn.ConnectionState().ServerName
isHttps.Set()
default: // http
log.Infof("start to serve http for %s", conn.RemoteAddr().String())
httpConn = pc
isHttps.UnSet()
}
// log.Infof("parse req http finished: %v", spew.Sdump(req))
if httpConn == nil {
return nil
}
if sni == "" {
return utils.Errorf("SNI empty...")
}
log.Infof("start to handle http request for %s", conn.RemoteAddr().String())
var readerBuffer bytes.Buffer
reqReader := io.TeeReader(httpConn, &readerBuffer)
firstRequest, err := utils.ReadHTTPRequestFromBufioReader(bufio.NewReader(reqReader))
if err != nil {
return utils.Errorf("read request failed: %s for %s", err, conn.RemoteAddr().String())
}
log.Infof("read request finished for %v", httpConn.RemoteAddr())
var fakeUrl string
if isHttps.IsSet() {
fakeUrl = fmt.Sprintf("https://%v", firstRequest.Host)
} else {
fakeUrl = fmt.Sprintf("http://%v", firstRequest.Host)
}
// 设置超时和 context 控制
var timeout time.Duration = 30 * time.Second
var ctxDDL time.Time
if ddl, ok := ctx.Deadline(); ok {
ctxDDL = ddl
timeout = ddl.Sub(time.Now())
if timeout <= 0 {
timeout = 30 * time.Second
}
}
host, port, err := utils.ParseStringToHostPort(fakeUrl)
if err != nil {
return utils.Errorf("cannot identify target[%s]: %s", fakeUrl, err)
}
originHost := host
if (!utils.IsIPv4(host)) && (!utils.IsIPv6(utils.FixForParseIP(host))) {
log.Infof("start to handle dns items for %v", host)
cachedTarget, ok := m.dnsCache.Load(host)
if !ok {
target := netx.LookupFirst(host, netx.WithTimeout(timeout), netx.WithDNSServers(m.DNSServers...))
if target == "" {
// httpConn.Write(fallbackHttpFrame)
return utils.Errorf("cannot query dns host[%s]", host)
}
log.Infof("dns query finished for %v: results: [%#v]", host, target)
host = target
m.dnsCache.Store(host, target)
} else {
_h := cachedTarget.(string)
log.Infof("dns cache matched: %v -> %v", host, _h)
host = _h
}
}
target := utils.HostPort(host, port)
// 如果是环回,就返回一个自定义内容
if utils.HostPort(host, port) == origin {
log.Infof("lookback: %s", origin)
httpConn.Write(fallbackHttpFrame)
return nil
}
log.Infof("start to connect remote addr: %v", target)
dialer := &net.Dialer{
Timeout: timeout,
Deadline: ctxDDL,
}
var remoteConn net.Conn
if !isHttps.IsSet() {
log.Infof("tcp connect to %s", target)
remoteConn, err = dialer.Dial("tcp", target)
if err != nil {
return utils.Errorf("remote tcp://%v failed to dial: %s", utils.HostPort(host, port), err)
}
} else {
log.Infof("tcp+tls connect to %s", target)
remoteConn, err = tls.DialWithDialer(dialer, "tcp", target, &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionSSL30, // nolint[:staticcheck]
MaxVersion: tls.VersionTLS13,
ServerName: originHost,
})
if err != nil {
return utils.Errorf("remote tcp+tls://%v failed: %s", target, err)
}
}
defer remoteConn.Close()
// 以下是转发模式的, 不做劫持
if m.transparentHijackMode == nil || !m.transparentHijackMode.IsSet() {
// 在透明模式里面,所有的回调都不生效
_, err = remoteConn.Write(readerBuffer.Bytes())
if err != nil {
return utils.Errorf("write first http.Request raw []byte failed: %s", err)
}
if firstRequest.Body != nil {
n, _ := io.Copy(remoteConn, firstRequest.Body)
log.Errorf("request have body len: %v", n)
}
wg := new(sync.WaitGroup)
wg.Add(2)
log.Infof("start to do transparent traffic")
go func() {
defer wg.Done()
defer remoteConn.Close()
io.Copy(remoteConn, httpConn)
}()
go func() {
defer wg.Done()
defer remoteConn.Close()
io.Copy(httpConn, remoteConn)
}()
defer log.Infof("finished conn from %s to %s", conn.LocalAddr().String(), conn.RemoteAddr().String())
wg.Wait()
} else {
// 接下来是如何进行网络交互?
// 透明模式,劫持开启之后回调才会生效
// 劫持第一个 request
reqBytes := readerBuffer.Bytes()
if m.transparentHijackRequest == nil && m.transparentHijackRequestManager == nil {
// 不劫持请求的时候,直接写,不要等待全部读完
log.Infof("write first request for %v", remoteConn.RemoteAddr().String())
_, err = remoteConn.Write(reqBytes)
if err != nil {
return utils.Errorf("write first http.Request raw []byte failed: %s", err)
}
if firstRequest.Body != nil {
n, _ := io.Copy(remoteConn, firstRequest.Body)
log.Errorf("request have body len: %v", n)
}
log.Infof("write first request finished for %v", remoteConn.RemoteAddr().String())
} else {
// 劫持场景下的处理第一个数据包
if firstRequest.Body != nil {
_, _ = ioutil.ReadAll(firstRequest.Body)
}
reqBytes = readerBuffer.Bytes()
if m.transparentHijackRequest != nil {
reqBytes = m.transparentHijackRequest(isHttps.IsSet(), reqBytes)
}
if m.transparentHijackRequestManager != nil {
reqBytes = m.transparentHijackRequestManager.Hijacked(isHttps.IsSet(), reqBytes)
}
remoteConn.Write(reqBytes)
}
var rspRaw bytes.Buffer
// 解析 response
responseReader := io.TeeReader(remoteConn, &rspRaw)
// 不劫持响应的话,读多少写多少保证速度
if m.transparentHijackResponse == nil {
responseReader = io.TeeReader(responseReader, httpConn)
}
// 构建响应,这个响应很关键
rsp, err := utils.ReadHTTPResponseFromBufioReader(bufio.NewReader(responseReader), firstRequest)
if err != nil {
return utils.Errorf("read response for req[%v]->%v failed: %s", firstRequest.URL.String(), remoteConn.RemoteAddr(), err)
}
// 解析 Body,这个 body 是从 remote -> local 的
// 不劫持详情的情况下,正常读完就行了,不用在乎太多
// 劫持的时候,读取并不会直接写入,需要手动 httpConn.Write
if rsp.Body != nil {
rspBody, _ := ioutil.ReadAll(rsp.Body)
if len(rspBody) > 0 {
log.Infof("rsp body found length: %v", len(rspBody))
}
}
log.Info("first req and rsp recv finished!")
rspBytes := rspRaw.Bytes()
// 劫持响应的话,要手动写 httpConn, 但是必须读完才能劫持,所以这里可能会影响速度
if m.transparentHijackResponse != nil {
rspBytes = m.transparentHijackResponse(isHttps.IsSet(), rspBytes)
_, err = httpConn.Write(rspBytes)
if err != nil {
return utils.Errorf("feedback response bytes from [%s] to [%s] failed: %s",
remoteConn.RemoteAddr().String(), httpConn.RemoteAddr().String(), err,
)
}
}
if m.transparentOriginMirror != nil {
go m.transparentOriginMirror(isHttps.IsSet(), readerBuffer.Bytes(), rspRaw.Bytes())
}
if m.transparentHijackedMirror != nil {
go m.transparentHijackedMirror(isHttps.IsSet(), reqBytes, rspBytes)
}
if rsp.Close {
return nil
}
for {
// 读取 request
var reqRaw bytes.Buffer
// 这里是移除一些没有用的不符合 HTTP 协议前缀请求的字符
buf := make([]byte, 1)
for {
_, err := httpConn.Read(buf)
if err != nil {
return utils.Errorf("httpConn read failed: %s", err)
}
if len(buf) > 0 {
firstByte := buf[0]
if ('A' <= firstByte && 'Z' >= firstByte) || (firstByte >= 91 && firstByte <= 122) {
break
} else {
continue
}
}
}
reqReader := io.TeeReader( // 从本地读 http.Request 出来
io.MultiReader(bytes.NewReader(buf), httpConn),
&reqRaw,
)
// 如果不劫持,读多少转发多少
if m.transparentHijackRequest == nil && m.transparentHijackRequestManager == nil {
reqReader = io.TeeReader(reqReader, remoteConn)
}
req, err := utils.ReadHTTPRequestFromBufioReader(bufio.NewReader(reqReader))
if err != nil {
return utils.Errorf("read http request from: %s failed: %s", httpConn.RemoteAddr().String(), err)
}
// 这个目的是为了把 body 的缓冲区读完,如果劫持了请求,会同步写入到 remoteConn 中
// 如果这里没有劫持请求,则不会发生什么奇怪的事情,仅仅读出来,在 reqRaw 中收结果吧
if req.Body != nil {
_, _ = ioutil.ReadAll(req.Body)
}
// 劫持请求,这个 reqRaw 一定是包含 body 的了(如果可能)
reqBytes := reqRaw.Bytes()
switch true {
case m.transparentHijackRequest != nil:
reqBytes = m.transparentHijackRequest(isHttps.IsSet(), reqBytes)
_, err = remoteConn.Write(reqBytes)
if err != nil {
return utils.Errorf("write http request from [%v] to [%s] failed: %s", httpConn.RemoteAddr().String(), remoteConn.RemoteAddr().String(), err)
}
case m.transparentHijackRequestManager != nil:
reqBytes = m.transparentHijackRequestManager.Hijacked(isHttps.IsSet(), reqBytes)
_, err := remoteConn.Write(reqBytes)
if err != nil {
return utils.Errorf("write http request from [%v] to [%s] failed: %s", httpConn.RemoteAddr().String(), remoteConn.RemoteAddr().String(), err)
}
}
// 读取 response
var rspRaw bytes.Buffer
remoteResponseReader := io.TeeReader(remoteConn, &rspRaw)
if m.transparentHijackResponse == nil {
remoteResponseReader = io.TeeReader(remoteResponseReader, httpConn)
}
rsp, err := utils.ReadHTTPResponseFromBufioReader(bufio.NewReader(remoteResponseReader), req)
if err != nil {
return utils.Errorf("read http response from: %s failed: %s", remoteConn.RemoteAddr().String(), err)
}
// 类似上面的代码,这个是为了读缓冲区出来
if rsp.Body != nil {
_, _ = ioutil.ReadAll(rsp.Body)
}
// 镜像流量,这个流量是没有劫持过得!
if m.transparentOriginMirror != nil {
go m.transparentOriginMirror(isHttps.IsSet(), reqRaw.Bytes(), rspRaw.Bytes())
}
// 劫持返回结果
// 这里的劫持,并没有自动写入 httpConn,所以需要手动写入,这里是同步操作,性能瓶颈在这里
rspBytes := rspRaw.Bytes()
if m.transparentHijackResponse != nil {
rspBytes = m.transparentHijackResponse(isHttps.IsSet(), rspBytes)
_, err = httpConn.Write(rspBytes)
if err != nil {
return utils.Errorf("write http response from [%s] to [%s] failed: %s", remoteConn.RemoteAddr().String(), httpConn.RemoteAddr().String(), err)
}
}
// 劫持后的镜像流量
if m.transparentHijackedMirror != nil {
go m.transparentHijackedMirror(isHttps.IsSet(), reqBytes, rspBytes)
}
// 当前 req/rsp 处理完毕,并且 response 要求关闭,关闭前一定要信息传输回去
if rsp.Close {
return nil
}
}
}
return nil
}