Skip to content

Commit

Permalink
feat(wireguard): WIREGUARD_ALLOWED_IPS variable
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Jul 6, 2023
1 parent 9c0f187 commit 6afa315
Show file tree
Hide file tree
Showing 11 changed files with 225 additions and 69 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
WIREGUARD_PRIVATE_KEY= \
WIREGUARD_PRESHARED_KEY= \
WIREGUARD_PUBLIC_KEY= \
WIREGUARD_ALLOWED_IPS= \
WIREGUARD_ADDRESSES= \
WIREGUARD_MTU=1400 \
WIREGUARD_IMPLEMENTATION=auto \
Expand Down
2 changes: 2 additions & 0 deletions internal/configuration/settings/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ var (
ErrUpdaterPeriodTooSmall = errors.New("VPN server data updater period is too small")
ErrVPNProviderNameNotValid = errors.New("VPN provider name is not valid")
ErrVPNTypeNotValid = errors.New("VPN type is not valid")
ErrWireguardAllowedIPNotSet = errors.New("allowed IP is not set")
ErrWireguardAllowedIPsNotSet = errors.New("allowed IPs is not set")
ErrWireguardEndpointIPNotSet = errors.New("endpoint IP is not set")
ErrWireguardEndpointPortNotAllowed = errors.New("endpoint port is not allowed")
ErrWireguardEndpointPortNotSet = errors.New("endpoint port is not set")
Expand Down
36 changes: 33 additions & 3 deletions internal/configuration/settings/wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ type Wireguard struct {
PreSharedKey *string `json:"pre_shared_key"`
// Addresses are the Wireguard interface addresses.
Addresses []netip.Prefix `json:"addresses"`
// AllowedIPs are the Wireguard allowed IPs.
// If left unset, they default to "0.0.0.0/0"
// and, if IPv6 is supported, "::0".
AllowedIPs []netip.Prefix `json:"allowed_ips"`
// Interface is the name of the Wireguard interface
// to create. It cannot be the empty string in the
// internal state.
Expand Down Expand Up @@ -89,13 +93,26 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error)
}
for i, ipNet := range w.Addresses {
if !ipNet.IsValid() {
return fmt.Errorf("%w: for address at index %d: %s",
ErrWireguardInterfaceAddressNotSet, i, ipNet.String())
return fmt.Errorf("%w: for address at index %d",
ErrWireguardInterfaceAddressNotSet, i)
}

if !ipv6Supported && ipNet.Addr().Is6() {
return fmt.Errorf("%w: address %s",
ErrWireguardInterfaceAddressIPv6, ipNet)
ErrWireguardInterfaceAddressIPv6, ipNet.String())
}
}

// Validate AllowedIPs
// WARNING: do not check for IPv6 networks in the allowed IPs,
// the wireguard code will take care to ignore it.
if len(w.AllowedIPs) == 0 {
return fmt.Errorf("%w", ErrWireguardAllowedIPsNotSet)
}
for i, allowedIP := range w.AllowedIPs {
if !allowedIP.IsValid() {
return fmt.Errorf("%w: for allowed ip %d of %d",
ErrWireguardAllowedIPNotSet, i+1, len(w.AllowedIPs))
}
}

