diff --git a/android/src/main/java/com/tailscale/ipn/ui/view/ExitNodePicker.kt b/android/src/main/java/com/tailscale/ipn/ui/view/ExitNodePicker.kt index 7e8b02f008..6a9396f008 100644 --- a/android/src/main/java/com/tailscale/ipn/ui/view/ExitNodePicker.kt +++ b/android/src/main/java/com/tailscale/ipn/ui/view/ExitNodePicker.kt @@ -3,7 +3,6 @@ package com.tailscale.ipn.ui.view -import android.os.Build import androidx.compose.foundation.clickable import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Row diff --git a/libtailscale/ranges_calc/ranges_calc.go b/libtailscale/ranges_calc/ranges_calc.go index ab04cbdcfa..dc7b7d728f 100644 --- a/libtailscale/ranges_calc/ranges_calc.go +++ b/libtailscale/ranges_calc/ranges_calc.go @@ -1,68 +1,71 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + package ranges_calc import ( - "fmt" - "math/big" - "net/netip" - "sort" + "fmt" + "math/big" + "net/netip" + "sort" ) // Internal representation of an IP range [Start, End] (inclusive) type ipRange struct { - Start netip.Addr - End netip.Addr + Start netip.Addr + End netip.Addr } // space describes the address space (32 for IPv4, 128 for IPv6) type space struct { - bits uint + bits uint } // ---------- netip.Addr <-> big.Int ---------- func (s space) addrToInt(a netip.Addr) *big.Int { - if s.bits == 32 { - b := a.As4() - return new(big.Int).SetBytes(b[:]) - } - b := a.As16() - return new(big.Int).SetBytes(b[:]) + if s.bits == 32 { + b := a.As4() + return new(big.Int).SetBytes(b[:]) + } + b := a.As16() + return new(big.Int).SetBytes(b[:]) } func (s space) intToAddr(i *big.Int) netip.Addr { - b := i.FillBytes(make([]byte, s.bits/8)) - if s.bits == 32 { - var a [4]byte - copy(a[:], b) - return netip.AddrFrom4(a) - } - var a [16]byte - copy(a[:], b) - return netip.AddrFrom16(a) + b := i.FillBytes(make([]byte, s.bits/8)) + if s.bits == 32 { + var a [4]byte + copy(a[:], b) + return netip.AddrFrom4(a) + } + var a [16]byte + copy(a[:], b) + return netip.AddrFrom16(a) } // ---------- merge overlapping ranges ---------- func (s space) mergeRanges(ranges []ipRange) []ipRange { - if len(ranges) == 0 { - return nil - } - sort.Slice(ranges, func(i, j int) bool { - return ranges[i].Start.Compare(ranges[j].Start) < 0 - }) - merged := []ipRange{ranges[0]} - one := big.NewInt(1) - for _, r := range ranges[1:] { - last := &merged[len(merged)-1] - lastEnd := s.addrToInt(last.End) - curStart := s.addrToInt(r.Start) - if curStart.Cmp(new(big.Int).Add(lastEnd, one)) <= 0 { - if r.End.Compare(last.End) > 0 { - last.End = r.End - } - } else { - merged = append(merged, r) - } - } - return merged + if len(ranges) == 0 { + return nil + } + sort.Slice(ranges, func(i, j int) bool { + return ranges[i].Start.Compare(ranges[j].Start) < 0 + }) + merged := []ipRange{ranges[0]} + one := big.NewInt(1) + for _, r := range ranges[1:] { + last := &merged[len(merged)-1] + lastEnd := s.addrToInt(last.End) + curStart := s.addrToInt(r.Start) + if curStart.Cmp(new(big.Int).Add(lastEnd, one)) <= 0 { + if r.End.Compare(last.End) > 0 { + last.End = r.End + } + } else { + merged = append(merged, r) + } + } + return merged } // ---------- range -> minimal number of CIDRs ---------- @@ -70,125 +73,125 @@ func (s space) mergeRanges(ranges []ipRange) []ipRange { // by one or more CIDR prefixes. This function calculates the minimal set of CIDR // prefixes that cover the given range. func (s space) rangeToCIDRs(r ipRange) []netip.Prefix { - var result []netip.Prefix - cur := s.addrToInt(r.Start) - last := s.addrToInt(r.End) - one := big.NewInt(1) - - for cur.Cmp(last) <= 0 { - // Find the largest power-of-2 block starting at cur - var maxSize uint - for size := uint(0); size <= s.bits; size++ { - block := new(big.Int).Lsh(one, size) - if new(big.Int).And(cur, new(big.Int).Sub(block, one)).Cmp(big.NewInt(0)) != 0 { - break - } - maxSize = size - } - - // Shrink maxSize if it would go past last - for { - block := new(big.Int).Lsh(one, maxSize) - lastAddr := new(big.Int).Add(cur, new(big.Int).Sub(block, one)) - if lastAddr.Cmp(last) <= 0 { - break - } - if maxSize == 0 { - break - } - maxSize-- - } - - prefixLen := int(s.bits - maxSize) - result = append(result, netip.PrefixFrom(s.intToAddr(cur), prefixLen)) - cur = cur.Add(cur, new(big.Int).Lsh(one, maxSize)) - } - - return result + var result []netip.Prefix + cur := s.addrToInt(r.Start) + last := s.addrToInt(r.End) + one := big.NewInt(1) + + for cur.Cmp(last) <= 0 { + // Find the largest power-of-2 block starting at cur + var maxSize uint + for size := uint(0); size <= s.bits; size++ { + block := new(big.Int).Lsh(one, size) + if new(big.Int).And(cur, new(big.Int).Sub(block, one)).Cmp(big.NewInt(0)) != 0 { + break + } + maxSize = size + } + + // Shrink maxSize if it would go past last + for { + block := new(big.Int).Lsh(one, maxSize) + lastAddr := new(big.Int).Add(cur, new(big.Int).Sub(block, one)) + if lastAddr.Cmp(last) <= 0 { + break + } + if maxSize == 0 { + break + } + maxSize-- + } + + prefixLen := int(s.bits - maxSize) + result = append(result, netip.PrefixFrom(s.intToAddr(cur), prefixLen)) + cur = cur.Add(cur, new(big.Int).Lsh(one, maxSize)) + } + + return result } // ---------- CIDR -> range ---------- // prefixToRange converts a netip.Prefix to an ipRange with Start and End addresses. // Start is the network address and End is the broadcast address. func (s space) prefixToRange(p netip.Prefix) ipRange { - p = p.Masked() - start := s.addrToInt(p.Addr()) - hostBits := int(s.bits) - p.Bits() - size := new(big.Int).Lsh(big.NewInt(1), uint(hostBits)) - size.Sub(size, big.NewInt(1)) - end := new(big.Int).Add(start, size) - return ipRange{Start: p.Addr(), End: s.intToAddr(end)} + p = p.Masked() + start := s.addrToInt(p.Addr()) + hostBits := int(s.bits) - p.Bits() + size := new(big.Int).Lsh(big.NewInt(1), uint(hostBits)) + size.Sub(size, big.NewInt(1)) + end := new(big.Int).Add(start, size) + return ipRange{Start: p.Addr(), End: s.intToAddr(end)} } // ---------- helper: subtract disallowed from allowed ---------- func (s space) subtractRanges(allowed []ipRange, disallowed []ipRange) []ipRange { - if len(allowed) == 0 { - return nil - } - if len(disallowed) == 0 { - return allowed - } - - var result []ipRange - for _, a := range allowed { - cur := []ipRange{a} - for _, d := range disallowed { - cur2 := []ipRange{} - for _, r := range cur { - cur2 = append(cur2, s.subtractOneRange(r, d)...) - } - cur = cur2 - if len(cur) == 0 { - break - } - } - result = append(result, cur...) - } - return s.mergeRanges(result) + if len(allowed) == 0 { + return nil + } + if len(disallowed) == 0 { + return allowed + } + + var result []ipRange + for _, a := range allowed { + cur := []ipRange{a} + for _, d := range disallowed { + cur2 := []ipRange{} + for _, r := range cur { + cur2 = append(cur2, s.subtractOneRange(r, d)...) + } + cur = cur2 + if len(cur) == 0 { + break + } + } + result = append(result, cur...) + } + return s.mergeRanges(result) } // subtractOneRange subtracts a single disallowed range from a single allowed range func (s space) subtractOneRange(allowed ipRange, disallowed ipRange) []ipRange { - aStart := s.addrToInt(allowed.Start) - aEnd := s.addrToInt(allowed.End) - dStart := s.addrToInt(disallowed.Start) - dEnd := s.addrToInt(disallowed.End) - one := big.NewInt(1) - - // No overlap - if aEnd.Cmp(dStart) < 0 || aStart.Cmp(dEnd) > 0 { - return []ipRange{allowed} - } - - var result []ipRange - - // left side - if aStart.Cmp(dStart) < 0 { - result = append(result, ipRange{ - Start: allowed.Start, - End: s.intToAddr(new(big.Int).Sub(dStart, one)), - }) - } - - // right side - if aEnd.Cmp(dEnd) > 0 { - result = append(result, ipRange{ - Start: s.intToAddr(new(big.Int).Add(dEnd, one)), - End: allowed.End, - }) - } - - return result + aStart := s.addrToInt(allowed.Start) + aEnd := s.addrToInt(allowed.End) + dStart := s.addrToInt(disallowed.Start) + dEnd := s.addrToInt(disallowed.End) + one := big.NewInt(1) + + // No overlap + if aEnd.Cmp(dStart) < 0 || aStart.Cmp(dEnd) > 0 { + return []ipRange{allowed} + } + + var result []ipRange + + // left side + if aStart.Cmp(dStart) < 0 { + result = append(result, ipRange{ + Start: allowed.Start, + End: s.intToAddr(new(big.Int).Sub(dStart, one)), + }) + } + + // right side + if aEnd.Cmp(dEnd) > 0 { + result = append(result, ipRange{ + Start: s.intToAddr(new(big.Int).Add(dEnd, one)), + End: allowed.End, + }) + } + + return result } // rangesCalc performs the calculation: Routes (allowed) minus LocalRoutes (disallowed) type rangesCalc struct { - allowed []netip.Prefix - disallowed []netip.Prefix + allowed []netip.Prefix + disallowed []netip.Prefix } func newRangesCalc(routes, localRoutes []netip.Prefix) *rangesCalc { - return &rangesCalc{allowed: routes, disallowed: localRoutes} + return &rangesCalc{allowed: routes, disallowed: localRoutes} } const maxCalculatedRoutes = 500 @@ -197,80 +200,80 @@ const maxCalculatedRoutes = 500 // separate IPv4 and IPv6 prefix lists. If the resulting route set exceeds // a conservative cap, an error is returned so the caller can fail fast. func (rc *rangesCalc) calculate() (ipv4 []netip.Prefix, ipv6 []netip.Prefix, err error) { - var out4 []netip.Prefix - var out6 []netip.Prefix - - // Collect IPv4 and IPv6 separately - var allowed4 []ipRange - var disallowed4 []ipRange - var allowed6 []ipRange - var disallowed6 []ipRange - - for _, p := range rc.allowed { - if p.Addr().Is4() { - s := space{bits: 32} - r := s.prefixToRange(p) - allowed4 = append(allowed4, r) - } else { - s := space{bits: 128} - r := s.prefixToRange(p) - allowed6 = append(allowed6, r) - } - } - - for _, p := range rc.disallowed { - // Skip loopback prefixes; mirror behavior of ExcludeRoutes handling. - if p.Addr().IsLoopback() { - continue - } - if p.Addr().Is4() { - s := space{bits: 32} - r := s.prefixToRange(p) - disallowed4 = append(disallowed4, r) - } else { - s := space{bits: 128} - r := s.prefixToRange(p) - disallowed6 = append(disallowed6, r) - } - } - - // Process IPv4 - if len(allowed4) > 0 { - s := space{bits: 32} - mergedAllowed := s.mergeRanges(allowed4) - mergedDisallowed := s.mergeRanges(disallowed4) - finalAllowed := s.subtractRanges(mergedAllowed, mergedDisallowed) - for _, r := range finalAllowed { - for _, pref := range s.rangeToCIDRs(r) { - out4 = append(out4, pref) - } - } - } - - // Process IPv6 - if len(allowed6) > 0 { - s := space{bits: 128} - mergedAllowed := s.mergeRanges(allowed6) - mergedDisallowed := s.mergeRanges(disallowed6) - finalAllowed := s.subtractRanges(mergedAllowed, mergedDisallowed) - for _, r := range finalAllowed { - for _, pref := range s.rangeToCIDRs(r) { - out6 = append(out6, pref) - } - } - } - - total := len(out4) + len(out6) - if total > maxCalculatedRoutes { - return nil, nil, fmt.Errorf("calculated routes (%d) exceed cap (%d)", total, maxCalculatedRoutes) - } - - return out4, out6, nil + var out4 []netip.Prefix + var out6 []netip.Prefix + + // Collect IPv4 and IPv6 separately + var allowed4 []ipRange + var disallowed4 []ipRange + var allowed6 []ipRange + var disallowed6 []ipRange + + for _, p := range rc.allowed { + if p.Addr().Is4() { + s := space{bits: 32} + r := s.prefixToRange(p) + allowed4 = append(allowed4, r) + } else { + s := space{bits: 128} + r := s.prefixToRange(p) + allowed6 = append(allowed6, r) + } + } + + for _, p := range rc.disallowed { + // Skip loopback prefixes; mirror behavior of ExcludeRoutes handling. + if p.Addr().IsLoopback() { + continue + } + if p.Addr().Is4() { + s := space{bits: 32} + r := s.prefixToRange(p) + disallowed4 = append(disallowed4, r) + } else { + s := space{bits: 128} + r := s.prefixToRange(p) + disallowed6 = append(disallowed6, r) + } + } + + // Process IPv4 + if len(allowed4) > 0 { + s := space{bits: 32} + mergedAllowed := s.mergeRanges(allowed4) + mergedDisallowed := s.mergeRanges(disallowed4) + finalAllowed := s.subtractRanges(mergedAllowed, mergedDisallowed) + for _, r := range finalAllowed { + for _, pref := range s.rangeToCIDRs(r) { + out4 = append(out4, pref) + } + } + } + + // Process IPv6 + if len(allowed6) > 0 { + s := space{bits: 128} + mergedAllowed := s.mergeRanges(allowed6) + mergedDisallowed := s.mergeRanges(disallowed6) + finalAllowed := s.subtractRanges(mergedAllowed, mergedDisallowed) + for _, r := range finalAllowed { + for _, pref := range s.rangeToCIDRs(r) { + out6 = append(out6, pref) + } + } + } + + total := len(out4) + len(out6) + if total > maxCalculatedRoutes { + return nil, nil, fmt.Errorf("calculated routes (%d) exceed cap (%d)", total, maxCalculatedRoutes) + } + + return out4, out6, nil } // Calculate is the exported helper that computes effective allowed prefixes // given allowed routes and localRoutes to exclude. func Calculate(routes, localRoutes []netip.Prefix) (ipv4 []netip.Prefix, ipv6 []netip.Prefix, err error) { - rc := newRangesCalc(routes, localRoutes) - return rc.calculate() + rc := newRangesCalc(routes, localRoutes) + return rc.calculate() } diff --git a/libtailscale/ranges_calc/ranges_calc_test.go b/libtailscale/ranges_calc/ranges_calc_test.go index 301788c003..a5e5b8a0ad 100644 --- a/libtailscale/ranges_calc/ranges_calc_test.go +++ b/libtailscale/ranges_calc/ranges_calc_test.go @@ -1,72 +1,75 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + package ranges_calc import ( - "fmt" - "net/netip" - "testing" + "fmt" + "net/netip" + "testing" ) func TestCalculate_NoDisallowed(t *testing.T) { - allowed := []netip.Prefix{} - p, _ := netip.ParsePrefix("10.0.0.0/8") - allowed = append(allowed, p) + allowed := []netip.Prefix{} + p, _ := netip.ParsePrefix("10.0.0.0/8") + allowed = append(allowed, p) - v4, v6, err := Calculate(allowed, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if len(v6) != 0 { - t.Fatalf("expected no IPv6 prefixes, got %d", len(v6)) - } - if len(v4) == 0 { - t.Fatalf("expected some IPv4 prefixes, got none") - } + v4, v6, err := Calculate(allowed, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(v6) != 0 { + t.Fatalf("expected no IPv6 prefixes, got %d", len(v6)) + } + if len(v4) == 0 { + t.Fatalf("expected some IPv4 prefixes, got none") + } } func TestCalculate_LoopbackIgnored(t *testing.T) { - allowed := []netip.Prefix{} - a, _ := netip.ParsePrefix("127.0.0.0/8") - allowed = append(allowed, a) + allowed := []netip.Prefix{} + a, _ := netip.ParsePrefix("127.0.0.0/8") + allowed = append(allowed, a) - // disallowed contains a loopback address which should be ignored. - d := []netip.Prefix{} - lp, _ := netip.ParsePrefix("127.0.0.1/32") - d = append(d, lp) + // disallowed contains a loopback address which should be ignored. + d := []netip.Prefix{} + lp, _ := netip.ParsePrefix("127.0.0.1/32") + d = append(d, lp) - v4a, _, err := Calculate(allowed, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + v4a, _, err := Calculate(allowed, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } - v4b, _, err := Calculate(allowed, d) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + v4b, _, err := Calculate(allowed, d) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } - // Results should be identical because loopback in disallowed is skipped. - if len(v4a) != len(v4b) { - t.Fatalf("loopback disallowed altered result: before=%d after=%d", len(v4a), len(v4b)) - } + // Results should be identical because loopback in disallowed is skipped. + if len(v4a) != len(v4b) { + t.Fatalf("loopback disallowed altered result: before=%d after=%d", len(v4a), len(v4b)) + } } func TestCalculate_CapExceeded(t *testing.T) { - // Create more than maxCalculatedRoutes separate /32 prefixes. - want := maxCalculatedRoutes + 1 - allowed := make([]netip.Prefix, 0, want) - for i := 0; i < want; i++ { - // Generate addresses 10.X.Y.1 where X = i/256, Y = i%256 - x := (i / 256) % 256 - y := i % 256 - s := fmt.Sprintf("10.%d.%d.1/32", x, y) - p, err := netip.ParsePrefix(s) - if err != nil { - t.Fatalf("parse prefix %q: %v", s, err) - } - allowed = append(allowed, p) - } + // Create more than maxCalculatedRoutes separate /32 prefixes. + want := maxCalculatedRoutes + 1 + allowed := make([]netip.Prefix, 0, want) + for i := 0; i < want; i++ { + // Generate addresses 10.X.Y.1 where X = i/256, Y = i%256 + x := (i / 256) % 256 + y := i % 256 + s := fmt.Sprintf("10.%d.%d.1/32", x, y) + p, err := netip.ParsePrefix(s) + if err != nil { + t.Fatalf("parse prefix %q: %v", s, err) + } + allowed = append(allowed, p) + } - _, _, err := Calculate(allowed, nil) - if err == nil { - t.Fatalf("expected error when exceeding cap (%d), got nil", maxCalculatedRoutes) - } + _, _, err := Calculate(allowed, nil) + if err == nil { + t.Fatalf("expected error when exceeding cap (%d), got nil", maxCalculatedRoutes) + } }