Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

util/linuxfw: fix IPv6 availability check for nftables #12009

Merged
merged 3 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion util/linuxfw/iptables_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func newIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
}
supportsV6Filter = checkSupportsV6Filter(ipt6, logf)
supportsV6NAT = checkSupportsV6NAT(ipt6, logf)
logf("v6 = %v, v6filter = %v, v6nat = %v", supportsV6, supportsV6Filter, supportsV6NAT)
logf("netfilter running in iptables mode v6 = %v, v6filter = %v, v6nat = %v", supportsV6, supportsV6Filter, supportsV6NAT)
}
return &iptablesRunner{
ipt4: ipt4,
Expand Down
7 changes: 7 additions & 0 deletions util/linuxfw/linuxfw.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,19 @@ func getTailscaleSubnetRouteMark() []byte {
return []byte{0x00, 0x04, 0x00, 0x00}
}

// checkIPv6ForTest can be set in tests.
var checkIPv6ForTest func(logger.Logf) error

// checkIPv6 checks whether the system appears to have a working IPv6
// network stack. It returns an error explaining what looks wrong or
// missing. It does not check that IPv6 is currently functional or
// that there's a global address, just that the system would support
// IPv6 if it were on an IPv6 network.
func CheckIPv6(logf logger.Logf) error {
if f := checkIPv6ForTest; f != nil {
return f(logf)
}

_, err := os.Stat("/proc/sys/net/ipv6")
if os.IsNotExist(err) {
return err
Expand Down
49 changes: 18 additions & 31 deletions util/linuxfw/nftables_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ type chainInfo struct {
chainPolicy *nftables.ChainPolicy
}

// nftable contains nat and filter tables for the given IP family (Proto).
type nftable struct {
Proto nftables.TableFamily
Proto nftables.TableFamily // IPv4 or IPv6
Filter *nftables.Table
Nat *nftables.Table
}
Expand All @@ -69,11 +70,10 @@ type nftable struct {
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains
type nftablesRunner struct {
conn *nftables.Conn
nft4 *nftable
nft6 *nftable
nft4 *nftable // IPv4 tables
nft6 *nftable // IPv6 tables

v6Available bool
v6NATAvailable bool
v6Available bool // whether the host supports IPv6
}

func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) {
Expand Down Expand Up @@ -590,6 +590,10 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
if err != nil {
return nil, fmt.Errorf("nftables connection: %w", err)
}
return newNfTablesRunnerWithConn(logf, conn), nil
}

func newNfTablesRunnerWithConn(logf logger.Logf, conn *nftables.Conn) *nftablesRunner {
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}

v6err := CheckIPv6(logf)
Expand All @@ -601,8 +605,8 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {

if supportsV6 {
nft6 = &nftable{Proto: nftables.TableFamilyIPv6}
logf("v6nat availability: true")
}
logf("netfilter running in nftables mode, v6 = %v", supportsV6)

// TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables

Expand All @@ -611,7 +615,7 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
nft4: nft4,
nft6: nft6,
v6Available: supportsV6,
}, nil
}
}

// newLoadSaddrExpr creates a new nftables expression that loads the source
Expand Down Expand Up @@ -829,24 +833,15 @@ func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
return n.conn.Flush()
}

// getTables gets the available nftable in nftables runner.
// getTables returns tables for IP families that this host was determined to
// support (either IPv4 and IPv6 or just IPv4).
func (n *nftablesRunner) getTables() []*nftable {
if n.v6Available {
if n.HasIPV6() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, does not change anything functionally

return []*nftable{n.nft4, n.nft6}
}
return []*nftable{n.nft4}
}

// getNATTables gets the available nftable in nftables runner.
// If the system does not support IPv6 NAT, only the IPv4 nftable
// will be returned.
func (n *nftablesRunner) getNATTables() []*nftable {
if n.v6NATAvailable {
return n.getTables()
}
return []*nftable{n.nft4}
}

// AddChains creates custom Tailscale chains in netfilter via nftables
// if the ts-chain doesn't already exist.
func (n *nftablesRunner) AddChains() error {
Expand Down Expand Up @@ -875,9 +870,7 @@ func (n *nftablesRunner) AddChains() error {
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create input chain: %w", err)
}
}

