diff --git a/go.mod b/go.mod index 39156d2f6d..98973239e5 100644 --- a/go.mod +++ b/go.mod @@ -99,7 +99,7 @@ require ( github.com/redis/go-redis/v9 v9.1.0 github.com/seh-msft/burpxml v1.0.1 github.com/stretchr/testify v1.9.0 - github.com/tarunKoyalwar/goleak v0.0.0-20240426214851-746d64600adc + github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9 github.com/zmap/zgrab2 v0.1.8-0.20230806160807-97ba87c0e706 golang.org/x/term v0.19.0 gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index 4a95dbd040..281ee725da 100644 --- a/go.sum +++ b/go.sum @@ -1016,8 +1016,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= -github.com/tarunKoyalwar/goleak v0.0.0-20240426214851-746d64600adc h1:/5P5I7oDqdLee8W9Moof0xSD8tT1qEVzhObSI9CqHkg= -github.com/tarunKoyalwar/goleak v0.0.0-20240426214851-746d64600adc/go.mod h1:uQdBQGrE1fZ2EyOs0pLcCDd1bBV4rSThieuIIGhXZ50= +github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9 h1:GXIyLuIJ5Qk46lI8WJ83qHBZKUI3zhmMmuoY9HICUIQ= +github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9/go.mod h1:uQdBQGrE1fZ2EyOs0pLcCDd1bBV4rSThieuIIGhXZ50= github.com/tidwall/assert v0.1.0 h1:aWcKyRBUAdLoVebxo95N7+YZVTFF/ASTr7BN4sLP6XI= github.com/tidwall/assert v0.1.0/go.mod h1:QLYtGyeqse53vuELQheYl9dngGCJQ+mTtlxcktb+Kj8= github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= diff --git a/lib/sdk.go b/lib/sdk.go index 63925f47ca..7a28cbc53e 100644 --- a/lib/sdk.go +++ b/lib/sdk.go @@ -18,7 +18,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/protocols" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/hosterrorscache" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/interactsh" - "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolinit" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/headless/engine" "github.com/projectdiscovery/nuclei/v3/pkg/reporting" "github.com/projectdiscovery/nuclei/v3/pkg/templates" @@ -206,8 +206,6 @@ func (e *NucleiEngine) Close() { if e.rateLimiter != nil { e.rateLimiter.Stop() } - // close global shared resources - protocolstate.Close() if e.inputProvider != nil { e.inputProvider.Close() } @@ -217,6 +215,8 @@ func (e *NucleiEngine) Close() { if e.httpxClient != nil { _ = e.httpxClient.Close() } + // close global shared resources + protocolinit.Close() } // ExecuteCallbackWithCtx executes templates on targets and calls callback on each result(only if results are found) diff --git a/lib/sdk_private.go b/lib/sdk_private.go index 13e8746510..6f8c929e38 100644 --- a/lib/sdk_private.go +++ b/lib/sdk_private.go @@ -35,7 +35,7 @@ import ( "github.com/projectdiscovery/ratelimit" ) -var sharedInit sync.Once = sync.Once{} +var sharedInit *sync.Once // applyRequiredDefaults to options func (e *NucleiEngine) applyRequiredDefaults() { @@ -118,6 +118,10 @@ func (e *NucleiEngine) init() error { e.parser = templates.NewParser() + if sharedInit == nil || protocolstate.ShouldInit() { + sharedInit = &sync.Once{} + } + sharedInit.Do(func() { _ = protocolstate.Init(e.opts) _ = protocolinit.Init(e.opts) diff --git a/pkg/protocols/common/protocolinit/init.go b/pkg/protocols/common/protocolinit/init.go index 2ab5d7ca9c..c8268337f5 100644 --- a/pkg/protocols/common/protocolinit/init.go +++ b/pkg/protocols/common/protocolinit/init.go @@ -13,7 +13,6 @@ import ( // Init initializes the client pools for the protocols func Init(options *types.Options) error { - if err := protocolstate.Init(options); err != nil { return err } @@ -39,5 +38,5 @@ func Init(options *types.Options) error { } func Close() { - protocolstate.Dialer.Close() + protocolstate.Close() } diff --git a/pkg/protocols/common/protocolstate/state.go b/pkg/protocols/common/protocolstate/state.go index 67820ec074..02e30e06fd 100644 --- a/pkg/protocols/common/protocolstate/state.go +++ b/pkg/protocols/common/protocolstate/state.go @@ -22,6 +22,10 @@ var ( Dialer *fastdialer.Dialer ) +func ShouldInit() bool { + return Dialer == nil +} + // Init creates the Dialer instance based on user configuration func Init(options *types.Options) error { if Dialer != nil { @@ -210,5 +214,6 @@ func Close() { if Dialer != nil { Dialer.Close() } + Dialer = nil StopActiveMemGuardian() } diff --git a/pkg/protocols/http/httpclientpool/clientpool.go b/pkg/protocols/http/httpclientpool/clientpool.go index 3e2baf55a5..96b0939511 100644 --- a/pkg/protocols/http/httpclientpool/clientpool.go +++ b/pkg/protocols/http/httpclientpool/clientpool.go @@ -16,7 +16,6 @@ import ( "golang.org/x/net/proxy" "golang.org/x/net/publicsuffix" - "github.com/projectdiscovery/fastdialer/fastdialer" "github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils" @@ -28,9 +27,6 @@ import ( ) var ( - // Dialer is a copy of the fastdialer from protocolstate - Dialer *fastdialer.Dialer - rawHttpClient *rawhttp.Client forceMaxRedirects int normalClient *retryablehttp.Client @@ -146,8 +142,8 @@ func GetRawHTTP(options *types.Options) *rawhttp.Client { rawHttpOptions.Proxy = types.ProxyURL } else if types.ProxySocksURL != "" { rawHttpOptions.Proxy = types.ProxySocksURL - } else if Dialer != nil { - rawHttpOptions.FastDialer = Dialer + } else if protocolstate.Dialer != nil { + rawHttpOptions.FastDialer = protocolstate.Dialer } rawHttpOptions.Timeout = GetHttpTimeout(options) rawHttpClient = rawhttp.NewClient(rawHttpOptions) @@ -167,10 +163,6 @@ func Get(options *types.Options, configuration *Configuration) (*retryablehttp.C func wrappedGet(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) { var err error - if Dialer == nil { - Dialer = protocolstate.Dialer - } - hash := configuration.Hash() if client, ok := clientPool.Get(hash); ok { return client, nil @@ -237,15 +229,15 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl transport := &http.Transport{ ForceAttemptHTTP2: options.ForceAttemptHTTP2, - DialContext: Dialer.Dial, + DialContext: protocolstate.Dialer.Dial, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { if options.TlsImpersonate { - return Dialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil) + return protocolstate.Dialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil) } if options.HasClientCertificates() || options.ForceAttemptHTTP2 { - return Dialer.DialTLSWithConfig(ctx, network, addr, tlsConfig) + return protocolstate.Dialer.DialTLSWithConfig(ctx, network, addr, tlsConfig) } - return Dialer.DialTLS(ctx, network, addr) + return protocolstate.Dialer.DialTLS(ctx, network, addr) }, MaxIdleConns: maxIdleConns, MaxIdleConnsPerHost: maxIdleConnsPerHost, diff --git a/pkg/protocols/http/request.go b/pkg/protocols/http/request.go index fc3df039ac..c46db8f4b5 100644 --- a/pkg/protocols/http/request.go +++ b/pkg/protocols/http/request.go @@ -777,7 +777,7 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ if input.MetaInput.CustomIP != "" { outputEvent["ip"] = input.MetaInput.CustomIP } else { - outputEvent["ip"] = httpclientpool.Dialer.GetDialedIP(hostname) + outputEvent["ip"] = protocolstate.Dialer.GetDialedIP(hostname) } if len(generatedRequest.interactshURLs) > 0 { @@ -873,7 +873,7 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ if input.MetaInput.CustomIP != "" { outputEvent["ip"] = input.MetaInput.CustomIP } else { - outputEvent["ip"] = httpclientpool.Dialer.GetDialedIP(hostname) + outputEvent["ip"] = protocolstate.Dialer.GetDialedIP(hostname) } if request.options.Interactsh != nil { request.options.Interactsh.MakePlaceholders(generatedRequest.interactshURLs, outputEvent)