@@ -12,6 +12,7 @@ import (
12
12
"os"
13
13
14
14
"github.com/jsimonetti/rtnetlink/rtnl"
15
+ "go.uber.org/zap"
15
16
"golang.zx2c4.com/wireguard/conn"
16
17
"golang.zx2c4.com/wireguard/device"
17
18
"golang.zx2c4.com/wireguard/ipc"
@@ -56,7 +57,7 @@ func NewDevice(address netaddr.IPPrefix, privateKey wgtypes.Key, listenPort uint
56
57
}
57
58
58
59
// Run the device.
59
- func (dev * Device ) Run (ctx context.Context , peers PeerSource ) error {
60
+ func (dev * Device ) Run (ctx context.Context , logger * zap. Logger , peers PeerSource ) error {
60
61
client , err := wgctrl .New ()
61
62
if err != nil {
62
63
return fmt .Errorf ("error initializing Wireguard client: %w" , err )
@@ -71,10 +72,10 @@ func (dev *Device) Run(ctx context.Context, peers PeerSource) error {
71
72
72
73
defer rtnlClient .Close () //nolint:errcheck
73
74
74
- logger := device .NewLogger (
75
- device . LogLevelVerbose ,
76
- fmt . Sprintf ( "(%s) " , interfaceName ) ,
77
- )
75
+ wgLogger := & device.Logger {
76
+ Verbosef : logger . Sugar (). Debugf ,
77
+ Errorf : logger . Sugar (). Errorf ,
78
+ }
78
79
79
80
uapi , err := ipc .UAPIListen (interfaceName , dev .fileUAPI )
80
81
if err != nil {
@@ -83,7 +84,7 @@ func (dev *Device) Run(ctx context.Context, peers PeerSource) error {
83
84
84
85
defer uapi .Close () //nolint:errcheck
85
86
86
- device := device .NewDevice (dev .tun , conn .NewDefaultBind (), logger )
87
+ device := device .NewDevice (dev .tun , conn .NewDefaultBind (), wgLogger )
87
88
88
89
defer device .Close ()
89
90
@@ -124,6 +125,8 @@ func (dev *Device) Run(ctx context.Context, peers PeerSource) error {
124
125
return fmt .Errorf ("error bringing link up: %w" , err )
125
126
}
126
127
128
+ logger .Info ("wireguard device set up" , zap .String ("interface" , interfaceName ), zap .Stringer ("address" , dev .address ))
129
+
127
130
for {
128
131
select {
129
132
case <- ctx .Done ():
@@ -133,24 +136,81 @@ func (dev *Device) Run(ctx context.Context, peers PeerSource) error {
133
136
case <- device .Wait ():
134
137
return nil
135
138
case peerEvent := <- peers .EventCh ():
136
- cfg := wgtypes.Config {
137
- Peers : []wgtypes.PeerConfig {
138
- {
139
- PublicKey : peerEvent .PubKey ,
140
- Remove : peerEvent .Remove ,
141
- ReplaceAllowedIPs : true ,
142
- AllowedIPs : []net.IPNet {
143
- * netaddr .IPPrefixFrom (peerEvent .Address , peerEvent .Address .BitLen ()).IPNet (),
144
- },
145
- },
146
- },
139
+ if err := dev .handlePeerEvent (client , logger , peerEvent ); err != nil {
140
+ return err
141
+ }
142
+ }
143
+ }
144
+ }
145
+
146
+ func (dev * Device ) checkDuplicateUpdate (client * wgctrl.Client , logger * zap.Logger , peerEvent PeerEvent ) (bool , error ) {
147
+ oldCfg , err := client .Device (interfaceName )
148
+ if err != nil {
149
+ return false , fmt .Errorf ("error retrieving Wireguard configuration: %w" , err )
150
+ }
151
+
152
+ // check if this update can be skipped
153
+ pubKey := peerEvent .PubKey .String ()
154
+
155
+ for _ , oldPeer := range oldCfg .Peers {
156
+ if oldPeer .PublicKey .String () == pubKey {
157
+ if len (oldPeer .AllowedIPs ) != 1 {
158
+ break
147
159
}
148
160
149
- if err = client .ConfigureDevice (interfaceName , cfg ); err != nil {
150
- return fmt .Errorf ("error configuring Wireguard peers: %w" , err )
161
+ if prefix , ok := netaddr .FromStdIPNet (& oldPeer .AllowedIPs [0 ]); ok {
162
+ if prefix .IP () == peerEvent .Address {
163
+ // skip the update
164
+ logger .Info ("skipping peer update" , zap .String ("public_key" , pubKey ))
165
+
166
+ return true , nil
167
+ }
151
168
}
169
+
170
+ break
152
171
}
153
172
}
173
+
174
+ return false , nil
175
+ }
176
+
177
+ func (dev * Device ) handlePeerEvent (client * wgctrl.Client , logger * zap.Logger , peerEvent PeerEvent ) error {
178
+ if ! peerEvent .Remove {
179
+ skipEvent , err := dev .checkDuplicateUpdate (client , logger , peerEvent )
180
+ if err != nil {
181
+ return err
182
+ }
183
+
184
+ if skipEvent {
185
+ return nil
186
+ }
187
+ }
188
+
189
+ cfg := wgtypes.Config {
190
+ Peers : []wgtypes.PeerConfig {
191
+ {
192
+ PublicKey : peerEvent .PubKey ,
193
+ Remove : peerEvent .Remove ,
194
+ },
195
+ },
196
+ }
197
+
198
+ if ! peerEvent .Remove {
199
+ cfg .Peers [0 ].ReplaceAllowedIPs = true
200
+ cfg .Peers [0 ].AllowedIPs = []net.IPNet {
201
+ * netaddr .IPPrefixFrom (peerEvent .Address , peerEvent .Address .BitLen ()).IPNet (),
202
+ }
203
+
204
+ logger .Info ("updating peer" , zap .Stringer ("public_key" , peerEvent .PubKey ), zap .Stringer ("address" , peerEvent .Address ))
205
+ } else {
206
+ logger .Info ("removing peer" , zap .Stringer ("public_key" , peerEvent .PubKey ))
207
+ }
208
+
209
+ if err := client .ConfigureDevice (interfaceName , cfg ); err != nil {
210
+ return fmt .Errorf ("error configuring Wireguard peers: %w" , err )
211
+ }
212
+
213
+ return nil
154
214
}
155
215
156
216
// Close the device.
0 commit comments