Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix re-entrant GetOrHandshake issues #1044

Merged
merged 5 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
swapPrimary trafficDecision = 3
migrateRelays trafficDecision = 4
tryRehandshake trafficDecision = 5
sendTestPacket trafficDecision = 6
)

type connectionManager struct {
Expand Down Expand Up @@ -176,7 +177,7 @@ func (n *connectionManager) Run(ctx context.Context) {
}

func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
decision, hostinfo, primary := n.makeTrafficDecision(localIndex, p, nb, out, now)
decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)

switch decision {
case deleteTunnel:
Expand All @@ -197,6 +198,9 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,

case tryRehandshake:
n.tryRehandshake(hostinfo)

case sendTestPacket:
n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
}

n.resetRelayTrafficCheck(hostinfo)
Expand Down Expand Up @@ -289,7 +293,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
}
}

func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []byte, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
n.hostMap.RLock()
defer n.hostMap.RUnlock()

Expand Down Expand Up @@ -356,6 +360,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []
return deleteTunnel, hostinfo, nil
}

decision := doNothing
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
if !outTraffic {
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
Expand All @@ -380,7 +385,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []
}

// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
decision = sendTestPacket

} else {
if n.l.Level >= logrus.DebugLevel {
Expand All @@ -390,7 +395,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []

n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
return doNothing, nil, nil
return decision, hostinfo, nil
}

func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
Expand Down
5 changes: 3 additions & 2 deletions connection_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ var vpnIp iputil.VpnIp

func newTestLighthouse() *LightHouse {
lh := &LightHouse{
l: test.NewLogger(),
addrMap: map[iputil.VpnIp]*RemoteList{},
l: test.NewLogger(),
addrMap: map[iputil.VpnIp]*RemoteList{},
queryChan: make(chan iputil.VpnIp, 10),
}
lighthouses := map[iputil.VpnIp]struct{}{}
staticList := map[iputil.VpnIp]struct{}{}
Expand Down
4 changes: 4 additions & 0 deletions examples/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ logging:
# A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out
#try_interval: 100ms
#retries: 20

# query_buffer is the size of the buffer channel for querying lighthouses
#query_buffer: 64

# trigger_buffer is the size of the buffer channel for quickly sending handshakes
# after receiving the response for lighthouse queries
#trigger_buffer: 64
Expand Down
7 changes: 3 additions & 4 deletions handshake_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
// Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
// the learned public ip for them. Query again to short circuit the promotion counter
hm.lightHouse.QueryServer(vpnIp, hm.f)
hm.lightHouse.QueryServer(vpnIp)
}

// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
Expand Down Expand Up @@ -374,13 +374,13 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo {
hm.Lock()
defer hm.Unlock()

if hh, ok := hm.vpnIps[vpnIp]; ok {
// We are already trying to handshake with this vpn ip
if cacheCb != nil {
cacheCb(hh)
}
hm.Unlock()
return hh.hostinfo
}

Expand Down Expand Up @@ -421,8 +421,7 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
}
}

hm.Unlock()
hm.lightHouse.QueryServer(vpnIp, hm.f)
hm.lightHouse.QueryServer(vpnIp)
return hostinfo
}

Expand Down
2 changes: 1 addition & 1 deletion hostmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
}

i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
ifce.lightHouse.QueryServer(i.vpnIp, ifce)
ifce.lightHouse.QueryServer(i.vpnIp)
}
}

Expand Down
2 changes: 1 addition & 1 deletion inside.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
f.lightHouse.QueryServer(hostinfo.vpnIp, f)
f.lightHouse.QueryServer(hostinfo.vpnIp)
hostinfo.lastRebindCount = f.rebindCount
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")
Expand Down
75 changes: 52 additions & 23 deletions lighthouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ type LightHouse struct {
// IP's of relays that can be used by peers to access me
relaysForMe atomic.Pointer[[]iputil.VpnIp]

queryChan chan iputil.VpnIp

calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote

metrics *MessageMetrics
Expand Down Expand Up @@ -110,6 +112,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
nebulaPort: nebulaPort,
punchConn: pc,
punchy: p,
queryChan: make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)),
l: l,
}
lighthouses := make(map[iputil.VpnIp]struct{})
Expand Down Expand Up @@ -139,6 +142,8 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
}
})

h.startQueryWorker()

return &h, nil
}

Expand Down Expand Up @@ -443,9 +448,9 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
return nil
}

func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList {
func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList {
if !lh.IsLighthouseIP(ip) {
lh.QueryServer(ip, f)
lh.QueryServer(ip)
}
lh.RLock()
if v, ok := lh.addrMap[ip]; ok {
Expand All @@ -456,30 +461,14 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList {
return nil
}

// This is asynchronous so no reply should be expected
func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f EncWriter) {
if lh.amLighthouse {
return
}

if lh.IsLighthouseIP(ip) {
return
}

// Send a query to the lighthouses and hope for the best next time
query, err := NewLhQueryByInt(ip).Marshal()
if err != nil {
lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload")
// QueryServer is asynchronous so no reply should be expected
func (lh *LightHouse) QueryServer(ip iputil.VpnIp) {
// Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses
if lh.amLighthouse || lh.IsLighthouseIP(ip) {
return
}

lighthouses := lh.GetLighthouses()
lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses)))
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for n := range lighthouses {
f.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out)
}
lh.queryChan <- ip
brad-defined marked this conversation as resolved.
Show resolved Hide resolved
}

func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
Expand Down Expand Up @@ -752,6 +741,46 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
}

func (lh *LightHouse) startQueryWorker() {
if lh.amLighthouse {
return
}

go func() {
nb := make([]byte, 12, 12)
out := make([]byte, mtu)

for {
select {
case <-lh.ctx.Done():
return
case ip := <-lh.queryChan:
lh.innerQueryServer(ip, nb, out)
}
}
}()
}

func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) {
if lh.IsLighthouseIP(ip) {
return
}

// Send a query to the lighthouses and hope for the best next time
query, err := NewLhQueryByInt(ip).Marshal()
if err != nil {
lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload")
return
}

lighthouses := lh.GetLighthouses()
lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses)))

for n := range lighthouses {
lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out)
}
}

func (lh *LightHouse) StartUpdateWorker() {
interval := lh.GetUpdateInterval()
if lh.amLighthouse || interval == 0 {
Expand Down
2 changes: 1 addition & 1 deletion ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
}

var cm *CacheMap
rl := ifce.lightHouse.Query(vpnIp, ifce)
rl := ifce.lightHouse.Query(vpnIp)
if rl != nil {
cm = rl.CopyCache()
}
Expand Down