Skip to content

Commit

Permalink
feat(protonvpn): port forwarding support with NAT-PMP (#1543)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicholas Xavier <nicho@nicho.dev>
  • Loading branch information
qdm12 and nichogx committed Jun 30, 2023
1 parent fae6544 commit 8ad16cd
Show file tree
Hide file tree
Showing 19 changed files with 1,118 additions and 17 deletions.
5 changes: 4 additions & 1 deletion internal/configuration/settings/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ func (p PortForwarding) validate(vpnProvider string) (err error) {
if *p.Provider != "" {
providerSelected = *p.Provider
}
validProviders := []string{providers.PrivateInternetAccess}
validProviders := []string{
providers.PrivateInternetAccess,
providers.Protonvpn,
}
if err = validate.IsOneOf(providerSelected, validProviders...); err != nil {
return fmt.Errorf("%w: %w", ErrPortForwardingEnabled, err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/configuration/settings/wireguardselection.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type WireguardSelection struct {
// It is only used with VPN providers generating Wireguard
// configurations specific to each server and user.
// To indicate it should not be used, it should be set
// to netaddr.IPv4Unspecified(). It can never be the zero value
// to netip.IPv4Unspecified(). It can never be the zero value
// in the internal state.
EndpointIP netip.Addr `json:"endpoint_ip"`
// EndpointPort is a the server port to use for the VPN server.
Expand Down
94 changes: 94 additions & 0 deletions internal/natpmp/checks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package natpmp

import (
"encoding/binary"
"errors"
"fmt"
)

var (
ErrRequestSizeTooSmall = errors.New("message size is too small")
)

func checkRequest(request []byte) (err error) {
const minMessageSize = 2 // version number + operation code
if len(request) < minMessageSize {
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)",
ErrRequestSizeTooSmall, minMessageSize, len(request))
}

return nil
}

var (
ErrResponseSizeTooSmall = errors.New("response size is too small")
ErrResponseSizeUnexpected = errors.New("response size is unexpected")
ErrProtocolVersionUnknown = errors.New("protocol version is unknown")
ErrOperationCodeUnexpected = errors.New("operation code is unexpected")
)

func checkResponse(response []byte, expectedOperationCode byte,
expectedResponseSize uint) (err error) {
const minResponseSize = 4
if len(response) < minResponseSize {
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)",
ErrResponseSizeTooSmall, minResponseSize, len(response))
}

if len(response) != int(expectedResponseSize) {
return fmt.Errorf("%w: expected %d bytes and got %d byte(s)",
ErrResponseSizeUnexpected, expectedResponseSize, len(response))
}

protocolVersion := response[0]
if protocolVersion != 0 {
return fmt.Errorf("%w: %d", ErrProtocolVersionUnknown, protocolVersion)
}

operationCode := response[1]
if operationCode != expectedOperationCode {
return fmt.Errorf("%w: expected 0x%x and got 0x%x",
ErrOperationCodeUnexpected, expectedOperationCode, operationCode)
}

resultCode := binary.BigEndian.Uint16(response[2:4])
err = checkResultCode(resultCode)
if err != nil {
return fmt.Errorf("result code: %w", err)
}

return nil
}

var (
ErrVersionNotSupported = errors.New("version is not supported")
ErrNotAuthorized = errors.New("not authorized")
ErrNetworkFailure = errors.New("network failure")
ErrOutOfResources = errors.New("out of resources")
ErrOperationCodeNotSupported = errors.New("operation code is not supported")
ErrResultCodeUnknown = errors.New("result code is unknown")
)

// checkResultCode checks the result code and returns an error
// if the result code is not a success (0).
// See https://www.ietf.org/rfc/rfc6886.html#section-3.5
//
//nolint:gomnd
func checkResultCode(resultCode uint16) (err error) {
switch resultCode {
case 0:
return nil
case 1:
return fmt.Errorf("%w", ErrVersionNotSupported)
case 2:
return fmt.Errorf("%w", ErrNotAuthorized)
case 3:
return fmt.Errorf("%w", ErrNetworkFailure)
case 4:
return fmt.Errorf("%w", ErrOutOfResources)
case 5:
return fmt.Errorf("%w", ErrOperationCodeNotSupported)
default:
return fmt.Errorf("%w: %d", ErrResultCodeUnknown, resultCode)
}
}
161 changes: 161 additions & 0 deletions internal/natpmp/checks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package natpmp

import (
"testing"

"github.com/stretchr/testify/assert"
)

func Test_checkRequest(t *testing.T) {
t.Parallel()

testCases := map[string]struct {
request []byte
err error
errMessage string
}{
"too_short": {
request: []byte{1},
err: ErrRequestSizeTooSmall,
errMessage: "message size is too small: need at least 2 bytes and got 1 byte(s)",
},
"success": {
request: []byte{0, 0},
},
}

for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()

err := checkRequest(testCase.request)

assert.ErrorIs(t, err, testCase.err)
if testCase.err != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}

func Test_checkResponse(t *testing.T) {
t.Parallel()

testCases := map[string]struct {
response []byte
expectedOperationCode byte
expectedResponseSize uint
err error
errMessage string
}{
"too_short": {
response: []byte{1},
err: ErrResponseSizeTooSmall,
errMessage: "response size is too small: need at least 4 bytes and got 1 byte(s)",
},
"size_mismatch": {
response: []byte{0, 0, 0, 0},
expectedResponseSize: 5,
err: ErrResponseSizeUnexpected,
errMessage: "response size is unexpected: expected 5 bytes and got 4 byte(s)",
},
"protocol_unknown": {
response: []byte{1, 0, 0, 0},
expectedResponseSize: 4,
err: ErrProtocolVersionUnknown,
errMessage: "protocol version is unknown: 1",
},
"operation_code_unexpected": {
response: []byte{0, 2, 0, 0},
expectedOperationCode: 1,
expectedResponseSize: 4,
err: ErrOperationCodeUnexpected,
errMessage: "operation code is unexpected: expected 0x1 and got 0x2",
},
"result_code_failure": {
response: []byte{0, 1, 0, 1},
expectedOperationCode: 1,
expectedResponseSize: 4,
err: ErrVersionNotSupported,
errMessage: "result code: version is not supported",
},
"success": {
response: []byte{0, 1, 0, 0},
expectedOperationCode: 1,
expectedResponseSize: 4,
},
}

for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()

err := checkResponse(testCase.response,
testCase.expectedOperationCode,
testCase.expectedResponseSize)

assert.ErrorIs(t, err, testCase.err)
if testCase.err != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}

func Test_checkResultCode(t *testing.T) {
t.Parallel()

testCases := map[string]struct {
resultCode uint16
err error
errMessage string
}{
"success": {},
"version_unsupported": {
resultCode: 1,
err: ErrVersionNotSupported,
errMessage: "version is not supported",
},
"not_authorized": {
resultCode: 2,
err: ErrNotAuthorized,
errMessage: "not authorized",
},
"network_failure": {
resultCode: 3,
err: ErrNetworkFailure,
errMessage: "network failure",
},
"out_of_resources": {
resultCode: 4,
err: ErrOutOfResources,
errMessage: "out of resources",
},
"unsupported_operation_code": {
resultCode: 5,
err: ErrOperationCodeNotSupported,
errMessage: "operation code is not supported",
},
"unknown": {
resultCode: 6,
err: ErrResultCodeUnknown,
errMessage: "result code is unknown: 6",
},
}

for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()

err := checkResultCode(testCase.resultCode)

assert.ErrorIs(t, err, testCase.err)
if testCase.err != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}
}
28 changes: 28 additions & 0 deletions internal/natpmp/externaladdress.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package natpmp

