Skip to content

Commit

Permalink
Merge pull request #501 from inercia/weave-225-fix-1
Browse files Browse the repository at this point in the history
Fix for negative TTLs in WeaveDNS
  • Loading branch information
awh committed Apr 13, 2015
2 parents 01af41c + 6f545a4 commit 29b3b91
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 149 deletions.
84 changes: 16 additions & 68 deletions nameserver/cache.go
Expand Up @@ -22,6 +22,8 @@ const (
defPendingTimeout int = 5 // timeout for a resolution
)

const nullTTL = 0 // a null TTL

type entryStatus uint8

const (
Expand Down Expand Up @@ -61,24 +63,15 @@ type cacheEntry struct {
validUntil time.Time // obtained from the reply and stored here for convenience/speed
putTime time.Time

waitChan chan struct{}

index int // for fast lookups in the heap
}

func newCacheEntry(question *dns.Question, reply *dns.Msg, status entryStatus, flags uint8, now time.Time) *cacheEntry {
func newCacheEntry(question *dns.Question, now time.Time) *cacheEntry {
e := &cacheEntry{
Status: status,
Flags: flags,
question: *question,
index: -1,
}

if e.Status == stPending {
e.validUntil = now.Add(time.Duration(defPendingTimeout) * time.Second)
e.waitChan = make(chan struct{})
} else {
e.setReply(reply, flags, now)
Status: stPending,
validUntil: now.Add(time.Second * time.Duration(defPendingTimeout)),
question: *question,
index: -1,
}

return e
Expand Down Expand Up @@ -133,9 +126,7 @@ func (e cacheEntry) hasExpired(now time.Time) bool {

// set the reply for the entry
// returns True if the entry has changed the validUntil time
func (e *cacheEntry) setReply(reply *dns.Msg, flags uint8, now time.Time) bool {
shouldNotify := (e.Status == stPending)

func (e *cacheEntry) setReply(reply *dns.Msg, ttl int, flags uint8, now time.Time) bool {
var prevValidUntil time.Time
if e.Status == stResolved {
prevValidUntil = e.validUntil
Expand All @@ -145,10 +136,9 @@ func (e *cacheEntry) setReply(reply *dns.Msg, flags uint8, now time.Time) bool {
e.Flags = flags
e.putTime = now

if e.Flags&CacheNoLocalReplies != 0 {
// use a fixed timeout for negative local resolutions
e.validUntil = now.Add(time.Second * time.Duration(negLocalTTL))
} else {
if ttl != nullTTL {
e.validUntil = now.Add(time.Second * time.Duration(ttl))
} else if reply != nil {
// calculate the validUntil from the reply TTL
var minTTL uint32 = math.MaxUint32
for _, rr := range reply.Answer {
Expand All @@ -165,37 +155,9 @@ func (e *cacheEntry) setReply(reply *dns.Msg, flags uint8, now time.Time) bool {
e.ReplyLen = reply.Len()
}

if shouldNotify {
close(e.waitChan) // notify all the waiters by closing the channel
}

return (prevValidUntil != e.validUntil)
}

// wait until a valid reply is set in the cache
func (e *cacheEntry) waitReply(request *dns.Msg, timeout time.Duration, maxLen int, now time.Time) (*dns.Msg, error) {
if e.Status == stResolved {
return e.getReply(request, maxLen, now)
}

if timeout > 0 {
select {
case <-e.waitChan:
return e.getReply(request, maxLen, now)
case <-time.After(time.Second * timeout):
return nil, errTimeout
}
}

return nil, errCouldNotResolve
}

func (e *cacheEntry) close() {
if e.Status == stPending {
close(e.waitChan)
}
}

//////////////////////////////////////////////////////////////////////////////////////

// An entriesPtrHeap is a min-heap of cache entries.
Expand Down Expand Up @@ -280,27 +242,26 @@ func (c *Cache) Purge(now time.Time) {
}

// Add adds a reply to the cache.
func (c *Cache) Put(request *dns.Msg, reply *dns.Msg, flags uint8, now time.Time) int {
func (c *Cache) Put(request *dns.Msg, reply *dns.Msg, ttl int, flags uint8, now time.Time) int {
c.lock.Lock()
defer c.lock.Unlock()

question := request.Question[0]
key := cacheKey(question)
ent, found := c.entries[key]
if found {
Debug.Printf("[cache msgid %d] replacing response in cache", request.MsgHdr.Id)
updated := ent.setReply(reply, flags, now)
updated := ent.setReply(reply, ttl, flags, now)
if updated {
heap.Fix(&c.entriesH, ent.index)
}
} else {
// If we will add a new item and the capacity has been exceeded, make some room...
if len(c.entriesH) >= c.Capacity {
lowestEntry := heap.Pop(&c.entriesH).(*cacheEntry)
lowestEntry.close()
delete(c.entries, cacheKey(lowestEntry.question))
}
ent = newCacheEntry(&question, reply, stResolved, flags, now)
ent = newCacheEntry(&question, now)
ent.setReply(reply, ttl, flags, now)
heap.Push(&c.entriesH, ent)
c.entries[key] = ent
}
Expand Down Expand Up @@ -329,20 +290,7 @@ func (c *Cache) Get(request *dns.Msg, maxLen int, now time.Time) (reply *dns.Msg
} else {
// we are the first asking for this name: create an entry with no reply... the caller must wait
Debug.Printf("[cache msgid %d] addind in pending state", request.MsgHdr.Id)
c.entries[key] = newCacheEntry(&question, nil, stPending, 0, now)
}
return
}

// Wait for a reply for a question in the cache
// Notice that the caller could Get() and then Wait() for a question, but the corresponding cache
// entry could have been removed in between. In that case, the caller should retry the query (and
// the user should increase the cache size!)
func (c *Cache) Wait(request *dns.Msg, timeout time.Duration, maxLen int, now time.Time) (reply *dns.Msg, err error) {
// do not try to lock the cache: otherwise, no one else could `Put()` the reply
question := request.Question[0]
if entry, found := c.entries[cacheKey(question)]; found {
reply, err = entry.waitReply(request, timeout, maxLen, now)
c.entries[key] = newCacheEntry(&question, now)
}
return
}
Expand Down
76 changes: 12 additions & 64 deletions nameserver/cache_test.go
Expand Up @@ -34,7 +34,7 @@ func TestCacheLength(t *testing.T) {
reply := makeAddressReply(questionMsg, question, ips)
reply.Answer[0].Header().Ttl = uint32(i)

l.Put(questionMsg, reply, 0, insTime)
l.Put(questionMsg, reply, 0, 0, insTime)
}

wt.AssertEqualInt(t, l.Len(), cacheLen, "cache length")
Expand Down Expand Up @@ -69,20 +69,20 @@ func TestCacheEntries(t *testing.T) {
resp, err := l.Get(questionMsg, minUDPSize, time.Now())
wt.AssertNoErr(t, err)
if resp != nil {
t.Logf("Got '%s'", resp)
t.Logf("Got\n%s", resp)
t.Fatalf("ERROR: Did not expect a reponse from Get() yet")
}
t.Logf("Trying to get it again")
resp, err = l.Get(questionMsg, minUDPSize, time.Now())
wt.AssertNoErr(t, err)
if resp != nil {
t.Logf("Got '%s'", resp)
t.Logf("Got\n%s", resp)
t.Fatalf("ERROR: Did not expect a reponse from Get() yet")
}

t.Logf("Inserting the reply")
reply1 := makeAddressReply(questionMsg, question, []net.IP{net.ParseIP("10.0.1.1")})
l.Put(questionMsg, reply1, 0, time.Now())
l.Put(questionMsg, reply1, nullTTL, 0, time.Now())

timeGet1 := time.Now()
t.Logf("Checking we can Get() the reply now")
Expand All @@ -93,13 +93,6 @@ func TestCacheEntries(t *testing.T) {
wt.AssertType(t, resp.Answer[0], (*dns.A)(nil), "DNS record")
ttlGet1 := resp.Answer[0].Header().Ttl

t.Logf("Checking a Wait() with timeout=0 gets the same result")
resp, err = l.Wait(questionMsg, time.Duration(0)*time.Second, minUDPSize, time.Now())
wt.AssertNoErr(t, err)
wt.AssertTrue(t, resp != nil, "reponse from a Wait(timeout=0)")
t.Logf("Received '%s'", resp.Answer[0])
wt.AssertType(t, resp.Answer[0], (*dns.A)(nil), "DNS record")

timeGet2 := timeGet1.Add(time.Duration(1) * time.Second)
t.Logf("Checking that a second Get(), after 1 second, gets the same result, but with reduced TTL")
resp, err = l.Get(questionMsg, minUDPSize, timeGet2)
Expand All @@ -115,13 +108,13 @@ func TestCacheEntries(t *testing.T) {
resp, err = l.Get(questionMsg, minUDPSize, timeGet3)
wt.AssertNoErr(t, err)
if resp != nil {
t.Logf("Got '%s'", resp)
t.Logf("Got\n%s", resp)
t.Fatalf("ERROR: Did NOT expect a reponse from the second Get()")
}

t.Logf("Checking that an Remove() results in Get() returning nothing")
replyTemp := makeAddressReply(questionMsg, question, []net.IP{net.ParseIP("10.0.9.9")})
l.Put(questionMsg, replyTemp, 0, time.Now())
l.Put(questionMsg, replyTemp, nullTTL, 0, time.Now())
lenBefore := l.Len()
l.Remove(question)
wt.AssertEqualInt(t, l.Len(), lenBefore-1, "cache length")
Expand All @@ -135,10 +128,10 @@ func TestCacheEntries(t *testing.T) {
t.Logf("Inserting a two replies for the same query")
timePut2 := time.Now()
reply2 := makeAddressReply(questionMsg, question, []net.IP{net.ParseIP("10.0.1.2")})
l.Put(questionMsg, reply2, 0, timePut2)
l.Put(questionMsg, reply2, nullTTL, 0, timePut2)
timePut3 := timePut2.Add(time.Duration(1) * time.Second)
reply3 := makeAddressReply(questionMsg, question, []net.IP{net.ParseIP("10.0.1.3")})
l.Put(questionMsg, reply3, 0, timePut3)
l.Put(questionMsg, reply3, nullTTL, 0, timePut3)

t.Logf("Checking we get the last one...")
resp, err = l.Get(questionMsg, minUDPSize, timePut3)
Expand All @@ -162,7 +155,7 @@ func TestCacheEntries(t *testing.T) {
resp, err = l.Get(questionMsg, minUDPSize, timePut3.Add(time.Duration(localTTL)*time.Second))
wt.AssertNoErr(t, err)
if resp != nil {
t.Logf("Received '%s'", resp.Answer[0])
t.Logf("Got\n%s", resp.Answer[0])
t.Fatalf("ERROR: Did NOT expect a reponse from the Get()")
}
wt.AssertEqualInt(t, l.Len(), lenBefore-1, "cache length (after getting an expired entry)")
Expand All @@ -180,13 +173,10 @@ func TestCacheEntries(t *testing.T) {
t.Logf("Checking that an Remove() between Get() and Put() does not break things")
replyTemp2 := makeAddressReply(questionMsg2, question2, []net.IP{net.ParseIP("10.0.9.9")})
l.Remove(question2)
l.Put(questionMsg2, replyTemp2, 0, time.Now())
l.Put(questionMsg2, replyTemp2, nullTTL, 0, time.Now())
resp, err = l.Get(questionMsg2, minUDPSize, time.Now())
wt.AssertNoErr(t, err)
wt.AssertNotNil(t, resp, "reponse from Get()")
resp, err = l.Wait(questionMsg2, time.Duration(0)*time.Second, minUDPSize, time.Now())
wt.AssertNoErr(t, err)
wt.AssertNotNil(t, resp, "reponse from Get()")

questionMsg3 := new(dns.Msg)
questionMsg3.SetQuestion("some.other.name", dns.TypeA)
Expand All @@ -195,7 +185,7 @@ func TestCacheEntries(t *testing.T) {

t.Logf("Checking that a entry with CacheNoLocalReplies return an error")
timePut3 = time.Now()
l.Put(questionMsg3, nil, CacheNoLocalReplies, timePut3)
l.Put(questionMsg3, nil, nullTTL, CacheNoLocalReplies, timePut3)
resp, err = l.Get(questionMsg3, minUDPSize, timePut3)
wt.AssertNil(t, resp, "Get() response with CacheNoLocalReplies")
wt.AssertNotNil(t, err, "Get() error with CacheNoLocalReplies")
Expand All @@ -208,51 +198,9 @@ func TestCacheEntries(t *testing.T) {

l.Remove(question3)
t.Logf("Checking that Put&Get with CacheNoLocalReplies with a Remove in the middle returns nothing")
l.Put(questionMsg3, nil, CacheNoLocalReplies, time.Now())
l.Put(questionMsg3, nil, nullTTL, CacheNoLocalReplies, time.Now())
l.Remove(question3)
resp, err = l.Get(questionMsg3, minUDPSize, time.Now())
wt.AssertNil(t, resp, "Get() reponse with CacheNoLocalReplies")
wt.AssertNil(t, err, "Get() error with CacheNoLocalReplies")
}

// Check that waiters are unblocked when the name they are waiting for is inserted
func TestCacheBlockingOps(t *testing.T) {
InitDefaultLogging(true)

const cacheLen = 256

l, err := NewCache(cacheLen)
wt.AssertNoErr(t, err)

requests := []*dns.Msg{}

t.Logf("Starting 256 queries that will block...")
for i := 0; i < cacheLen; i++ {
questionName := fmt.Sprintf("name%d", i)
questionMsg := new(dns.Msg)
questionMsg.SetQuestion(questionName, dns.TypeA)
questionMsg.RecursionDesired = true

requests = append(requests, questionMsg)

go func(request *dns.Msg) {
t.Logf("Querying about %s...", request.Question[0].Name)
_, err := l.Get(request, minUDPSize, time.Now())
wt.AssertNoErr(t, err)
t.Logf("Waiting for %s...", request.Question[0].Name)
r, err := l.Wait(request, 1*time.Second, minUDPSize, time.Now())
t.Logf("Obtained response for %s:\n%s", request.Question[0].Name, r)
wt.AssertNoErr(t, err)
}(questionMsg)
}

// insert the IPs for those names
for i, requestMsg := range requests {
ip := net.ParseIP(fmt.Sprintf("10.0.1.%d", i))
ips := []net.IP{ip}
reply := makeAddressReply(requestMsg, &requestMsg.Question[0], ips)

t.Logf("Inserting response for %s...", requestMsg.Question[0].Name)
l.Put(requestMsg, reply, 0, time.Now())
}
}

0 comments on commit 29b3b91

Please sign in to comment.