/
local_dialer.go
128 lines (113 loc) · 2.92 KB
/
local_dialer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package netx
import (
"context"
"crypto/tls"
"net"
"net/http"
"runtime"
"sync"
"sync/atomic"
"time"
"get.pme.sh/pmesh/config"
"get.pme.sh/pmesh/subnet"
"get.pme.sh/pmesh/xlog"
"golang.org/x/net/http2"
)
type TracedConn struct {
net.Conn
}
var localDialerSocketCount atomic.Int64
var DebugSocketCount = false
func GetLocalDialerSocketCount() int64 {
return localDialerSocketCount.Load()
}
func NewTracedConn(conn net.Conn) net.Conn {
if !DebugSocketCount {
return conn
}
x := localDialerSocketCount.Add(1)
if x&255 == 0 && x != 0 && DebugSocketCount {
xlog.Info().Int64("n", x).Stringer("local", conn.LocalAddr()).Stringer("remote", conn.RemoteAddr()).Msg("Socket created")
}
return &TracedConn{Conn: conn}
}
func (c *TracedConn) Close() error {
x := localDialerSocketCount.Add(-1)
if x&255 == 0 && x != 0 && DebugSocketCount {
xlog.Info().Int64("n", x).Stringer("local", c.Conn.LocalAddr()).Stringer("remote", c.Conn.RemoteAddr()).Msg("Socket closed")
}
return c.Conn.Close()
}
type LocalDialer struct {
net.Dialer
}
var localDialerAllocator = sync.OnceValue(func() *subnet.Allocator {
return subnet.NewAllocator(*config.DialerSubnet, true)
})
func MakeLocalDialer(d *net.Dialer) *net.Dialer {
if d == nil {
d = &net.Dialer{}
}
if d.Timeout == 0 {
d.Timeout = 15 * time.Second
}
if d.KeepAlive == 0 {
d.KeepAlive = -1
}
if runtime.GOOS != "darwin" {
ipv4 := localDialerAllocator().Generate()
d.LocalAddr = &net.TCPAddr{IP: ipv4, Port: 0}
}
return d
}
func (d LocalDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := d.Dialer
conn, err := MakeLocalDialer(&dialer).DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
return NewTracedConn(conn), nil
}
func MakeLocalTransport(idle int, max int, opts *http.Transport) *http.Transport {
var ldial LocalDialer
opts.DisableCompression = true
opts.MaxIdleConns = idle
opts.MaxIdleConnsPerHost = idle
opts.MaxConnsPerHost = max
opts.IdleConnTimeout = 10 * time.Second
opts.DialContext = ldial.DialContext
return opts
}
var LocalTransport = MakeLocalTransport(16384, 0, &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
ResponseHeaderTimeout: 1 * time.Minute,
})
var LocalH2Transport = &http2.Transport{
DisableCompression: true,
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
cfg.InsecureSkipVerify = true
tdial := tls.Dialer{
Config: cfg,
NetDialer: MakeLocalDialer(nil),
}
conn, err := tdial.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
return NewTracedConn(conn), nil
},
}
func ResetConn(conn net.Conn) {
if tcp, ok := conn.(*net.TCPConn); ok {
tcp.Close()
} else {
conn.Close()
}
}
func ResetRequestConn(w http.ResponseWriter) {
rc := http.NewResponseController(w)
if conn, _, err := rc.Hijack(); err == nil {
ResetConn(conn)
}
panic(http.ErrAbortHandler)
}