import (
"context"
"encoding/binary"
"fmt"
"net/netip"
"time"
)

// ExternalAddress fetches the duration since the start of epoch and the external
// IPv4 address of the gateway.
// See https://www.ietf.org/rfc/rfc6886.html#section-3.2
func (c *Client) ExternalAddress(ctx context.Context, gateway netip.Addr) (
durationSinceStartOfEpoch time.Duration,
externalIPv4Address netip.Addr, err error) {
request := []byte{0, 0} // version 0, operationCode 0
const responseSize = 12
response, err := c.rpc(ctx, gateway, request, responseSize)
if err != nil {
return 0, externalIPv4Address, fmt.Errorf("executing remote procedure call: %w", err)
}

secondsSinceStartOfEpoch := binary.BigEndian.Uint32(response[4:8])
durationSinceStartOfEpoch = time.Duration(secondsSinceStartOfEpoch) * time.Second
externalIPv4Address = netip.AddrFrom4([4]byte{response[8], response[9], response[10], response[11]})
return durationSinceStartOfEpoch, externalIPv4Address, nil
}
71 changes: 71 additions & 0 deletions internal/natpmp/externaladdress_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package natpmp

import (
"context"
"net/netip"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func Test_Client_ExternalAddress(t *testing.T) {
t.Parallel()

canceledCtx, cancel := context.WithCancel(context.Background())
cancel()

testCases := map[string]struct {
ctx context.Context
gateway netip.Addr
initialRetry time.Duration
exchanges []udpExchange
durationSinceStartOfEpoch time.Duration
externalIPv4Address netip.Addr
err error
errMessage string
}{
"failure": {
ctx: canceledCtx,
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
initialRetry: time.Millisecond,
err: context.Canceled,
errMessage: "executing remote procedure call: reading from udp connection: context canceled",
},
"success": {
ctx: context.Background(),
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}),
initialRetry: time.Millisecond,
exchanges: []udpExchange{{
request: []byte{0, 0},
response: []byte{0x0, 0x80, 0x0, 0x0, 0x0, 0x13, 0xf2, 0x4f, 0x49, 0x8c, 0x36, 0x9a},
}},
durationSinceStartOfEpoch: time.Duration(0x13f24f) * time.Second,
externalIPv4Address: netip.AddrFrom4([4]byte{0x49, 0x8c, 0x36, 0x9a}),
},
}

for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()

remoteAddress := launchUDPServer(t, testCase.exchanges)

client := Client{
serverPort: uint16(remoteAddress.Port),
initialRetry: testCase.initialRetry,
maxRetries: 1,
}

durationSinceStartOfEpoch, externalIPv4Address, err :=
client.ExternalAddress(testCase.ctx, testCase.gateway)
assert.ErrorIs(t, err, testCase.err)
if testCase.err != nil {
assert.EqualError(t, err, testCase.errMessage)
}
assert.Equal(t, testCase.durationSinceStartOfEpoch, durationSinceStartOfEpoch)
assert.Equal(t, testCase.externalIPv4Address, externalIPv4Address)
})
}
}

0 comments on commit 8ad16cd

Please sign in to comment.