-
-
Notifications
You must be signed in to change notification settings - Fork 334
/
pick.go
59 lines (48 loc) · 1.79 KB
/
pick.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
package utils
import (
"errors"
"fmt"
"math/rand"
"net/netip"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
)
var ErrNoConnectionToPickFrom = errors.New("no connection to pick from")
// pickConnection picks a connection from a pool of connections.
// If the VPN protocol is Wireguard and the target IP is set,
// it finds the connection corresponding to this target IP.
// Otherwise, it picks a random connection from the pool of connections
// and sets the target IP address as the IP if this one is set.
func pickConnection(connections []models.Connection,
selection settings.ServerSelection, randSource rand.Source) (
connection models.Connection, err error) {
if len(connections) == 0 {
return connection, ErrNoConnectionToPickFrom
}
targetIPSet := selection.TargetIP.IsValid() && !selection.TargetIP.IsUnspecified()
if targetIPSet && selection.VPN == vpn.Wireguard {
// we need the right public key
return getTargetIPConnection(connections, selection.TargetIP)
}
connection = pickRandomConnection(connections, randSource)
if targetIPSet {
connection.IP = selection.TargetIP
}
return connection, nil
}
func pickRandomConnection(connections []models.Connection,
source rand.Source) models.Connection {
return connections[rand.New(source).Intn(len(connections))] //nolint:gosec
}
var errTargetIPNotFound = errors.New("target IP address not found")
func getTargetIPConnection(connections []models.Connection,
targetIP netip.Addr) (connection models.Connection, err error) {
for _, connection := range connections {
if targetIP == connection.IP {
return connection, nil
}
}
return connection, fmt.Errorf("%w: in %d filtered connections",
errTargetIPNotFound, len(connections))
}