-
Notifications
You must be signed in to change notification settings - Fork 4.9k
/
hostresolver.go
133 lines (108 loc) · 2.99 KB
/
hostresolver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package requestdecorator
import (
"context"
"errors"
"fmt"
"net"
"sort"
"strings"
"time"
"github.com/miekg/dns"
"github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log"
)
type cnameResolv struct {
TTL time.Duration
Record string
}
type byTTL []*cnameResolv
func (a byTTL) Len() int { return len(a) }
func (a byTTL) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a byTTL) Less(i, j int) bool { return a[i].TTL > a[j].TTL }
// Resolver used for host resolver.
type Resolver struct {
CnameFlattening bool
ResolvConfig string
ResolvDepth int
cache *cache.Cache
}
// CNAMEFlatten check if CNAME record exists, flatten if possible.
func (hr *Resolver) CNAMEFlatten(ctx context.Context, host string) string {
if hr.cache == nil {
hr.cache = cache.New(30*time.Minute, 5*time.Minute)
}
result := host
request := host
value, found := hr.cache.Get(host)
if found {
return value.(string)
}
logger := log.Ctx(ctx)
cacheDuration := 0 * time.Second
for depth := range hr.ResolvDepth {
resolv, err := cnameResolve(ctx, request, hr.ResolvConfig)
if err != nil {
logger.Error().Err(err).Send()
break
}
if resolv == nil {
break
}
result = resolv.Record
if depth == 0 {
cacheDuration = resolv.TTL
}
request = resolv.Record
}
hr.cache.Set(host, result, cacheDuration)
return result
}
// cnameResolve resolves CNAME if exists, and return with the highest TTL.
func cnameResolve(ctx context.Context, host, resolvPath string) (*cnameResolv, error) {
config, err := dns.ClientConfigFromFile(resolvPath)
if err != nil {
return nil, fmt.Errorf("invalid resolver configuration file: %s", resolvPath)
}
if net.ParseIP(host) != nil {
return nil, nil
}
client := &dns.Client{Timeout: 30 * time.Second}
m := &dns.Msg{}
m.SetQuestion(dns.Fqdn(host), dns.TypeCNAME)
var result []*cnameResolv
for _, server := range config.Servers {
tempRecord, err := getRecord(client, m, server, config.Port)
if err != nil {
if errors.Is(err, errNoCNAMERecord) {
log.Ctx(ctx).Debug().Err(err).Msgf("CNAME lookup for hostname %q", host)
continue
}
log.Ctx(ctx).Error().Err(err).Msgf("CNAME lookup for hostname %q", host)
continue
}
result = append(result, tempRecord)
}
if len(result) == 0 {
return nil, nil
}
sort.Sort(byTTL(result))
return result[0], nil
}
var errNoCNAMERecord = errors.New("no CNAME record for host")
func getRecord(client *dns.Client, msg *dns.Msg, server, port string) (*cnameResolv, error) {
resp, _, err := client.Exchange(msg, net.JoinHostPort(server, port))
if err != nil {
return nil, fmt.Errorf("exchange error for server %s: %w", server, err)
}
if resp == nil || len(resp.Answer) == 0 {
return nil, fmt.Errorf("%w: %s", errNoCNAMERecord, server)
}
rr, ok := resp.Answer[0].(*dns.CNAME)
if !ok {
return nil, fmt.Errorf("invalid response type for server %s", server)
}
return &cnameResolv{
TTL: time.Duration(rr.Hdr.Ttl) * time.Second,
Record: strings.TrimSuffix(rr.Target, "."),
}, nil
}