diff --git a/disco/udp/stun.go b/disco/udp/stun.go index 69bcbce..a805f4b 100644 --- a/disco/udp/stun.go +++ b/disco/udp/stun.go @@ -29,7 +29,6 @@ func (rt *stunRoundTripper) roundTrip(ctx context.Context, udpConn *net.UDPConn, rt.init() txID := stun.NewTxID() ch := make(chan stunResponse) - defer close(ch) rt.stunResponseMapMutex.Lock() rt.stunResponseMap[string(txID[:])] = ch rt.stunResponseMapMutex.Unlock() @@ -87,7 +86,6 @@ func (c *stunRoundTripper) recvResponse(b []byte, peerAddr net.Addr) { return } resp := stunResponse{txid: string(txid[:]), addr: addr} - defer func() { recover() }() select { case r <- resp: default: diff --git a/disco/ws/ws.go b/disco/ws/ws.go index a305313..6237f2d 100644 --- a/disco/ws/ws.go +++ b/disco/ws/ws.go @@ -100,9 +100,6 @@ func (c *WSConn) Write(p []byte) (n int, err error) { func (c *WSConn) Close() error { c.closed.Store(true) close(c.closedSig) - close(c.datagrams) - close(c.events) - close(c.connData) close(c.connEOF) if conn := c.rawConn.Load(); conn != nil { _ = conn.WriteControl(websocket.CloseMessage, diff --git a/netlink/addr_darwin.go b/netlink/addr_darwin.go index 0bfc5df..6e54792 100644 --- a/netlink/addr_darwin.go +++ b/netlink/addr_darwin.go @@ -19,22 +19,27 @@ func AddrSubscribe(ctx context.Context, ch chan<- AddrUpdate) error { return fmt.Errorf("syscall socket: %w", err) } go func() { - err := runAddrMsgReadLoop(fd, ch) + err := runAddrMsgReadLoop(ctx, fd, ch) if err != nil { slog.Error("AddrSubscribe", "err", fmt.Errorf("msg read loop exited: %w", err)) } }() - go func() { - <-ctx.Done() - syscall.Close(fd) - close(ch) - }() return nil } -func runAddrMsgReadLoop(fd int, ch chan<- AddrUpdate) error { +func runAddrMsgReadLoop(ctx context.Context, fd int, ch chan<- AddrUpdate) error { buf := make([]byte, os.Getpagesize()) for { + select { + case <-ctx.Done(): + err := syscall.Close(fd) + if err != nil { + slog.Error("runAddrMsgReadLoop", "err", fmt.Errorf("syscall close: %w", err)) + } + close(ch) + return nil + default: + } n, err := syscall.Read(fd, buf) if err != nil { return fmt.Errorf("syscall read: %w", err) diff --git a/netlink/addr_linux.go b/netlink/addr_linux.go index e5309b3..7f08388 100644 --- a/netlink/addr_linux.go +++ b/netlink/addr_linux.go @@ -13,10 +13,10 @@ func AddrSubscribe(ctx context.Context, ch chan<- AddrUpdate) error { return err } go func() { - defer close(ch) for { select { case <-ctx.Done(): + close(ch) return case e := <-rawChan: ch <- AddrUpdate{ diff --git a/netlink/route_darwin.go b/netlink/route_darwin.go index 2b125a2..8b62da4 100644 --- a/netlink/route_darwin.go +++ b/netlink/route_darwin.go @@ -20,22 +20,27 @@ func RouteSubscribe(ctx context.Context, ch chan<- RouteUpdate) error { return fmt.Errorf("syscall socket: %w", err) } go func() { - err := runRouteMsgReadLoop(fd, ch) + err := runRouteMsgReadLoop(ctx, fd, ch) if err != nil { slog.Error("RouteSubscribe", "err", fmt.Errorf("msg read loop exited: %w", err)) } }() - go func() { - <-ctx.Done() - syscall.Close(fd) - close(ch) - }() return nil } -func runRouteMsgReadLoop(fd int, ch chan<- RouteUpdate) error { +func runRouteMsgReadLoop(ctx context.Context, fd int, ch chan<- RouteUpdate) error { buf := make([]byte, os.Getpagesize()) for { + select { + case <-ctx.Done(): + err := syscall.Close(fd) + if err != nil { + slog.Error("runRouteMsgReadLoop", "err", fmt.Errorf("syscall close: %w", err)) + } + close(ch) + return nil + default: + } n, err := syscall.Read(fd, buf) if err != nil { return fmt.Errorf("syscall read: %w", err) diff --git a/netlink/route_linux.go b/netlink/route_linux.go index 51693fb..6d3146e 100644 --- a/netlink/route_linux.go +++ b/netlink/route_linux.go @@ -18,10 +18,10 @@ func RouteSubscribe(ctx context.Context, ch chan<- RouteUpdate) error { return err } go func() { - defer close(ch) for { select { case <-ctx.Done(): + close(ch) return case e := <-rawChan: if e.Dst == nil || e.Gw == nil { diff --git a/p2p/conn.go b/p2p/conn.go index 229b7dc..3951e7d 100644 --- a/p2p/conn.go +++ b/p2p/conn.go @@ -145,9 +145,9 @@ func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *PacketConn) Close() error { c.closeOnce.Do(func() { close(c.closeChan) - c.deadlineRead.Close() - c.udpConn.Close() - c.wsConn.Close() + _ = c.deadlineRead.Close() + _ = c.udpConn.Close() + _ = c.wsConn.Close() }) return nil } diff --git a/peermap/peermap.go b/peermap/peermap.go index c68dd63..274d028 100644 --- a/peermap/peermap.go +++ b/peermap/peermap.go @@ -85,13 +85,17 @@ func (p *peerConn) Read(b []byte) (n int, err error) { return } - wsb, ok := <-p.connData - if !ok { + select { + case <-p.closeChan: return 0, io.EOF - } - n = copy(b, wsb) - if n < len(wsb) { - p.connBuf = wsb[n:] + case wsb, ok := <-p.connData: + if !ok { + return 0, io.EOF + } + n = copy(b, wsb) + if n < len(wsb) { + p.connBuf = wsb[n:] + } } return } @@ -115,7 +119,6 @@ func (p *peerConn) Close() error { websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(2*time.Second)) p.conn.Close() close(p.closeChan) - close(p.connData) p.broadcastLeave() }) return nil diff --git a/vpn/vpn.go b/vpn/vpn.go index 2712d48..e90bbf8 100644 --- a/vpn/vpn.go +++ b/vpn/vpn.go @@ -47,17 +47,17 @@ func (vpn *VPN) Run(ctx context.Context, nic *nic.VirtualNIC, packetConn net.Pac var wg sync.WaitGroup wg.Add(5) go vpn.routingTableUpdate(ctx, &wg) - go vpn.nicRead(&wg, nic) - go vpn.nicWrite(&wg, nic) - go vpn.packetConnRead(&wg, packetConn) - go vpn.packetConnWrite(&wg, packetConn) + go vpn.nicRead(ctx, &wg, nic) + go vpn.nicWrite(ctx, &wg, nic) + go vpn.packetConnRead(ctx, &wg, packetConn) + go vpn.packetConnWrite(ctx, &wg, packetConn) <-ctx.Done() packetConn.Close() nic.Close() + wg.Wait() close(vpn.inbound) close(vpn.outbound) - wg.Wait() return nil } @@ -83,7 +83,7 @@ func (vpn *VPN) routingTableUpdate(ctx context.Context, wg *sync.WaitGroup) { } // nicRead read ip packet from nic device and send to outbound channel -func (vpn *VPN) nicRead(wg *sync.WaitGroup, nic *nic.VirtualNIC) { +func (vpn *VPN) nicRead(ctx context.Context, wg *sync.WaitGroup, nic *nic.VirtualNIC) { defer wg.Done() for { packet, err := nic.Read() @@ -93,12 +93,17 @@ func (vpn *VPN) nicRead(wg *sync.WaitGroup, nic *nic.VirtualNIC) { } panic(err) } - vpn.outbound <- packet + select { + case <-ctx.Done(): + return + case vpn.outbound <- packet: + } + } } // nicWrite read ip packet from inbound channel and write to nic device -func (vpn *VPN) nicWrite(wg *sync.WaitGroup, vnic *nic.VirtualNIC) { +func (vpn *VPN) nicWrite(ctx context.Context, wg *sync.WaitGroup, vnic *nic.VirtualNIC) { defer wg.Done() handle := func(pkt *nic.Packet) *nic.Packet { for _, in := range vpn.cfg.InboundHandlers { @@ -109,22 +114,30 @@ func (vpn *VPN) nicWrite(wg *sync.WaitGroup, vnic *nic.VirtualNIC) { } return pkt } - for packet := range vpn.inbound { - in := packet - if packet = handle(packet); packet == nil { - nic.RecyclePacket(in) - continue - } - err := vnic.Write(packet) - if err != nil { - slog.Debug("WriteTo nic device", "err", err.Error()) + for { + select { + case <-ctx.Done(): + return + case packet, ok := <-vpn.inbound: + if !ok { + return + } + in := packet + if packet = handle(packet); packet == nil { + nic.RecyclePacket(in) + continue + } + err := vnic.Write(packet) + if err != nil { + slog.Debug("WriteTo nic device", "err", err.Error()) + } + nic.RecyclePacket(packet) } - nic.RecyclePacket(packet) } } // packetConnRead read ip packet from packet conn and send to inbound channel -func (vpn *VPN) packetConnRead(wg *sync.WaitGroup, packetConn net.PacketConn) { +func (vpn *VPN) packetConnRead(ctx context.Context, wg *sync.WaitGroup, packetConn net.PacketConn) { defer wg.Done() buf := make([]byte, cmp.Or(vpn.cfg.MTU, (2<<15)-8-40-40)+40) for { @@ -135,12 +148,17 @@ func (vpn *VPN) packetConnRead(wg *sync.WaitGroup, packetConn net.PacketConn) { } panic(err) } - vpn.inbound <- nic.GetPacket(buf[:n]) + + select { + case <-ctx.Done(): + return + case vpn.inbound <- nic.GetPacket(buf[:n]): + } } } // packetConnWrite read ip packet from outbound channel and write to packet conn -func (vpn *VPN) packetConnWrite(wg *sync.WaitGroup, packetConn net.PacketConn) { +func (vpn *VPN) packetConnWrite(ctx context.Context, wg *sync.WaitGroup, packetConn net.PacketConn) { defer wg.Done() sendPacketToPeer := func(packet *nic.Packet, srcIP, dstIP net.IP) { defer nic.RecyclePacket(packet) @@ -167,38 +185,47 @@ func (vpn *VPN) packetConnWrite(wg *sync.WaitGroup, packetConn net.PacketConn) { } return pkt } - for packet := range vpn.outbound { - out := packet - if packet = handle(packet); packet == nil { - nic.RecyclePacket(out) - continue - } - pkt := packet.AsBytes() - if packet.Ver() == 4 { - header, err := ipv4.ParseHeader(pkt) - if err != nil { - panic(err) + for { + select { + case <-ctx.Done(): + return + case packet, ok := <-vpn.outbound: + if !ok { + return } - if header.Dst.String() == netlink.Show().IPv4 { - vpn.inbound <- packet + out := packet + + if packet = handle(packet); packet == nil { + nic.RecyclePacket(out) continue } - sendPacketToPeer(packet, header.Src, header.Dst) - continue - } - if packet.Ver() == 6 { - header, err := ipv6.ParseHeader(pkt) - if err != nil { - panic(err) + pkt := packet.AsBytes() + if packet.Ver() == 4 { + header, err := ipv4.ParseHeader(pkt) + if err != nil { + panic(err) + } + if header.Dst.String() == netlink.Show().IPv4 { + vpn.inbound <- packet + continue + } + sendPacketToPeer(packet, header.Src, header.Dst) + continue } - if header.Dst.String() == netlink.Show().IPv6 { - vpn.inbound <- packet + if packet.Ver() == 6 { + header, err := ipv6.ParseHeader(pkt) + if err != nil { + panic(err) + } + if header.Dst.String() == netlink.Show().IPv6 { + vpn.inbound <- packet + continue + } + sendPacketToPeer(packet, header.Src, header.Dst) continue } - sendPacketToPeer(packet, header.Src, header.Dst) - continue + slog.Warn("Received invalid packet", "packet", hex.EncodeToString(pkt)) + nic.RecyclePacket(packet) } - slog.Warn("Received invalid packet", "packet", hex.EncodeToString(pkt)) - nic.RecyclePacket(packet) } }