Skip to content

Commit

Permalink
feat(portforward): port redirection with `VPN_PORT_FORWARDING_LISTENI…
Browse files Browse the repository at this point in the history
…NG_PORT`
  • Loading branch information
qdm12 committed Nov 23, 2023
1 parent 8318be3 commit 4105f74
Show file tree
Hide file tree
Showing 14 changed files with 226 additions and 6 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
# # Private Internet Access only:
PRIVATE_INTERNET_ACCESS_OPENVPN_ENCRYPTION_PRESET= \
VPN_PORT_FORWARDING=off \
VPN_PORT_FORWARDING_LISTENING_PORT=0 \
VPN_PORT_FORWARDING_PROVIDER= \
VPN_PORT_FORWARDING_STATUS_FILE="/tmp/gluetun/forwarded_port" \
# # Cyberghost only:
Expand Down
2 changes: 1 addition & 1 deletion internal/configuration/settings/dot.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (d DoT) toLinesNode() (node *gotree.Node) {
return node
}

update := "disabled"
update := "disabled" //nolint:goconst
if *d.UpdatePeriod > 0 {
update = "every " + d.UpdatePeriod.String()
}
Expand Down
21 changes: 18 additions & 3 deletions internal/configuration/settings/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ type PortForwarding struct {
// to write to a file. It cannot be nil for the
// internal state
Filepath *string `json:"status_file_path"`
// ListeningPort is the port traffic would be redirected to from the
// forwarded port. The redirection is disabled if it is set to 0, which
// is its default as well.
ListeningPort *uint16 `json:"listening_port"`
}

func (p PortForwarding) Validate(vpnProvider string) (err error) {
Expand Down Expand Up @@ -61,28 +65,32 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {

func (p *PortForwarding) Copy() (copied PortForwarding) {
return PortForwarding{
Enabled: gosettings.CopyPointer(p.Enabled),
Provider: gosettings.CopyPointer(p.Provider),
Filepath: gosettings.CopyPointer(p.Filepath),
Enabled: gosettings.CopyPointer(p.Enabled),
Provider: gosettings.CopyPointer(p.Provider),
Filepath: gosettings.CopyPointer(p.Filepath),
ListeningPort: gosettings.CopyPointer(p.ListeningPort),
}
}

func (p *PortForwarding) mergeWith(other PortForwarding) {
p.Enabled = gosettings.MergeWithPointer(p.Enabled, other.Enabled)
p.Provider = gosettings.MergeWithPointer(p.Provider, other.Provider)
p.Filepath = gosettings.MergeWithPointer(p.Filepath, other.Filepath)
p.ListeningPort = gosettings.MergeWithPointer(p.ListeningPort, other.ListeningPort)
}

func (p *PortForwarding) OverrideWith(other PortForwarding) {
p.Enabled = gosettings.OverrideWithPointer(p.Enabled, other.Enabled)
p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider)
p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath)
p.ListeningPort = gosettings.OverrideWithPointer(p.ListeningPort, other.ListeningPort)
}

func (p *PortForwarding) setDefaults() {
p.Enabled = gosettings.DefaultPointer(p.Enabled, false)
p.Provider = gosettings.DefaultPointer(p.Provider, "")
p.Filepath = gosettings.DefaultPointer(p.Filepath, "/tmp/gluetun/forwarded_port")
p.ListeningPort = gosettings.DefaultPointer(p.ListeningPort, 0)
}

func (p PortForwarding) String() string {
Expand All @@ -95,6 +103,13 @@ func (p PortForwarding) toLinesNode() (node *gotree.Node) {
}

node = gotree.New("Automatic port forwarding settings:")

listeningPort := "disabled"
if *p.ListeningPort != 0 {
listeningPort = fmt.Sprintf("%d", *p.ListeningPort)
}
node.Appendf("Redirection listening port: %s", listeningPort)

if *p.Provider == "" {
node.Appendf("Use port forwarding code for current provider")
} else {
Expand Down
5 changes: 5 additions & 0 deletions internal/configuration/sources/env/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,10 @@ func (s *Source) readPortForward() (
"PRIVATE_INTERNET_ACCESS_VPN_PORT_FORWARDING_STATUS_FILE",
))

portForwarding.ListeningPort, err = s.env.Uint16Ptr("VPN_PORT_FORWARDING_LISTENING_PORT")
if err != nil {
return portForwarding, err
}

return portForwarding, nil
}
23 changes: 23 additions & 0 deletions internal/firewall/enable.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ func (c *Config) disable(ctx context.Context) (err error) {
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv6 policies: %w", err)
}

