Skip to content

Commit

Permalink
Merge pull request #51 from projectdiscovery/feature-zone-transfer
Browse files Browse the repository at this point in the history
Adding support for AXFR zone transfer
  • Loading branch information
Mzack9999 committed May 26, 2022
2 parents a15667f + 9c32c61 commit fe66c1a
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 69 deletions.
233 changes: 164 additions & 69 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/projectdiscovery/retryabledns/doh"
"github.com/projectdiscovery/retryabledns/hostsfile"
"github.com/projectdiscovery/retryablehttp-go"
"github.com/projectdiscovery/sliceutil"
)

var internalRangeCheckerInstance *internalRangeChecker
Expand Down Expand Up @@ -52,7 +53,7 @@ func New(baseResolvers []string, maxRetries int) *Client {

// New creates a new dns client with options
func NewWithOptions(options Options) *Client {
parsedBaseResolvers := parseResolvers(deduplicate(options.BaseResolvers))
parsedBaseResolvers := parseResolvers(sliceutil.Dedupe(options.BaseResolvers))
var knownHosts map[string][]string
if options.Hostsfile {
knownHosts, _ = hostsfile.ParseDefault()
Expand Down Expand Up @@ -185,13 +186,27 @@ func (c *Client) NS(host string) (*DNSData, error) {
return c.QueryMultiple(host, []uint16{dns.TypeNS})
}

func (c *Client) AXFR(host string) (*AXFRData, error) {
return c.axfr(host)
}

// QueryMultiple sends a provided dns request and return the data with a specific resolver
func (c *Client) QueryMultipleWithResolver(host string, requestTypes []uint16, resolver Resolver) (*DNSData, error) {
return c.queryMultiple(host, requestTypes, resolver)
}

// CAA helper function
func (c *Client) CAA(host string) (*DNSData, error) {
return c.QueryMultiple(host, []uint16{dns.TypeCAA})
}

// QueryMultiple sends a provided dns request and return the data
func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, error) {
return c.queryMultiple(host, requestTypes, nil)
}

// QueryMultiple sends a provided dns request and return the data
func (c *Client) queryMultiple(host string, requestTypes []uint16, resolver Resolver) (*DNSData, error) {
var (
dnsdata DNSData
err error
Expand All @@ -212,47 +227,73 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er

msg := &dns.Msg{}
msg.Id = dns.Id()
msg.RecursionDesired = true
msg.Question = make([]dns.Question, 1)
msg.SetEdns0(4096, false)

for _, requestType := range requestTypes {
name := dns.Fqdn(host)
msg.Question = make([]dns.Question, 1)

// In case of PTR adjust the domain name
if requestType == dns.TypePTR {
switch requestType {
case dns.TypeAXFR:
msg.SetAxfr(name)
case dns.TypePTR: // In case of PTR adjust the domain name
var err error
if net.ParseIP(host) != nil {
name, err = dns.ReverseAddr(host)
if err != nil {
return nil, err
}
}
fallthrough
default:
// Enable Extension Mechanisms for DNS for all messages
msg.RecursionDesired = true
question := dns.Question{
Name: name,
Qtype: requestType,
Qclass: dns.ClassINET,
}
msg.Question[0] = question
}

question := dns.Question{
Name: name,
Qtype: requestType,
Qclass: dns.ClassINET,
}
msg.Question[0] = question

// Enable Extension Mechanisms for DNS for all messages
msg.SetEdns0(4096, false)

var resp *dns.Msg
var (
resp *dns.Msg
trResp chan *dns.Envelope
)
for i := 0; i < c.options.MaxRetries; i++ {
index := atomic.AddUint32(&c.serversIndex, 1)
resolver := c.resolvers[index%uint32(len(c.resolvers))]

if resolver == nil {
resolver = c.resolvers[index%uint32(len(c.resolvers))]
}
switch r := resolver.(type) {
case *NetworkResolver:
switch r.Protocol {
case TCP:
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
case UDP:
resp, _, err = c.udpClient.Exchange(msg, resolver.String())
case DOT:
resp, _, err = c.dotClient.Exchange(msg, resolver.String())
if requestType == dns.TypeAXFR {
var dnsconn *dns.Conn
switch r.Protocol {
case TCP:
dnsconn, err = c.tcpClient.Dial(resolver.String())
case UDP:
dnsconn, err = c.udpClient.Dial(resolver.String())
case DOT:
dnsconn, err = c.dotClient.Dial(resolver.String())
default:
dnsconn, err = c.tcpClient.Dial(resolver.String())
}
if err != nil {
break
}
defer dnsconn.Close()
dnsTransfer := &dns.Transfer{Conn: dnsconn}
trResp, err = dnsTransfer.In(msg, resolver.String())
} else {
switch r.Protocol {
case TCP:
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
case UDP:
resp, _, err = c.udpClient.Exchange(msg, resolver.String())
case DOT:
resp, _, err = c.dotClient.Exchange(msg, resolver.String())
}
}
case *DohResolver:
method := doh.MethodPost
Expand All @@ -261,26 +302,37 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er
}
resp, err = c.dohClient.QueryWithDOHMsg(method, doh.Resolver{URL: r.URL}, msg)
}
if err != nil || resp == nil {

if err != nil || (trResp == nil && resp == nil) {
continue
}

// https://github.com/projectdiscovery/retryabledns/issues/25
if resp.Truncated && c.TCPFallback {
if resp != nil && resp.Truncated && c.TCPFallback {
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
if err != nil || resp == nil {
continue
}
}

err = dnsdata.ParseFromMsg(resp)
switch requestType {
case dns.TypeAXFR:
err = dnsdata.ParseFromEnvelopeChan(trResp)
default:
err = dnsdata.ParseFromMsg(resp)
}

// populate anyway basic info
dnsdata.Host = host
dnsdata.StatusCode = dns.RcodeToString[resp.Rcode]
dnsdata.StatusCodeRaw = resp.Rcode
switch {
case resp != nil:
dnsdata.StatusCode = dns.RcodeToString[resp.Rcode]
dnsdata.StatusCodeRaw = resp.Rcode
dnsdata.Raw += resp.String()
case trResp != nil:
// pass
}
dnsdata.Timestamp = time.Now()
dnsdata.Raw += resp.String()
dnsdata.Resolver = append(dnsdata.Resolver, resolver.String())

if err != nil || !dnsdata.contains() {
Expand All @@ -289,7 +341,10 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er
dnsdata.dedupe()

// stop on success
if resp.Rcode == dns.RcodeSuccess {
if resp != nil && resp.Rcode == dns.RcodeSuccess {
break
}
if trResp != nil {
break
}
}
Expand Down Expand Up @@ -336,7 +391,7 @@ func (c *Client) QueryParallel(host string, requestType uint16, resolvers []stri
return dnsdatas, nil
}

// QueryMultiple sends a provided dns request and return the data
// Trace the requested domain with the provided query type
func (c *Client) Trace(host string, requestType uint16, maxrecursion int) (*TraceData, error) {
var tracedata TraceData
host = dns.CanonicalName(host)
Expand Down Expand Up @@ -388,7 +443,7 @@ func (c *Client) Trace(host string, requestType uint16, maxrecursion int) (*Trac
}
}
}
newNSResolvers = deduplicate(newNSResolvers)
newNSResolvers = sliceutil.Dedupe(newNSResolvers)

// if we have no new resolvers => return
if len(newNSResolvers) == 0 {
Expand All @@ -413,6 +468,40 @@ func (c *Client) Trace(host string, requestType uint16, maxrecursion int) (*Trac
return &tracedata, nil
}

func (c *Client) axfr(host string) (*AXFRData, error) {
// obtain ns servers
dnsData, err := c.NS(host)
if err != nil {
return nil, err
}
// resolve ns servers to ips
var resolvers []Resolver

for _, ns := range dnsData.NS {
nsData, err := c.A(ns)
if err != nil {
continue
}
for _, a := range nsData.A {
resolvers = append(resolvers, &NetworkResolver{Protocol: TCP, Host: a, Port: "53"})
}
}

resolvers = append(resolvers, c.resolvers...)

var data []*DNSData
// perform zone transfer for each ns
for _, resolver := range resolvers {
nsData, err := c.QueryMultipleWithResolver(host, []uint16{dns.TypeAXFR}, resolver)
if err != nil {
continue
}
data = append(data, nsData)
}

return &AXFRData{Host: host, DNSData: data}, nil
}

// DNSData is the data for a DNS request response
type DNSData struct {
Host string `json:"host,omitempty"`
Expand All @@ -426,27 +515,25 @@ type DNSData struct {
SOA []string `json:"soa,omitempty"`
NS []string `json:"ns,omitempty"`
TXT []string `json:"txt,omitempty"`
CAA []string `json:"caa,omitempty"`
AllRecords []string `json:"all,omitempty"`
Raw string `json:"raw,omitempty"`
HasInternalIPs bool `json:"has_internal_ips"`
HasInternalIPs bool `json:"has_internal_ips,omitempty"`
InternalIPs []string `json:"internal_ips,omitempty"`
StatusCode string `json:"status_code,omitempty"`
StatusCodeRaw int `json:"status_code_raw,omitempty"`
TraceData *TraceData `json:"trace,omitempty"`
AXFRData *AXFRData `json:"axfr,omitempty"`
RawResp *dns.Msg `json:"raw_resp,omitempty"`
Timestamp time.Time `json:"timestamp,omitempty"`
CAA []string `json:"caa,omitempty"`
}

// CheckInternalIPs when set to true returns if DNS response IPs
// belong to internal IP ranges.
var CheckInternalIPs = false

// ParseFromMsg and enrich data
func (d *DNSData) ParseFromMsg(msg *dns.Msg) error {
allRecords := append(msg.Answer, msg.Extra...)
allRecords = append(allRecords, msg.Ns...)

for _, record := range allRecords {
func (d *DNSData) ParseFromRR(rrs []dns.RR) error {
for _, record := range rrs {
switch recordType := record.(type) {
case *dns.A:
if CheckInternalIPs && internalRangeCheckerInstance != nil && internalRangeCheckerInstance.ContainsIPv4(recordType.A) {
Expand Down Expand Up @@ -478,11 +565,29 @@ func (d *DNSData) ParseFromMsg(msg *dns.Msg) error {
}
d.AAAA = append(d.AAAA, trimChars(recordType.AAAA.String()))
}
d.AllRecords = append(d.AllRecords, record.String())
}

return nil
}

// ParseFromMsg and enrich data
func (d *DNSData) ParseFromMsg(msg *dns.Msg) error {
allRecords := append(msg.Answer, msg.Extra...)
allRecords = append(allRecords, msg.Ns...)
return d.ParseFromRR(allRecords)
}

func (d *DNSData) ParseFromEnvelopeChan(envChan chan *dns.Envelope) error {
var allRecords []dns.RR
for env := range envChan {
if env.Error != nil {
return env.Error
}
allRecords = append(allRecords, env.RR...)
}
return d.ParseFromRR(allRecords)
}

func (d *DNSData) contains() bool {
return len(d.A) > 0 || len(d.AAAA) > 0 || len(d.CNAME) > 0 || len(d.MX) > 0 || len(d.NS) > 0 || len(d.PTR) > 0 || len(d.TXT) > 0 || len(d.SOA) > 0 || len(d.CAA) > 0
}
Expand All @@ -498,16 +603,17 @@ func trimChars(s string) string {
}

func (d *DNSData) dedupe() {
d.Resolver = deduplicate(d.Resolver)
d.A = deduplicate(d.A)
d.AAAA = deduplicate(d.AAAA)
d.CNAME = deduplicate(d.CNAME)
d.MX = deduplicate(d.MX)
d.PTR = deduplicate(d.PTR)
d.SOA = deduplicate(d.SOA)
d.NS = deduplicate(d.NS)
d.TXT = deduplicate(d.TXT)
d.CAA = deduplicate(d.CAA)
d.Resolver = sliceutil.Dedupe(d.Resolver)
d.A = sliceutil.Dedupe(d.A)
d.AAAA = sliceutil.Dedupe(d.AAAA)
d.CNAME = sliceutil.Dedupe(d.CNAME)
d.MX = sliceutil.Dedupe(d.MX)
d.PTR = sliceutil.Dedupe(d.PTR)
d.SOA = sliceutil.Dedupe(d.SOA)
d.NS = sliceutil.Dedupe(d.NS)
d.TXT = sliceutil.Dedupe(d.TXT)
d.CAA = sliceutil.Dedupe(d.CAA)
d.AllRecords = sliceutil.Dedupe(d.AllRecords)
}

// Marshal encodes the dnsdata to a binary representation
Expand All @@ -527,24 +633,13 @@ func (d *DNSData) Unmarshal(b []byte) error {
return dec.Decode(&d)
}

// deduplicate returns a new slice with duplicates values removed.
func deduplicate(s []string) []string {
if len(s) < 2 {
return s
}
var results []string
seen := make(map[string]struct{})
for _, val := range s {
if _, ok := seen[val]; !ok {
results = append(results, val)
seen[val] = struct{}{}
}
}
return results
}

// TraceData contains the trace information for a dns query
type TraceData struct {
Host string `json:"host,omitempty"`
DNSData []*DNSData `json:"chain,omitempty"`
}

type AXFRData struct {
Host string `json:"host,omitempty"`
DNSData []*DNSData `json:"chain,omitempty"`
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/projectdiscovery/fileutil v0.0.0-20210926202739-6050d0acf73c
github.com/projectdiscovery/iputil v0.0.0-20210804143329-3a30fcde43f3
github.com/projectdiscovery/retryablehttp-go v1.0.2
github.com/projectdiscovery/sliceutil v0.0.0-20220225084130-8392ac12fa6d
github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d
github.com/stretchr/testify v1.7.1
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ github.com/projectdiscovery/mapcidr v0.0.7 h1:WK6WFimbWjUxfvcHEgofYNqIyqQh0vTDKz
github.com/projectdiscovery/mapcidr v0.0.7/go.mod h1:7CzdUdjuLVI0s33dQ33lWgjg3vPuLFw2rQzZ0RxkT00=
github.com/projectdiscovery/retryablehttp-go v1.0.2 h1:LV1/KAQU+yeWhNVlvveaYFsjBYRwXlNEq0PvrezMV0U=
github.com/projectdiscovery/retryablehttp-go v1.0.2/go.mod h1:dx//aY9V247qHdsRf0vdWHTBZuBQ2vm6Dq5dagxrDYI=
github.com/projectdiscovery/sliceutil v0.0.0-20220225084130-8392ac12fa6d h1:wIQPYRZEwTeJuoZLv3NT9r+il2fAv1ObRzTdHkNgOxk=
github.com/projectdiscovery/sliceutil v0.0.0-20220225084130-8392ac12fa6d/go.mod h1:QHXvznfPfA5f0AZUIBkbLapoUJJlsIDgUlkKva6dOr4=
github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d h1:lrdpJCBOvRrTnm44Ov7O3tLd3oOWhCvVUhTKkWwibq4=
github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d/go.mod h1:oTRc18WBv9t6BpaN9XBY+QmG28PUpsyDzRht56Qf49I=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down

0 comments on commit fe66c1a

Please sign in to comment.