Skip to content

Commit

Permalink
util/linuxfw: fix IPv6 availability check for nftables (#12009) (#12123)
Browse files Browse the repository at this point in the history
* util/linuxfw: fix IPv6 NAT availability check for nftables

When running firewall in nftables mode,
there is no need for a separate NAT availability check
(unlike with iptables, there are no hosts that support nftables, but not IPv6 NAT - see #11353).
This change fixes a firewall NAT availability check that was using the no-longer set ipv6NATAvailable field
by removing the field and using a method that, for nftables, just checks that IPv6 is available.

Updates #12008

Signed-off-by: Irbe Krumina <irbe@tailscale.com>
(cherry picked from commit 7ef2f72)
  • Loading branch information
irbekrm authored May 14, 2024
1 parent 32cb8a3 commit 9d2768a
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 78 deletions.
2 changes: 1 addition & 1 deletion util/linuxfw/iptables_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func newIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
}
supportsV6Filter = checkSupportsV6Filter(ipt6, logf)
supportsV6NAT = checkSupportsV6NAT(ipt6, logf)
logf("v6 = %v, v6filter = %v, v6nat = %v", supportsV6, supportsV6Filter, supportsV6NAT)
logf("netfilter running in iptables mode v6 = %v, v6filter = %v, v6nat = %v", supportsV6, supportsV6Filter, supportsV6NAT)
}
return &iptablesRunner{
ipt4: ipt4,
Expand Down
7 changes: 7 additions & 0 deletions util/linuxfw/linuxfw.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,19 @@ func getTailscaleSubnetRouteMark() []byte {
return []byte{0x00, 0x04, 0x00, 0x00}
}

// checkIPv6ForTest can be set in tests.
var checkIPv6ForTest func(logger.Logf) error

// checkIPv6 checks whether the system appears to have a working IPv6
// network stack. It returns an error explaining what looks wrong or
// missing. It does not check that IPv6 is currently functional or
// that there's a global address, just that the system would support
// IPv6 if it were on an IPv6 network.
func CheckIPv6(logf logger.Logf) error {
if f := checkIPv6ForTest; f != nil {
return f(logf)
}

_, err := os.Stat("/proc/sys/net/ipv6")
if os.IsNotExist(err) {
return err
Expand Down
49 changes: 18 additions & 31 deletions util/linuxfw/nftables_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ type chainInfo struct {
chainPolicy *nftables.ChainPolicy
}

// nftable contains nat and filter tables for the given IP family (Proto).
type nftable struct {
Proto nftables.TableFamily
Proto nftables.TableFamily // IPv4 or IPv6
Filter *nftables.Table
Nat *nftables.Table
}
Expand All @@ -69,11 +70,10 @@ type nftable struct {
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains
type nftablesRunner struct {
conn *nftables.Conn
nft4 *nftable
nft6 *nftable
nft4 *nftable // IPv4 tables
nft6 *nftable // IPv6 tables

v6Available bool
v6NATAvailable bool
v6Available bool // whether the host supports IPv6
}

func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) {
Expand Down Expand Up @@ -598,6 +598,10 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
if err != nil {
return nil, fmt.Errorf("nftables connection: %w", err)
}
return newNfTablesRunnerWithConn(logf, conn), nil
}

func newNfTablesRunnerWithConn(logf logger.Logf, conn *nftables.Conn) *nftablesRunner {
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}

v6err := CheckIPv6(logf)
Expand All @@ -609,8 +613,8 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {

if supportsV6 {
nft6 = &nftable{Proto: nftables.TableFamilyIPv6}
logf("v6nat availability: true")
}
logf("netfilter running in nftables mode, v6 = %v", supportsV6)

// TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables

Expand All @@ -619,7 +623,7 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
nft4: nft4,
nft6: nft6,
v6Available: supportsV6,
}, nil
}
}

// newLoadSaddrExpr creates a new nftables expression that loads the source
Expand Down Expand Up @@ -837,24 +841,15 @@ func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
return n.conn.Flush()
}

