diff --git a/fastdialer/dialer.go b/fastdialer/dialer.go index 8967cd1..5daef0a 100644 --- a/fastdialer/dialer.go +++ b/fastdialer/dialer.go @@ -13,10 +13,12 @@ import ( "time" "github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate" + "github.com/projectdiscovery/fastdialer/fastdialer/metafiles" "github.com/projectdiscovery/hmap/store/hybrid" "github.com/projectdiscovery/networkpolicy" retryabledns "github.com/projectdiscovery/retryabledns" cryptoutil "github.com/projectdiscovery/utils/crypto" + "github.com/projectdiscovery/utils/env" errorutil "github.com/projectdiscovery/utils/errors" iputil "github.com/projectdiscovery/utils/ip" ptrutil "github.com/projectdiscovery/utils/ptr" @@ -28,22 +30,24 @@ import ( // option to disable ztls fallback in case of handshake error // reads from env variable DISABLE_ZTLS_FALLBACK -var disableZTLSFallback = false +var ( + disableZTLSFallback = false + MaxDNSCacheSize = 10 * 1024 * 1024 // 10 MB +) func init() { // enable permissive parsing for ztls, so that it can allow permissive parsing for X509 certificates asn1.AllowPermissiveParsing = true - value := os.Getenv("DISABLE_ZTLS_FALLBACK") - if strings.EqualFold(value, "true") { - disableZTLSFallback = true - } + disableZTLSFallback = env.GetEnvOrDefault("DISABLE_ZTLS_FALLBACK", false) + MaxDNSCacheSize = env.GetEnvOrDefault("MAX_DNS_CACHE_SIZE", 10*1024*1024) } // Dialer structure containing data information type Dialer struct { options *Options dnsclient *retryabledns.Client - hm *hybrid.HybridMap + dnsCache *hybrid.HybridMap + hostsFileData *hybrid.HybridMap dialerHistory *hybrid.HybridMap dialerTLSData *hybrid.HybridMap dialer *net.Dialer @@ -62,12 +66,8 @@ func NewDialer(options Options) (*Dialer, error) { } } - cacheOptions := getHMapConfiguration(options) resolvers = append(resolvers, options.BaseResolvers...) - hm, err := hybrid.New(cacheOptions) - if err != nil { - return nil, err - } + var err error var dialerHistory *hybrid.HybridMap if options.WithDialerHistory { // we need to use disk to store all the dialed ips @@ -78,6 +78,22 @@ func NewDialer(options Options) (*Dialer, error) { return nil, err } } + // when loading in memory set max size to 10 MB + var dnsCache *hybrid.HybridMap + if options.CacheType == Memory { + opts := hybrid.DefaultMemoryOptions + opts.MaxMemorySize = MaxDNSCacheSize + dnsCache, err = hybrid.New(opts) + if err != nil { + return nil, err + } + } else { + dnsCache, err = hybrid.New(hybrid.DefaultHybridOptions) + if err != nil { + return nil, err + } + } + var dialerTLSData *hybrid.HybridMap if options.WithTLSData { dialerTLSData, err = hybrid.New(hybrid.DefaultDiskOptions) @@ -97,10 +113,14 @@ func NewDialer(options Options) (*Dialer, error) { } } + var hostsFileData *hybrid.HybridMap // load hardcoded values from host file if options.HostsFile { - // nolint:errcheck // if they cannot be loaded it's not a hard failure - loadHostsFile(hm) + if options.CacheType == Memory { + hostsFileData, _ = metafiles.GetHostsFileDnsData(metafiles.InMemory) + } else { + hostsFileData, _ = metafiles.GetHostsFileDnsData(metafiles.Hybrid) + } } dnsclient, err := retryabledns.New(resolvers, options.MaxRetries) if err != nil { @@ -128,7 +148,17 @@ func NewDialer(options Options) (*Dialer, error) { return nil, err } - return &Dialer{dnsclient: dnsclient, hm: hm, dialerHistory: dialerHistory, dialerTLSData: dialerTLSData, dialer: dialer, proxyDialer: options.ProxyDialer, options: &options, networkpolicy: np}, nil + return &Dialer{ + dnsclient: dnsclient, + dnsCache: dnsCache, + hostsFileData: hostsFileData, + dialerHistory: dialerHistory, + dialerTLSData: dialerTLSData, + dialer: dialer, + proxyDialer: options.ProxyDialer, + options: &options, + networkpolicy: np, + }, nil } // Dial function compatible with net/http @@ -398,8 +428,8 @@ func (d *Dialer) dial(ctx context.Context, network, address string, shouldUseTLS // Close instance and cleanups func (d *Dialer) Close() { - if d.hm != nil { - d.hm.Close() + if d.dnsCache != nil { + d.dnsCache.Close() } if d.options.WithDialerHistory && d.dialerHistory != nil { d.dialerHistory.Close() @@ -407,6 +437,7 @@ func (d *Dialer) Close() { if d.options.WithTLSData { d.dialerTLSData.Close() } + // donot close hosts file as it is meant to be shared } // GetDialedIP returns the ip dialed by the HTTP client @@ -447,11 +478,17 @@ func (d *Dialer) GetTLSData(hostname string) (*cryptoutil.TLSData, error) { func (d *Dialer) GetDNSDataFromCache(hostname string) (*retryabledns.DNSData, error) { hostname = asAscii(hostname) var data retryabledns.DNSData - dataBytes, ok := d.hm.Get(hostname) + var dataBytes []byte + var ok bool + if d.hostsFileData != nil { + dataBytes, ok = d.hostsFileData.Get(hostname) + } if !ok { - return nil, NoDNSDataError + dataBytes, ok = d.dnsCache.Get(hostname) + if !ok { + return nil, NoDNSDataError + } } - err := data.Unmarshal(dataBytes) return &data, err } @@ -498,7 +535,7 @@ func (d *Dialer) GetDNSData(hostname string) (*retryabledns.DNSData, error) { } if len(data.A)+len(data.AAAA) > 0 { b, _ := data.Marshal() - err = d.hm.Set(hostname, b) + err = d.dnsCache.Set(hostname, b) } if err != nil { return nil, err @@ -508,30 +545,6 @@ func (d *Dialer) GetDNSData(hostname string) (*retryabledns.DNSData, error) { return data, nil } -func getHMapConfiguration(options Options) hybrid.Options { - var cacheOptions hybrid.Options - switch options.CacheType { - case Memory: - cacheOptions = hybrid.DefaultMemoryOptions - if options.CacheMemoryMaxItems > 0 { - cacheOptions.MaxMemorySize = options.CacheMemoryMaxItems - } - case Disk: - cacheOptions = hybrid.DefaultDiskOptions - cacheOptions.DBType = getHMAPDBType(options) - case Hybrid: - cacheOptions = hybrid.DefaultHybridOptions - } - if options.WithCleanup { - cacheOptions.Cleanup = options.WithCleanup - if options.CacheMemoryMaxItems > 0 { - cacheOptions.MaxMemorySize = options.CacheMemoryMaxItems - } - cacheOptions.DBType = getHMAPDBType(options) - } - return cacheOptions -} - func getHMAPDBType(options Options) hybrid.DBType { switch options.DiskDbType { case Pogreb: diff --git a/fastdialer/metafiles/doc.go b/fastdialer/metafiles/doc.go new file mode 100644 index 0000000..fc7af1c --- /dev/null +++ b/fastdialer/metafiles/doc.go @@ -0,0 +1,3 @@ +// metafiles are metadata files related to networking like +// /etc/hosts etc +package metafiles diff --git a/fastdialer/hostsfile.go b/fastdialer/metafiles/hostsfile.go similarity index 74% rename from fastdialer/hostsfile.go rename to fastdialer/metafiles/hostsfile.go index 29b2a3b..5bb7a95 100644 --- a/fastdialer/hostsfile.go +++ b/fastdialer/metafiles/hostsfile.go @@ -1,10 +1,11 @@ -package fastdialer +package metafiles import ( "bufio" "net" "os" "path/filepath" + "runtime/debug" "strings" "github.com/dimchansky/utfbom" @@ -12,7 +13,8 @@ import ( "github.com/projectdiscovery/retryabledns" ) -func loadHostsFile(hm *hybrid.HybridMap) error { +// loads Entries from hosts file if max is -1 it will load all entries to given hybrid map +func loadHostsFile(hm *hybrid.HybridMap, max int) error { osHostsFilePath := os.ExpandEnv(filepath.FromSlash(HostsFilePath)) if env, isset := os.LookupEnv("HOSTS_PATH"); isset && len(env) > 0 { @@ -28,6 +30,9 @@ func loadHostsFile(hm *hybrid.HybridMap) error { dnsDatas := make(map[string]retryabledns.DNSData) scanner := bufio.NewScanner(utfbom.SkipOnly(file)) for scanner.Scan() { + if max > 0 && len(dnsDatas) == MaxHostsEntires { + break + } ip, hosts := HandleHostLine(scanner.Text()) if ip == "" || len(hosts) == 0 { continue @@ -53,10 +58,15 @@ func loadHostsFile(hm *hybrid.HybridMap) error { dnsdataBytes, _ := dnsdata.Marshal() _ = hm.Set(host, dnsdataBytes) } + if len(dnsDatas) > 10000 && max < 0 { + // this freeups memory when loading large hosts files + // useful when loading all entries to hybrid storage + debug.FreeOSMemory() + } return nil } -const commentChar string = "#" +const CommentChar string = "#" // HandleHostLine a hosts file line func HandleHostLine(raw string) (ip string, hosts []string) { @@ -67,7 +77,7 @@ func HandleHostLine(raw string) (ip string, hosts []string) { // trim comment if HasComment(raw) { - commentSplit := strings.Split(raw, commentChar) + commentSplit := strings.Split(raw, CommentChar) raw = commentSplit[0] } @@ -88,10 +98,10 @@ func HandleHostLine(raw string) (ip string, hosts []string) { // IsComment check if the file is a comment func IsComment(raw string) bool { - return strings.HasPrefix(strings.TrimSpace(raw), commentChar) + return strings.HasPrefix(strings.TrimSpace(raw), CommentChar) } // HasComment check if the line has a comment func HasComment(raw string) bool { - return strings.Contains(raw, commentChar) + return strings.Contains(raw, CommentChar) } diff --git a/fastdialer/hostsfile_unix.go b/fastdialer/metafiles/hostsfile_unix.go similarity index 70% rename from fastdialer/hostsfile_unix.go rename to fastdialer/metafiles/hostsfile_unix.go index 7b43a0f..b2d0590 100644 --- a/fastdialer/hostsfile_unix.go +++ b/fastdialer/metafiles/hostsfile_unix.go @@ -1,6 +1,7 @@ +//go:build !windows // +build !windows -package fastdialer +package metafiles // HostsFilePath in unix file os const HostsFilePath = "/etc/hosts" diff --git a/fastdialer/hostsfile_windows.go b/fastdialer/metafiles/hostsfile_windows.go similarity index 69% rename from fastdialer/hostsfile_windows.go rename to fastdialer/metafiles/hostsfile_windows.go index 3ab7b7c..09bbdb6 100644 --- a/fastdialer/hostsfile_windows.go +++ b/fastdialer/metafiles/hostsfile_windows.go @@ -1,5 +1,6 @@ +//go:build windows // +build windows -package fastdialer +package metafiles const HostsFilePath = "${SystemRoot}/System32/drivers/etc/hosts" diff --git a/fastdialer/metafiles/shared.go b/fastdialer/metafiles/shared.go new file mode 100644 index 0000000..f280072 --- /dev/null +++ b/fastdialer/metafiles/shared.go @@ -0,0 +1,89 @@ +package metafiles + +import ( + "runtime" + "sync" + + "github.com/projectdiscovery/hmap/store/hybrid" + "github.com/projectdiscovery/utils/env" +) + +type StorageType int + +const ( + InMemory StorageType = iota + Hybrid +) + +var ( + MaxHostsEntires = 4096 + // LoadAllEntries is a switch when true loads all entries to hybrid storage + // backend and uses it even if in-memory storage backend was requested + LoadAllEntries = false +) + +func init() { + MaxHostsEntires = env.GetEnvOrDefault("HF_MAX_HOSTS", 4096) + LoadAllEntries = env.GetEnvOrDefault("HF_LOAD_ALL", false) +} + +// GetHostsFileDnsData returns the immutable dns data that is constant throughout the program +// lifecycle and shouldn't be purged by cache etc. +func GetHostsFileDnsData(storage StorageType) (*hybrid.HybridMap, error) { + if LoadAllEntries { + storage = Hybrid + } + switch storage { + case InMemory: + return getHFInMemory() + case Hybrid: + return getHFHybridStorage() + } + return nil, nil +} + +var hostsMemOnce = &sync.Once{} + +// getImm +func getHFInMemory() (*hybrid.HybridMap, error) { + var hm *hybrid.HybridMap + var err error + hostsMemOnce.Do(func() { + opts := hybrid.DefaultMemoryOptions + hm, err = hybrid.New(opts) + if err != nil { + return + } + err = loadHostsFile(hm, MaxHostsEntires) + if err != nil { + hm.Close() + return + } + }) + return hm, nil +} + +var hostsHybridOnce = &sync.Once{} + +func getHFHybridStorage() (*hybrid.HybridMap, error) { + var hm *hybrid.HybridMap + var err error + hostsHybridOnce.Do(func() { + opts := hybrid.DefaultHybridOptions + opts.Cleanup = true + hm, err = hybrid.New(opts) + if err != nil { + return + } + err = loadHostsFile(hm, -1) + if err != nil { + hm.Close() + return + } + // set finalizer for cleanup + runtime.SetFinalizer(hm, func(hm *hybrid.HybridMap) { + _ = hm.Close() + }) + }) + return hm, nil +} diff --git a/fastdialer/resolverfile.go b/fastdialer/resolverfile.go index 96fd665..1593ec7 100644 --- a/fastdialer/resolverfile.go +++ b/fastdialer/resolverfile.go @@ -8,8 +8,19 @@ import ( "strings" "github.com/dimchansky/utfbom" + "github.com/projectdiscovery/fastdialer/fastdialer/metafiles" + "github.com/projectdiscovery/utils/env" ) +var ( + MaxResolverEntries = 4096 +) + +func init() { + // use -1 for all entries + MaxResolverEntries = env.GetEnvOrDefault("MAX_RESOLVERS", 4096) +} + func loadResolverFile() ([]string, error) { osResolversFilePath := os.ExpandEnv(filepath.FromSlash(ResolverFilePath)) @@ -27,6 +38,9 @@ func loadResolverFile() ([]string, error) { scanner := bufio.NewScanner(utfbom.SkipOnly(file)) for scanner.Scan() { + if MaxResolverEntries != -1 && len(systemResolvers) >= MaxResolverEntries { + break + } resolverIP := HandleResolverLine(scanner.Text()) if resolverIP == "" { continue @@ -39,13 +53,13 @@ func loadResolverFile() ([]string, error) { // HandleLine a resolver file line func HandleResolverLine(raw string) (ip string) { // ignore comment - if IsComment(raw) { + if metafiles.IsComment(raw) { return } // trim comment - if HasComment(raw) { - commentSplit := strings.Split(raw, commentChar) + if metafiles.HasComment(raw) { + commentSplit := strings.Split(raw, metafiles.CommentChar) raw = commentSplit[0] }