Skip to content

Commit

Permalink
fix memory leaks (#241)
Browse files Browse the repository at this point in the history
* make hostsfile cache shared + seperate dnsCache from hostsFile data

* add max resolvers
  • Loading branch information
tarunKoyalwar committed Jan 16, 2024
1 parent 811e7df commit e5ed9c0
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 55 deletions.
101 changes: 57 additions & 44 deletions fastdialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import (
"time"

"github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate"
"github.com/projectdiscovery/fastdialer/fastdialer/metafiles"
"github.com/projectdiscovery/hmap/store/hybrid"
"github.com/projectdiscovery/networkpolicy"
retryabledns "github.com/projectdiscovery/retryabledns"
cryptoutil "github.com/projectdiscovery/utils/crypto"
"github.com/projectdiscovery/utils/env"
errorutil "github.com/projectdiscovery/utils/errors"
iputil "github.com/projectdiscovery/utils/ip"
ptrutil "github.com/projectdiscovery/utils/ptr"
Expand All @@ -28,22 +30,24 @@ import (

// option to disable ztls fallback in case of handshake error
// reads from env variable DISABLE_ZTLS_FALLBACK
var disableZTLSFallback = false
var (
disableZTLSFallback = false
MaxDNSCacheSize = 10 * 1024 * 1024 // 10 MB
)

func init() {
// enable permissive parsing for ztls, so that it can allow permissive parsing for X509 certificates
asn1.AllowPermissiveParsing = true
value := os.Getenv("DISABLE_ZTLS_FALLBACK")
if strings.EqualFold(value, "true") {
disableZTLSFallback = true
}
disableZTLSFallback = env.GetEnvOrDefault("DISABLE_ZTLS_FALLBACK", false)
MaxDNSCacheSize = env.GetEnvOrDefault("MAX_DNS_CACHE_SIZE", 10*1024*1024)
}

// Dialer structure containing data information
type Dialer struct {
options *Options
dnsclient *retryabledns.Client
hm *hybrid.HybridMap
dnsCache *hybrid.HybridMap
hostsFileData *hybrid.HybridMap
dialerHistory *hybrid.HybridMap
dialerTLSData *hybrid.HybridMap
dialer *net.Dialer
Expand All @@ -62,12 +66,8 @@ func NewDialer(options Options) (*Dialer, error) {
}
}

cacheOptions := getHMapConfiguration(options)
resolvers = append(resolvers, options.BaseResolvers...)
hm, err := hybrid.New(cacheOptions)
if err != nil {
return nil, err
}
var err error
var dialerHistory *hybrid.HybridMap
if options.WithDialerHistory {
// we need to use disk to store all the dialed ips
Expand All @@ -78,6 +78,22 @@ func NewDialer(options Options) (*Dialer, error) {
return nil, err
}
}
// when loading in memory set max size to 10 MB
var dnsCache *hybrid.HybridMap
if options.CacheType == Memory {
opts := hybrid.DefaultMemoryOptions
opts.MaxMemorySize = MaxDNSCacheSize
dnsCache, err = hybrid.New(opts)
if err != nil {
return nil, err
}
} else {
dnsCache, err = hybrid.New(hybrid.DefaultHybridOptions)
if err != nil {
return nil, err
}
}

var dialerTLSData *hybrid.HybridMap
if options.WithTLSData {
dialerTLSData, err = hybrid.New(hybrid.DefaultDiskOptions)
Expand All @@ -97,10 +113,14 @@ func NewDialer(options Options) (*Dialer, error) {
}
}

var hostsFileData *hybrid.HybridMap
// load hardcoded values from host file
if options.HostsFile {
// nolint:errcheck // if they cannot be loaded it's not a hard failure
loadHostsFile(hm)
if options.CacheType == Memory {
hostsFileData, _ = metafiles.GetHostsFileDnsData(metafiles.InMemory)
} else {
hostsFileData, _ = metafiles.GetHostsFileDnsData(metafiles.Hybrid)
}
}
dnsclient, err := retryabledns.New(resolvers, options.MaxRetries)
if err != nil {
Expand Down Expand Up @@ -128,7 +148,17 @@ func NewDialer(options Options) (*Dialer, error) {
return nil, err
}

return &Dialer{dnsclient: dnsclient, hm: hm, dialerHistory: dialerHistory, dialerTLSData: dialerTLSData, dialer: dialer, proxyDialer: options.ProxyDialer, options: &options, networkpolicy: np}, nil
return &Dialer{
dnsclient: dnsclient,
dnsCache: dnsCache,
hostsFileData: hostsFileData,
dialerHistory: dialerHistory,
dialerTLSData: dialerTLSData,
dialer: dialer,
proxyDialer: options.ProxyDialer,
options: &options,
networkpolicy: np,
}, nil
}

// Dial function compatible with net/http
Expand Down Expand Up @@ -398,15 +428,16 @@ func (d *Dialer) dial(ctx context.Context, network, address string, shouldUseTLS

// Close instance and cleanups
func (d *Dialer) Close() {
if d.hm != nil {
d.hm.Close()
if d.dnsCache != nil {
d.dnsCache.Close()
}
if d.options.WithDialerHistory && d.dialerHistory != nil {
d.dialerHistory.Close()
}
if d.options.WithTLSData {
d.dialerTLSData.Close()
}
// donot close hosts file as it is meant to be shared
}

// GetDialedIP returns the ip dialed by the HTTP client
Expand Down Expand Up @@ -447,11 +478,17 @@ func (d *Dialer) GetTLSData(hostname string) (*cryptoutil.TLSData, error) {
func (d *Dialer) GetDNSDataFromCache(hostname string) (*retryabledns.DNSData, error) {
hostname = asAscii(hostname)
var data retryabledns.DNSData
dataBytes, ok := d.hm.Get(hostname)
var dataBytes []byte
var ok bool
if d.hostsFileData != nil {
dataBytes, ok = d.hostsFileData.Get(hostname)
}
if !ok {
return nil, NoDNSDataError
dataBytes, ok = d.dnsCache.Get(hostname)
if !ok {
return nil, NoDNSDataError
}
}

err := data.Unmarshal(dataBytes)
return &data, err
}
Expand Down Expand Up @@ -498,7 +535,7 @@ func (d *Dialer) GetDNSData(hostname string) (*retryabledns.DNSData, error) {
}
if len(data.A)+len(data.AAAA) > 0 {
b, _ := data.Marshal()
err = d.hm.Set(hostname, b)
err = d.dnsCache.Set(hostname, b)
}
if err != nil {
return nil, err
Expand All @@ -508,30 +545,6 @@ func (d *Dialer) GetDNSData(hostname string) (*retryabledns.DNSData, error) {
return data, nil
}

func getHMapConfiguration(options Options) hybrid.Options {
var cacheOptions hybrid.Options
switch options.CacheType {
case Memory:
cacheOptions = hybrid.DefaultMemoryOptions
if options.CacheMemoryMaxItems > 0 {
cacheOptions.MaxMemorySize = options.CacheMemoryMaxItems
}
case Disk:
cacheOptions = hybrid.DefaultDiskOptions
cacheOptions.DBType = getHMAPDBType(options)
case Hybrid:
cacheOptions = hybrid.DefaultHybridOptions
}
if options.WithCleanup {
cacheOptions.Cleanup = options.WithCleanup
if options.CacheMemoryMaxItems > 0 {
cacheOptions.MaxMemorySize = options.CacheMemoryMaxItems
}
cacheOptions.DBType = getHMAPDBType(options)
}
return cacheOptions
}

func getHMAPDBType(options Options) hybrid.DBType {
switch options.DiskDbType {
case Pogreb:
Expand Down
3 changes: 3 additions & 0 deletions fastdialer/metafiles/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
// metafiles are metadata files related to networking like
// /etc/hosts etc
package metafiles
22 changes: 16 additions & 6 deletions fastdialer/hostsfile.go → fastdialer/metafiles/hostsfile.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
package fastdialer
package metafiles

import (
"bufio"
"net"
"os"
"path/filepath"
"runtime/debug"
"strings"

"github.com/dimchansky/utfbom"
"github.com/projectdiscovery/hmap/store/hybrid"
"github.com/projectdiscovery/retryabledns"
)

func loadHostsFile(hm *hybrid.HybridMap) error {
// loads Entries from hosts file if max is -1 it will load all entries to given hybrid map
func loadHostsFile(hm *hybrid.HybridMap, max int) error {
osHostsFilePath := os.ExpandEnv(filepath.FromSlash(HostsFilePath))

if env, isset := os.LookupEnv("HOSTS_PATH"); isset && len(env) > 0 {
Expand All @@ -28,6 +30,9 @@ func loadHostsFile(hm *hybrid.HybridMap) error {
dnsDatas := make(map[string]retryabledns.DNSData)
scanner := bufio.NewScanner(utfbom.SkipOnly(file))
for scanner.Scan() {
if max > 0 && len(dnsDatas) == MaxHostsEntires {
break
}
ip, hosts := HandleHostLine(scanner.Text())
if ip == "" || len(hosts) == 0 {
continue
Expand All @@ -53,10 +58,15 @@ func loadHostsFile(hm *hybrid.HybridMap) error {
dnsdataBytes, _ := dnsdata.Marshal()
_ = hm.Set(host, dnsdataBytes)
}
if len(dnsDatas) > 10000 && max < 0 {
// this freeups memory when loading large hosts files
// useful when loading all entries to hybrid storage
debug.FreeOSMemory()
}
return nil
}

const commentChar string = "#"
const CommentChar string = "#"

// HandleHostLine a hosts file line
func HandleHostLine(raw string) (ip string, hosts []string) {
Expand All @@ -67,7 +77,7 @@ func HandleHostLine(raw string) (ip string, hosts []string) {

// trim comment
if HasComment(raw) {
commentSplit := strings.Split(raw, commentChar)
commentSplit := strings.Split(raw, CommentChar)
raw = commentSplit[0]
}

Expand All @@ -88,10 +98,10 @@ func HandleHostLine(raw string) (ip string, hosts []string) {

// IsComment check if the file is a comment
func IsComment(raw string) bool {
return strings.HasPrefix(strings.TrimSpace(raw), commentChar)
return strings.HasPrefix(strings.TrimSpace(raw), CommentChar)
}

// HasComment check if the line has a comment
func HasComment(raw string) bool {
return strings.Contains(raw, commentChar)
return strings.Contains(raw, CommentChar)
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//go:build !windows
// +build !windows

package fastdialer
package metafiles

// HostsFilePath in unix file os
const HostsFilePath = "/etc/hosts"
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//go:build windows
// +build windows

package fastdialer
package metafiles

const HostsFilePath = "${SystemRoot}/System32/drivers/etc/hosts"
89 changes: 89 additions & 0 deletions fastdialer/metafiles/shared.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package metafiles

import (
"runtime"
"sync"

"github.com/projectdiscovery/hmap/store/hybrid"
"github.com/projectdiscovery/utils/env"
)

type StorageType int

const (
InMemory StorageType = iota
Hybrid
)

var (
MaxHostsEntires = 4096
// LoadAllEntries is a switch when true loads all entries to hybrid storage
// backend and uses it even if in-memory storage backend was requested
LoadAllEntries = false
)

func init() {
MaxHostsEntires = env.GetEnvOrDefault("HF_MAX_HOSTS", 4096)
LoadAllEntries = env.GetEnvOrDefault("HF_LOAD_ALL", false)
}

// GetHostsFileDnsData returns the immutable dns data that is constant throughout the program
// lifecycle and shouldn't be purged by cache etc.
func GetHostsFileDnsData(storage StorageType) (*hybrid.HybridMap, error) {
if LoadAllEntries {
storage = Hybrid
}
switch storage {
case InMemory:
return getHFInMemory()
case Hybrid:
return getHFHybridStorage()
}
return nil, nil
}

var hostsMemOnce = &sync.Once{}

// getImm
func getHFInMemory() (*hybrid.HybridMap, error) {
var hm *hybrid.HybridMap
var err error
hostsMemOnce.Do(func() {
opts := hybrid.DefaultMemoryOptions
hm, err = hybrid.New(opts)
if err != nil {
return
}
err = loadHostsFile(hm, MaxHostsEntires)
if err != nil {
hm.Close()
return
}
})
return hm, nil
}

var hostsHybridOnce = &sync.Once{}

func getHFHybridStorage() (*hybrid.HybridMap, error) {
var hm *hybrid.HybridMap
var err error
hostsHybridOnce.Do(func() {
opts := hybrid.DefaultHybridOptions
opts.Cleanup = true
hm, err = hybrid.New(opts)
if err != nil {
return
}
err = loadHostsFile(hm, -1)
if err != nil {
hm.Close()
return
}
// set finalizer for cleanup
runtime.SetFinalizer(hm, func(hm *hybrid.HybridMap) {
_ = hm.Close()
})
})
return hm, nil
}

0 comments on commit e5ed9c0

Please sign in to comment.