// getTables gets the available nftable in nftables runner.
// getTables returns tables for IP families that this host was determined to
// support (either IPv4 and IPv6 or just IPv4).
func (n *nftablesRunner) getTables() []*nftable {
if n.v6Available {
if n.HasIPV6() {
return []*nftable{n.nft4, n.nft6}
}
return []*nftable{n.nft4}
}

// getNATTables gets the available nftable in nftables runner.
// If the system does not support IPv6 NAT, only the IPv4 nftable
// will be returned.
func (n *nftablesRunner) getNATTables() []*nftable {
if n.v6NATAvailable {
return n.getTables()
}
return []*nftable{n.nft4}
}

// AddChains creates custom Tailscale chains in netfilter via nftables
// if the ts-chain doesn't already exist.
func (n *nftablesRunner) AddChains() error {
Expand Down Expand Up @@ -883,9 +878,7 @@ func (n *nftablesRunner) AddChains() error {
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create input chain: %w", err)
}
}

for _, table := range n.getNATTables() {
// Create the nat table if it doesn't exist, this table name is the same
// as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump
Expand Down Expand Up @@ -923,7 +916,7 @@ const (
// can be used. It cleans up the dummy chains after creation.
func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) {
polAccept := ptr.To(nftables.ChainPolicyAccept)
for _, table := range n.getNATTables() {
for _, table := range n.getTables() {
nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName)
if err != nil {
return fmt.Errorf("create nat table: %w", err)
Expand Down Expand Up @@ -980,7 +973,7 @@ func (n *nftablesRunner) DelChains() error {
return fmt.Errorf("delete chain: %w", err)
}

if n.v6NATAvailable {
if n.HasIPV6NAT() {
if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
return fmt.Errorf("delete chain: %w", err)
}
Expand Down Expand Up @@ -1046,9 +1039,7 @@ func (n *nftablesRunner) AddHooks() error {
if err != nil {
return fmt.Errorf("Addhook: %w", err)
}
}

for _, table := range n.getNATTables() {
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
Expand Down Expand Up @@ -1102,9 +1093,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
if err != nil {
return fmt.Errorf("delhook: %w", err)
}
}

for _, table := range n.getNATTables() {
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
Expand Down Expand Up @@ -1612,9 +1601,7 @@ func (n *nftablesRunner) DelBase() error {
return fmt.Errorf("get forward chain: %v", err)
}
conn.FlushChain(forwardChain)
}

for _, table := range n.getNATTables() {
postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %v", err)
Expand Down Expand Up @@ -1684,7 +1671,7 @@ func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, cha
func (n *nftablesRunner) AddSNATRule() error {
conn := n.conn

for _, table := range n.getNATTables() {
for _, table := range n.getTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err)
Expand Down Expand Up @@ -1727,7 +1714,7 @@ func (n *nftablesRunner) DelSNATRule() error {
&expr.Masq{},
}

for _, table := range n.getNATTables() {
for _, table := range n.getTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err)
Expand Down
127 changes: 81 additions & 46 deletions util/linuxfw/nftables_runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"github.com/mdlayher/netlink"
"github.com/vishvananda/netns"
"tailscale.com/net/tsaddr"
"tailscale.com/tstest"
"tailscale.com/types/logger"
)

// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
Expand Down Expand Up @@ -503,19 +505,6 @@ func cleanupSysConn(t *testing.T, ns netns.NsHandle) {
}
}

func newFakeNftablesRunner(t *testing.T, conn *nftables.Conn) *nftablesRunner {
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
nft6 := &nftable{Proto: nftables.TableFamilyIPv6}

return &nftablesRunner{
conn: conn,
nft4: nft4,
nft6: nft6,
v6Available: true,
v6NATAvailable: true,
}
}

func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wantCount int) {
t.Helper()
got, err := conn.ListChainsOfTableFamily(fam)
Expand All @@ -526,42 +515,76 @@ func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wa
t.Fatalf("len(got) = %d, want %d", len(got), wantCount)
}
}

func TestAddAndDelNetfilterChains(t *testing.T) {
conn := newSysConn(t)
checkChains(t, conn, nftables.TableFamilyIPv4, 0)
checkChains(t, conn, nftables.TableFamilyIPv6, 0)

runner := newFakeNftablesRunner(t, conn)
if err := runner.AddChains(); err != nil {
t.Fatalf("runner.AddChains() failed: %v", err)
}

tables, err := conn.ListTables()
func checkTables(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wantCount int) {
t.Helper()
got, err := conn.ListTablesOfFamily(fam)
if err != nil {
t.Fatalf("conn.ListTables() failed: %v", err)
t.Fatalf("conn.ListTablesOfFamily(%v) failed: %v", fam, err)
}

if len(tables) != 4 {
t.Fatalf("len(tables) = %d, want 4", len(tables))
if len(got) != wantCount {
t.Fatalf("len(got) = %d, want %d", len(got), wantCount)
}
}

checkChains(t, conn, nftables.TableFamilyIPv4, 6)
checkChains(t, conn, nftables.TableFamilyIPv6, 6)
func TestAddAndDelNetfilterChains(t *testing.T) {
type test struct {
hostHasIPv6 bool
initIPv4ChainCount int
initIPv6ChainCount int
ipv4TableCount int
ipv6TableCount int
ipv4ChainCount int
ipv6ChainCount int
ipv4ChainCountPostDelete int
ipv6ChainCountPostDelete int
}
tests := []test{
{
hostHasIPv6: true,
initIPv4ChainCount: 0,
initIPv6ChainCount: 0,
ipv4TableCount: 2,
ipv6TableCount: 2,
ipv4ChainCount: 6,
ipv6ChainCount: 6,
ipv4ChainCountPostDelete: 3,
ipv6ChainCountPostDelete: 3,
},
{ // host without IPv6 support
ipv4TableCount: 2,
ipv4ChainCount: 6,
ipv4ChainCountPostDelete: 3,
}}
for _, tt := range tests {
t.Logf("running a test case for IPv6 support: %v", tt.hostHasIPv6)
conn := newSysConn(t)
runner := newFakeNftablesRunnerWithConn(t, conn, tt.hostHasIPv6)

// Check that we start off with no chains.
checkChains(t, conn, nftables.TableFamilyIPv4, tt.initIPv4ChainCount)
checkChains(t, conn, nftables.TableFamilyIPv6, tt.initIPv6ChainCount)

runner.DelChains()
if err := runner.AddChains(); err != nil {
t.Fatalf("runner.AddChains() failed: %v", err)
}

// The default chains should still be present.
checkChains(t, conn, nftables.TableFamilyIPv4, 3)
checkChains(t, conn, nftables.TableFamilyIPv6, 3)
// Check that the amount of tables for each IP family is as expected.
checkTables(t, conn, nftables.TableFamilyIPv4, tt.ipv4TableCount)
checkTables(t, conn, nftables.TableFamilyIPv6, tt.ipv6TableCount)

tables, err = conn.ListTables()
if err != nil {
t.Fatalf("conn.ListTables() failed: %v", err)
}
// Check that the amount of chains for each IP family is as expected.
checkChains(t, conn, nftables.TableFamilyIPv4, tt.ipv4ChainCount)
checkChains(t, conn, nftables.TableFamilyIPv6, tt.ipv6ChainCount)

if err := runner.DelChains(); err != nil {
t.Fatalf("runner.DelChains() failed: %v", err)
}

if len(tables) != 4 {
t.Fatalf("len(tables) = %d, want 4", len(tables))
// Test that the tables as well as the default chains are still present.
checkChains(t, conn, nftables.TableFamilyIPv4, tt.ipv4ChainCountPostDelete)
checkChains(t, conn, nftables.TableFamilyIPv6, tt.ipv6ChainCountPostDelete)
checkTables(t, conn, nftables.TableFamilyIPv4, tt.ipv4TableCount)
checkTables(t, conn, nftables.TableFamilyIPv6, tt.ipv6TableCount)
}
}

Expand Down Expand Up @@ -665,7 +688,8 @@ func checkChainRules(t *testing.T, conn *nftables.Conn, chain *nftables.Chain, w
func TestNFTAddAndDelNetfilterBase(t *testing.T) {
conn := newSysConn(t)

runner := newFakeNftablesRunner(t, conn)
runner := newFakeNftablesRunnerWithConn(t, conn, true)

if err := runner.AddChains(); err != nil {
t.Fatalf("AddChains() failed: %v", err)
}
Expand Down Expand Up @@ -759,7 +783,7 @@ func findLoopBackRule(conn *nftables.Conn, proto nftables.TableFamily, table *nf
func TestNFTAddAndDelLoopbackRule(t *testing.T) {
conn := newSysConn(t)

runner := newFakeNftablesRunner(t, conn)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
if err := runner.AddChains(); err != nil {
t.Fatalf("AddChains() failed: %v", err)
}
Expand Down Expand Up @@ -817,7 +841,7 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) {

func TestNFTAddAndDelHookRule(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunner(t, conn)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
if err := runner.AddChains(); err != nil {
t.Fatalf("AddChains() failed: %v", err)
}
Expand Down Expand Up @@ -868,11 +892,11 @@ func (t *testFWDetector) nftDetect() (int, error) {
// postrouting chains are cleaned up.
func TestCreateDummyPostroutingChains(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunner(t, conn)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
if err := runner.createDummyPostroutingChains(); err != nil {
t.Fatalf("createDummyPostroutingChains() failed: %v", err)
}
for _, table := range runner.getNATTables() {
for _, table := range runner.getTables() {
nt, err := getTableIfExists(conn, table.Proto, tsDummyTableName)
if err != nil {
t.Fatalf("getTableIfExists() failed: %v", err)
Expand Down Expand Up @@ -929,3 +953,14 @@ func TestPickFirewallModeFromInstalledRules(t *testing.T) {
})
}
}

func newFakeNftablesRunnerWithConn(t *testing.T, conn *nftables.Conn, hasIPv6 bool) *nftablesRunner {
t.Helper()
if !hasIPv6 {
tstest.Replace(t, &checkIPv6ForTest, func(logger.Logf) error {
return errors.New("test: no IPv6")
})

}
return newNfTablesRunnerWithConn(t.Logf, conn)
}

0 comments on commit 9d2768a

Please sign in to comment.