Skip to content

Commit

Permalink
wireguard: be able to bind to addresses.
Browse files Browse the repository at this point in the history
This allows wireguard to bind to a specific IPv4:port or IPv6:port, in
addition to the default `[::]:port` where all the IPv4 and IPv6
addresses are listened on.

It addes a new `WGDEVICE_A_BIND_ADDR` field to the netlink interface
to the userspace.  The address family could either be IPv4 or IPv6 and
is ultimately identified in the kernel space.

Signed-off-by: FireflyTang <Tang.Rulin.Phys@gmail.com>
Signed-off-by: Benda Xu <heroxbd@gentoo.org>
  • Loading branch information
FireflyTang committed Oct 21, 2020
1 parent c4d6fe7 commit 5fa9808
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 27 deletions.
4 changes: 3 additions & 1 deletion drivers/net/wireguard/device.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ static int wg_open(struct net_device *dev)
dev_v6->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_NONE;

mutex_lock(&wg->device_update_lock);
ret = wg_socket_init(wg, wg->incoming_port);
ret = wg_socket_init(wg, &wg->bind_addr, wg->incoming_port);
if (ret < 0)
goto out;
list_for_each_entry(peer, &wg->peer_list, peer_list) {
Expand Down Expand Up @@ -228,6 +228,7 @@ static void wg_destruct(struct net_device *dev)
mutex_lock(&wg->device_update_lock);
rcu_assign_pointer(wg->creating_net, NULL);
wg->incoming_port = 0;
memset(&wg->bind_addr, 0, sizeof(struct addr_struct));
wg_socket_reinit(wg, NULL, NULL);
/* The final references are cleared in the below calls to destroy_workqueue. */
wg_peer_remove_all(wg);
Expand Down Expand Up @@ -302,6 +303,7 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
wg_cookie_checker_init(&wg->cookie_checker, wg);
INIT_LIST_HEAD(&wg->peer_list);
wg->device_update_gen = 1;
memset(&wg->bind_addr, 0, sizeof(struct addr_struct));

wg->peer_hashtable = wg_pubkey_hashtable_alloc();
if (!wg->peer_hashtable)
Expand Down
9 changes: 9 additions & 0 deletions drivers/net/wireguard/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ struct crypt_queue {
};
};

struct addr_struct {
union {
struct sockaddr addr;
struct sockaddr_in addr4;
struct sockaddr_in6 addr6;
};
};

struct wg_device {
struct net_device *dev;
struct crypt_queue encrypt_queue, decrypt_queue;
Expand All @@ -56,6 +64,7 @@ struct wg_device {
unsigned int num_peers, device_update_gen;
u32 fwmark;
u16 incoming_port;
struct addr_struct bind_addr;
};

int wg_device_init(void);
Expand Down
85 changes: 72 additions & 13 deletions drivers/net/wireguard/netlink.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = {
[WGDEVICE_A_FLAGS] = { .type = NLA_U32 },
[WGDEVICE_A_LISTEN_PORT] = { .type = NLA_U16 },
[WGDEVICE_A_FWMARK] = { .type = NLA_U32 },
[WGDEVICE_A_PEERS] = { .type = NLA_NESTED }
[WGDEVICE_A_PEERS] = { .type = NLA_NESTED },
[WGDEVICE_A_BIND_ADDR] = { .type = NLA_BINARY, .len = sizeof(struct sockaddr_in6) }
};

static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = {
Expand Down Expand Up @@ -236,6 +237,18 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
nla_put_string(skb, WGDEVICE_A_IFNAME, wg->dev->name))
goto out;

if (wg->bind_addr.addr.sa_family == AF_INET) {
if (nla_put(skb, WGDEVICE_A_BIND_ADDR,
sizeof(struct sockaddr_in),
&wg->bind_addr.addr4))
goto out;
} else if (wg->bind_addr.addr.sa_family == AF_INET6) {
if (nla_put(skb, WGDEVICE_A_BIND_ADDR,
sizeof(struct sockaddr_in6),
&wg->bind_addr.addr6))
goto out;
}

down_read(&wg->static_identity.lock);
if (wg->static_identity.has_identity) {
if (nla_put(skb, WGDEVICE_A_PRIVATE_KEY,
Expand Down Expand Up @@ -311,19 +324,12 @@ static int wg_get_device_done(struct netlink_callback *cb)
return 0;
}

static int set_port(struct wg_device *wg, u16 port)
static int set_socket(struct wg_device *wg, struct addr_struct *bind_addr, u16 port)
{
struct wg_peer *peer;

if (wg->incoming_port == port)
return 0;
list_for_each_entry(peer, &wg->peer_list, peer_list)
wg_socket_clear_peer_endpoint_src(peer);
if (!netif_running(wg->dev)) {
wg->incoming_port = port;
return 0;
}
return wg_socket_init(wg, port);
return wg_socket_init(wg, bind_addr, port);
}

static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs)
Expand Down Expand Up @@ -491,11 +497,25 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs)
return ret;
}

