Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/custom lookup func #9

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 57 additions & 45 deletions dnscache.go
Original file line number Diff line number Diff line change
@@ -1,77 +1,89 @@
package dnscache

// Package dnscache caches DNS lookups

import (
"net"
"sync"
"time"
"net"
"sync"
"time"
)

var LookupFunc = net.LookupIP

type Resolver struct {
lock sync.RWMutex
cache map[string][]net.IP
lock sync.RWMutex
cache map[string][]net.IP
LookupFunc func(host string) ([]net.IP, error)
}

func New(refreshRate time.Duration) *Resolver {
resolver := &Resolver {
cache: make(map[string][]net.IP, 64),
}
if refreshRate > 0 {
go resolver.autoRefresh(refreshRate)
}
return resolver
resolver := &Resolver{
cache: make(map[string][]net.IP, 64),
}
if refreshRate > 0 {
go resolver.autoRefresh(refreshRate)
}
return resolver
}

func (r *Resolver) Fetch(address string) ([]net.IP, error) {
r.lock.RLock()
ips, exists := r.cache[address]
r.lock.RUnlock()
if exists { return ips, nil }
r.lock.RLock()
ips, exists := r.cache[address]
r.lock.RUnlock()
if exists {
return ips, nil
}

return r.Lookup(address)
return r.Lookup(address)
}

func (r *Resolver) FetchOne(address string) (net.IP, error) {
ips, err := r.Fetch(address)
if err != nil || len(ips) == 0 { return nil, err}
return ips[0], nil
ips, err := r.Fetch(address)
if err != nil || len(ips) == 0 {
return nil, err
}
return ips[0], nil
}

func (r *Resolver) FetchOneString(address string) (string, error) {
ip, err := r.FetchOne(address)
if err != nil || ip == nil { return "", err }
return ip.String(), nil
ip, err := r.FetchOne(address)
if err != nil || ip == nil {
return "", err
}
return ip.String(), nil
}

func (r *Resolver) Refresh() {
i := 0
r.lock.RLock()
addresses := make([]string, len(r.cache))
for key, _ := range r.cache {
addresses[i] = key
i++
}
r.lock.RUnlock()
i := 0
r.lock.RLock()
addresses := make([]string, len(r.cache))
for key, _ := range r.cache {
addresses[i] = key
i++
}
r.lock.RUnlock()

for _, address := range addresses {
r.Lookup(address)
time.Sleep(time.Second * 2)
}
for _, address := range addresses {
r.Lookup(address)
time.Sleep(time.Second * 2)
}
}

func (r *Resolver) Lookup(address string) ([]net.IP, error) {
ips, err := net.LookupIP(address)
if err != nil { return nil, err }
ips, err := LookupFunc(address)
if err != nil {
return nil, err
}

r.lock.Lock()
r.cache[address] = ips
r.lock.Unlock()
return ips, nil
r.lock.Lock()
r.cache[address] = ips
r.lock.Unlock()
return ips, nil
}

func (r *Resolver) autoRefresh(rate time.Duration) {
for {
time.Sleep(rate)
r.Refresh()
}
for {
time.Sleep(rate)
r.Refresh()
}
}
116 changes: 68 additions & 48 deletions dnscache_test.go
Original file line number Diff line number Diff line change
@@ -1,79 +1,99 @@
package dnscache

import (
"net"
"sort"
"time"
"testing"
"net"
"sort"
"testing"
"time"
)

var testIpList = []string{"1.123.58.14", "31.85.32.110"}
var testLookupFunc = func(host string) ([]net.IP, error) {
var ips []net.IP
for i := 0; i < len(testIpList); i += 1 {
ip := net.ParseIP(testIpList[i])
ips = append(ips, ip)
}
return ips, nil
}

func TestFetchReturnsAndErrorOnInvalidLookup(t *testing.T) {
ips, err := New(0).Lookup("invalid.viki.io")
if ips != nil {
t.Errorf("Expecting nil ips, got %v", ips)
}
expected := "lookup invalid.viki.io: no such host"
if err.Error() != expected {
t.Errorf("Expecting %q error, got %q", expected, err.Error())
}
ips, err := New(0).Lookup("invalid.viki.io")
if ips != nil {
t.Errorf("Expecting nil ips, got %v", ips)
}
expected := "lookup invalid.viki.io: no such host"
if err.Error() != expected {
t.Errorf("Expecting %q error, got %q", expected, err.Error())
}
}

func TestFetchReturnsAListOfIps(t *testing.T) {
ips, _ := New(0).Lookup("dnscache.go.test.viki.io")
assertIps(t, ips, []string{"1.123.58.13", "31.85.32.110"})
LookupFunc = testLookupFunc
ips, _ := New(0).Lookup("dnscache.go.test.viki.io")
assertIps(t, ips, testIpList)
LookupFunc = net.LookupIP
}