Expand All @@ -118,6 +135,7 @@ func (w *Wireguard) copy() (copied Wireguard) {
PrivateKey: gosettings.CopyPointer(w.PrivateKey),
PreSharedKey: gosettings.CopyPointer(w.PreSharedKey),
Addresses: gosettings.CopySlice(w.Addresses),
AllowedIPs: gosettings.CopySlice(w.AllowedIPs),
Interface: w.Interface,
MTU: w.MTU,
Implementation: w.Implementation,
Expand All @@ -128,6 +146,7 @@ func (w *Wireguard) mergeWith(other Wireguard) {
w.PrivateKey = gosettings.MergeWithPointer(w.PrivateKey, other.PrivateKey)
w.PreSharedKey = gosettings.MergeWithPointer(w.PreSharedKey, other.PreSharedKey)
w.Addresses = gosettings.MergeWithSlice(w.Addresses, other.Addresses)
w.AllowedIPs = gosettings.MergeWithSlice(w.AllowedIPs, other.AllowedIPs)
w.Interface = gosettings.MergeWithString(w.Interface, other.Interface)
w.MTU = gosettings.MergeWithNumber(w.MTU, other.MTU)
w.Implementation = gosettings.MergeWithString(w.Implementation, other.Implementation)
Expand All @@ -137,6 +156,7 @@ func (w *Wireguard) overrideWith(other Wireguard) {
w.PrivateKey = gosettings.OverrideWithPointer(w.PrivateKey, other.PrivateKey)
w.PreSharedKey = gosettings.OverrideWithPointer(w.PreSharedKey, other.PreSharedKey)
w.Addresses = gosettings.OverrideWithSlice(w.Addresses, other.Addresses)
w.AllowedIPs = gosettings.OverrideWithSlice(w.AllowedIPs, other.AllowedIPs)
w.Interface = gosettings.OverrideWithString(w.Interface, other.Interface)
w.MTU = gosettings.OverrideWithNumber(w.MTU, other.MTU)
w.Implementation = gosettings.OverrideWithString(w.Implementation, other.Implementation)
Expand All @@ -150,6 +170,11 @@ func (w *Wireguard) setDefaults(vpnProvider string) {
defaultNordVPNPrefix := netip.PrefixFrom(defaultNordVPNAddress, defaultNordVPNAddress.BitLen())
w.Addresses = gosettings.DefaultSlice(w.Addresses, []netip.Prefix{defaultNordVPNPrefix})
}
defaultAllowedIPs := []netip.Prefix{
netip.PrefixFrom(netip.IPv4Unspecified(), 0),
netip.PrefixFrom(netip.IPv6Unspecified(), 0),
}
w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs)
w.Interface = gosettings.DefaultString(w.Interface, "wg0")
const defaultMTU = 1400
w.MTU = gosettings.DefaultNumber(w.MTU, defaultMTU)
Expand Down Expand Up @@ -178,6 +203,11 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
addressesNode.Appendf(address.String())
}

allowedIPsNode := node.Appendf("Allowed IPs:")
for _, allowedIP := range w.AllowedIPs {
allowedIPsNode.Appendf(allowedIP.String())
}

interfaceNode := node.Appendf("Network interface: %s", w.Interface)
interfaceNode.Appendf("MTU: %d", w.MTU)

Expand Down
4 changes: 4 additions & 0 deletions internal/configuration/sources/env/wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ func (s *Source) readWireguard() (wireguard settings.Wireguard, err error) {
if err != nil {
return wireguard, err // already wrapped
}
wireguard.AllowedIPs, err = s.env.CSVNetipPrefixes("WIREGUARD_ALLOWED_IPS")
if err != nil {
return wireguard, err // already wrapped
}
mtuPtr, err := s.env.Uint16Ptr("WIREGUARD_MTU")
if err != nil {
return wireguard, err
Expand Down
8 changes: 8 additions & 0 deletions internal/provider/utils/wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,13 @@ func BuildWireguardSettings(connection models.Connection,
settings.Addresses = append(settings.Addresses, addressCopy)
}

settings.AllowedIPs = make([]netip.Prefix, 0, len(userSettings.AllowedIPs))
for _, allowedIP := range userSettings.AllowedIPs {
if !ipv6Supported && allowedIP.Addr().Is6() {
continue
}
settings.AllowedIPs = append(settings.AllowedIPs, allowedIP)
}

return settings
}
7 changes: 7 additions & 0 deletions internal/provider/utils/wireguard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ func Test_BuildWireguardSettings(t *testing.T) {
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
},
AllowedIPs: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32),
},
Interface: "wg1",
},
ipv6Supported: false,
Expand All @@ -46,6 +50,9 @@ func Test_BuildWireguardSettings(t *testing.T) {
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
},
AllowedIPs: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
},
RulePriority: 101,
IPv6: boolPtr(false),
},
Expand Down
3 changes: 3 additions & 0 deletions internal/wireguard/constructor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ func Test_New(t *testing.T) {
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32),
},
AllowedIPs: []netip.Prefix{
allIPv4(),
},
FirewallMark: 100,
MTU: device.DefaultMTU,
IPv6: ptr(false),
Expand Down
21 changes: 20 additions & 1 deletion internal/wireguard/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,30 @@ package wireguard
import (
"fmt"
"net/netip"
"strings"

"github.com/qdm12/gluetun/internal/netlink"
)

// TODO add IPv6 route if IPv6 is supported
func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
firewallMark int) (err error) {
for _, dst := range destinations {
err = w.addRoute(link, dst, firewallMark)
if err == nil {
continue
}

if dst.Addr().Is6() && strings.Contains(err.Error(), "permission denied") {
w.logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
"Ignoring and continuing execution; "+
"Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+
"Full error string: %s", err)
continue
}
return fmt.Errorf("adding route for destination %s: %w", dst, err)
}
return nil
}

func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
firewallMark int) (err error) {
Expand Down
34 changes: 5 additions & 29 deletions internal/wireguard/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net"
"strings"

"github.com/qdm12/gluetun/internal/netlink"
"golang.org/x/sys/unix"
Expand Down Expand Up @@ -103,19 +102,21 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
return w.netlink.LinkSetDown(link)
})

err = w.addRoute(link, allIPv4(), w.settings.FirewallMark)
err = w.addRoutes(link, w.settings.AllowedIPs, w.settings.FirewallMark)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
return
}

if *w.settings.IPv6 {
// requires net.ipv6.conf.all.disable_ipv6=0
err = w.setupIPv6(link, &closers)
ruleCleanup6, err := w.addRule(w.settings.RulePriority,
w.settings.FirewallMark, unix.AF_INET6)
if err != nil {
waitError <- fmt.Errorf("setting up IPv6: %w", err)
waitError <- fmt.Errorf("adding IPv6 rule: %w", err)
return
}
closers.add("removing IPv6 rule", stepOne, ruleCleanup6)
}

ruleCleanup, err := w.addRule(w.settings.RulePriority,
Expand All @@ -132,31 +133,6 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
waitError <- waitAndCleanup()
}

