From 6f545a42966b7958682cff3aae1bd2eaa8a50a4b Mon Sep 17 00:00:00 2001 From: Alvaro Saurin Date: Thu, 9 Apr 2015 15:53:56 +0200 Subject: [PATCH] Fix for negative TTLs by using a different lookups set depending on the value stored in the cache. --- nameserver/cache.go | 84 ++++++++-------------------------------- nameserver/cache_test.go | 76 ++++++------------------------------ nameserver/server.go | 42 ++++++++++++-------- 3 files changed, 53 insertions(+), 149 deletions(-) diff --git a/nameserver/cache.go b/nameserver/cache.go index fc7db41cb6..a27226d44c 100644 --- a/nameserver/cache.go +++ b/nameserver/cache.go @@ -22,6 +22,8 @@ const ( defPendingTimeout int = 5 // timeout for a resolution ) +const nullTTL = 0 // a null TTL + type entryStatus uint8 const ( @@ -61,24 +63,15 @@ type cacheEntry struct { validUntil time.Time // obtained from the reply and stored here for convenience/speed putTime time.Time - waitChan chan struct{} - index int // for fast lookups in the heap } -func newCacheEntry(question *dns.Question, reply *dns.Msg, status entryStatus, flags uint8, now time.Time) *cacheEntry { +func newCacheEntry(question *dns.Question, now time.Time) *cacheEntry { e := &cacheEntry{ - Status: status, - Flags: flags, - question: *question, - index: -1, - } - - if e.Status == stPending { - e.validUntil = now.Add(time.Duration(defPendingTimeout) * time.Second) - e.waitChan = make(chan struct{}) - } else { - e.setReply(reply, flags, now) + Status: stPending, + validUntil: now.Add(time.Second * time.Duration(defPendingTimeout)), + question: *question, + index: -1, } return e @@ -133,9 +126,7 @@ func (e cacheEntry) hasExpired(now time.Time) bool { // set the reply for the entry // returns True if the entry has changed the validUntil time -func (e *cacheEntry) setReply(reply *dns.Msg, flags uint8, now time.Time) bool { - shouldNotify := (e.Status == stPending) - +func (e *cacheEntry) setReply(reply *dns.Msg, ttl int, flags uint8, now time.Time) bool { var prevValidUntil time.Time if e.Status == stResolved { prevValidUntil = e.validUntil @@ -145,10 +136,9 @@ func (e *cacheEntry) setReply(reply *dns.Msg, flags uint8, now time.Time) bool { e.Flags = flags e.putTime = now - if e.Flags&CacheNoLocalReplies != 0 { - // use a fixed timeout for negative local resolutions - e.validUntil = now.Add(time.Second * time.Duration(negLocalTTL)) - } else { + if ttl != nullTTL { + e.validUntil = now.Add(time.Second * time.Duration(ttl)) + } else if reply != nil { // calculate the validUntil from the reply TTL var minTTL uint32 = math.MaxUint32 for _, rr := range reply.Answer { @@ -165,37 +155,9 @@ func (e *cacheEntry) setReply(reply *dns.Msg, flags uint8, now time.Time) bool { e.ReplyLen = reply.Len() } - if shouldNotify { - close(e.waitChan) // notify all the waiters by closing the channel - } - return (prevValidUntil != e.validUntil) } -// wait until a valid reply is set in the cache -func (e *cacheEntry) waitReply(request *dns.Msg, timeout time.Duration, maxLen int, now time.Time) (*dns.Msg, error) { - if e.Status == stResolved { - return e.getReply(request, maxLen, now) - } - - if timeout > 0 { - select { - case <-e.waitChan: - return e.getReply(request, maxLen, now) - case <-time.After(time.Second * timeout): - return nil, errTimeout - } - } - - return nil, errCouldNotResolve -} - -func (e *cacheEntry) close() { - if e.Status == stPending { - close(e.waitChan) - } -} - ////////////////////////////////////////////////////////////////////////////////////// // An entriesPtrHeap is a min-heap of cache entries. @@ -280,7 +242,7 @@ func (c *Cache) Purge(now time.Time) { } // Add adds a reply to the cache. -func (c *Cache) Put(request *dns.Msg, reply *dns.Msg, flags uint8, now time.Time) int { +func (c *Cache) Put(request *dns.Msg, reply *dns.Msg, ttl int, flags uint8, now time.Time) int { c.lock.Lock() defer c.lock.Unlock() @@ -288,8 +250,7 @@ func (c *Cache) Put(request *dns.Msg, reply *dns.Msg, flags uint8, now time.Time key := cacheKey(question) ent, found := c.entries[key] if found { - Debug.Printf("[cache msgid %d] replacing response in cache", request.MsgHdr.Id) - updated := ent.setReply(reply, flags, now) + updated := ent.setReply(reply, ttl, flags, now) if updated { heap.Fix(&c.entriesH, ent.index) } @@ -297,10 +258,10 @@ func (c *Cache) Put(request *dns.Msg, reply *dns.Msg, flags uint8, now time.Time // If we will add a new item and the capacity has been exceeded, make some room... if len(c.entriesH) >= c.Capacity { lowestEntry := heap.Pop(&c.entriesH).(*cacheEntry) - lowestEntry.close() delete(c.entries, cacheKey(lowestEntry.question)) } - ent = newCacheEntry(&question, reply, stResolved, flags, now) + ent = newCacheEntry(&question, now) + ent.setReply(reply, ttl, flags, now) heap.Push(&c.entriesH, ent) c.entries[key] = ent } @@ -329,20 +290,7 @@ func (c *Cache) Get(request *dns.Msg, maxLen int, now time.Time) (reply *dns.Msg } else { // we are the first asking for this name: create an entry with no reply... the caller must wait Debug.Printf("[cache msgid %d] addind in pending state", request.MsgHdr.Id) - c.entries[key] = newCacheEntry(&question, nil, stPending, 0, now) - } - return -} - -// Wait for a reply for a question in the cache -// Notice that the caller could Get() and then Wait() for a question, but the corresponding cache -// entry could have been removed in between. In that case, the caller should retry the query (and -// the user should increase the cache size!) -func (c *Cache) Wait(request *dns.Msg, timeout time.Duration, maxLen int, now time.Time) (reply *dns.Msg, err error) { - // do not try to lock the cache: otherwise, no one else could `Put()` the reply - question := request.Question[0] - if entry, found := c.entries[cacheKey(question)]; found { - reply, err = entry.waitReply(request, timeout, maxLen, now) + c.entries[key] = newCacheEntry(&question, now) } return } diff --git a/nameserver/cache_test.go b/nameserver/cache_test.go index bccb8d327a..95738757e5 100644 --- a/nameserver/cache_test.go +++ b/nameserver/cache_test.go @@ -34,7 +34,7 @@ func TestCacheLength(t *testing.T) { reply := makeAddressReply(questionMsg, question, ips) reply.Answer[0].Header().Ttl = uint32(i) - l.Put(questionMsg, reply, 0, insTime) + l.Put(questionMsg, reply, 0, 0, insTime) } wt.AssertEqualInt(t, l.Len(), cacheLen, "cache length") @@ -69,20 +69,20 @@ func TestCacheEntries(t *testing.T) { resp, err := l.Get(questionMsg, minUDPSize, time.Now()) wt.AssertNoErr(t, err) if resp != nil { - t.Logf("Got '%s'", resp) + t.Logf("Got\n%s", resp) t.Fatalf("ERROR: Did not expect a reponse from Get() yet") } t.Logf("Trying to get it again") resp, err = l.Get(questionMsg, minUDPSize, time.Now()) wt.AssertNoErr(t, err) if resp != nil { - t.Logf("Got '%s'", resp) + t.Logf("Got\n%s", resp) t.Fatalf("ERROR: Did not expect a reponse from Get() yet") } t.Logf("Inserting the reply") reply1 := makeAddressReply(questionMsg, question, []net.IP{net.ParseIP("10.0.1.1")}) - l.Put(questionMsg, reply1, 0, time.Now()) + l.Put(questionMsg, reply1, nullTTL, 0, time.Now()) timeGet1 := time.Now() t.Logf("Checking we can Get() the reply now") @@ -93,13 +93,6 @@ func TestCacheEntries(t *testing.T) { wt.AssertType(t, resp.Answer[0], (*dns.A)(nil), "DNS record") ttlGet1 := resp.Answer[0].Header().Ttl - t.Logf("Checking a Wait() with timeout=0 gets the same result") - resp, err = l.Wait(questionMsg, time.Duration(0)*time.Second, minUDPSize, time.Now()) - wt.AssertNoErr(t, err) - wt.AssertTrue(t, resp != nil, "reponse from a Wait(timeout=0)") - t.Logf("Received '%s'", resp.Answer[0]) - wt.AssertType(t, resp.Answer[0], (*dns.A)(nil), "DNS record") - timeGet2 := timeGet1.Add(time.Duration(1) * time.Second) t.Logf("Checking that a second Get(), after 1 second, gets the same result, but with reduced TTL") resp, err = l.Get(questionMsg, minUDPSize, timeGet2) @@ -115,13 +108,13 @@ func TestCacheEntries(t *testing.T) { resp, err = l.Get(questionMsg, minUDPSize, timeGet3) wt.AssertNoErr(t, err) if resp != nil { - t.Logf("Got '%s'", resp) + t.Logf("Got\n%s", resp) t.Fatalf("ERROR: Did NOT expect a reponse from the second Get()") } t.Logf("Checking that an Remove() results in Get() returning nothing") replyTemp := makeAddressReply(questionMsg, question, []net.IP{net.ParseIP("10.0.9.9")}) - l.Put(questionMsg, replyTemp, 0, time.Now()) + l.Put(questionMsg, replyTemp, nullTTL, 0, time.Now()) lenBefore := l.Len() l.Remove(question) wt.AssertEqualInt(t, l.Len(), lenBefore-1, "cache length") @@ -135,10 +128,10 @@ func TestCacheEntries(t *testing.T) { t.Logf("Inserting a two replies for the same query") timePut2 := time.Now() reply2 := makeAddressReply(questionMsg, question, []net.IP{net.ParseIP("10.0.1.2")}) - l.Put(questionMsg, reply2, 0, timePut2) + l.Put(questionMsg, reply2, nullTTL, 0, timePut2) timePut3 := timePut2.Add(time.Duration(1) * time.Second) reply3 := makeAddressReply(questionMsg, question, []net.IP{net.ParseIP("10.0.1.3")}) - l.Put(questionMsg, reply3, 0, timePut3) + l.Put(questionMsg, reply3, nullTTL, 0, timePut3) t.Logf("Checking we get the last one...") resp, err = l.Get(questionMsg, minUDPSize, timePut3) @@ -162,7 +155,7 @@ func TestCacheEntries(t *testing.T) { resp, err = l.Get(questionMsg, minUDPSize, timePut3.Add(time.Duration(localTTL)*time.Second)) wt.AssertNoErr(t, err) if resp != nil { - t.Logf("Received '%s'", resp.Answer[0]) + t.Logf("Got\n%s", resp.Answer[0]) t.Fatalf("ERROR: Did NOT expect a reponse from the Get()") } wt.AssertEqualInt(t, l.Len(), lenBefore-1, "cache length (after getting an expired entry)") @@ -180,13 +173,10 @@ func TestCacheEntries(t *testing.T) { t.Logf("Checking that an Remove() between Get() and Put() does not break things") replyTemp2 := makeAddressReply(questionMsg2, question2, []net.IP{net.ParseIP("10.0.9.9")}) l.Remove(question2) - l.Put(questionMsg2, replyTemp2, 0, time.Now()) + l.Put(questionMsg2, replyTemp2, nullTTL, 0, time.Now()) resp, err = l.Get(questionMsg2, minUDPSize, time.Now()) wt.AssertNoErr(t, err) wt.AssertNotNil(t, resp, "reponse from Get()") - resp, err = l.Wait(questionMsg2, time.Duration(0)*time.Second, minUDPSize, time.Now()) - wt.AssertNoErr(t, err) - wt.AssertNotNil(t, resp, "reponse from Get()") questionMsg3 := new(dns.Msg) questionMsg3.SetQuestion("some.other.name", dns.TypeA) @@ -195,7 +185,7 @@ func TestCacheEntries(t *testing.T) { t.Logf("Checking that a entry with CacheNoLocalReplies return an error") timePut3 = time.Now() - l.Put(questionMsg3, nil, CacheNoLocalReplies, timePut3) + l.Put(questionMsg3, nil, nullTTL, CacheNoLocalReplies, timePut3) resp, err = l.Get(questionMsg3, minUDPSize, timePut3) wt.AssertNil(t, resp, "Get() response with CacheNoLocalReplies") wt.AssertNotNil(t, err, "Get() error with CacheNoLocalReplies") @@ -208,51 +198,9 @@ func TestCacheEntries(t *testing.T) { l.Remove(question3) t.Logf("Checking that Put&Get with CacheNoLocalReplies with a Remove in the middle returns nothing") - l.Put(questionMsg3, nil, CacheNoLocalReplies, time.Now()) + l.Put(questionMsg3, nil, nullTTL, CacheNoLocalReplies, time.Now()) l.Remove(question3) resp, err = l.Get(questionMsg3, minUDPSize, time.Now()) wt.AssertNil(t, resp, "Get() reponse with CacheNoLocalReplies") wt.AssertNil(t, err, "Get() error with CacheNoLocalReplies") } - -// Check that waiters are unblocked when the name they are waiting for is inserted -func TestCacheBlockingOps(t *testing.T) { - InitDefaultLogging(true) - - const cacheLen = 256 - - l, err := NewCache(cacheLen) - wt.AssertNoErr(t, err) - - requests := []*dns.Msg{} - - t.Logf("Starting 256 queries that will block...") - for i := 0; i < cacheLen; i++ { - questionName := fmt.Sprintf("name%d", i) - questionMsg := new(dns.Msg) - questionMsg.SetQuestion(questionName, dns.TypeA) - questionMsg.RecursionDesired = true - - requests = append(requests, questionMsg) - - go func(request *dns.Msg) { - t.Logf("Querying about %s...", request.Question[0].Name) - _, err := l.Get(request, minUDPSize, time.Now()) - wt.AssertNoErr(t, err) - t.Logf("Waiting for %s...", request.Question[0].Name) - r, err := l.Wait(request, 1*time.Second, minUDPSize, time.Now()) - t.Logf("Obtained response for %s:\n%s", request.Question[0].Name, r) - wt.AssertNoErr(t, err) - }(questionMsg) - } - - // insert the IPs for those names - for i, requestMsg := range requests { - ip := net.ParseIP(fmt.Sprintf("10.0.1.%d", i)) - ips := []net.IP{ip} - reply := makeAddressReply(requestMsg, &requestMsg.Question[0], ips) - - t.Logf("Inserting response for %s...", requestMsg.Question[0].Name) - l.Put(requestMsg, reply, 0, time.Now()) - } -} diff --git a/nameserver/server.go b/nameserver/server.go index f0b5247967..ff05be7218 100644 --- a/nameserver/server.go +++ b/nameserver/server.go @@ -140,8 +140,8 @@ func NewDNSServer(config DNSServerConfig, zone Zone, iface *net.Interface) (s *D // (we use the same protocol for asking upstream servers) mux := func(proto dnsProtocol) *dns.ServeMux { m := dns.NewServeMux() - m.HandleFunc(s.Domain, s.queryHandler([]Lookup{s.Zone, s.mdnsCli}, proto)) - m.HandleFunc(RDNSDomain, s.rdnsHandler([]Lookup{s.Zone, s.mdnsCli}, proto)) + m.HandleFunc(s.Domain, s.queryHandler(proto)) + m.HandleFunc(RDNSDomain, s.rdnsHandler(proto)) m.HandleFunc(".", s.notUsHandler(proto)) return m } @@ -215,26 +215,32 @@ func (s *DNSServer) Stop() error { return nil } -func (s *DNSServer) queryHandler(lookups []Lookup, proto dnsProtocol) dns.HandlerFunc { +func (s *DNSServer) queryHandler(proto dnsProtocol) dns.HandlerFunc { return func(w dns.ResponseWriter, r *dns.Msg) { now := time.Now() q := r.Question[0] maxLen := getMaxReplyLen(r, proto) + lookups := []Lookup{s.Zone, s.mdnsCli} Debug.Printf("Query: %+v", q) if q.Qtype != dns.TypeA { Debug.Printf("[dns msgid %d] Unsuported query type %s", r.MsgHdr.Id, dns.TypeToString[q.Qtype]) m := makeDNSNotImplResponse(r) - s.cache.Put(r, m, 0, now) + s.cache.Put(r, m, negLocalTTL, 0, now) w.WriteMsg(m) return } reply, err := s.cache.Get(r, maxLen, time.Now()) if err != nil { - Debug.Printf("[dns msgid %d] Error from cache: %s", r.MsgHdr.Id, err) - w.WriteMsg(makeDNSFailResponse(r)) - return + if err == errNoLocalReplies { + Debug.Printf("[dns msgid %d] Cached 'no local replies' - skipping local lookup", r.MsgHdr.Id) + lookups = []Lookup{s.Zone} + } else { + Debug.Printf("[dns msgid %d] Error from cache: %s", r.MsgHdr.Id, err) + w.WriteMsg(makeDNSFailResponse(r)) + return + } } if reply != nil { Debug.Printf("[dns msgid %d] Returning reply from cache: %s/%d answers", @@ -248,7 +254,7 @@ func (s *DNSServer) queryHandler(lookups []Lookup, proto dnsProtocol) dns.Handle Debug.Printf("[dns msgid %d] Caching response for %s-query for \"%s\": %s [code:%s]", m.MsgHdr.Id, dns.TypeToString[q.Qtype], q.Name, ip, dns.RcodeToString[m.Rcode]) - s.cache.Put(r, m, 0, now) + s.cache.Put(r, m, nullTTL, 0, now) w.WriteMsg(m) return } @@ -257,17 +263,17 @@ func (s *DNSServer) queryHandler(lookups []Lookup, proto dnsProtocol) dns.Handle Info.Printf("[dns msgid %d] No results for type %s query %s", r.MsgHdr.Id, dns.TypeToString[q.Qtype], q.Name) - m := makeDNSFailResponse(r) - s.cache.Put(r, m, 0, now) - w.WriteMsg(m) + s.cache.Put(r, nil, negLocalTTL, CacheNoLocalReplies, now) + w.WriteMsg(makeDNSFailResponse(r)) } } -func (s *DNSServer) rdnsHandler(lookups []Lookup, proto dnsProtocol) dns.HandlerFunc { +func (s *DNSServer) rdnsHandler(proto dnsProtocol) dns.HandlerFunc { fallback := s.notUsHandler(proto) return func(w dns.ResponseWriter, r *dns.Msg) { now := time.Now() q := r.Question[0] + lookups := []Lookup{s.Zone, s.mdnsCli} maxLen := getMaxReplyLen(r, proto) Debug.Printf("Reverse query: %+v", q) @@ -275,7 +281,7 @@ func (s *DNSServer) rdnsHandler(lookups []Lookup, proto dnsProtocol) dns.Handler Warning.Printf("[dns msgid %d] Unexpected reverse query type %s: %+v", r.MsgHdr.Id, dns.TypeToString[q.Qtype], q) m := makeDNSNotImplResponse(r) - s.cache.Put(r, m, 0, now) + s.cache.Put(r, m, negLocalTTL, 0, now) w.WriteMsg(m) return } @@ -284,12 +290,12 @@ func (s *DNSServer) rdnsHandler(lookups []Lookup, proto dnsProtocol) dns.Handler if err != nil { if err == errNoLocalReplies { Debug.Printf("[dns msgid %d] Cached 'no local replies' - skipping local lookup", r.MsgHdr.Id) - fallback(w, r) + lookups = []Lookup{s.Zone} } else { Debug.Printf("[dns msgid %d] Error from cache: %s", r.MsgHdr.Id, err) w.WriteMsg(makeDNSFailResponse(r)) + return } - return } if reply != nil { Debug.Printf("[dns msgid %d] Returning reply from cache: %s/%d answers", @@ -302,14 +308,16 @@ func (s *DNSServer) rdnsHandler(lookups []Lookup, proto dnsProtocol) dns.Handler m := makePTRReply(r, &q, []string{name}) Debug.Printf("[dns msgid %d] Caching response for %s-query for \"%s\": %s [code:%s]", m.MsgHdr.Id, dns.TypeToString[q.Qtype], q.Name, name, dns.RcodeToString[m.Rcode]) - s.cache.Put(r, m, 0, now) + s.cache.Put(r, m, nullTTL, 0, now) w.WriteMsg(m) return } now = time.Now() } - s.cache.Put(r, nil, CacheNoLocalReplies, now) + Info.Printf("[dns msgid %d] No results for %s-query about '%s' [caching no-local] -> sending to fallback server", + r.MsgHdr.Id, dns.TypeToString[q.Qtype], q.Name) + s.cache.Put(r, nil, negLocalTTL, CacheNoLocalReplies, now) fallback(w, r) } }