Skip to content

Commit

Permalink
util/linuxfw: fix IPv6 NAT availability check for nftables
Browse files Browse the repository at this point in the history
When running firewall in nftables mode,
there is no need for a separate NAT availability check
(unlike with iptables, there are no hosts that support nftables, but not IPv6 NAT - see #11353).
This change fixes a firewall NAT availability check that was using the no-longer set ipv6NATAvailable field
by removing the field and using a method that, for nftables, just checks that IPv6 is available.

Updates #12008

Signed-off-by: Irbe Krumina <irbe@tailscale.com>
  • Loading branch information
irbekrm committed May 5, 2024
1 parent ed843e6 commit 1155420
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 34 deletions.
2 changes: 1 addition & 1 deletion util/linuxfw/iptables_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func newIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
}
supportsV6Filter = checkSupportsV6Filter(ipt6, logf)
supportsV6NAT = checkSupportsV6NAT(ipt6, logf)
logf("v6 = %v, v6filter = %v, v6nat = %v", supportsV6, supportsV6Filter, supportsV6NAT)
logf("netfilter running in iptables mode v6 = %v, v6filter = %v, v6nat = %v", supportsV6, supportsV6Filter, supportsV6NAT)
}
return &iptablesRunner{
ipt4: ipt4,
Expand Down
38 changes: 11 additions & 27 deletions util/linuxfw/nftables_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type chainInfo struct {
chainPolicy *nftables.ChainPolicy
}

// nftable knows how to run commands to create/delete nftables rules for a
// particular IP familiy (Proto).
type nftable struct {
Proto nftables.TableFamily
Filter *nftables.Table
Expand Down Expand Up @@ -72,8 +74,7 @@ type nftablesRunner struct {
nft4 *nftable
nft6 *nftable

v6Available bool
v6NATAvailable bool
v6Available bool
}

func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) {
Expand Down Expand Up @@ -601,8 +602,8 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {

if supportsV6 {
nft6 = &nftable{Proto: nftables.TableFamilyIPv6}
logf("v6nat availability: true")
}
logf("netfilter running in nftables mode, v6 = %v", supportsV6)

// TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables

Expand Down Expand Up @@ -829,24 +830,15 @@ func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
return n.conn.Flush()
}

// getTables gets the available nftable in nftables runner.
// getTables returns a list of nftables runners for IP families that this host
// was determined to support (either IPv4 and IPv6 or just IPv4).
func (n *nftablesRunner) getTables() []*nftable {
if n.v6Available {
if n.HasIPV6() {
return []*nftable{n.nft4, n.nft6}
}
return []*nftable{n.nft4}
}

// getNATTables gets the available nftable in nftables runner.
// If the system does not support IPv6 NAT, only the IPv4 nftable
// will be returned.
func (n *nftablesRunner) getNATTables() []*nftable {
if n.v6NATAvailable {
return n.getTables()
}
return []*nftable{n.nft4}
}

// AddChains creates custom Tailscale chains in netfilter via nftables
// if the ts-chain doesn't already exist.
func (n *nftablesRunner) AddChains() error {
Expand Down Expand Up @@ -875,9 +867,7 @@ func (n *nftablesRunner) AddChains() error {
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create input chain: %w", err)
}
}

for _, table := range n.getNATTables() {
// Create the nat table if it doesn't exist, this table name is the same
// as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump
Expand Down Expand Up @@ -915,7 +905,7 @@ const (
// can be used. It cleans up the dummy chains after creation.
func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) {
polAccept := ptr.To(nftables.ChainPolicyAccept)
for _, table := range n.getNATTables() {
for _, table := range n.getTables() {
nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName)
if err != nil {
return fmt.Errorf("create nat table: %w", err)
Expand Down Expand Up @@ -972,7 +962,7 @@ func (n *nftablesRunner) DelChains() error {
return fmt.Errorf("delete chain: %w", err)
}

if n.v6NATAvailable {
if n.HasIPV6NAT() {
if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
return fmt.Errorf("delete chain: %w", err)
}
Expand Down Expand Up @@ -1038,9 +1028,7 @@ func (n *nftablesRunner) AddHooks() error {
if err != nil {
return fmt.Errorf("Addhook: %w", err)
}
}

for _, table := range n.getNATTables() {
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
Expand Down Expand Up @@ -1094,9 +1082,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
if err != nil {
return fmt.Errorf("delhook: %w", err)
}
}

for _, table := range n.getNATTables() {
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
if err != nil {
return fmt.Errorf("get INPUT chain: %w", err)
Expand Down Expand Up @@ -1604,9 +1590,7 @@ func (n *nftablesRunner) DelBase() error {
return fmt.Errorf("get forward chain: %v", err)
}
conn.FlushChain(forwardChain)
}

for _, table := range n.getNATTables() {
postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %v", err)
Expand Down Expand Up @@ -1676,7 +1660,7 @@ func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, cha
func (n *nftablesRunner) AddSNATRule() error {
conn := n.conn

for _, table := range n.getNATTables() {
for _, table := range n.getTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err)
Expand Down Expand Up @@ -1719,7 +1703,7 @@ func (n *nftablesRunner) DelSNATRule() error {
&expr.Masq{},
}

for _, table := range n.getNATTables() {
for _, table := range n.getTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err)
Expand Down
11 changes: 5 additions & 6 deletions util/linuxfw/nftables_runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,11 +508,10 @@ func newFakeNftablesRunner(t *testing.T, conn *nftables.Conn) *nftablesRunner {
nft6 := &nftable{Proto: nftables.TableFamilyIPv6}

return &nftablesRunner{
conn: conn,
nft4: nft4,
nft6: nft6,
v6Available: true,
v6NATAvailable: true,
conn: conn,
nft4: nft4,
nft6: nft6,
v6Available: true,
}
}

Expand Down Expand Up @@ -872,7 +871,7 @@ func TestCreateDummyPostroutingChains(t *testing.T) {
if err := runner.createDummyPostroutingChains(); err != nil {
t.Fatalf("createDummyPostroutingChains() failed: %v", err)
}
for _, table := range runner.getNATTables() {
for _, table := range runner.getTables() {
nt, err := getTableIfExists(conn, table.Proto, tsDummyTableName)
if err != nil {
t.Fatalf("getTableIfExists() failed: %v", err)
Expand Down

0 comments on commit 1155420

Please sign in to comment.