Skip to content

Commit

Permalink
updates with unit + leak tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tarunKoyalwar committed May 9, 2024
1 parent f1cdfbe commit dd1cec4
Show file tree
Hide file tree
Showing 10 changed files with 218 additions and 40 deletions.
12 changes: 10 additions & 2 deletions fastdialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ type Dialer struct {
group simpleflight.Group[string]
l4HandlerCache gcache.Cache[string, *l4ConnHandler]
sg *sizedwaitgroup.SizedWaitGroup
ctx context.Context
cancel context.CancelFunc
}

// NewDialer instance
func NewDialer(options Options) (*Dialer, error) {
func NewDialerWithCtx(ctx context.Context, options Options) (*Dialer, error) {
var resolvers []string
// Add system resolvers as the first to be tried
if options.ResolversFile {
Expand Down Expand Up @@ -190,9 +191,15 @@ func NewDialer(options Options) (*Dialer, error) {
tmp := sizedwaitgroup.New(options.MaxOpenConnections)
dx.sg = &tmp
}
dx.ctx, dx.cancel = context.WithCancel(ctx)
return dx, nil
}

// NewDialer instance
func NewDialer(options Options) (*Dialer, error) {
return NewDialerWithCtx(context.Background(), options)
}

// Dial function compatible with net/http
func (d *Dialer) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
return d.dial(ctx, &dialOptions{
Expand Down Expand Up @@ -404,6 +411,7 @@ func (d *Dialer) GetDNSData(hostname string) (*retryabledns.DNSData, error) {

// Close instance and cleanups
func (d *Dialer) Close() {
d.cancel()
if d.mDnsCache != nil {
d.mDnsCache.Purge()
}
Expand Down
11 changes: 11 additions & 0 deletions fastdialer/dialer/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ func NewConnWrap(nd net.Conn) ConnWrapper {
// DialTLS connects to the address on the named network using TLS.
// If ztlsFallback is true, it will fallback to ZTLS if the handshake fails.
func (d *connWrap) DialTLS(ctx context.Context, network, address string, config *tls.Config, ztlsFallback bool) (net.Conn, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}

if config == nil {
config = getDefaultTLSConfig()
}
Expand All @@ -57,6 +61,9 @@ func (d *connWrap) DialTLS(ctx context.Context, network, address string, config

// DialTLSAndImpersonate connects to the address on the named network using TLS and impersonates with given data
func (d *connWrap) DialTLSAndImpersonate(ctx context.Context, network, address string, config *tls.Config, strategy impersonate.Strategy, identify *impersonate.Identity) (net.Conn, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
// clone existing tls config
uTLSConfig := &utls.Config{
InsecureSkipVerify: config.InsecureSkipVerify,
Expand Down Expand Up @@ -84,6 +91,10 @@ func (d *connWrap) DialTLSAndImpersonate(ctx context.Context, network, address s

// DialZTLS connects to the address on the named network using ZTLS.
func (d *connWrap) DialZTLS(ctx context.Context, network, address string, config *ztls.Config) (net.Conn, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}

if config == nil {
config = AsZTLSConfig(getDefaultTLSConfig())
config.CipherSuites = ztls.ChromeCiphers // for reliable fallback
Expand Down
3 changes: 3 additions & 0 deletions fastdialer/dialer/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ func NewSimpleDialer(nd *net.Dialer, pd proxy.Dialer, timeout time.Duration) Sim
}

func (d *simpleDialer) Dial(ctx context.Context, network, address string) (net.Conn, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
if d.pd != nil {
ctx, cancel := context.WithTimeoutCause(ctx, d.timeout, errors.New("dialer timeout"))
defer cancel()
Expand Down
6 changes: 6 additions & 0 deletions fastdialer/dialer_private.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ type dialOptions struct {
}

func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, err error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic: %v", r)
Expand Down Expand Up @@ -113,6 +116,9 @@ func (d *Dialer) getLayer4Conn(ctx context.Context, network, hostname string, po
// no need to use handler at all if given input is ip
// or only one ip is available
for _, ip := range ips {
if ctx.Err() != nil {
return nil, "", ctx.Err()
}
d.acquire()
conn, err := d.simpleDialer.Dial(ctx, network, net.JoinHostPort(ip, port))
conn = d.releaseWithHook(conn)
Expand Down
144 changes: 111 additions & 33 deletions fastdialer/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ type l4ConnHandler struct {
cancel context.CancelFunc
// poolingChan is the channel that continiously dials to the address
// and stores the results in the cache
poolingChan chan dialResult
poolingChan chan *dialResult
// initChan contains intial dial results
initChan chan dialResult
initChan chan *dialResult
// synchronize initChan etc to avoid parallel data race
m sync.Mutex
}

// temporary struct to store dial results
Expand All @@ -53,6 +55,9 @@ type dialResult struct {

// getDialHandler returns a new dialHandler instance for the given address or returns an existing one
func getDialHandler(ctx context.Context, fd *Dialer, hostname, network, port string, ips []string) (*l4ConnHandler, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
// if handler already exists for this address, return it no need to create a new one
if h, err := fd.l4HandlerCache.GetIFPresent(network + ":" + hostname + ":" + port); !errors.Is(err, gcache.KeyNotFoundError) && h != nil {
return h, nil
Expand All @@ -76,7 +81,7 @@ func getDialHandler(ctx context.Context, fd *Dialer, hostname, network, port str
network: network,
port: port,
firstFlight: &atomic.Bool{},
poolingChan: make(chan dialResult, 3), // cache 3 connections
poolingChan: make(chan *dialResult, handlerOpts.PoolSize), // cache size
ips: ips,
}
h.firstFlight.Store(true)
Expand All @@ -85,7 +90,7 @@ func getDialHandler(ctx context.Context, fd *Dialer, hostname, network, port str
// context since it's lifetime is limited to dialing only
// so this context should be inherited from fastdialer instance context
// if it has any
ctx, cancel := context.WithCancel(context.TODO())
ctx, cancel := context.WithCancel(fd.ctx)
h.cancel = cancel
h.ctx = ctx
go h.run(ctx)
Expand All @@ -102,6 +107,9 @@ func getDialHandler(ctx context.Context, fd *Dialer, hostname, network, port str
// dialFirst performs the first dial to the address
// and stored results to be used by other dials
func (d *l4ConnHandler) dialFirst(ctx context.Context) error {
if ctx.Err() != nil {
return ctx.Err()
}
_, err, _ := d.fd.group.Do(d.hostname+":"+d.port, func() (interface{}, error) {
errX := d.dialAllParallel(ctx)
if errX != nil {
Expand All @@ -116,73 +124,136 @@ func (d *l4ConnHandler) dialFirst(ctx context.Context) error {
// dialAllParallel dials to all ip addresses in parallel and returns error if all of them failed
// if any of them succeeded, it puts them in initChan to be used by immediate calls
func (d *l4ConnHandler) dialAllParallel(ctx context.Context) error {
if ctx.Err() != nil {
return ctx.Err()
}

ch := make(chan dialResult, len(d.ips))
go func() {
var wg sync.WaitGroup
defer close(ch)
defer wg.Wait()

var wg sync.WaitGroup
alive := []string{}

defer func() {
if ctx.Err() != nil {
return
}
wg.Wait()
// only store alive ips
d.m.Lock()
d.ips = alive
d.m.Unlock()
}()

for _, ip := range d.ips {
wg.Add(1)
go func(ip string) {
defer wg.Done()
if ctx.Err() != nil {
return
}

d.fd.acquire() // no-op if max open connections is not set
conn, err := d.fd.simpleDialer.Dial(ctx, d.network, net.JoinHostPort(ip, d.port))
conn = d.fd.releaseWithHook(conn)

ch <- dialResult{conn, err, ip}
select {
case ch <- dialResult{conn, err, ip}:
case <-ctx.Done():
if conn != nil {
_ = conn.Close()
}
return
}
if err == nil {
d.m.Lock()
alive = append(alive, ip)
d.m.Unlock()
}
}(ip)
}
wg.Wait()
// only store alive ips
d.ips = alive
}()

var err error
idle := []net.Conn{}

for res := range ch {
if res.err != nil {
err = multierr.Append(err, res.err)
continue
loop:
for {
select {
case <-ctx.Done():
return ctx.Err()
case res, ok := <-ch:
if !ok {
break loop
}
if res.err != nil {
err = multierr.Append(err, res.err)
continue
}
// put conn in cache
idle = append(idle, res.conn)
}
// put conn in cache
idle = append(idle, res.conn)
}

if len(idle) == 0 {
return err
}
// put all in initChan
d.initChan = make(chan dialResult, len(idle))
d.m.Lock()
d.initChan = make(chan *dialResult, len(idle))
d.m.Unlock()
for _, conn := range idle {
d.initChan <- dialResult{conn: conn}
select {
case <-ctx.Done():
return ctx.Err()
case d.initChan <- &dialResult{conn: conn}:
}
}
return nil
}

// run continiously dials to the address and stores the results in the cache
// it runs in background and is used by getConn to get connections
func (d *l4ConnHandler) run(ctx context.Context) {
defer close(d.poolingChan)
defer close(d.initChan)
defer func() {
d.m.Lock()
close(d.poolingChan)
if d.initChan != nil {
close(d.initChan)
}
d.m.Unlock()
}()

var lastResult *dialResult
index := 0

d.m.Lock()
ips := d.ips
d.m.Unlock()

for {
select {
case <-ctx.Done():
if lastResult != nil && lastResult.conn != nil {
_ = lastResult.conn.Close()
}
return
case d.poolingChan <- lastResult:

default:
// dial new conn and put it in buffered chan
// reset index if it is out of bounds
if index >= len(ips) {
index = 0
}
ip := ips[index]

// dial new conn and put it in buffered chan
d.fd.acquire() // no-op if max open connections is not set
conn, err := d.fd.simpleDialer.Dial(ctx, d.network, net.JoinHostPort(d.hostname, d.port))
conn, err := d.fd.simpleDialer.Dial(ctx, d.network, net.JoinHostPort(ip, d.port))
conn = d.fd.releaseWithHook(conn)

d.poolingChan <- dialResult{conn, err, d.hostname}
// this is to avoid blocking when context is cancelled
lastResult = &dialResult{conn, err, ip}
index++
}
}
}
Expand All @@ -196,14 +267,21 @@ func (d *l4ConnHandler) getConn(ctx context.Context) (net.Conn, string, error) {
return nil, "", err
}
}

select {
case <-ctx.Done():
return nil, "", ctx.Err()
case res := <-d.initChan:
return res.conn, res.ip, res.err
case res := <-d.poolingChan:
return res.conn, res.ip, res.err
for {
select {
case <-ctx.Done():
return nil, "", ctx.Err()
case res := <-d.initChan:
if res == nil {
continue
}
return res.conn, res.ip, res.err
case res := <-d.poolingChan:
if res == nil {
continue
}
return res.conn, res.ip, res.err
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion fastdialer/resolverfile_unix.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// +build !windows
//go:build !windows

package fastdialer

Expand Down
2 changes: 1 addition & 1 deletion fastdialer/resolverfile_windows.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// +build windows
//go:build windows

package fastdialer

Expand Down
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ require (
golang.org/x/net v0.23.0
)

require github.com/logrusorgru/aurora/v4 v4.0.0 // indirect

require (
github.com/akrylysov/pogreb v0.10.1 // indirect
github.com/andybalholm/brotli v1.0.6 // indirect
Expand All @@ -39,6 +41,7 @@ require (
github.com/remeh/sizedwaitgroup v1.0.0
github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d // indirect
github.com/syndtr/goleveldb v1.0.0 // indirect
github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9
github.com/tidwall/btree v1.4.3 // indirect
github.com/tidwall/buntdb v1.3.0 // indirect
github.com/tidwall/gjson v1.14.3 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/logrusorgru/aurora/v4 v4.0.0 h1:sRjfPpun/63iADiSvGGjgA1cAYegEWMPCJdUpJYn9JA=
github.com/logrusorgru/aurora/v4 v4.0.0/go.mod h1:lP0iIa2nrnT/qoFXcOZSrZQpJ1o6n2CUf/hyHi2Q4ZQ=
github.com/microcosm-cc/bluemonday v1.0.25 h1:4NEwSfiJ+Wva0VxN5B8OwMicaJvD8r9tlJWm9rtloEg=
github.com/microcosm-cc/bluemonday v1.0.25/go.mod h1:ZIOjCQp1OrzBBPIJmfX4qDYFuhU02nx4bn030ixfHLE=
github.com/miekg/dns v1.1.56 h1:5imZaSeoRNvpM9SzWNhEcP9QliKiz20/dA2QabIGVnE=
Expand Down Expand Up @@ -116,6 +118,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9 h1:GXIyLuIJ5Qk46lI8WJ83qHBZKUI3zhmMmuoY9HICUIQ=
github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9/go.mod h1:uQdBQGrE1fZ2EyOs0pLcCDd1bBV4rSThieuIIGhXZ50=
github.com/tidwall/assert v0.1.0 h1:aWcKyRBUAdLoVebxo95N7+YZVTFF/ASTr7BN4sLP6XI=
github.com/tidwall/assert v0.1.0/go.mod h1:QLYtGyeqse53vuELQheYl9dngGCJQ+mTtlxcktb+Kj8=
github.com/tidwall/btree v1.4.3 h1:Lf5U/66bk0ftNppOBjVoy/AIPBrLMkheBp4NnSNiYOo=
Expand Down
Loading

0 comments on commit dd1cec4

Please sign in to comment.