func TestCallingLookupAddsTheItemToTheCache(t *testing.T) {
r := New(0)
r.Lookup("dnscache.go.test.viki.io")
assertIps(t, r.cache["dnscache.go.test.viki.io"], []string{"1.123.58.13", "31.85.32.110"})
LookupFunc = testLookupFunc
r := New(0)
r.Lookup("dnscache.go.test.viki.io")
assertIps(t, r.cache["dnscache.go.test.viki.io"], testIpList)
LookupFunc = net.LookupIP
}

func TestFetchLoadsValueFromTheCache(t *testing.T) {
r := New(0)
r.cache["invalid.viki.io"] = []net.IP{net.ParseIP("1.1.2.3")}
ips, _ := r.Fetch("invalid.viki.io")
assertIps(t, ips, []string{"1.1.2.3"})
r := New(0)
r.cache["invalid.viki.io"] = []net.IP{net.ParseIP("1.1.2.3")}
ips, _ := r.Fetch("invalid.viki.io")
assertIps(t, ips, []string{"1.1.2.3"})
}

func TestFetchOneLoadsTheFirstValue(t *testing.T) {
r := New(0)
r.cache["something.viki.io"] = []net.IP{net.ParseIP("1.1.2.3"), net.ParseIP("100.100.102.103")}
ip, _ := r.FetchOne("something.viki.io")
assertIps(t, []net.IP{ip}, []string{"1.1.2.3"})
LookupFunc = testLookupFunc
r := New(0)
r.cache["something.viki.io"] = []net.IP{net.ParseIP("1.1.2.3"), net.ParseIP("100.100.102.103")}
ip, _ := r.FetchOne("something.viki.io")
assertIps(t, []net.IP{ip}, []string{"1.1.2.3"})
LookupFunc = net.LookupIP
}

func TestFetchOneStringLoadsTheFirstValue(t *testing.T) {
r := New(0)
r.cache["something.viki.io"] = []net.IP{net.ParseIP("100.100.102.103"), net.ParseIP("100.100.102.104")}
ip, _ := r.FetchOneString("something.viki.io")
if ip != "100.100.102.103" {
t.Errorf("expected 100.100.102.103 but got %v", ip)
}
r := New(0)
r.cache["something.viki.io"] = []net.IP{net.ParseIP("100.100.102.103"), net.ParseIP("100.100.102.104")}
ip, _ := r.FetchOneString("something.viki.io")
if ip != "100.100.102.103" {
t.Errorf("expected 100.100.102.103 but got %v", ip)
}
}

func TestFetchLoadsTheIpAndCachesIt(t *testing.T) {
r := New(0)
ips, _ := r.Fetch("dnscache.go.test.viki.io")
assertIps(t, ips, []string{"1.123.58.13", "31.85.32.110"})
assertIps(t, r.cache["dnscache.go.test.viki.io"], []string{"1.123.58.13", "31.85.32.110"})
LookupFunc = testLookupFunc
r := New(0)
ips, _ := r.Fetch("dnscache.go.test.viki.io")
assertIps(t, ips, testIpList)
assertIps(t, r.cache["dnscache.go.test.viki.io"], testIpList)
LookupFunc = net.LookupIP
}

func TestItReloadsTheIpsAtAGivenInterval(t *testing.T) {
r := New(1)
r.cache["dnscache.go.test.viki.io"] = nil
time.Sleep(time.Second * 2)
assertIps(t, r.cache["dnscache.go.test.viki.io"], []string{"1.123.58.13", "31.85.32.110"})
LookupFunc = testLookupFunc
r := New(1)
r.cache["dnscache.go.test.viki.io"] = nil
time.Sleep(time.Second * 2)
assertIps(t, r.cache["dnscache.go.test.viki.io"], testIpList)
LookupFunc = net.LookupIP
}

func assertIps(t *testing.T, actuals []net.IP, expected []string) {
if len(actuals) != len(expected) {
t.Errorf("Expecting %d ips, got %d", len(expected), len(actuals))
}
sort.Strings(expected)
for _, ip := range actuals {
if sort.SearchStrings(expected, ip.String()) == -1 {
t.Errorf("Got an unexpected ip: %v:", actuals[0])
}
}
if len(actuals) != len(expected) {
t.Errorf("Expecting %d ips, got %d", len(expected), len(actuals))
}
sort.Strings(expected)
for _, ip := range actuals {
if sort.SearchStrings(expected, ip.String()) == -1 {
t.Errorf("Got an unexpected ip: %v:", actuals[0])
}
}
}