diff --git a/Dockerfile b/Dockerfile index 7ed456b2b..76753ed3c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -96,6 +96,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ WIREGUARD_PRIVATE_KEY= \ WIREGUARD_PRESHARED_KEY= \ WIREGUARD_PUBLIC_KEY= \ + WIREGUARD_ALLOWED_IPS= \ WIREGUARD_ADDRESSES= \ WIREGUARD_MTU=1400 \ WIREGUARD_IMPLEMENTATION=auto \ diff --git a/internal/configuration/settings/errors.go b/internal/configuration/settings/errors.go index 3803fe7e0..c11b646a4 100644 --- a/internal/configuration/settings/errors.go +++ b/internal/configuration/settings/errors.go @@ -34,6 +34,8 @@ var ( ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small") ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid") ErrVPNTypeNotValid = errors.New("VPN type is not valid") + ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set") + ErrWireguardAllowedIPsNotSet = errors.New("allowed IPs is not set") ErrWireguardEndpointIPNotSet = errors.New("endpoint IP is not set") ErrWireguardEndpointPortNotAllowed = errors.New("endpoint port is not allowed") ErrWireguardEndpointPortNotSet = errors.New("endpoint port is not set") diff --git a/internal/configuration/settings/wireguard.go b/internal/configuration/settings/wireguard.go index e26fd1ed4..fdc99ea0d 100644 --- a/internal/configuration/settings/wireguard.go +++ b/internal/configuration/settings/wireguard.go @@ -25,6 +25,10 @@ type Wireguard struct { PreSharedKey *string `json:"pre_shared_key"` // Addresses are the Wireguard interface addresses. Addresses []netip.Prefix `json:"addresses"` + // AllowedIPs are the Wireguard allowed IPs. + // If left unset, they default to "0.0.0.0/0" + // and, if IPv6 is supported, "::0". + AllowedIPs []netip.Prefix `json:"allowed_ips"` // Interface is the name of the Wireguard interface // to create. It cannot be the empty string in the // internal state. @@ -89,13 +93,26 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error) } for i, ipNet := range w.Addresses { if !ipNet.IsValid() { - return fmt.Errorf("%w: for address at index %d: %s", - ErrWireguardInterfaceAddressNotSet, i, ipNet.String()) + return fmt.Errorf("%w: for address at index %d", + ErrWireguardInterfaceAddressNotSet, i) } if !ipv6Supported && ipNet.Addr().Is6() { return fmt.Errorf("%w: address %s", - ErrWireguardInterfaceAddressIPv6, ipNet) + ErrWireguardInterfaceAddressIPv6, ipNet.String()) + } + } + + // Validate AllowedIPs + // WARNING: do not check for IPv6 networks in the allowed IPs, + // the wireguard code will take care to ignore it. + if len(w.AllowedIPs) == 0 { + return fmt.Errorf("%w", ErrWireguardAllowedIPsNotSet) + } + for i, allowedIP := range w.AllowedIPs { + if !allowedIP.IsValid() { + return fmt.Errorf("%w: for allowed ip %d of %d", + ErrWireguardAllowedIPNotSet, i+1, len(w.AllowedIPs)) } } @@ -118,6 +135,7 @@ func (w *Wireguard) copy() (copied Wireguard) { PrivateKey: gosettings.CopyPointer(w.PrivateKey), PreSharedKey: gosettings.CopyPointer(w.PreSharedKey), Addresses: gosettings.CopySlice(w.Addresses), + AllowedIPs: gosettings.CopySlice(w.AllowedIPs), Interface: w.Interface, MTU: w.MTU, Implementation: w.Implementation, @@ -128,6 +146,7 @@ func (w *Wireguard) mergeWith(other Wireguard) { w.PrivateKey = gosettings.MergeWithPointer(w.PrivateKey, other.PrivateKey) w.PreSharedKey = gosettings.MergeWithPointer(w.PreSharedKey, other.PreSharedKey) w.Addresses = gosettings.MergeWithSlice(w.Addresses, other.Addresses) + w.AllowedIPs = gosettings.MergeWithSlice(w.AllowedIPs, other.AllowedIPs) w.Interface = gosettings.MergeWithString(w.Interface, other.Interface) w.MTU = gosettings.MergeWithNumber(w.MTU, other.MTU) w.Implementation = gosettings.MergeWithString(w.Implementation, other.Implementation) @@ -137,6 +156,7 @@ func (w *Wireguard) overrideWith(other Wireguard) { w.PrivateKey = gosettings.OverrideWithPointer(w.PrivateKey, other.PrivateKey) w.PreSharedKey = gosettings.OverrideWithPointer(w.PreSharedKey, other.PreSharedKey) w.Addresses = gosettings.OverrideWithSlice(w.Addresses, other.Addresses) + w.AllowedIPs = gosettings.OverrideWithSlice(w.AllowedIPs, other.AllowedIPs) w.Interface = gosettings.OverrideWithString(w.Interface, other.Interface) w.MTU = gosettings.OverrideWithNumber(w.MTU, other.MTU) w.Implementation = gosettings.OverrideWithString(w.Implementation, other.Implementation) @@ -150,6 +170,11 @@ func (w *Wireguard) setDefaults(vpnProvider string) { defaultNordVPNPrefix := netip.PrefixFrom(defaultNordVPNAddress, defaultNordVPNAddress.BitLen()) w.Addresses = gosettings.DefaultSlice(w.Addresses, []netip.Prefix{defaultNordVPNPrefix}) } + defaultAllowedIPs := []netip.Prefix{ + netip.PrefixFrom(netip.IPv4Unspecified(), 0), + netip.PrefixFrom(netip.IPv6Unspecified(), 0), + } + w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs) w.Interface = gosettings.DefaultString(w.Interface, "wg0") const defaultMTU = 1400 w.MTU = gosettings.DefaultNumber(w.MTU, defaultMTU) @@ -178,6 +203,11 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) { addressesNode.Appendf(address.String()) } + allowedIPsNode := node.Appendf("Allowed IPs:") + for _, allowedIP := range w.AllowedIPs { + allowedIPsNode.Appendf(allowedIP.String()) + } + interfaceNode := node.Appendf("Network interface: %s", w.Interface) interfaceNode.Appendf("MTU: %d", w.MTU) diff --git a/internal/configuration/sources/env/wireguard.go b/internal/configuration/sources/env/wireguard.go index 8a5b2edc1..20694afa8 100644 --- a/internal/configuration/sources/env/wireguard.go +++ b/internal/configuration/sources/env/wireguard.go @@ -19,6 +19,10 @@ func (s *Source) readWireguard() (wireguard settings.Wireguard, err error) { if err != nil { return wireguard, err // already wrapped } + wireguard.AllowedIPs, err = s.env.CSVNetipPrefixes("WIREGUARD_ALLOWED_IPS") + if err != nil { + return wireguard, err // already wrapped + } mtuPtr, err := s.env.Uint16Ptr("WIREGUARD_MTU") if err != nil { return wireguard, err diff --git a/internal/provider/utils/wireguard.go b/internal/provider/utils/wireguard.go index 3a50ca765..6f1a1f3c6 100644 --- a/internal/provider/utils/wireguard.go +++ b/internal/provider/utils/wireguard.go @@ -32,5 +32,13 @@ func BuildWireguardSettings(connection models.Connection, settings.Addresses = append(settings.Addresses, addressCopy) } + settings.AllowedIPs = make([]netip.Prefix, 0, len(userSettings.AllowedIPs)) + for _, allowedIP := range userSettings.AllowedIPs { + if !ipv6Supported && allowedIP.Addr().Is6() { + continue + } + settings.AllowedIPs = append(settings.AllowedIPs, allowedIP) + } + return settings } diff --git a/internal/provider/utils/wireguard_test.go b/internal/provider/utils/wireguard_test.go index b6dfb8de0..7d0aa71c2 100644 --- a/internal/provider/utils/wireguard_test.go +++ b/internal/provider/utils/wireguard_test.go @@ -34,6 +34,10 @@ func Test_BuildWireguardSettings(t *testing.T) { netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32), netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32), }, + AllowedIPs: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), + netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32), + }, Interface: "wg1", }, ipv6Supported: false, @@ -46,6 +50,9 @@ func Test_BuildWireguardSettings(t *testing.T) { Addresses: []netip.Prefix{ netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32), }, + AllowedIPs: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), + }, RulePriority: 101, IPv6: boolPtr(false), }, diff --git a/internal/wireguard/constructor_test.go b/internal/wireguard/constructor_test.go index a56ab9035..28409cd02 100644 --- a/internal/wireguard/constructor_test.go +++ b/internal/wireguard/constructor_test.go @@ -48,6 +48,9 @@ func Test_New(t *testing.T) { Addresses: []netip.Prefix{ netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32), }, + AllowedIPs: []netip.Prefix{ + allIPv4(), + }, FirewallMark: 100, MTU: device.DefaultMTU, IPv6: ptr(false), diff --git a/internal/wireguard/route.go b/internal/wireguard/route.go index c133a9379..9fcfedae0 100644 --- a/internal/wireguard/route.go +++ b/internal/wireguard/route.go @@ -3,11 +3,30 @@ package wireguard import ( "fmt" "net/netip" + "strings" "github.com/qdm12/gluetun/internal/netlink" ) -// TODO add IPv6 route if IPv6 is supported +func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix, + firewallMark int) (err error) { + for _, dst := range destinations { + err = w.addRoute(link, dst, firewallMark) + if err == nil { + continue + } + + if dst.Addr().Is6() && strings.Contains(err.Error(), "permission denied") { + w.logger.Errorf("cannot add route for IPv6 due to a permission denial. "+ + "Ignoring and continuing execution; "+ + "Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+ + "Full error string: %s", err) + continue + } + return fmt.Errorf("adding route for destination %s: %w", dst, err) + } + return nil +} func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix, firewallMark int) (err error) { diff --git a/internal/wireguard/run.go b/internal/wireguard/run.go index 342cb2088..cecd7ecca 100644 --- a/internal/wireguard/run.go +++ b/internal/wireguard/run.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net" - "strings" "github.com/qdm12/gluetun/internal/netlink" "golang.org/x/sys/unix" @@ -103,7 +102,7 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< return w.netlink.LinkSetDown(link) }) - err = w.addRoute(link, allIPv4(), w.settings.FirewallMark) + err = w.addRoutes(link, w.settings.AllowedIPs, w.settings.FirewallMark) if err != nil { waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err) return @@ -111,11 +110,13 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< if *w.settings.IPv6 { // requires net.ipv6.conf.all.disable_ipv6=0 - err = w.setupIPv6(link, &closers) + ruleCleanup6, err := w.addRule(w.settings.RulePriority, + w.settings.FirewallMark, unix.AF_INET6) if err != nil { - waitError <- fmt.Errorf("setting up IPv6: %w", err) + waitError <- fmt.Errorf("adding IPv6 rule: %w", err) return } + closers.add("removing IPv6 rule", stepOne, ruleCleanup6) } ruleCleanup, err := w.addRule(w.settings.RulePriority, @@ -132,31 +133,6 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< waitError <- waitAndCleanup() } -func (w *Wireguard) setupIPv6(link netlink.Link, closers *closers) (err error) { - // requires net.ipv6.conf.all.disable_ipv6=0 - err = w.addRoute(link, allIPv6(), w.settings.FirewallMark) - if err != nil { - if strings.Contains(err.Error(), "permission denied") { - w.logger.Errorf("cannot add route for IPv6 due to a permission denial. "+ - "Ignoring and continuing execution; "+ - "Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+ - "Full error string: %s", err) - return nil - } - return fmt.Errorf("%w: %s", ErrRouteAdd, err) - } - - ruleCleanup6, ruleErr := w.addRule( - w.settings.RulePriority, w.settings.FirewallMark, - unix.AF_INET6) - if ruleErr != nil { - return fmt.Errorf("adding IPv6 rule: %w", ruleErr) - } - - closers.add("removing IPv6 rule", stepOne, ruleCleanup6) - return nil -} - type waitAndCleanupFunc func() error func setupKernelSpace(ctx context.Context, diff --git a/internal/wireguard/settings.go b/internal/wireguard/settings.go index a43f14e9b..fcf659e98 100644 --- a/internal/wireguard/settings.go +++ b/internal/wireguard/settings.go @@ -26,6 +26,10 @@ type Settings struct { // Addresses assigned to the client. // Note IPv6 addresses are ignored if IPv6 is not supported. Addresses []netip.Prefix + // AllowedIPs is the IP networks to be routed through + // the Wireguard interface. + // Note IPv6 addresses are ignored if IPv6 is not supported. + AllowedIPs []netip.Prefix // FirewallMark to be used in routing tables and IP rules. // It defaults to 51820 if left to 0. FirewallMark int @@ -68,6 +72,13 @@ func (s *Settings) SetDefaults() { s.IPv6 = &ipv6 } + if len(s.AllowedIPs) == 0 { + s.AllowedIPs = append(s.AllowedIPs, allIPv4()) + if *s.IPv6 { + s.AllowedIPs = append(s.AllowedIPs, allIPv6()) + } + } + if s.Implementation == "" { const defaultImplementation = "auto" s.Implementation = defaultImplementation @@ -75,19 +86,22 @@ func (s *Settings) SetDefaults() { } var ( - ErrInterfaceNameInvalid = errors.New("invalid interface name") - ErrPrivateKeyMissing = errors.New("private key is missing") - ErrPrivateKeyInvalid = errors.New("cannot parse private key") - ErrPublicKeyMissing = errors.New("public key is missing") - ErrPublicKeyInvalid = errors.New("cannot parse public key") - ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key") - ErrEndpointAddrMissing = errors.New("endpoint address is missing") - ErrEndpointPortMissing = errors.New("endpoint port is missing") - ErrAddressMissing = errors.New("interface address is missing") - ErrAddressNotValid = errors.New("interface address is not valid") - ErrFirewallMarkMissing = errors.New("firewall mark is missing") - ErrMTUMissing = errors.New("MTU is missing") - ErrImplementationInvalid = errors.New("invalid implementation") + ErrInterfaceNameInvalid = errors.New("invalid interface name") + ErrPrivateKeyMissing = errors.New("private key is missing") + ErrPrivateKeyInvalid = errors.New("cannot parse private key") + ErrPublicKeyMissing = errors.New("public key is missing") + ErrPublicKeyInvalid = errors.New("cannot parse public key") + ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key") + ErrEndpointAddrMissing = errors.New("endpoint address is missing") + ErrEndpointPortMissing = errors.New("endpoint port is missing") + ErrAddressMissing = errors.New("interface address is missing") + ErrAddressNotValid = errors.New("interface address is not valid") + ErrAllowedIPsMissing = errors.New("allowed IPs are missing") + ErrAllowedIPNotValid = errors.New("allowed IP is not valid") + ErrAllowedIPv6NotSupported = errors.New("allowed IPv6 address not supported") + ErrFirewallMarkMissing = errors.New("firewall mark is missing") + ErrMTUMissing = errors.New("MTU is missing") + ErrImplementationInvalid = errors.New("invalid implementation") ) var interfaceNameRegexp = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) @@ -132,6 +146,20 @@ func (s *Settings) Check() (err error) { } } + if len(s.AllowedIPs) == 0 { + return fmt.Errorf("%w", ErrAllowedIPsMissing) + } + for i, allowedIP := range s.AllowedIPs { + switch { + case !allowedIP.IsValid(): + return fmt.Errorf("%w: for allowed IP %d of %d", + ErrAllowedIPNotValid, i+1, len(s.AllowedIPs)) + case allowedIP.Addr().Is6() && !*s.IPv6: + return fmt.Errorf("%w: for allowed IP %s", + ErrAllowedIPv6NotSupported, allowedIP) + } + } + if s.FirewallMark == 0 { return fmt.Errorf("%w", ErrFirewallMarkMissing) } @@ -247,5 +275,16 @@ func (s Settings) ToLines(settings ToLinesSettings) (lines []string) { } } + if len(s.AllowedIPs) > 0 { + lines = append(lines, fieldPrefix+"Allowed IPs:") + for i, allowedIP := range s.AllowedIPs { + prefix := fieldPrefix + if i == len(s.AllowedIPs)-1 { + prefix = lastFieldPrefix + } + lines = append(lines, indent+prefix+allowedIP.String()) + } + } + return lines } diff --git a/internal/wireguard/settings_test.go b/internal/wireguard/settings_test.go index cb29a5ce4..4d7363590 100644 --- a/internal/wireguard/settings_test.go +++ b/internal/wireguard/settings_test.go @@ -1,12 +1,10 @@ package wireguard import ( - "errors" "net/netip" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/device" ) @@ -23,6 +21,7 @@ func Test_Settings_SetDefaults(t *testing.T) { expected: Settings{ InterfaceName: "wg0", FirewallMark: 51820, + AllowedIPs: []netip.Prefix{allIPv4()}, MTU: device.DefaultMTU, IPv6: ptr(false), Implementation: "auto", @@ -36,6 +35,7 @@ func Test_Settings_SetDefaults(t *testing.T) { InterfaceName: "wg0", FirewallMark: 51820, Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820), + AllowedIPs: []netip.Prefix{allIPv4()}, MTU: device.DefaultMTU, IPv6: ptr(false), Implementation: "auto", @@ -46,6 +46,7 @@ func Test_Settings_SetDefaults(t *testing.T) { InterfaceName: "wg1", FirewallMark: 999, Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 9999), + AllowedIPs: []netip.Prefix{allIPv4()}, MTU: device.DefaultMTU, IPv6: ptr(true), Implementation: "userspace", @@ -54,6 +55,7 @@ func Test_Settings_SetDefaults(t *testing.T) { InterfaceName: "wg1", FirewallMark: 999, Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 9999), + AllowedIPs: []netip.Prefix{allIPv4()}, MTU: device.DefaultMTU, IPv6: ptr(true), Implementation: "userspace", @@ -82,37 +84,43 @@ func Test_Settings_Check(t *testing.T) { ) testCases := map[string]struct { - settings Settings - err error + settings Settings + errWrapped error + errMessage string }{ "empty settings": { - err: errors.New("invalid interface name: "), + errWrapped: ErrInterfaceNameInvalid, + errMessage: "invalid interface name: ", }, "bad interface name": { settings: Settings{ InterfaceName: "$H1T", }, - err: errors.New("invalid interface name: $H1T"), + errWrapped: ErrInterfaceNameInvalid, + errMessage: "invalid interface name: $H1T", }, "empty private key": { settings: Settings{ InterfaceName: "wg0", }, - err: ErrPrivateKeyMissing, + errWrapped: ErrPrivateKeyMissing, + errMessage: "private key is missing", }, "bad private key": { settings: Settings{ InterfaceName: "wg0", PrivateKey: "bad key", }, - err: ErrPrivateKeyInvalid, + errWrapped: ErrPrivateKeyInvalid, + errMessage: "cannot parse private key", }, "empty public key": { settings: Settings{ InterfaceName: "wg0", PrivateKey: validKey1, }, - err: ErrPublicKeyMissing, + errWrapped: ErrPublicKeyMissing, + errMessage: "public key is missing", }, "bad public key": { settings: Settings{ @@ -120,7 +128,8 @@ func Test_Settings_Check(t *testing.T) { PrivateKey: validKey1, PublicKey: "bad key", }, - err: errors.New("cannot parse public key: bad key"), + errWrapped: ErrPublicKeyInvalid, + errMessage: "cannot parse public key: bad key", }, "bad preshared key": { settings: Settings{ @@ -129,7 +138,8 @@ func Test_Settings_Check(t *testing.T) { PublicKey: validKey2, PreSharedKey: "bad key", }, - err: errors.New("cannot parse pre-shared key"), + errWrapped: ErrPreSharedKeyInvalid, + errMessage: "cannot parse pre-shared key", }, "invalid endpoint address": { settings: Settings{ @@ -137,7 +147,8 @@ func Test_Settings_Check(t *testing.T) { PrivateKey: validKey1, PublicKey: validKey2, }, - err: ErrEndpointAddrMissing, + errWrapped: ErrEndpointAddrMissing, + errMessage: "endpoint address is missing", }, "zero endpoint port": { settings: Settings{ @@ -146,7 +157,8 @@ func Test_Settings_Check(t *testing.T) { PublicKey: validKey2, Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 0), }, - err: ErrEndpointPortMissing, + errWrapped: ErrEndpointPortMissing, + errMessage: "endpoint port is missing", }, "no address": { settings: Settings{ @@ -155,7 +167,8 @@ func Test_Settings_Check(t *testing.T) { PublicKey: validKey2, Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820), }, - err: ErrAddressMissing, + errWrapped: ErrAddressMissing, + errMessage: "interface address is missing", }, "invalid address": { settings: Settings{ @@ -165,7 +178,53 @@ func Test_Settings_Check(t *testing.T) { Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820), Addresses: []netip.Prefix{{}}, }, - err: errors.New("interface address is not valid: for address 1 of 1"), + errWrapped: ErrAddressNotValid, + errMessage: "interface address is not valid: for address 1 of 1", + }, + + "no allowed IP": { + settings: Settings{ + InterfaceName: "wg0", + PrivateKey: validKey1, + PublicKey: validKey2, + Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820), + Addresses: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 24), + }, + }, + errWrapped: ErrAllowedIPsMissing, + errMessage: "allowed IPs are missing", + }, + "invalid allowed IP": { + settings: Settings{ + InterfaceName: "wg0", + PrivateKey: validKey1, + PublicKey: validKey2, + Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820), + Addresses: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 24), + }, + AllowedIPs: []netip.Prefix{{}}, + }, + errWrapped: ErrAllowedIPNotValid, + errMessage: "allowed IP is not valid: for allowed IP 1 of 1", + }, + "ipv6 allowed IP": { + settings: Settings{ + InterfaceName: "wg0", + PrivateKey: validKey1, + PublicKey: validKey2, + Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820), + Addresses: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 24), + }, + AllowedIPs: []netip.Prefix{ + allIPv6(), + }, + IPv6: ptrTo(false), + }, + errWrapped: ErrAllowedIPv6NotSupported, + errMessage: "allowed IPv6 address not supported: for allowed IP ::/0", }, "zero firewall mark": { settings: Settings{ @@ -173,11 +232,13 @@ func Test_Settings_Check(t *testing.T) { PrivateKey: validKey1, PublicKey: validKey2, Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820), + AllowedIPs: []netip.Prefix{allIPv4()}, Addresses: []netip.Prefix{ netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), }, }, - err: ErrFirewallMarkMissing, + errWrapped: ErrFirewallMarkMissing, + errMessage: "firewall mark is missing", }, "missing_MTU": { settings: Settings{ @@ -185,12 +246,14 @@ func Test_Settings_Check(t *testing.T) { PrivateKey: validKey1, PublicKey: validKey2, Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820), + AllowedIPs: []netip.Prefix{allIPv4()}, Addresses: []netip.Prefix{ netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), }, FirewallMark: 999, }, - err: ErrMTUMissing, + errWrapped: ErrMTUMissing, + errMessage: "MTU is missing", }, "invalid implementation": { settings: Settings{ @@ -198,6 +261,7 @@ func Test_Settings_Check(t *testing.T) { PrivateKey: validKey1, PublicKey: validKey2, Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820), + AllowedIPs: []netip.Prefix{allIPv4()}, Addresses: []netip.Prefix{ netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), }, @@ -205,7 +269,8 @@ func Test_Settings_Check(t *testing.T) { MTU: 1420, Implementation: "x", }, - err: errors.New("invalid implementation: x"), + errWrapped: ErrImplementationInvalid, + errMessage: "invalid implementation: x", }, "all valid": { settings: Settings{ @@ -213,11 +278,15 @@ func Test_Settings_Check(t *testing.T) { PrivateKey: validKey1, PublicKey: validKey2, Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820), + AllowedIPs: []netip.Prefix{ + allIPv6(), + }, Addresses: []netip.Prefix{ netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), }, FirewallMark: 999, MTU: 1420, + IPv6: ptrTo(true), Implementation: "userspace", }, }, @@ -230,11 +299,9 @@ func Test_Settings_Check(t *testing.T) { err := testCase.settings.Check() - if testCase.err != nil { - require.Error(t, err) - assert.Equal(t, testCase.err.Error(), err.Error()) - } else { - assert.NoError(t, err) + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) } }) }