Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions disco/udp/stun.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ func (rt *stunRoundTripper) roundTrip(ctx context.Context, udpConn *net.UDPConn,
rt.init()
txID := stun.NewTxID()
ch := make(chan stunResponse)
defer close(ch)
rt.stunResponseMapMutex.Lock()
rt.stunResponseMap[string(txID[:])] = ch
rt.stunResponseMapMutex.Unlock()
Expand Down Expand Up @@ -87,7 +86,6 @@ func (c *stunRoundTripper) recvResponse(b []byte, peerAddr net.Addr) {
return
}
resp := stunResponse{txid: string(txid[:]), addr: addr}
defer func() { recover() }()
select {
case r <- resp:
default:
Expand Down
3 changes: 0 additions & 3 deletions disco/ws/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,6 @@ func (c *WSConn) Write(p []byte) (n int, err error) {
func (c *WSConn) Close() error {
c.closed.Store(true)
close(c.closedSig)
close(c.datagrams)
close(c.events)
close(c.connData)
close(c.connEOF)
if conn := c.rawConn.Load(); conn != nil {
_ = conn.WriteControl(websocket.CloseMessage,
Expand Down
19 changes: 12 additions & 7 deletions netlink/addr_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,27 @@ func AddrSubscribe(ctx context.Context, ch chan<- AddrUpdate) error {
return fmt.Errorf("syscall socket: %w", err)
}
go func() {
err := runAddrMsgReadLoop(fd, ch)
err := runAddrMsgReadLoop(ctx, fd, ch)
if err != nil {
slog.Error("AddrSubscribe", "err", fmt.Errorf("msg read loop exited: %w", err))
}
}()
go func() {
<-ctx.Done()
syscall.Close(fd)
close(ch)
}()
return nil
}

func runAddrMsgReadLoop(fd int, ch chan<- AddrUpdate) error {
func runAddrMsgReadLoop(ctx context.Context, fd int, ch chan<- AddrUpdate) error {
buf := make([]byte, os.Getpagesize())
for {
select {
case <-ctx.Done():
err := syscall.Close(fd)
if err != nil {
slog.Error("runAddrMsgReadLoop", "err", fmt.Errorf("syscall close: %w", err))
}
close(ch)
return nil
default:
}
n, err := syscall.Read(fd, buf)
if err != nil {
return fmt.Errorf("syscall read: %w", err)
Expand Down
2 changes: 1 addition & 1 deletion netlink/addr_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ func AddrSubscribe(ctx context.Context, ch chan<- AddrUpdate) error {
return err
}
go func() {
defer close(ch)
for {
select {
case <-ctx.Done():
close(ch)
return
case e := <-rawChan:
ch <- AddrUpdate{
Expand Down
19 changes: 12 additions & 7 deletions netlink/route_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,27 @@ func RouteSubscribe(ctx context.Context, ch chan<- RouteUpdate) error {
return fmt.Errorf("syscall socket: %w", err)
}
go func() {
err := runRouteMsgReadLoop(fd, ch)
err := runRouteMsgReadLoop(ctx, fd, ch)
if err != nil {
slog.Error("RouteSubscribe", "err", fmt.Errorf("msg read loop exited: %w", err))
}
}()
go func() {
<-ctx.Done()
syscall.Close(fd)
close(ch)
}()
return nil
}

func runRouteMsgReadLoop(fd int, ch chan<- RouteUpdate) error {
func runRouteMsgReadLoop(ctx context.Context, fd int, ch chan<- RouteUpdate) error {
buf := make([]byte, os.Getpagesize())
for {
select {
case <-ctx.Done():
err := syscall.Close(fd)
if err != nil {
slog.Error("runRouteMsgReadLoop", "err", fmt.Errorf("syscall close: %w", err))
}
close(ch)
return nil
default:
}
n, err := syscall.Read(fd, buf)
if err != nil {
return fmt.Errorf("syscall read: %w", err)
Expand Down
2 changes: 1 addition & 1 deletion netlink/route_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ func RouteSubscribe(ctx context.Context, ch chan<- RouteUpdate) error {
return err
}
go func() {
defer close(ch)
for {
select {
case <-ctx.Done():
close(ch)
return
case e := <-rawChan:
if e.Dst == nil || e.Gw == nil {
Expand Down
6 changes: 3 additions & 3 deletions p2p/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
func (c *PacketConn) Close() error {
c.closeOnce.Do(func() {
close(c.closeChan)
c.deadlineRead.Close()
c.udpConn.Close()
c.wsConn.Close()
_ = c.deadlineRead.Close()
_ = c.udpConn.Close()
_ = c.wsConn.Close()
})
return nil
}
Expand Down
17 changes: 10 additions & 7 deletions peermap/peermap.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,17 @@ func (p *peerConn) Read(b []byte) (n int, err error) {
return
}

wsb, ok := <-p.connData
if !ok {
select {
case <-p.closeChan:
return 0, io.EOF
}
n = copy(b, wsb)
if n < len(wsb) {
p.connBuf = wsb[n:]
case wsb, ok := <-p.connData:
if !ok {
return 0, io.EOF
}
n = copy(b, wsb)
if n < len(wsb) {
p.connBuf = wsb[n:]
}
}
return
}
Expand All @@ -115,7 +119,6 @@ func (p *peerConn) Close() error {
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(2*time.Second))
p.conn.Close()
close(p.closeChan)
close(p.connData)
p.broadcastLeave()
})
return nil
Expand Down
121 changes: 74 additions & 47 deletions vpn/vpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,17 @@ func (vpn *VPN) Run(ctx context.Context, nic *nic.VirtualNIC, packetConn net.Pac
var wg sync.WaitGroup
wg.Add(5)
go vpn.routingTableUpdate(ctx, &wg)
go vpn.nicRead(&wg, nic)
go vpn.nicWrite(&wg, nic)
go vpn.packetConnRead(&wg, packetConn)
go vpn.packetConnWrite(&wg, packetConn)
go vpn.nicRead(ctx, &wg, nic)
go vpn.nicWrite(ctx, &wg, nic)
go vpn.packetConnRead(ctx, &wg, packetConn)
go vpn.packetConnWrite(ctx, &wg, packetConn)

<-ctx.Done()
packetConn.Close()
nic.Close()
wg.Wait()
close(vpn.inbound)
close(vpn.outbound)
wg.Wait()
return nil
}

Expand All @@ -83,7 +83,7 @@ func (vpn *VPN) routingTableUpdate(ctx context.Context, wg *sync.WaitGroup) {
}

// nicRead read ip packet from nic device and send to outbound channel
func (vpn *VPN) nicRead(wg *sync.WaitGroup, nic *nic.VirtualNIC) {
func (vpn *VPN) nicRead(ctx context.Context, wg *sync.WaitGroup, nic *nic.VirtualNIC) {
defer wg.Done()
for {
packet, err := nic.Read()
Expand All @@ -93,12 +93,17 @@ func (vpn *VPN) nicRead(wg *sync.WaitGroup, nic *nic.VirtualNIC) {
}
panic(err)
}
vpn.outbound <- packet
select {
case <-ctx.Done():
return
case vpn.outbound <- packet:
}

}
}

// nicWrite read ip packet from inbound channel and write to nic device
func (vpn *VPN) nicWrite(wg *sync.WaitGroup, vnic *nic.VirtualNIC) {
func (vpn *VPN) nicWrite(ctx context.Context, wg *sync.WaitGroup, vnic *nic.VirtualNIC) {
defer wg.Done()
handle := func(pkt *nic.Packet) *nic.Packet {
for _, in := range vpn.cfg.InboundHandlers {
Expand All @@ -109,22 +114,30 @@ func (vpn *VPN) nicWrite(wg *sync.WaitGroup, vnic *nic.VirtualNIC) {
}
return pkt
}
for packet := range vpn.inbound {
in := packet
if packet = handle(packet); packet == nil {
nic.RecyclePacket(in)
continue
}
err := vnic.Write(packet)
if err != nil {
slog.Debug("WriteTo nic device", "err", err.Error())
for {
select {
case <-ctx.Done():
return
case packet, ok := <-vpn.inbound:
if !ok {
return
}
in := packet
if packet = handle(packet); packet == nil {
nic.RecyclePacket(in)
continue
}
err := vnic.Write(packet)
if err != nil {
slog.Debug("WriteTo nic device", "err", err.Error())
}
nic.RecyclePacket(packet)
}
nic.RecyclePacket(packet)
}
}

// packetConnRead read ip packet from packet conn and send to inbound channel
func (vpn *VPN) packetConnRead(wg *sync.WaitGroup, packetConn net.PacketConn) {
func (vpn *VPN) packetConnRead(ctx context.Context, wg *sync.WaitGroup, packetConn net.PacketConn) {
defer wg.Done()
buf := make([]byte, cmp.Or(vpn.cfg.MTU, (2<<15)-8-40-40)+40)
for {
Expand All @@ -135,12 +148,17 @@ func (vpn *VPN) packetConnRead(wg *sync.WaitGroup, packetConn net.PacketConn) {
}
panic(err)
}
vpn.inbound <- nic.GetPacket(buf[:n])

select {
case <-ctx.Done():
return
case vpn.inbound <- nic.GetPacket(buf[:n]):
}
}
}

// packetConnWrite read ip packet from outbound channel and write to packet conn
func (vpn *VPN) packetConnWrite(wg *sync.WaitGroup, packetConn net.PacketConn) {
func (vpn *VPN) packetConnWrite(ctx context.Context, wg *sync.WaitGroup, packetConn net.PacketConn) {
defer wg.Done()
sendPacketToPeer := func(packet *nic.Packet, srcIP, dstIP net.IP) {
defer nic.RecyclePacket(packet)
Expand All @@ -167,38 +185,47 @@ func (vpn *VPN) packetConnWrite(wg *sync.WaitGroup, packetConn net.PacketConn) {
}
return pkt
}
for packet := range vpn.outbound {
out := packet
if packet = handle(packet); packet == nil {
nic.RecyclePacket(out)
continue
}
pkt := packet.AsBytes()
if packet.Ver() == 4 {
header, err := ipv4.ParseHeader(pkt)
if err != nil {
panic(err)
for {
select {
case <-ctx.Done():
return
case packet, ok := <-vpn.outbound:
if !ok {
return
}
if header.Dst.String() == netlink.Show().IPv4 {
vpn.inbound <- packet
out := packet

if packet = handle(packet); packet == nil {
nic.RecyclePacket(out)
continue
}
sendPacketToPeer(packet, header.Src, header.Dst)
continue
}
if packet.Ver() == 6 {
header, err := ipv6.ParseHeader(pkt)
if err != nil {
panic(err)
pkt := packet.AsBytes()
if packet.Ver() == 4 {
header, err := ipv4.ParseHeader(pkt)
if err != nil {
panic(err)
}
if header.Dst.String() == netlink.Show().IPv4 {
vpn.inbound <- packet
continue
}
sendPacketToPeer(packet, header.Src, header.Dst)
continue
}
if header.Dst.String() == netlink.Show().IPv6 {
vpn.inbound <- packet
if packet.Ver() == 6 {
header, err := ipv6.ParseHeader(pkt)
if err != nil {
panic(err)
}
if header.Dst.String() == netlink.Show().IPv6 {
vpn.inbound <- packet
continue
}
sendPacketToPeer(packet, header.Src, header.Dst)
continue
}
sendPacketToPeer(packet, header.Src, header.Dst)
continue
slog.Warn("Received invalid packet", "packet", hex.EncodeToString(pkt))
nic.RecyclePacket(packet)
}
slog.Warn("Received invalid packet", "packet", hex.EncodeToString(pkt))
nic.RecyclePacket(packet)
}
}