Skip to content

Commit

Permalink
Added WSA as a fallback for Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
lewishazell committed Nov 11, 2023
1 parent 872b5ee commit 8980eb2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 24 deletions.
3 changes: 1 addition & 2 deletions udp/udp_generic.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//go:build !windows && (!linux || android) && !e2e_testing
// +build !windows
//go:build (!linux || android) && !e2e_testing
// +build !linux android
// +build !e2e_testing

Expand Down
11 changes: 10 additions & 1 deletion udp/udp_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
return nil, fmt.Errorf("multiple udp listeners not supported on windows")
}

rc, err := NewRIOListener(l, ip, port)
var rc Conn
var err error

rc, err = NewRIOListener(l, ip, port)
if err == nil {
return rc, nil
}

l.WithError(err).Error("Falling back to WSA")
rc, err = NewWsaListener(l, ip, port, multi, batch)
if err == nil {
return rc, nil
}
Expand Down
47 changes: 26 additions & 21 deletions udp/udp_wsa_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ import (
"golang.org/x/sys/windows"
)

// Assert we meet the standard conn interface
var _ Conn = &WsaConn{}

//TODO: make it support reload as best you can!

type Conn struct {
type WsaConn struct {
sysFd windows.Handle
l *logrus.Logger
batch int
Expand All @@ -34,7 +37,7 @@ type msghdr struct {
Flags *uint32
}

type rawMessage struct {
type wsaMessage struct {
Len *uint32
Hdr msghdr
}
Expand Down Expand Up @@ -111,7 +114,7 @@ func MAKEWORD(low, high uint8) uint32 {
return uint32(ret)
}

func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*Conn, error) {
func NewWsaListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (*WsaConn, error) {
var wsaData windows.WSAData

l.Debug("Library [ws2_32.dll] loaded at ", modws2_32.Handle())
Expand Down Expand Up @@ -152,30 +155,30 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (
return nil, fmt.Errorf("unable to bind to socket: %s", err)
}

return &Conn{sysFd: fd, l: l, batch: batch}, err
return &WsaConn{sysFd: fd, l: l, batch: batch}, err
}

func (u *Conn) Rebind() error {
func (u *WsaConn) Rebind() error {
return nil
}

func (u *Conn) SetRecvBuffer(n int) error {
func (u *WsaConn) SetRecvBuffer(n int) error {
return windows.SetsockoptInt(u.sysFd, windows.SOL_SOCKET, windows.SO_RCVBUF, n)
}

func (u *Conn) SetSendBuffer(n int) error {
func (u *WsaConn) SetSendBuffer(n int) error {
return windows.SetsockoptInt(u.sysFd, windows.SOL_SOCKET, windows.SO_SNDBUF, n)
}

func (u *Conn) GetRecvBuffer() (int, error) {
func (u *WsaConn) GetRecvBuffer() (int, error) {
return windows.GetsockoptInt(u.sysFd, windows.SOL_SOCKET, windows.SO_RCVBUF)
}

func (u *Conn) GetSendBuffer() (int, error) {
func (u *WsaConn) GetSendBuffer() (int, error) {
return windows.GetsockoptInt(u.sysFd, windows.SOL_SOCKET, windows.SO_SNDBUF)
}

func (u *Conn) LocalAddr() (*Addr, error) {
func (u *WsaConn) LocalAddr() (*Addr, error) {
sa, err := windows.Getsockname(u.sysFd)
if err != nil {
return nil, err
Expand All @@ -194,8 +197,8 @@ func (u *Conn) LocalAddr() (*Addr, error) {
return addr, nil
}

func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, []windows.RawSockaddrAny) {
msgs := make([]rawMessage, n)
func (u *WsaConn) PrepareWsaMessages(n int) ([]wsaMessage, [][]byte, []windows.RawSockaddrAny) {
msgs := make([]wsaMessage, n)

// all require allocation to sequential memory addresses
buffers := make([][]byte, n)
Expand Down Expand Up @@ -237,13 +240,13 @@ func (u *Conn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, []windows.RawS
return msgs, buffers, names
}

func (u *Conn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
func (u *WsaConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
plaintext := make([]byte, MTU)
h := &header.H{}
fwPacket := &firewall.Packet{}
udpAddr := &Addr{}
nb := make([]byte, 12, 12)
msgs, buffers, names := u.PrepareRawMessages(u.batch)
msgs, buffers, names := u.PrepareWsaMessages(u.batch)

read := u.ReadMulti
if u.batch == 1 {
Expand Down Expand Up @@ -291,7 +294,7 @@ func RawsockAddrToIPAndPort(rsa *windows.RawSockaddrAny) (net.IP, uint16, error)
return nil, 0, syscall.EAFNOSUPPORT
}

func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) {
func (u *WsaConn) ReadSingle(msgs []wsaMessage) (int, error) {
for {
len, sa, err := windows.Recvfrom(u.sysFd, msgs[0].Hdr.Buf, 0)

Expand Down Expand Up @@ -323,7 +326,7 @@ func (u *Conn) ReadSingle(msgs []rawMessage) (int, error) {
}
}

func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) {
func (u *WsaConn) ReadMulti(msgs []wsaMessage) (int, error) {
flags := uint32(0)
err := windows.WSARecvFrom(
u.sysFd,
Expand Down Expand Up @@ -355,7 +358,7 @@ func (u *Conn) ReadMulti(msgs []rawMessage) (int, error) {
return int(index), err
}

func (u *Conn) WriteTo(b []byte, addr *Addr) error {
func (u *WsaConn) WriteTo(b []byte, addr *Addr) error {
var buf [16]byte
copy(buf[:], addr.IP.To16())
sa := &windows.SockaddrInet6{Addr: buf, Port: int(addr.Port)}
Expand All @@ -374,7 +377,7 @@ func (u *Conn) WriteTo(b []byte, addr *Addr) error {
}
}

func (u *Conn) ReloadConfig(c *config.C) {
func (u *WsaConn) ReloadConfig(c *config.C) {
b := c.GetInt("listen.read_buffer", 0)
if b > 0 {
err := u.SetRecvBuffer(b)
Expand Down Expand Up @@ -406,7 +409,9 @@ func (u *Conn) ReloadConfig(c *config.C) {
}
}

func NewUDPStatsEmitter(udpConns []*Conn) func() {
// No UDP stats for non-linux
return func() {}
func (u *WsaConn) Close() error {
windows.Close(u.sysFd)
windows.WSACleanup()

return nil
}

0 comments on commit 8980eb2

Please sign in to comment.