Skip to content

Commit

Permalink
Add mutex to cached liveness scanner to prevent race conditions (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmwample committed Nov 15, 2021
1 parent fd58c5a commit bed93bb
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 16 deletions.
65 changes: 49 additions & 16 deletions application/liveness/liveness.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"os/exec"
"strconv"
"sync"
"time"
)

Expand All @@ -33,6 +34,7 @@ type CachedLivenessTester struct {
ipCache map[string]cacheElement
signal chan bool
cacheExpirationTime time.Duration
m sync.RWMutex
}

// UncachedLivenessTester implements LivenessTester interface without caching,
Expand All @@ -42,6 +44,9 @@ type UncachedLivenessTester struct {

// Init parses cache expiry duration and initializes the Cache.
func (blt *CachedLivenessTester) Init(expirationTime string) error {
blt.m.Lock()
defer blt.m.Unlock()

blt.ipCache = make(map[string]cacheElement)
blt.signal = make(chan bool)

Expand All @@ -62,6 +67,9 @@ func (blt *CachedLivenessTester) Stop() {

// ClearExpiredCache cleans out stale entries in the cache.
func (blt *CachedLivenessTester) ClearExpiredCache() {
blt.m.Lock()
defer blt.m.Unlock()

for ipAddr, status := range blt.ipCache {
if time.Since(status.cachedTime) > blt.cacheExpirationTime {
delete(blt.ipCache, ipAddr)
Expand Down Expand Up @@ -106,17 +114,23 @@ func (blt *CachedLivenessTester) PeriodicScan(t string) {

for _, ip := range records {
if ip[0] != "saddr" {
if _, ok := blt.ipCache[ip[0]]; !ok {
var val cacheElement
val.isLive = true
val.cachedTime = time.Now()
blt.ipCache[ip[0]] = val
_, err := f.WriteString(ip[0] + "/32" + "\n")
if err != nil {
fmt.Println("Unable to write blocklist file", err)
f.Close()
func() {
// closure to ensure mutex unlocks in case of error.
blt.m.Lock()
defer blt.m.Unlock()

if _, ok := blt.ipCache[ip[0]]; !ok {
var val cacheElement
val.isLive = true
val.cachedTime = time.Now()
blt.ipCache[ip[0]] = val
_, err := f.WriteString(ip[0] + "/32" + "\n")
if err != nil {
fmt.Println("Unable to write blocklist file", err)
f.Close()
}
}
}
}()
}
}
f.Close()
Expand Down Expand Up @@ -145,21 +159,40 @@ func (blt *CachedLivenessTester) PeriodicScan(t string) {
// immediately and no network probes are sent. If the host was measured not
// live, the entry is stale, or there is not entry then network probes are sent
// and the result is then added to the cache.
//
// Lock on mutex is taken for lookup, then for cache update. Do NOT hold mutex
// while scanning for liveness as this will make cache extremely slow.
func (blt *CachedLivenessTester) PhantomIsLive(addr string, port uint16) (bool, error) {
// cache lookup internal function to use RLock
if live, err := blt.phantomLookup(addr, port); live || err != nil {
return live, err
}

// existing phantomIsLive() implementation
isLive, err := phantomIsLive(net.JoinHostPort(addr, strconv.Itoa(int(port))))
var val cacheElement
val.isLive = isLive
val.cachedTime = time.Now()

blt.m.Lock()
defer blt.m.Unlock()

blt.ipCache[addr] = val
return isLive, err
}

func (blt *CachedLivenessTester) phantomLookup(addr string, port uint16) (bool, error) {
blt.m.RLock()
defer blt.m.RUnlock()

if status, ok := blt.ipCache[addr]; ok {
if time.Since(status.cachedTime) < blt.cacheExpirationTime {
if status.isLive {
return true, fmt.Errorf(CACHED_PHANTOM_MSG)
}
}
}
isLive, err := phantomIsLive(net.JoinHostPort(addr, strconv.Itoa(int(port))))
var val cacheElement
val.isLive = isLive
val.cachedTime = time.Now()
blt.ipCache[addr] = val
return isLive, err
return false, nil
}

// PhantomIsLive sends 4 TCP syn packets to determine if the host will respond
Expand Down
43 changes: 43 additions & 0 deletions application/liveness/liveness_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package liveness
import (
"fmt"
"os"
"sync"
"testing"
"time"
)
Expand Down Expand Up @@ -98,3 +99,45 @@ func TestCachedLiveness(t *testing.T) {
}

}

func TestCachedLivenessThreaded(t *testing.T) {

test_cases := [...]struct {
address string
port uint16
expected bool
}{
{"1.1.1.1", 80, true},
{"192.0.0.2", 443, false},
{"2606:4700:4700::64", 443, true},
}

iterations := 10
failed := false
var wg sync.WaitGroup

clt := CachedLivenessTester{}
err := clt.Init("1h")
if err != nil {
t.Fatalf("failed to init cached liveness tester: %s", err)
}

for i := 0; i < iterations; i++ {
wg.Add(1)

go func(j int) {
test := test_cases[j%len(test_cases)]
liveness, response := clt.PhantomIsLive(test.address, test.port)
if liveness != test.expected {
t.Logf("%s:%d -> %v (expected %v)\n", test.address, test.port, response, test.expected)
}
wg.Done()
}(i)
}

wg.Wait()

if failed {
t.Fatalf("failed")
}
}

0 comments on commit bed93bb

Please sign in to comment.