const remove = true
err = c.redirectPorts(ctx, remove)
if err != nil {
return fmt.Errorf("removing port redirections: %w", err)
}

return nil
}

Expand Down Expand Up @@ -124,6 +131,11 @@ func (c *Config) enable(ctx context.Context) (err error) {
return err
}

err = c.redirectPorts(ctx, remove)
if err != nil {
return fmt.Errorf("redirecting ports: %w", err)
}

if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
return fmt.Errorf("running user defined post firewall rules: %w", err)
}
Expand Down Expand Up @@ -188,3 +200,14 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
}
return nil
}

func (c *Config) redirectPorts(ctx context.Context, remove bool) (err error) {
for _, portRedirection := range c.portRedirections {
err = c.redirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
portRedirection.destinationPort, remove)
if err != nil {
return err
}
}
return nil
}
1 change: 1 addition & 0 deletions internal/firewall/firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type Config struct { //nolint:maligned
vpnIntf string
outboundSubnets []netip.Prefix
allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping
portRedirections portRedirections
stateMutex sync.Mutex
}

Expand Down
32 changes: 32 additions & 0 deletions internal/firewall/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,38 @@ func (c *Config) acceptInputToPort(ctx context.Context, intf string, port uint16
})
}

// Used for VPN server side port forwarding, with intf set to the VPN tunnel interface.
func (c *Config) redirectPort(ctx context.Context, intf string,
sourcePort, destinationPort uint16, remove bool) (err error) {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}

err = c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -d 127.0.0.1 -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("-t nat %s PREROUTING %s -d 127.0.0.1 -p udp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
})
if err != nil {
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err)
}

err = c.runIP6tablesInstructions(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -d ::1 -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("-t nat %s PREROUTING %s -d ::1 -p udp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
})
if err != nil {
return fmt.Errorf("redirecting IPv6 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err)
}
return nil
}

func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) {
Expand Down
119 changes: 119 additions & 0 deletions internal/firewall/redirect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package firewall

import (
"context"
"fmt"
)

// RedirectPort redirects a source port to a destination port on the interface
// intf. If intf is empty, it is set to "*" which means all interfaces.
// If a redirection for the source port given already exists, it is removed first.
// If the destination port is zero, the redirection for the source port is removed
// and no new redirection is added.
func (c *Config) RedirectPort(ctx context.Context, intf string, sourcePort,
destinationPort uint16) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()

if sourcePort == 0 {
panic("source port cannot be 0")
}

newRedirection := portRedirection{
interfaceName: intf,
sourcePort: sourcePort,
destinationPort: destinationPort,
}

if !c.enabled {
c.logger.Info("firewall disabled, only updating redirected ports internal state")
if destinationPort == 0 {
c.portRedirections.remove(intf, sourcePort)
return nil
}
exists, conflict := c.portRedirections.check(newRedirection)
switch {
case exists:
return nil
case conflict != nil:
c.portRedirections.remove(conflict.interfaceName,
conflict.sourcePort)
}
c.portRedirections.append(newRedirection)
return nil
}

exists, conflict := c.portRedirections.check(newRedirection)
switch {
case exists:
return nil
case conflict != nil:
const remove = true
err = c.redirectPort(ctx, conflict.interfaceName, conflict.sourcePort,
conflict.destinationPort, remove)
if err != nil {
return fmt.Errorf("removing conflicting redirection: %w", err)
}
c.portRedirections.remove(conflict.interfaceName,
conflict.sourcePort)
}

const remove = false
err = c.redirectPort(ctx, intf, sourcePort, destinationPort, remove)
if err != nil {
return fmt.Errorf("redirecting port: %w", err)
}
c.portRedirections.append(newRedirection)

return nil
}

type portRedirection struct {
interfaceName string
sourcePort uint16
destinationPort uint16
}

type portRedirections []portRedirection

