Skip to content

Commit

Permalink
fix non gc-able dialer
Browse files Browse the repository at this point in the history
closes #5165
  • Loading branch information
Mzack9999 committed May 15, 2024
1 parent 9adfc53 commit 47ca8fe
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 25 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ require (
github.com/redis/go-redis/v9 v9.1.0
github.com/seh-msft/burpxml v1.0.1
github.com/stretchr/testify v1.9.0
github.com/tarunKoyalwar/goleak v0.0.0-20240426214851-746d64600adc
github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9
github.com/zmap/zgrab2 v0.1.8-0.20230806160807-97ba87c0e706
golang.org/x/term v0.19.0
gopkg.in/yaml.v3 v3.0.1
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1016,8 +1016,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
github.com/tarunKoyalwar/goleak v0.0.0-20240426214851-746d64600adc h1:/5P5I7oDqdLee8W9Moof0xSD8tT1qEVzhObSI9CqHkg=
github.com/tarunKoyalwar/goleak v0.0.0-20240426214851-746d64600adc/go.mod h1:uQdBQGrE1fZ2EyOs0pLcCDd1bBV4rSThieuIIGhXZ50=
github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9 h1:GXIyLuIJ5Qk46lI8WJ83qHBZKUI3zhmMmuoY9HICUIQ=
github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9/go.mod h1:uQdBQGrE1fZ2EyOs0pLcCDd1bBV4rSThieuIIGhXZ50=
github.com/tidwall/assert v0.1.0 h1:aWcKyRBUAdLoVebxo95N7+YZVTFF/ASTr7BN4sLP6XI=
github.com/tidwall/assert v0.1.0/go.mod h1:QLYtGyeqse53vuELQheYl9dngGCJQ+mTtlxcktb+Kj8=
github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI=
Expand Down
6 changes: 3 additions & 3 deletions lib/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/protocols"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/hosterrorscache"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/interactsh"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolinit"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/headless/engine"
"github.com/projectdiscovery/nuclei/v3/pkg/reporting"
"github.com/projectdiscovery/nuclei/v3/pkg/templates"
Expand Down Expand Up @@ -206,8 +206,6 @@ func (e *NucleiEngine) Close() {
if e.rateLimiter != nil {
e.rateLimiter.Stop()
}
// close global shared resources
protocolstate.Close()
if e.inputProvider != nil {
e.inputProvider.Close()
}
Expand All @@ -217,6 +215,8 @@ func (e *NucleiEngine) Close() {
if e.httpxClient != nil {
_ = e.httpxClient.Close()
}
// close global shared resources
protocolinit.Close()
}

// ExecuteCallbackWithCtx executes templates on targets and calls callback on each result(only if results are found)
Expand Down
6 changes: 5 additions & 1 deletion lib/sdk_private.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
"github.com/projectdiscovery/ratelimit"
)

var sharedInit sync.Once = sync.Once{}
var sharedInit *sync.Once

// applyRequiredDefaults to options
func (e *NucleiEngine) applyRequiredDefaults() {
Expand Down Expand Up @@ -118,6 +118,10 @@ func (e *NucleiEngine) init() error {

e.parser = templates.NewParser()

if sharedInit == nil || protocolstate.ShouldInit() {
sharedInit = &sync.Once{}
}

sharedInit.Do(func() {
_ = protocolstate.Init(e.opts)
_ = protocolinit.Init(e.opts)
Expand Down
3 changes: 1 addition & 2 deletions pkg/protocols/common/protocolinit/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (

// Init initializes the client pools for the protocols
func Init(options *types.Options) error {

if err := protocolstate.Init(options); err != nil {
return err
}
Expand All @@ -39,5 +38,5 @@ func Init(options *types.Options) error {
}

func Close() {
protocolstate.Dialer.Close()
protocolstate.Close()
}
5 changes: 5 additions & 0 deletions pkg/protocols/common/protocolstate/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ var (
Dialer *fastdialer.Dialer
)

func ShouldInit() bool {
return Dialer == nil
}

// Init creates the Dialer instance based on user configuration
func Init(options *types.Options) error {
if Dialer != nil {
Expand Down Expand Up @@ -210,5 +214,6 @@ func Close() {
if Dialer != nil {
Dialer.Close()
}
Dialer = nil
StopActiveMemGuardian()
}
20 changes: 6 additions & 14 deletions pkg/protocols/http/httpclientpool/clientpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"golang.org/x/net/proxy"
"golang.org/x/net/publicsuffix"

"github.com/projectdiscovery/fastdialer/fastdialer"
"github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils"
Expand All @@ -28,9 +27,6 @@ import (
)

var (
// Dialer is a copy of the fastdialer from protocolstate
Dialer *fastdialer.Dialer

rawHttpClient *rawhttp.Client
forceMaxRedirects int
normalClient *retryablehttp.Client
Expand Down Expand Up @@ -146,8 +142,8 @@ func GetRawHTTP(options *types.Options) *rawhttp.Client {
rawHttpOptions.Proxy = types.ProxyURL
} else if types.ProxySocksURL != "" {
rawHttpOptions.Proxy = types.ProxySocksURL
} else if Dialer != nil {
rawHttpOptions.FastDialer = Dialer
} else if protocolstate.Dialer != nil {
rawHttpOptions.FastDialer = protocolstate.Dialer
}
rawHttpOptions.Timeout = GetHttpTimeout(options)
rawHttpClient = rawhttp.NewClient(rawHttpOptions)
Expand All @@ -167,10 +163,6 @@ func Get(options *types.Options, configuration *Configuration) (*retryablehttp.C
func wrappedGet(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {
var err error

if Dialer == nil {
Dialer = protocolstate.Dialer
}

hash := configuration.Hash()
if client, ok := clientPool.Get(hash); ok {
return client, nil
Expand Down Expand Up @@ -237,15 +229,15 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl

transport := &http.Transport{
ForceAttemptHTTP2: options.ForceAttemptHTTP2,
DialContext: Dialer.Dial,
DialContext: protocolstate.Dialer.Dial,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if options.TlsImpersonate {
return Dialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil)
return protocolstate.Dialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil)
}
if options.HasClientCertificates() || options.ForceAttemptHTTP2 {
return Dialer.DialTLSWithConfig(ctx, network, addr, tlsConfig)
return protocolstate.Dialer.DialTLSWithConfig(ctx, network, addr, tlsConfig)
}
return Dialer.DialTLS(ctx, network, addr)
return protocolstate.Dialer.DialTLS(ctx, network, addr)
},
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
Expand Down
4 changes: 2 additions & 2 deletions pkg/protocols/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ
if input.MetaInput.CustomIP != "" {
outputEvent["ip"] = input.MetaInput.CustomIP
} else {
outputEvent["ip"] = httpclientpool.Dialer.GetDialedIP(hostname)
outputEvent["ip"] = protocolstate.Dialer.GetDialedIP(hostname)
}

if len(generatedRequest.interactshURLs) > 0 {
Expand Down Expand Up @@ -873,7 +873,7 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ
if input.MetaInput.CustomIP != "" {
outputEvent["ip"] = input.MetaInput.CustomIP
} else {
outputEvent["ip"] = httpclientpool.Dialer.GetDialedIP(hostname)
outputEvent["ip"] = protocolstate.Dialer.GetDialedIP(hostname)
}
if request.options.Interactsh != nil {
request.options.Interactsh.MakePlaceholders(generatedRequest.interactshURLs, outputEvent)
Expand Down

0 comments on commit 47ca8fe

Please sign in to comment.