Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use generics for CIDRTrees to avoid casting issues #1004

Merged
merged 1 commit into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 14 additions & 29 deletions allow_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ import (

type AllowList struct {
// The values of this cidrTree are `bool`, signifying allow/deny
cidrTree *cidr.Tree6
cidrTree *cidr.Tree6[bool]
}

type RemoteAllowList struct {
AllowList *AllowList

// Inside Range Specific, keys of this tree are inside CIDRs and values
// are *AllowList
insideAllowLists *cidr.Tree6
insideAllowLists *cidr.Tree6[*AllowList]
}

type LocalAllowList struct {
Expand Down Expand Up @@ -88,7 +88,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
}

tree := cidr.NewTree6()
tree := cidr.NewTree6[bool]()

// Keep track of the rules we have added for both ipv4 and ipv6
type allowListRules struct {
Expand Down Expand Up @@ -218,13 +218,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error
return nameRules, nil
}

func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) {
func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) {
value := c.Get(k)
if value == nil {
return nil, nil
}

remoteAllowRanges := cidr.NewTree6()
remoteAllowRanges := cidr.NewTree6[*AllowList]()

rawMap, ok := value.(map[interface{}]interface{})
if !ok {
Expand Down Expand Up @@ -257,41 +257,26 @@ func (al *AllowList) Allow(ip net.IP) bool {
return true
}

result := al.cidrTree.MostSpecificContains(ip)
switch v := result.(type) {
case bool:
return v
default:
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
}
_, result := al.cidrTree.MostSpecificContains(ip)
return result
}

func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
if al == nil {
return true
}

result := al.cidrTree.MostSpecificContainsIpV4(ip)
switch v := result.(type) {
case bool:
return v
default:
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
}
_, result := al.cidrTree.MostSpecificContainsIpV4(ip)
return result
}

func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
if al == nil {
return true
}

result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
switch v := result.(type) {
case bool:
return v
default:
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
}
_, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
return result
}

func (al *LocalAllowList) Allow(ip net.IP) bool {
Expand Down Expand Up @@ -352,9 +337,9 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {

func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
if al.insideAllowLists != nil {
inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
if inside != nil {
return inside.(*AllowList)
ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
if ok {
return inside
}
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion allow_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
func TestAllowList_Allow(t *testing.T) {
assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))

tree := cidr.NewTree6()
tree := cidr.NewTree6[bool]()
tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
Expand Down
4 changes: 2 additions & 2 deletions calculated_remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort {
return &Ip4AndPort{Ip: uint32(masked), Port: c.port}
}

func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4, error) {
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) {
value := c.Get(k)
if value == nil {
return nil, nil
}

calculatedRemotes := cidr.NewTree4()
calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]()

rawMap, ok := value.(map[any]any)
if !ok {
Expand Down
62 changes: 34 additions & 28 deletions cidr/tree4.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,36 @@ import (
"github.com/slackhq/nebula/iputil"
)

type Node struct {
left *Node
right *Node
parent *Node
value interface{}
type Node[T any] struct {
left *Node[T]
right *Node[T]
parent *Node[T]
hasValue bool
value T
}

type entry struct {
type entry[T any] struct {
CIDR *net.IPNet
Value *interface{}
Value T
}

type Tree4 struct {
root *Node
list []entry
type Tree4[T any] struct {
root *Node[T]
list []entry[T]
}

const (
startbit = iputil.VpnIp(0x80000000)
)

func NewTree4() *Tree4 {
tree := new(Tree4)
tree.root = &Node{}
tree.list = []entry{}
func NewTree4[T any]() *Tree4[T] {
tree := new(Tree4[T])
tree.root = &Node[T]{}
tree.list = []entry[T]{}
return tree
}

func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
bit := startbit
node := tree.root
next := tree.root
Expand Down Expand Up @@ -68,14 +69,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
}
}

tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
node.value = val
node.hasValue = true
return
}

// Build up the rest of the tree we don't already have
for bit&mask != 0 {
next = &Node{}
next = &Node[T]{}
next.parent = node

if ip&bit != 0 {
Expand All @@ -90,17 +92,18 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {

// Final node marks our cidr, set the value
node.value = val
tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
node.hasValue = true
tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
}

// Contains finds the first match, which may be the least specific
func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
bit := startbit
node := tree.root

for node != nil {
if node.value != nil {
return node.value
if node.hasValue {
return true, node.value
}

if ip&bit != 0 {
Expand All @@ -113,17 +116,18 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {

}

return value
return false, value
}

// MostSpecificContains finds the most specific match
func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
bit := startbit
node := tree.root

for node != nil {
if node.value != nil {
if node.hasValue {
value = node.value
ok = true
}

if ip&bit != 0 {
Expand All @@ -135,11 +139,12 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
bit >>= 1
}

return value
return ok, value
}

// Match finds the most specific match
func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
// TODO this is exact match
func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
bit := startbit
node := tree.root
lastNode := node
Expand All @@ -157,11 +162,12 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {

if bit == 0 && lastNode != nil {
value = lastNode.value
ok = true
}
return value
return ok, value
}

// List will return all CIDRs and their current values. Do not modify the contents!
func (tree *Tree4) List() []entry {
func (tree *Tree4[T]) List() []entry[T] {
return tree.list
}
Loading