Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 59 additions & 35 deletions src/internal/poll/fd_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ var wsaMsgPool = sync.Pool{

// newWSAMsg creates a new WSAMsg with the provided parameters.
// Use [freeWSAMsg] to free it.
func newWSAMsg(p []byte, oob []byte, flags int, unconnected bool) *windows.WSAMsg {
func newWSAMsg(p []byte, oob []byte, flags int, rsa *wsaRsa) *windows.WSAMsg {
// The returned object can't be allocated in the stack because it is accessed asynchronously
// by Windows in between several system calls. If the stack frame is moved while that happens,
// then Windows may access invalid memory.
Expand All @@ -164,34 +164,46 @@ func newWSAMsg(p []byte, oob []byte, flags int, unconnected bool) *windows.WSAMs
Buf: unsafe.SliceData(oob),
}
msg.Flags = uint32(flags)
if unconnected {
msg.Name = wsaRsaPool.Get().(*syscall.RawSockaddrAny)
msg.Namelen = int32(unsafe.Sizeof(syscall.RawSockaddrAny{}))
if rsa != nil {
msg.Name = &rsa.name
msg.Namelen = rsa.namelen
}
return msg
}

func freeWSAMsg(msg *windows.WSAMsg) {
// Clear pointers to buffers so they can be released by garbage collector.
msg.Name = nil
msg.Namelen = 0
msg.Buffers.Len = 0
msg.Buffers.Buf = nil
msg.Control.Len = 0
msg.Control.Buf = nil
if msg.Name != nil {
*msg.Name = syscall.RawSockaddrAny{}
wsaRsaPool.Put(msg.Name)
msg.Name = nil
msg.Namelen = 0
}
wsaMsgPool.Put(msg)
}

// wsaRsa bundles a [syscall.RawSockaddrAny] with its length for efficient caching.
//
// When used by WSARecvFrom, wsaRsa must be on the heap. See
// https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsarecvfrom.
type wsaRsa struct {
name syscall.RawSockaddrAny
namelen int32
}

var wsaRsaPool = sync.Pool{
New: func() any {
return new(syscall.RawSockaddrAny)
return new(wsaRsa)
},
}

func newWSARsa() *wsaRsa {
rsa := wsaRsaPool.Get().(*wsaRsa)
rsa.name = syscall.RawSockaddrAny{}
rsa.namelen = int32(unsafe.Sizeof(syscall.RawSockaddrAny{}))
return rsa
}

var operationPool = sync.Pool{
New: func() any {
return new(operation)
Expand Down Expand Up @@ -737,19 +749,18 @@ func (fd *FD) ReadFrom(buf []byte) (int, syscall.Sockaddr, error) {

fd.pin('r', &buf[0])

rsa := wsaRsaPool.Get().(*syscall.RawSockaddrAny)
rsa := newWSARsa()
defer wsaRsaPool.Put(rsa)
n, err := fd.execIO('r', func(o *operation) (qty uint32, err error) {
rsan := int32(unsafe.Sizeof(*rsa))
var flags uint32
err = syscall.WSARecvFrom(fd.Sysfd, newWsaBuf(buf), 1, &qty, &flags, rsa, &rsan, &o.o, nil)
err = syscall.WSARecvFrom(fd.Sysfd, newWsaBuf(buf), 1, &qty, &flags, &rsa.name, &rsa.namelen, &o.o, nil)
return qty, err
})
err = fd.eofError(n, err)
if err != nil {
return n, nil, err
}
sa, _ := rsa.Sockaddr()
sa, _ := rsa.name.Sockaddr()
return n, sa, nil
}

Expand All @@ -768,19 +779,18 @@ func (fd *FD) ReadFromInet4(buf []byte, sa4 *syscall.SockaddrInet4) (int, error)

fd.pin('r', &buf[0])

rsa := wsaRsaPool.Get().(*syscall.RawSockaddrAny)
rsa := newWSARsa()
defer wsaRsaPool.Put(rsa)
n, err := fd.execIO('r', func(o *operation) (qty uint32, err error) {
rsan := int32(unsafe.Sizeof(*rsa))
var flags uint32
err = syscall.WSARecvFrom(fd.Sysfd, newWsaBuf(buf), 1, &qty, &flags, rsa, &rsan, &o.o, nil)
err = syscall.WSARecvFrom(fd.Sysfd, newWsaBuf(buf), 1, &qty, &flags, &rsa.name, &rsa.namelen, &o.o, nil)
return qty, err
})
err = fd.eofError(n, err)
if err != nil {
return n, err
}
rawToSockaddrInet4(rsa, sa4)
rawToSockaddrInet4(&rsa.name, sa4)
return n, err
}

Expand All @@ -799,19 +809,18 @@ func (fd *FD) ReadFromInet6(buf []byte, sa6 *syscall.SockaddrInet6) (int, error)

fd.pin('r', &buf[0])

rsa := wsaRsaPool.Get().(*syscall.RawSockaddrAny)
rsa := newWSARsa()
defer wsaRsaPool.Put(rsa)
n, err := fd.execIO('r', func(o *operation) (qty uint32, err error) {
rsan := int32(unsafe.Sizeof(*rsa))
var flags uint32
err = syscall.WSARecvFrom(fd.Sysfd, newWsaBuf(buf), 1, &qty, &flags, rsa, &rsan, &o.o, nil)
err = syscall.WSARecvFrom(fd.Sysfd, newWsaBuf(buf), 1, &qty, &flags, &rsa.name, &rsa.namelen, &o.o, nil)
return qty, err
})
err = fd.eofError(n, err)
if err != nil {
return n, err
}
rawToSockaddrInet6(rsa, sa6)
rawToSockaddrInet6(&rsa.name, sa6)
return n, err
}

Expand Down Expand Up @@ -1371,7 +1380,9 @@ func (fd *FD) ReadMsg(p []byte, oob []byte, flags int) (int, int, int, syscall.S
p = p[:maxRW]
}

msg := newWSAMsg(p, oob, flags, true)
rsa := newWSARsa()
defer wsaRsaPool.Put(rsa)
msg := newWSAMsg(p, oob, flags, rsa)
defer freeWSAMsg(msg)
n, err := fd.execIO('r', func(o *operation) (qty uint32, err error) {
err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil)
Expand All @@ -1396,7 +1407,9 @@ func (fd *FD) ReadMsgInet4(p []byte, oob []byte, flags int, sa4 *syscall.Sockadd
p = p[:maxRW]
}

msg := newWSAMsg(p, oob, flags, true)
rsa := newWSARsa()
defer wsaRsaPool.Put(rsa)
msg := newWSAMsg(p, oob, flags, rsa)
defer freeWSAMsg(msg)
n, err := fd.execIO('r', func(o *operation) (qty uint32, err error) {
err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil)
Expand All @@ -1420,7 +1433,9 @@ func (fd *FD) ReadMsgInet6(p []byte, oob []byte, flags int, sa6 *syscall.Sockadd
p = p[:maxRW]
}

msg := newWSAMsg(p, oob, flags, true)
rsa := newWSARsa()
defer wsaRsaPool.Put(rsa)
msg := newWSAMsg(p, oob, flags, rsa)
defer freeWSAMsg(msg)
n, err := fd.execIO('r', func(o *operation) (qty uint32, err error) {
err = windows.WSARecvMsg(fd.Sysfd, msg, &qty, &o.o, nil)
Expand All @@ -1444,15 +1459,18 @@ func (fd *FD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (int, int, err
}
defer fd.writeUnlock()

msg := newWSAMsg(p, oob, 0, sa != nil)
defer freeWSAMsg(msg)
var rsa *wsaRsa
if sa != nil {
rsa = newWSARsa()
defer wsaRsaPool.Put(rsa)
var err error
msg.Namelen, err = sockaddrToRaw(msg.Name, sa)
rsa.namelen, err = sockaddrToRaw(&rsa.name, sa)
if err != nil {
return 0, 0, err
}
}
msg := newWSAMsg(p, oob, 0, rsa)
defer freeWSAMsg(msg)
n, err := fd.execIO('w', func(o *operation) (qty uint32, err error) {
err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil)
return qty, err
Expand All @@ -1471,11 +1489,14 @@ func (fd *FD) WriteMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (in
}
defer fd.writeUnlock()

msg := newWSAMsg(p, oob, 0, sa != nil)
defer freeWSAMsg(msg)
var rsa *wsaRsa
if sa != nil {
msg.Namelen = sockaddrInet4ToRaw(msg.Name, sa)
rsa = newWSARsa()
defer wsaRsaPool.Put(rsa)
rsa.namelen = sockaddrInet4ToRaw(&rsa.name, sa)
}
msg := newWSAMsg(p, oob, 0, rsa)
defer freeWSAMsg(msg)
n, err := fd.execIO('w', func(o *operation) (qty uint32, err error) {
err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil)
return qty, err
Expand All @@ -1494,11 +1515,14 @@ func (fd *FD) WriteMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (in
}
defer fd.writeUnlock()

msg := newWSAMsg(p, oob, 0, sa != nil)
defer freeWSAMsg(msg)
var rsa *wsaRsa
if sa != nil {
msg.Namelen = sockaddrInet6ToRaw(msg.Name, sa)
rsa = newWSARsa()
defer wsaRsaPool.Put(rsa)
rsa.namelen = sockaddrInet6ToRaw(&rsa.name, sa)
}
msg := newWSAMsg(p, oob, 0, rsa)
defer freeWSAMsg(msg)
n, err := fd.execIO('w', func(o *operation) (qty uint32, err error) {
err = windows.WSASendMsg(fd.Sysfd, msg, 0, nil, &o.o, nil)
return qty, err
Expand Down
Loading