/
resolver-mdns.go
441 lines (387 loc) Β· 11.9 KB
/
resolver-mdns.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
package resolver
import (
"context"
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/safing/portmaster/network/netutils"
"github.com/miekg/dns"
"github.com/safing/portbase/log"
)
// DNS Classes
const (
DNSClassMulticast = dns.ClassINET | 1<<15
)
var (
multicast4Conn *net.UDPConn
multicast6Conn *net.UDPConn
unicast4Conn *net.UDPConn
unicast6Conn *net.UDPConn
questions = make(map[uint16]*savedQuestion)
questionsLock sync.Mutex
mDNSResolver = &Resolver{
ConfigURL: ServerSourceMDNS,
Info: &ResolverInfo{
Type: ServerTypeMDNS,
Source: ServerSourceMDNS,
IPScope: netutils.SiteLocal,
},
Conn: &mDNSResolverConn{},
}
mDNSResolvers = []*Resolver{mDNSResolver}
)
type mDNSResolverConn struct{}
func (mrc *mDNSResolverConn) Query(ctx context.Context, q *Query) (*RRCache, error) {
return queryMulticastDNS(ctx, q)
}
func (mrc *mDNSResolverConn) ReportFailure() {}
func (mrc *mDNSResolverConn) IsFailing() bool {
return false
}
func (mrc *mDNSResolverConn) ResetFailure() {}
type savedQuestion struct {
question dns.Question
expires time.Time
response chan *RRCache
}
func indexOfRR(entry *dns.RR_Header, list *[]dns.RR) int {
for k, v := range *list {
if entry.Name == v.Header().Name && entry.Rrtype == v.Header().Rrtype {
return k
}
}
return -1
}
//nolint:gocyclo,gocognit // TODO: make simpler
func listenToMDNS(ctx context.Context) error {
var err error
messages := make(chan *dns.Msg, 32)
// TODO: init and start every listener in its own service worker
// this will make the more resilient and actually able to restart
multicast4Conn, err = net.ListenMulticastUDP("udp4", nil, &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251), Port: 5353})
if err != nil {
// TODO: retry after some time
log.Warningf("intel(mdns): failed to create udp4 listen multicast socket: %s", err)
} else {
module.StartServiceWorker("mdns udp4 multicast listener", 0, func(ctx context.Context) error {
return listenForDNSPackets(ctx, multicast4Conn, messages)
})
defer multicast4Conn.Close()
}
multicast6Conn, err = net.ListenMulticastUDP("udp6", nil, &net.UDPAddr{IP: net.IP([]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb}), Port: 5353})
if err != nil {
// TODO: retry after some time
log.Warningf("intel(mdns): failed to create udp6 listen multicast socket: %s", err)
} else {
module.StartServiceWorker("mdns udp6 multicast listener", 0, func(ctx context.Context) error {
return listenForDNSPackets(ctx, multicast6Conn, messages)
})
defer multicast6Conn.Close()
}
unicast4Conn, err = net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
// TODO: retry after some time
log.Warningf("intel(mdns): failed to create udp4 listen socket: %s", err)
} else {
module.StartServiceWorker("mdns udp4 unicast listener", 0, func(ctx context.Context) error {
return listenForDNSPackets(ctx, unicast4Conn, messages)
})
defer unicast4Conn.Close()
}
unicast6Conn, err = net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
if err != nil {
// TODO: retry after some time
log.Warningf("intel(mdns): failed to create udp6 listen socket: %s", err)
} else {
module.StartServiceWorker("mdns udp6 unicast listener", 0, func(ctx context.Context) error {
return listenForDNSPackets(ctx, unicast6Conn, messages)
})
defer unicast6Conn.Close()
}
// start message handler
module.StartServiceWorker("mdns message handler", 0, func(ctx context.Context) error {
return handleMDNSMessages(ctx, messages)
})
// wait for shutdown
<-module.Ctx.Done()
return nil
}
//nolint:gocyclo,gocognit // TODO
func handleMDNSMessages(ctx context.Context, messages chan *dns.Msg) error {
for {
select {
case <-ctx.Done():
return nil
case message := <-messages:
// log.Tracef("resolver: got net mdns message: %s", message)
var err error
var question *dns.Question
var saveFullRequest bool
scavengedRecords := make(map[string]dns.RR)
var rrCache *RRCache
// save every received response
// if previous save was less than 2 seconds ago, add to response, else replace
// pick out A and AAAA records and save separately
// continue if not response
if !message.Response {
// log.Tracef("resolver: mdns message has no response, ignoring")
continue
}
// continue if rcode is not success
if message.Rcode != dns.RcodeSuccess {
// log.Tracef("resolver: mdns message has error, ignoring")
continue
}
// continue if answer section is empty
if len(message.Answer) == 0 {
// log.Tracef("resolver: mdns message has no answers, ignoring")
continue
}
// return saved question
questionsLock.Lock()
savedQ := questions[message.MsgHdr.Id]
questionsLock.Unlock()
// get question, some servers do not reply with question
if len(message.Question) > 0 {
question = &message.Question[0]
// if questions do not match, disregard saved question
if savedQ != nil && message.Question[0].String() != savedQ.question.String() {
savedQ = nil
}
} else if savedQ != nil {
question = &savedQ.question
}
if question != nil {
// continue if class is not INTERNET
if question.Qclass != dns.ClassINET && question.Qclass != DNSClassMulticast {
continue
}
// mark request to be saved
saveFullRequest = true
}
// get entry from database
if saveFullRequest {
// get from database
rrCache, err = GetRRCache(question.Name, dns.Type(question.Qtype))
// if we have no cached entry, or it has been updated more than two seconds ago, or if it expired:
// create new and do not append
if err != nil || rrCache.Modified < time.Now().Add(-2*time.Second).Unix() || rrCache.Expired() {
rrCache = &RRCache{
Domain: question.Name,
Question: dns.Type(question.Qtype),
RCode: dns.RcodeSuccess,
Resolver: mDNSResolver.Info.Copy(),
}
}
}
// add all entries to RRCache
for _, entry := range message.Answer {
if domainInScope(entry.Header().Name, multicastDomains) {
if saveFullRequest {
k := indexOfRR(entry.Header(), &rrCache.Answer)
if k == -1 {
rrCache.Answer = append(rrCache.Answer, entry)
} else {
rrCache.Answer[k] = entry
}
}
switch entry.(type) {
case *dns.A:
scavengedRecords[fmt.Sprintf("%sA", entry.Header().Name)] = entry
case *dns.AAAA:
scavengedRecords[fmt.Sprintf("%sAAAA", entry.Header().Name)] = entry
case *dns.PTR:
if !strings.HasPrefix(entry.Header().Name, "_") {
scavengedRecords[fmt.Sprintf("%sPTR", entry.Header().Name)] = entry
}
}
}
}
for _, entry := range message.Ns {
if domainInScope(entry.Header().Name, multicastDomains) {
if saveFullRequest {
k := indexOfRR(entry.Header(), &rrCache.Ns)
if k == -1 {
rrCache.Ns = append(rrCache.Ns, entry)
} else {
rrCache.Ns[k] = entry
}
}
switch entry.(type) {
case *dns.A:
scavengedRecords[fmt.Sprintf("%sA", entry.Header().Name)] = entry
case *dns.AAAA:
scavengedRecords[fmt.Sprintf("%sAAAA", entry.Header().Name)] = entry
case *dns.PTR:
if !strings.HasPrefix(entry.Header().Name, "_") {
scavengedRecords[fmt.Sprintf("%sPTR", entry.Header().Name)] = entry
}
}
}
}
for _, entry := range message.Extra {
if domainInScope(entry.Header().Name, multicastDomains) {
if saveFullRequest {
k := indexOfRR(entry.Header(), &rrCache.Extra)
if k == -1 {
rrCache.Extra = append(rrCache.Extra, entry)
} else {
rrCache.Extra[k] = entry
}
}
switch entry.(type) {
case *dns.A:
scavengedRecords[fmt.Sprintf("%sA", entry.Header().Name)] = entry
case *dns.AAAA:
scavengedRecords[fmt.Sprintf("%sAAAA", entry.Header().Name)] = entry
case *dns.PTR:
if !strings.HasPrefix(entry.Header().Name, "_") {
scavengedRecords[fmt.Sprintf("%sPTR", entry.Header().Name)] = entry
}
}
}
}
var questionID string
if saveFullRequest {
rrCache.Clean(minMDnsTTL)
err := rrCache.Save()
if err != nil {
log.Warningf("resolver: failed to cache RR %s: %s", rrCache.Domain, err)
}
// return finished response
if savedQ != nil {
select {
case savedQ.response <- rrCache:
default:
}
}
questionID = fmt.Sprintf("%s%s", question.Name, dns.Type(question.Qtype).String())
}
for k, v := range scavengedRecords {
if saveFullRequest && k == questionID {
continue
}
rrCache = &RRCache{
Domain: v.Header().Name,
Question: dns.Type(v.Header().Class),
RCode: dns.RcodeSuccess,
Answer: []dns.RR{v},
Resolver: mDNSResolver.Info.Copy(),
}
rrCache.Clean(minMDnsTTL)
err := rrCache.Save()
if err != nil {
log.Warningf("resolver: failed to cache RR %s: %s", rrCache.Domain, err)
}
// log.Tracef("resolver: mdns scavenged %s", k)
}
}
cleanSavedQuestions()
}
}
func listenForDNSPackets(ctx context.Context, conn *net.UDPConn, messages chan *dns.Msg) error {
buf := make([]byte, 65536)
for {
n, err := conn.Read(buf)
if err != nil {
if module.IsStopping() {
return nil
}
log.Debugf("resolver: failed to read packet: %s", err)
return err
}
message := new(dns.Msg)
if err = message.Unpack(buf[:n]); err != nil {
log.Debugf("resolver: failed to unpack message: %s", err)
continue
}
select {
case messages <- message:
case <-ctx.Done():
return nil
}
}
}
func queryMulticastDNS(ctx context.Context, q *Query) (*RRCache, error) {
// check for active connections
if unicast4Conn == nil && unicast6Conn == nil {
return nil, errors.New("unicast mdns connections not initialized")
}
// trace log
log.Tracer(ctx).Trace("resolver: resolving with mDNS")
// create query
dnsQuery := new(dns.Msg)
dnsQuery.SetQuestion(q.FQDN, uint16(q.QType))
// request unicast response
// q.Question[0].Qclass |= 1 << 15
dnsQuery.RecursionDesired = false
// create response channel
response := make(chan *RRCache)
// save question
questionsLock.Lock()
defer questionsLock.Unlock()
questions[dnsQuery.MsgHdr.Id] = &savedQuestion{
question: dnsQuery.Question[0],
expires: time.Now().Add(10 * time.Second),
response: response,
}
// pack qeury
buf, err := dnsQuery.Pack()
if err != nil {
return nil, fmt.Errorf("failed to pack query: %s", err)
}
// send queries
if unicast4Conn != nil && uint16(q.QType) != dns.TypeAAAA {
err = unicast4Conn.SetWriteDeadline(time.Now().Add(1 * time.Second))
if err != nil {
return nil, fmt.Errorf("failed to configure query (set timout): %s", err)
}
_, err = unicast4Conn.WriteToUDP(buf, &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251), Port: 5353})
if err != nil {
return nil, fmt.Errorf("failed to send query: %s", err)
}
}
if unicast6Conn != nil && uint16(q.QType) != dns.TypeA {
err = unicast6Conn.SetWriteDeadline(time.Now().Add(1 * time.Second))
if err != nil {
return nil, fmt.Errorf("failed to configure query (set timout): %s", err)
}
_, err = unicast6Conn.WriteToUDP(buf, &net.UDPAddr{IP: net.IP([]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb}), Port: 5353})
if err != nil {
return nil, fmt.Errorf("failed to send query: %s", err)
}
}
// wait for response or timeout
select {
case rrCache := <-response:
if rrCache != nil {
return rrCache, nil
}
case <-time.After(1 * time.Second):
// check cache again
rrCache, err := GetRRCache(q.FQDN, q.QType)
if err != nil {
return rrCache, nil
}
}
// Respond with NXDomain.
return &RRCache{
Domain: q.FQDN,
Question: q.QType,
RCode: dns.RcodeNameError,
Resolver: mDNSResolver.Info.Copy(),
}, nil
}
func cleanSavedQuestions() {
questionsLock.Lock()
defer questionsLock.Unlock()
now := time.Now()
for msgID, savedQuestion := range questions {
if now.After(savedQuestion.expires) {
delete(questions, msgID)
}
}
}