static bool bind_addr_eq(const struct sockaddr *a, const struct sockaddr *b)
{
return (a->sa_family == AF_INET && b->sa_family == AF_INET &&
((struct sockaddr_in *)a)->sin_addr.s_addr ==
((struct sockaddr_in *)b)->sin_addr.s_addr) ||
(a->sa_family == AF_INET6 && b->sa_family == AF_INET6 &&
ipv6_addr_equal(&((struct sockaddr_in6 *)a)->sin6_addr,
&((struct sockaddr_in6 *)b)->sin6_addr)) ||
(a->sa_family == AF_UNSPEC && b->sa_family == AF_UNSPEC);
}

static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
{
struct wg_device *wg = lookup_interface(info->attrs, skb);
u32 flags = 0;
int ret;
bool if_set_socket = false;
u16 port_new;
struct addr_struct bind_addr_new;

if (IS_ERR(wg)) {
ret = PTR_ERR(wg);
Expand All @@ -505,13 +525,18 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
rtnl_lock();
mutex_lock(&wg->device_update_lock);

port_new = wg->incoming_port;
bind_addr_new = wg->bind_addr;

if (info->attrs[WGDEVICE_A_FLAGS])
flags = nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]);
ret = -EOPNOTSUPP;
if (flags & ~__WGDEVICE_F_ALL)
goto out;

if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
if (info->attrs[WGDEVICE_A_LISTEN_PORT] ||
info->attrs[WGDEVICE_A_BIND_ADDR] ||
info->attrs[WGDEVICE_A_FWMARK]) {
struct net *net;
rcu_read_lock();
net = rcu_dereference(wg->creating_net);
Expand All @@ -532,8 +557,42 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
}

if (info->attrs[WGDEVICE_A_LISTEN_PORT]) {
ret = set_port(wg,
nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT]));
u16 port = nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT]);

if (wg->incoming_port != port) {
if (!netif_running(wg->dev)) {
wg->incoming_port = port;
} else {
if_set_socket = true;
port_new = port;
}
}
}

if (info->attrs[WGDEVICE_A_BIND_ADDR]) {
struct sockaddr *addr = nla_data(info->attrs[WGDEVICE_A_BIND_ADDR]);

if (!bind_addr_eq(&wg->bind_addr.addr, addr)) {
size_t len = nla_len(info->attrs[WGDEVICE_A_BIND_ADDR]);
if ((addr->sa_family == AF_INET &&
len == sizeof(struct sockaddr_in)) ||
((addr->sa_family == AF_INET6 ||
addr->sa_family == AF_UNSPEC) &&
len == sizeof(struct sockaddr_in6))) {
if (!netif_running(wg->dev)) {
memcpy(&wg->bind_addr.addr, addr, len);
} else {
if_set_socket = true;
memcpy(&bind_addr_new, addr, len);
}
} else {
goto out;
}
}
}

if (if_set_socket) {
ret = set_socket(wg, &bind_addr_new, port_new);
if (ret)
goto out;
}
Expand Down
43 changes: 31 additions & 12 deletions drivers/net/wireguard/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ static void set_sock_opts(struct socket *sock)
sk_set_memalloc(sock->sk);
}