for _, table := range n.getNATTables() {
// Create the nat table if it doesn't exist, this table name is the same
// as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump
Expand Down Expand Up @@ -915,7 +908,7 @@ const (
// can be used. It cleans up the dummy chains after creation.
func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) {
polAccept := ptr.To(nftables.ChainPolicyAccept)
for _, table := range n.getNATTables() {
for _, table := range n.getTables() {
nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName)
if err != nil {
return fmt.Errorf("create nat table: %w", err)
Expand Down Expand Up @@ -972,7 +965,7 @@ func (n *nftablesRunner) DelChains() error {
return fmt.Errorf("delete chain: %w", err)
}

if n.v6NATAvailable {
if n.HasIPV6NAT() {
if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
return fmt.Errorf("delete chain: %w", err)
}
Expand Down Expand Up @@ -1038,9 +1031,7 @@ func (n *nftablesRunner) AddHooks() error {
if err != nil {
return fmt.Errorf("Addhook: %w", err)
}
}

for _, table := range n.getNATTables() {
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
Expand Down Expand Up @@ -1094,9 +1085,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
if err != nil {
return fmt.Errorf("delhook: %w", err)
}
}

for _, table := range n.getNATTables() {
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
Expand Down Expand Up @@ -1604,9 +1593,7 @@ func (n *nftablesRunner) DelBase() error {
return fmt.Errorf("get forward chain: %v", err)
}
conn.FlushChain(forwardChain)
}

for _, table := range n.getNATTables() {
postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %v", err)
Expand Down Expand Up @@ -1676,7 +1663,7 @@ func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, cha
func (n *nftablesRunner) AddSNATRule() error {
conn := n.conn

for _, table := range n.getNATTables() {
for _, table := range n.getTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err)
Expand Down Expand Up @@ -1719,7 +1706,7 @@ func (n *nftablesRunner) DelSNATRule() error {
&expr.Masq{},
}

for _, table := range n.getNATTables() {
for _, table := range n.getTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err)
Expand Down
127 changes: 81 additions & 46 deletions util/linuxfw/nftables_runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"github.com/mdlayher/netlink"
"github.com/vishvananda/netns"
"tailscale.com/net/tsaddr"
"tailscale.com/tstest"
"tailscale.com/types/logger"
)

// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
Expand Down Expand Up @@ -503,19 +505,6 @@ func cleanupSysConn(t *testing.T, ns netns.NsHandle) {
}
}

func newFakeNftablesRunner(t *testing.T, conn *nftables.Conn) *nftablesRunner {
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
nft6 := &nftable{Proto: nftables.TableFamilyIPv6}

return &nftablesRunner{
conn: conn,
nft4: nft4,
nft6: nft6,
v6Available: true,
v6NATAvailable: true,
}
}

func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wantCount int) {
t.Helper()
got, err := conn.ListChainsOfTableFamily(fam)
Expand All @@ -526,42 +515,76 @@ func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wa
t.Fatalf("len(got) = %d, want %d", len(got), wantCount)
}
}

func TestAddAndDelNetfilterChains(t *testing.T) {
conn := newSysConn(t)
checkChains(t, conn, nftables.TableFamilyIPv4, 0)
checkChains(t, conn, nftables.TableFamilyIPv6, 0)

runner := newFakeNftablesRunner(t, conn)
if err := runner.AddChains(); err != nil {
t.Fatalf("runner.AddChains() failed: %v", err)
}

tables, err := conn.ListTables()
func checkTables(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wantCount int) {
t.Helper()
got, err := conn.ListTablesOfFamily(fam)
if err != nil {
t.Fatalf("conn.ListTables() failed: %v", err)
t.Fatalf("conn.ListTablesOfFamily(%v) failed: %v", fam, err)
}

if len(tables) != 4 {
t.Fatalf("len(tables) = %d, want 4", len(tables))
if len(got) != wantCount {
t.Fatalf("len(got) = %d, want %d", len(got), wantCount)
}
}

checkChains(t, conn, nftables.TableFamilyIPv4, 6)
checkChains(t, conn, nftables.TableFamilyIPv6, 6)
func TestAddAndDelNetfilterChains(t *testing.T) {
type test struct {
hostHasIPv6 bool
initIPv4ChainCount int
initIPv6ChainCount int
ipv4TableCount int
ipv6TableCount int
ipv4ChainCount int
ipv6ChainCount int
ipv4ChainCountPostDelete int
ipv6ChainCountPostDelete int
}
tests := []test{
{
hostHasIPv6: true,
initIPv4ChainCount: 0,
initIPv6ChainCount: 0,
ipv4TableCount: 2,
ipv6TableCount: 2,
ipv4ChainCount: 6,
ipv6ChainCount: 6,
ipv4ChainCountPostDelete: 3,
ipv6ChainCountPostDelete: 3,
},
{ // host without IPv6 support
ipv4TableCount: 2,
ipv4ChainCount: 6,
ipv4ChainCountPostDelete: 3,
}}
for _, tt := range tests {
t.Logf("running a test case for IPv6 support: %v", tt.hostHasIPv6)
conn := newSysConn(t)
runner := newFakeNftablesRunnerWithConn(t, conn, tt.hostHasIPv6)

// Check that we start off with no chains.
checkChains(t, conn, nftables.TableFamilyIPv4, tt.initIPv4ChainCount)
checkChains(t, conn, nftables.TableFamilyIPv6, tt.initIPv6ChainCount)

runner.DelChains()
if err := runner.AddChains(); err != nil {
t.Fatalf("runner.AddChains() failed: %v", err)
}

// The default chains should still be present.
checkChains(t, conn, nftables.TableFamilyIPv4, 3)
checkChains(t, conn, nftables.TableFamilyIPv6, 3)
// Check that the amount of tables for each IP family is as expected.
checkTables(t, conn, nftables.TableFamilyIPv4, tt.ipv4TableCount)
checkTables(t, conn, nftables.TableFamilyIPv6, tt.ipv6TableCount)

tables, err = conn.ListTables()
if err != nil {
t.Fatalf("conn.ListTables() failed: %v", err)
}
// Check that the amount of chains for each IP family is as expected.
checkChains(t, conn, nftables.TableFamilyIPv4, tt.ipv4ChainCount)
checkChains(t, conn, nftables.TableFamilyIPv6, tt.ipv6ChainCount)

if err := runner.DelChains(); err != nil {
t.Fatalf("runner.DelChains() failed: %v", err)
}

if len(tables) != 4 {
t.Fatalf("len(tables) = %d, want 4", len(tables))
// Test that the tables as well as the default chains are still present.
checkChains(t, conn, nftables.TableFamilyIPv4, tt.ipv4ChainCountPostDelete)
checkChains(t, conn, nftables.TableFamilyIPv6, tt.ipv6ChainCountPostDelete)
checkTables(t, conn, nftables.TableFamilyIPv4, tt.ipv4TableCount)
checkTables(t, conn, nftables.TableFamilyIPv6, tt.ipv6TableCount)
}
}

Expand Down Expand Up @@ -665,7 +688,8 @@ func checkChainRules(t *testing.T, conn *nftables.Conn, chain *nftables.Chain, w
func TestNFTAddAndDelNetfilterBase(t *testing.T) {
conn := newSysConn(t)

runner := newFakeNftablesRunner(t, conn)
runner := newFakeNftablesRunnerWithConn(t, conn, true)

if err := runner.AddChains(); err != nil {
t.Fatalf("AddChains() failed: %v", err)
}
Expand Down Expand Up @@ -759,7 +783,7 @@ func findLoopBackRule(conn *nftables.Conn, proto nftables.TableFamily, table *nf
func TestNFTAddAndDelLoopbackRule(t *testing.T) {
conn := newSysConn(t)

runner := newFakeNftablesRunner(t, conn)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
if err := runner.AddChains(); err != nil {
t.Fatalf("AddChains() failed: %v", err)
}
Expand Down Expand Up @@ -817,7 +841,7 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) {

func TestNFTAddAndDelHookRule(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunner(t, conn)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
if err := runner.AddChains(); err != nil {
t.Fatalf("AddChains() failed: %v", err)
}
Expand Down Expand Up @@ -868,11 +892,11 @@ func (t *testFWDetector) nftDetect() (int, error) {
// postrouting chains are cleaned up.
func TestCreateDummyPostroutingChains(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunner(t, conn)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
if err := runner.createDummyPostroutingChains(); err != nil {
t.Fatalf("createDummyPostroutingChains() failed: %v", err)
}
for _, table := range runner.getNATTables() {
for _, table := range runner.getTables() {
nt, err := getTableIfExists(conn, table.Proto, tsDummyTableName)
if err != nil {
t.Fatalf("getTableIfExists() failed: %v", err)
Expand Down Expand Up @@ -929,3 +953,14 @@ func TestPickFirewallModeFromInstalledRules(t *testing.T) {
})
}
}

func newFakeNftablesRunnerWithConn(t *testing.T, conn *nftables.Conn, hasIPv6 bool) *nftablesRunner {
t.Helper()
if !hasIPv6 {
tstest.Replace(t, &checkIPv6ForTest, func(logger.Logf) error {
return errors.New("test: no IPv6")
})

}
return newNfTablesRunnerWithConn(t.Logf, conn)
}
Loading