-
-
Notifications
You must be signed in to change notification settings - Fork 325
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(protonvpn): port forwarding support with NAT-PMP (#1543)
Co-authored-by: Nicholas Xavier <nicho@nicho.dev>
- Loading branch information
Showing
19 changed files
with
1,118 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
package natpmp | ||
|
||
import ( | ||
"encoding/binary" | ||
"errors" | ||
"fmt" | ||
) | ||
|
||
var ( | ||
ErrRequestSizeTooSmall = errors.New("message size is too small") | ||
) | ||
|
||
func checkRequest(request []byte) (err error) { | ||
const minMessageSize = 2 // version number + operation code | ||
if len(request) < minMessageSize { | ||
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)", | ||
ErrRequestSizeTooSmall, minMessageSize, len(request)) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
var ( | ||
ErrResponseSizeTooSmall = errors.New("response size is too small") | ||
ErrResponseSizeUnexpected = errors.New("response size is unexpected") | ||
ErrProtocolVersionUnknown = errors.New("protocol version is unknown") | ||
ErrOperationCodeUnexpected = errors.New("operation code is unexpected") | ||
) | ||
|
||
func checkResponse(response []byte, expectedOperationCode byte, | ||
expectedResponseSize uint) (err error) { | ||
const minResponseSize = 4 | ||
if len(response) < minResponseSize { | ||
return fmt.Errorf("%w: need at least %d bytes and got %d byte(s)", | ||
ErrResponseSizeTooSmall, minResponseSize, len(response)) | ||
} | ||
|
||
if len(response) != int(expectedResponseSize) { | ||
return fmt.Errorf("%w: expected %d bytes and got %d byte(s)", | ||
ErrResponseSizeUnexpected, expectedResponseSize, len(response)) | ||
} | ||
|
||
protocolVersion := response[0] | ||
if protocolVersion != 0 { | ||
return fmt.Errorf("%w: %d", ErrProtocolVersionUnknown, protocolVersion) | ||
} | ||
|
||
operationCode := response[1] | ||
if operationCode != expectedOperationCode { | ||
return fmt.Errorf("%w: expected 0x%x and got 0x%x", | ||
ErrOperationCodeUnexpected, expectedOperationCode, operationCode) | ||
} | ||
|
||
resultCode := binary.BigEndian.Uint16(response[2:4]) | ||
err = checkResultCode(resultCode) | ||
if err != nil { | ||
return fmt.Errorf("result code: %w", err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
var ( | ||
ErrVersionNotSupported = errors.New("version is not supported") | ||
ErrNotAuthorized = errors.New("not authorized") | ||
ErrNetworkFailure = errors.New("network failure") | ||
ErrOutOfResources = errors.New("out of resources") | ||
ErrOperationCodeNotSupported = errors.New("operation code is not supported") | ||
ErrResultCodeUnknown = errors.New("result code is unknown") | ||
) | ||
|
||
// checkResultCode checks the result code and returns an error | ||
// if the result code is not a success (0). | ||
// See https://www.ietf.org/rfc/rfc6886.html#section-3.5 | ||
// | ||
//nolint:gomnd | ||
func checkResultCode(resultCode uint16) (err error) { | ||
switch resultCode { | ||
case 0: | ||
return nil | ||
case 1: | ||
return fmt.Errorf("%w", ErrVersionNotSupported) | ||
case 2: | ||
return fmt.Errorf("%w", ErrNotAuthorized) | ||
case 3: | ||
return fmt.Errorf("%w", ErrNetworkFailure) | ||
case 4: | ||
return fmt.Errorf("%w", ErrOutOfResources) | ||
case 5: | ||
return fmt.Errorf("%w", ErrOperationCodeNotSupported) | ||
default: | ||
return fmt.Errorf("%w: %d", ErrResultCodeUnknown, resultCode) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
package natpmp | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func Test_checkRequest(t *testing.T) { | ||
t.Parallel() | ||
|
||
testCases := map[string]struct { | ||
request []byte | ||
err error | ||
errMessage string | ||
}{ | ||
"too_short": { | ||
request: []byte{1}, | ||
err: ErrRequestSizeTooSmall, | ||
errMessage: "message size is too small: need at least 2 bytes and got 1 byte(s)", | ||
}, | ||
"success": { | ||
request: []byte{0, 0}, | ||
}, | ||
} | ||
|
||
for name, testCase := range testCases { | ||
testCase := testCase | ||
t.Run(name, func(t *testing.T) { | ||
t.Parallel() | ||
|
||
err := checkRequest(testCase.request) | ||
|
||
assert.ErrorIs(t, err, testCase.err) | ||
if testCase.err != nil { | ||
assert.EqualError(t, err, testCase.errMessage) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func Test_checkResponse(t *testing.T) { | ||
t.Parallel() | ||
|
||
testCases := map[string]struct { | ||
response []byte | ||
expectedOperationCode byte | ||
expectedResponseSize uint | ||
err error | ||
errMessage string | ||
}{ | ||
"too_short": { | ||
response: []byte{1}, | ||
err: ErrResponseSizeTooSmall, | ||
errMessage: "response size is too small: need at least 4 bytes and got 1 byte(s)", | ||
}, | ||
"size_mismatch": { | ||
response: []byte{0, 0, 0, 0}, | ||
expectedResponseSize: 5, | ||
err: ErrResponseSizeUnexpected, | ||
errMessage: "response size is unexpected: expected 5 bytes and got 4 byte(s)", | ||
}, | ||
"protocol_unknown": { | ||
response: []byte{1, 0, 0, 0}, | ||
expectedResponseSize: 4, | ||
err: ErrProtocolVersionUnknown, | ||
errMessage: "protocol version is unknown: 1", | ||
}, | ||
"operation_code_unexpected": { | ||
response: []byte{0, 2, 0, 0}, | ||
expectedOperationCode: 1, | ||
expectedResponseSize: 4, | ||
err: ErrOperationCodeUnexpected, | ||
errMessage: "operation code is unexpected: expected 0x1 and got 0x2", | ||
}, | ||
"result_code_failure": { | ||
response: []byte{0, 1, 0, 1}, | ||
expectedOperationCode: 1, | ||
expectedResponseSize: 4, | ||
err: ErrVersionNotSupported, | ||
errMessage: "result code: version is not supported", | ||
}, | ||
"success": { | ||
response: []byte{0, 1, 0, 0}, | ||
expectedOperationCode: 1, | ||
expectedResponseSize: 4, | ||
}, | ||
} | ||
|
||
for name, testCase := range testCases { | ||
testCase := testCase | ||
t.Run(name, func(t *testing.T) { | ||
t.Parallel() | ||
|
||
err := checkResponse(testCase.response, | ||
testCase.expectedOperationCode, | ||
testCase.expectedResponseSize) | ||
|
||
assert.ErrorIs(t, err, testCase.err) | ||
if testCase.err != nil { | ||
assert.EqualError(t, err, testCase.errMessage) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func Test_checkResultCode(t *testing.T) { | ||
t.Parallel() | ||
|
||
testCases := map[string]struct { | ||
resultCode uint16 | ||
err error | ||
errMessage string | ||
}{ | ||
"success": {}, | ||
"version_unsupported": { | ||
resultCode: 1, | ||
err: ErrVersionNotSupported, | ||
errMessage: "version is not supported", | ||
}, | ||
"not_authorized": { | ||
resultCode: 2, | ||
err: ErrNotAuthorized, | ||
errMessage: "not authorized", | ||
}, | ||
"network_failure": { | ||
resultCode: 3, | ||
err: ErrNetworkFailure, | ||
errMessage: "network failure", | ||
}, | ||
"out_of_resources": { | ||
resultCode: 4, | ||
err: ErrOutOfResources, | ||
errMessage: "out of resources", | ||
}, | ||
"unsupported_operation_code": { | ||
resultCode: 5, | ||
err: ErrOperationCodeNotSupported, | ||
errMessage: "operation code is not supported", | ||
}, | ||
"unknown": { | ||
resultCode: 6, | ||
err: ErrResultCodeUnknown, | ||
errMessage: "result code is unknown: 6", | ||
}, | ||
} | ||
|
||
for name, testCase := range testCases { | ||
testCase := testCase | ||
t.Run(name, func(t *testing.T) { | ||
t.Parallel() | ||
|
||
err := checkResultCode(testCase.resultCode) | ||
|
||
assert.ErrorIs(t, err, testCase.err) | ||
if testCase.err != nil { | ||
assert.EqualError(t, err, testCase.errMessage) | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package natpmp | ||
|
||
import ( | ||
"context" | ||
"encoding/binary" | ||
"fmt" | ||
"net/netip" | ||
"time" | ||
) | ||
|
||
// ExternalAddress fetches the duration since the start of epoch and the external | ||
// IPv4 address of the gateway. | ||
// See https://www.ietf.org/rfc/rfc6886.html#section-3.2 | ||
func (c *Client) ExternalAddress(ctx context.Context, gateway netip.Addr) ( | ||
durationSinceStartOfEpoch time.Duration, | ||
externalIPv4Address netip.Addr, err error) { | ||
request := []byte{0, 0} // version 0, operationCode 0 | ||
const responseSize = 12 | ||
response, err := c.rpc(ctx, gateway, request, responseSize) | ||
if err != nil { | ||
return 0, externalIPv4Address, fmt.Errorf("executing remote procedure call: %w", err) | ||
} | ||
|
||
secondsSinceStartOfEpoch := binary.BigEndian.Uint32(response[4:8]) | ||
durationSinceStartOfEpoch = time.Duration(secondsSinceStartOfEpoch) * time.Second | ||
externalIPv4Address = netip.AddrFrom4([4]byte{response[8], response[9], response[10], response[11]}) | ||
return durationSinceStartOfEpoch, externalIPv4Address, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package natpmp | ||
|
||
import ( | ||
"context" | ||
"net/netip" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func Test_Client_ExternalAddress(t *testing.T) { | ||
t.Parallel() | ||
|
||
canceledCtx, cancel := context.WithCancel(context.Background()) | ||
cancel() | ||
|
||
testCases := map[string]struct { | ||
ctx context.Context | ||
gateway netip.Addr | ||
initialRetry time.Duration | ||
exchanges []udpExchange | ||
durationSinceStartOfEpoch time.Duration | ||
externalIPv4Address netip.Addr | ||
err error | ||
errMessage string | ||
}{ | ||
"failure": { | ||
ctx: canceledCtx, | ||
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), | ||
initialRetry: time.Millisecond, | ||
err: context.Canceled, | ||
errMessage: "executing remote procedure call: reading from udp connection: context canceled", | ||
}, | ||
"success": { | ||
ctx: context.Background(), | ||
gateway: netip.AddrFrom4([4]byte{127, 0, 0, 1}), | ||
initialRetry: time.Millisecond, | ||
exchanges: []udpExchange{{ | ||
request: []byte{0, 0}, | ||
response: []byte{0x0, 0x80, 0x0, 0x0, 0x0, 0x13, 0xf2, 0x4f, 0x49, 0x8c, 0x36, 0x9a}, | ||
}}, | ||
durationSinceStartOfEpoch: time.Duration(0x13f24f) * time.Second, | ||
externalIPv4Address: netip.AddrFrom4([4]byte{0x49, 0x8c, 0x36, 0x9a}), | ||
}, | ||
} | ||
|
||
for name, testCase := range testCases { | ||
testCase := testCase | ||
t.Run(name, func(t *testing.T) { | ||
t.Parallel() | ||
|
||
remoteAddress := launchUDPServer(t, testCase.exchanges) | ||
|
||
client := Client{ | ||
serverPort: uint16(remoteAddress.Port), | ||
initialRetry: testCase.initialRetry, | ||
maxRetries: 1, | ||
} | ||
|
||
durationSinceStartOfEpoch, externalIPv4Address, err := | ||
client.ExternalAddress(testCase.ctx, testCase.gateway) | ||
assert.ErrorIs(t, err, testCase.err) | ||
if testCase.err != nil { | ||
assert.EqualError(t, err, testCase.errMessage) | ||
} | ||
assert.Equal(t, testCase.durationSinceStartOfEpoch, durationSinceStartOfEpoch) | ||
assert.Equal(t, testCase.externalIPv4Address, externalIPv4Address) | ||
}) | ||
} | ||
} |
Oops, something went wrong.