func (w *Wireguard) setupIPv6(link netlink.Link, closers *closers) (err error) {
// requires net.ipv6.conf.all.disable_ipv6=0
err = w.addRoute(link, allIPv6(), w.settings.FirewallMark)
if err != nil {
if strings.Contains(err.Error(), "permission denied") {
w.logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
"Ignoring and continuing execution; "+
"Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+
"Full error string: %s", err)
return nil
}
return fmt.Errorf("%w: %s", ErrRouteAdd, err)
}

ruleCleanup6, ruleErr := w.addRule(
w.settings.RulePriority, w.settings.FirewallMark,
unix.AF_INET6)
if ruleErr != nil {
return fmt.Errorf("adding IPv6 rule: %w", ruleErr)
}

closers.add("removing IPv6 rule", stepOne, ruleCleanup6)
return nil
}

type waitAndCleanupFunc func() error

func setupKernelSpace(ctx context.Context,
Expand Down
65 changes: 52 additions & 13 deletions internal/wireguard/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ type Settings struct {
// Addresses assigned to the client.
// Note IPv6 addresses are ignored if IPv6 is not supported.
Addresses []netip.Prefix
// AllowedIPs is the IP networks to be routed through
// the Wireguard interface.
// Note IPv6 addresses are ignored if IPv6 is not supported.
AllowedIPs []netip.Prefix
// FirewallMark to be used in routing tables and IP rules.
// It defaults to 51820 if left to 0.
FirewallMark int
Expand Down Expand Up @@ -68,26 +72,36 @@ func (s *Settings) SetDefaults() {
s.IPv6 = &ipv6
}

if len(s.AllowedIPs) == 0 {
s.AllowedIPs = append(s.AllowedIPs, allIPv4())
if *s.IPv6 {
s.AllowedIPs = append(s.AllowedIPs, allIPv6())
}
}

if s.Implementation == "" {
const defaultImplementation = "auto"
s.Implementation = defaultImplementation
}
}

var (
ErrInterfaceNameInvalid = errors.New("invalid interface name")
ErrPrivateKeyMissing = errors.New("private key is missing")
ErrPrivateKeyInvalid = errors.New("cannot parse private key")
ErrPublicKeyMissing = errors.New("public key is missing")
ErrPublicKeyInvalid = errors.New("cannot parse public key")
ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key")
ErrEndpointAddrMissing = errors.New("endpoint address is missing")
ErrEndpointPortMissing = errors.New("endpoint port is missing")
ErrAddressMissing = errors.New("interface address is missing")
ErrAddressNotValid = errors.New("interface address is not valid")
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
ErrMTUMissing = errors.New("MTU is missing")
ErrImplementationInvalid = errors.New("invalid implementation")
ErrInterfaceNameInvalid = errors.New("invalid interface name")
ErrPrivateKeyMissing = errors.New("private key is missing")
ErrPrivateKeyInvalid = errors.New("cannot parse private key")
ErrPublicKeyMissing = errors.New("public key is missing")
ErrPublicKeyInvalid = errors.New("cannot parse public key")
ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key")
ErrEndpointAddrMissing = errors.New("endpoint address is missing")
ErrEndpointPortMissing = errors.New("endpoint port is missing")
ErrAddressMissing = errors.New("interface address is missing")
ErrAddressNotValid = errors.New("interface address is not valid")
ErrAllowedIPsMissing = errors.New("allowed IPs are missing")
ErrAllowedIPNotValid = errors.New("allowed IP is not valid")
ErrAllowedIPv6NotSupported = errors.New("allowed IPv6 address not supported")
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
ErrMTUMissing = errors.New("MTU is missing")
ErrImplementationInvalid = errors.New("invalid implementation")
)

var interfaceNameRegexp = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
Expand Down Expand Up @@ -132,6 +146,20 @@ func (s *Settings) Check() (err error) {
}
}

if len(s.AllowedIPs) == 0 {
return fmt.Errorf("%w", ErrAllowedIPsMissing)
}
for i, allowedIP := range s.AllowedIPs {
switch {
case !allowedIP.IsValid():
return fmt.Errorf("%w: for allowed IP %d of %d",
ErrAllowedIPNotValid, i+1, len(s.AllowedIPs))
case allowedIP.Addr().Is6() && !*s.IPv6:
return fmt.Errorf("%w: for allowed IP %s",
ErrAllowedIPv6NotSupported, allowedIP)
}
}

if s.FirewallMark == 0 {
return fmt.Errorf("%w", ErrFirewallMarkMissing)
}
Expand Down Expand Up @@ -247,5 +275,16 @@ func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
}
}

if len(s.AllowedIPs) > 0 {
lines = append(lines, fieldPrefix+"Allowed IPs:")
for i, allowedIP := range s.AllowedIPs {
prefix := fieldPrefix
if i == len(s.AllowedIPs)-1 {
prefix = lastFieldPrefix
}
lines = append(lines, indent+prefix+allowedIP.String())
}
}

return lines
}
Loading

0 comments on commit 6afa315

Please sign in to comment.