int wg_socket_init(struct wg_device *wg, u16 port)
int wg_socket_init(struct wg_device *wg, struct addr_struct *bind_addr, u16 port)
{
struct net *net;
int ret;
Expand All @@ -366,6 +366,7 @@ int wg_socket_init(struct wg_device *wg, u16 port)
struct udp_port_cfg port6 = {
.family = AF_INET6,
.local_ip6 = IN6ADDR_ANY_INIT,
.local_udp_port = htons(port),
.use_udp6_tx_checksums = true,
.use_udp6_rx_checksums = true,
.ipv6_v6only = true
Expand All @@ -383,20 +384,30 @@ int wg_socket_init(struct wg_device *wg, u16 port)
retry:
#endif

ret = udp_sock_create(net, &port4, &new4);
if (ret < 0) {
pr_err("%s: Could not create IPv4 socket\n", wg->dev->name);
goto out;
if (bind_addr->addr.sa_family == AF_INET ||
bind_addr->addr.sa_family == AF_UNSPEC) {
if (bind_addr->addr.sa_family == AF_INET)
port4.local_ip = bind_addr->addr4.sin_addr;
ret = udp_sock_create(net, &port4, &new4);
if (ret < 0) {
pr_err("%s: Could not create IPv4 socket\n",
wg->dev->name);
goto out;
}
set_sock_opts(new4);
setup_udp_tunnel_sock(net, new4, &cfg);
}
set_sock_opts(new4);
setup_udp_tunnel_sock(net, new4, &cfg);

#if IS_ENABLED(CONFIG_IPV6)
if (ipv6_mod_enabled()) {
port6.local_udp_port = inet_sk(new4->sk)->inet_sport;
if (ipv6_mod_enabled() &&
(bind_addr->addr.sa_family == AF_INET6 ||
bind_addr->addr.sa_family == AF_UNSPEC)) {
if (bind_addr->addr.sa_family == AF_INET6)
port6.local_ip6 = bind_addr->addr6.sin6_addr;
ret = udp_sock_create(net, &port6, &new6);
if (ret < 0) {
udp_tunnel_sock_release(new4);
if (new4)
udp_tunnel_sock_release(new4);
if (ret == -EADDRINUSE && !port && retries++ < 100)
goto retry;
pr_err("%s: Could not create IPv6 socket\n",
Expand All @@ -408,7 +419,7 @@ int wg_socket_init(struct wg_device *wg, u16 port)
}
#endif

wg_socket_reinit(wg, new4->sk, new6 ? new6->sk : NULL);
wg_socket_reinit(wg, new4 ? new4->sk : NULL, new6 ? new6->sk : NULL);
ret = 0;
out:
put_net(net);
Expand All @@ -427,8 +438,16 @@ void wg_socket_reinit(struct wg_device *wg, struct sock *new4,
lockdep_is_held(&wg->socket_update_lock));
rcu_assign_pointer(wg->sock4, new4);
rcu_assign_pointer(wg->sock6, new6);
if (new4)
if (new4) {
wg->incoming_port = ntohs(inet_sk(new4)->inet_sport);
wg->bind_addr.addr4.sin_addr.s_addr = inet_sk(new4)->inet_saddr;
wg->bind_addr.addr.sa_family = new6 ? AF_UNSPEC : AF_INET;
} else if (new6) {
wg->incoming_port = ntohs(inet_sk(new6)->inet_sport);
memcpy(&wg->bind_addr.addr6.sin6_addr,
&inet6_sk(new6)->saddr, sizeof(struct in6_addr));
wg->bind_addr.addr.sa_family = AF_INET6;
}
mutex_unlock(&wg->socket_update_lock);
synchronize_rcu();
sock_free(old4);
Expand Down
2 changes: 1 addition & 1 deletion drivers/net/wireguard/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <linux/if_vlan.h>
#include <linux/if_ether.h>

int wg_socket_init(struct wg_device *wg, u16 port);
int wg_socket_init(struct wg_device *wg, struct addr_struct *bind_addr, u16 port);
void wg_socket_reinit(struct wg_device *wg, struct sock *new4,
struct sock *new6);
int wg_socket_send_buffer_to_peer(struct wg_peer *peer, void *data,
Expand Down
3 changes: 3 additions & 0 deletions include/uapi/linux/wireguard.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
* 0: NLA_NESTED
* ...
* ...
* WGDEVICE_A_BIND_ADDR: NLA_BINARY(struct sockaddr_in6)
*
* It is possible that all of the allowed IPs of a single peer will not
* fit within a single netlink message. In that case, the same peer will
Expand Down Expand Up @@ -114,6 +115,7 @@
* 0: NLA_NESTED
* ...
* ...
* WGDEVICE_A_BIND_ADDR: NLA_BINARY(struct sockaddr_in6)
*
* It is possible that the amount of configuration data exceeds that of
* the maximum message length accepted by the kernel. In that case, several
Expand Down Expand Up @@ -157,6 +159,7 @@ enum wgdevice_attribute {
WGDEVICE_A_LISTEN_PORT,
WGDEVICE_A_FWMARK,
WGDEVICE_A_PEERS,
WGDEVICE_A_BIND_ADDR,
__WGDEVICE_A_LAST
};
#define WGDEVICE_A_MAX (__WGDEVICE_A_LAST - 1)
Expand Down

0 comments on commit 5fa9808

Please sign in to comment.