func (p *portRedirections) remove(intf string, sourcePort uint16) {
slice := *p
for i, redirection := range slice {
interfaceMatch := intf == "" || intf == redirection.interfaceName
if redirection.sourcePort == sourcePort && interfaceMatch {
// Remove redirection - note: order does not matter
slice[i] = slice[len(slice)-1]
slice = slice[:len(slice)-1]
}
}
*p = slice
}

func (p *portRedirections) check(dryRun portRedirection) (alreadyExists bool,
conflict *portRedirection) {
slice := *p
for _, redirection := range slice {
interfaceMatch := redirection.interfaceName == "" ||
redirection.interfaceName == dryRun.interfaceName

if redirection.sourcePort == dryRun.sourcePort &&
redirection.destinationPort == dryRun.destinationPort &&
interfaceMatch {
return true, nil
}

if redirection.sourcePort == dryRun.sourcePort &&
interfaceMatch {
// Source port has a redirection already for the same interface or all interfaces
return false, &redirection
}
}
return false, nil
}

// append should be called after running `check` to avoid rule conflicts.
func (p *portRedirections) append(newRedirection portRedirection) {
slice := *p
slice = append(slice, newRedirection)
*p = slice
}
2 changes: 2 additions & 0 deletions internal/portforward/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ type Routing interface {
type PortAllower interface {
SetAllowedPort(ctx context.Context, port uint16, intf string) (err error)
RemoveAllowedPort(ctx context.Context, port uint16) (err error)
RedirectPort(ctx context.Context, intf string, sourcePort,
destinationPort uint16) (err error)
}

type Logger interface {
Expand Down
5 changes: 3 additions & 2 deletions internal/portforward/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ func NewLoop(settings settings.PortForwarding, routing Routing,
settings: Settings{
VPNIsUp: ptrTo(false),
Service: service.Settings{
Enabled: settings.Enabled,
Filepath: *settings.Filepath,
Enabled: settings.Enabled,
Filepath: *settings.Filepath,
ListeningPort: *settings.ListeningPort,
},
},
routing: routing,
Expand Down
2 changes: 2 additions & 0 deletions internal/portforward/service/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
type PortAllower interface {
SetAllowedPort(ctx context.Context, port uint16, intf string) (err error)
RemoveAllowedPort(ctx context.Context, port uint16) (err error)
RedirectPort(ctx context.Context, intf string, sourcePort,
destinationPort uint16) (err error)
}

type Routing interface {
Expand Down
3 changes: 3 additions & 0 deletions internal/portforward/service/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Settings struct {
Filepath string
Interface string // needed for PIA and ProtonVPN, tun0 for example
ServerName string // needed for PIA
ListeningPort uint16
}

func (s Settings) Copy() (copied Settings) {
Expand All @@ -22,6 +23,7 @@ func (s Settings) Copy() (copied Settings) {
copied.Filepath = s.Filepath
copied.Interface = s.Interface
copied.ServerName = s.ServerName
copied.ListeningPort = s.ListeningPort
return copied
}

Expand All @@ -31,6 +33,7 @@ func (s *Settings) OverrideWith(update Settings) {
s.Filepath = gosettings.OverrideWithString(s.Filepath, update.Filepath)
s.Interface = gosettings.OverrideWithString(s.Interface, update.Interface)
s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName)
s.ListeningPort = gosettings.OverrideWithNumber(s.ListeningPort, update.ListeningPort)
}

var (
Expand Down
7 changes: 7 additions & 0 deletions internal/portforward/service/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
return nil, fmt.Errorf("allowing port in firewall: %w", err)
}

if s.settings.ListeningPort != 0 {
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort)
if err != nil {
return nil, fmt.Errorf("redirecting port in firewall: %w", err)
}
}

err = s.writePortForwardedFile(port)
if err != nil {
_ = s.cleanup()
Expand Down
9 changes: 9 additions & 0 deletions internal/portforward/service/stop.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ func (s *Service) cleanup() (err error) {
return fmt.Errorf("blocking previous port in firewall: %w", err)
}

if s.settings.ListeningPort != 0 {
ctx := context.Background()
const listeningPort = 0 // 0 to clear the redirection
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, s.port, listeningPort)
if err != nil {
return fmt.Errorf("removing previous port redirection in firewall: %w", err)
}
}

s.port = 0

filepath := s.settings.Filepath
Expand Down

0 comments on commit 4105f74

Please sign in to comment.