Skip to content
Permalink
Browse files

net: implement ReadMsg/WriteMsg on windows

This means {Read,Write}Msg{UDP,IP} now work on windows.

Fixes golang#9252

Change-Id: Ifb105f9ad18d61289b22d7358a95faabe73d2d02
  • Loading branch information...
tmm1 committed Nov 7, 2017
1 parent 8c8bc3d commit 774fc9ccdbd3c710a3b4437b7906fdbdeb365e24
@@ -154,7 +154,7 @@ var pkgDeps = map[string][]string{
"syscall",
},

"internal/poll": {"L0", "internal/race", "syscall", "time", "unicode/utf16", "unicode/utf8"},
"internal/poll": {"L0", "internal/race", "syscall", "time", "unicode/utf16", "unicode/utf8", "internal/syscall/windows"},
"os": {"L1", "os", "syscall", "time", "internal/poll", "internal/syscall/windows"},
"path/filepath": {"L2", "os", "syscall", "internal/syscall/windows"},
"io/ioutil": {"L2", "os", "path/filepath", "time"},
@@ -7,6 +7,7 @@ package poll
import (
"errors"
"internal/race"
"internal/syscall/windows"
"io"
"runtime"
"sync"
@@ -926,3 +927,151 @@ func (fd *FD) RawWrite(f func(uintptr) bool) error {
}
}
}

func sockaddrToRaw(sa syscall.Sockaddr) (unsafe.Pointer, int32, error) {
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
var raw syscall.RawSockaddrInet4
raw.Family = syscall.AF_INET
p := (*[2]byte)(unsafe.Pointer(&raw.Port))
p[0] = byte(sa.Port >> 8)
p[1] = byte(sa.Port)
for i := 0; i < len(sa.Addr); i++ {
raw.Addr[i] = sa.Addr[i]
}
return unsafe.Pointer(&raw), int32(unsafe.Sizeof(raw)), nil
case *syscall.SockaddrInet6:
var raw syscall.RawSockaddrInet6
raw.Family = syscall.AF_INET6
p := (*[2]byte)(unsafe.Pointer(&raw.Port))
p[0] = byte(sa.Port >> 8)
p[1] = byte(sa.Port)
raw.Scope_id = sa.ZoneId
for i := 0; i < len(sa.Addr); i++ {
raw.Addr[i] = sa.Addr[i]
}
return unsafe.Pointer(&raw), int32(unsafe.Sizeof(raw)), nil
default:
return nil, 0, syscall.EWINDOWS
}
}

func recvmsg(fd syscall.Handle, p, oob []byte, flags int) (n, oobn int, recvflags int, from syscall.Sockaddr, err error) {
err = windows.LoadWSARecvMsg()
if err != nil {
return
}

var buf syscall.WSABuf
if len(p) > 0 {
buf.Buf = &p[0]
buf.Len = uint32(len(p))
}
var msg windows.WSAMsg
var bytesReceived uint32
rsa := new(syscall.RawSockaddrAny)
msg.Name = uintptr(unsafe.Pointer(rsa))
msg.Namelen = int32(unsafe.Sizeof(*rsa))
msg.Buffers = &buf
msg.BufferCount = 1
if len(oob) > 0 {
msg.Control.Buf = &oob[0]
msg.Control.Len = uint32(len(oob))
}
controlLen := msg.Control.Len
err = windows.WSARecvMsg(fd, &msg, &bytesReceived, nil, nil)
if err == windows.WSAEMSGSIZE && (msg.Flags&windows.MSG_CTRUNC) != 0 {
// On windows, EMSGSIZE is raised in addition to MSG_CTRUNC, and
// the original untruncated length of the control data is returned.
// We reset the length back to the truncated portion which was received,
// so the caller doesn't try to go out of bounds.
// We also ignore the EMSGSIZE to emulate behavior of other platforms.
msg.Control.Len = controlLen
err = nil
}
if err != nil {
return
}
oobn = int(msg.Control.Len)
n = int(bytesReceived)
recvflags = int(msg.Flags)
from, err = rsa.Sockaddr()
return
}

func sendmsgN(fd syscall.Handle, p, oob []byte, to syscall.Sockaddr, flags int) (n int, err error) {
err = windows.LoadWSASendMsg()
if err != nil {
return
}

var buf syscall.WSABuf
if len(p) > 0 {
buf.Buf = &p[0]
buf.Len = uint32(len(p))
}
var msg windows.WSAMsg
var bytesSent uint32
if to != nil {
var sa unsafe.Pointer
sa, msg.Namelen, err = sockaddrToRaw(to)
if err != nil {
return
}
msg.Name = uintptr(sa)
}
msg.Buffers = &buf
msg.BufferCount = 1
if len(oob) > 0 {
msg.Control.Buf = &oob[0]
msg.Control.Len = uint32(len(oob))
}
err = windows.WSASendMsg(fd, &msg, uint32(flags), &bytesSent, nil, nil)
return int(bytesSent), err
}

// ReadMsg wraps the WSARecvMsg network call.
func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, error) {
if err := fd.readLock(); err != nil {
return 0, 0, 0, nil, err
}
defer fd.readUnlock()
if err := fd.pd.prepareRead(fd.isFile); err != nil {
return 0, 0, 0, nil, err
}
for {
n, oobn, flags, sa, err := recvmsg(fd.Sysfd, p, oob, 0)
if err != nil {
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitRead(fd.isFile); err == nil {
continue
}
}
}
err = fd.eofError(n, err)
return n, oobn, flags, sa, err
}
}

// WriteMsg wraps the WSASendMsg network call.
func (fd *FD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (int, int, error) {
if err := fd.writeLock(); err != nil {
return 0, 0, err
}
defer fd.writeUnlock()
if err := fd.pd.prepareWrite(fd.isFile); err != nil {
return 0, 0, err
}
for {
n, err := sendmsgN(fd.Sysfd, p, oob, sa, 0)
if err == syscall.EAGAIN && fd.pd.pollable() {
if err = fd.pd.waitWrite(fd.isFile); err == nil {
continue
}
}
if err != nil {
return n, 0, err
}
return n, len(oob), err
}
}
@@ -4,7 +4,11 @@

package windows

import "syscall"
import (
"sync"
"syscall"
"unsafe"
)

const (
ERROR_SHARING_VIOLATION syscall.Errno = 32
@@ -115,9 +119,72 @@ const (
const (
WSA_FLAG_OVERLAPPED = 0x01
WSA_FLAG_NO_HANDLE_INHERIT = 0x80

WSAEMSGSIZE syscall.Errno = 10040

MSG_TRUNC = 0x0100
MSG_CTRUNC = 0x0200
)

type WSAMsg struct {
Name uintptr
Namelen int32
Buffers *syscall.WSABuf
BufferCount uint32
Control syscall.WSABuf
Flags uint32
}

const socket_error = uintptr(^uint32(0))

//sys WSASocket(af int32, typ int32, protocol int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = ws2_32.WSASocketW
//sys WSASendMsg(fd syscall.Handle, msg *WSAMsg, flags uint32, bytesSent *uint32, overlapped *syscall.Overlapped, croutine *byte) (err error) [failretval==socket_error] = ws2_32.WSASendMsg

func LoadWSASendMsg() error {
return procWSASendMsg.Find()
}

var recvMsgFunc struct {
once sync.Once
addr uintptr
err error
}

func LoadWSARecvMsg() error {
recvMsgFunc.once.Do(func() {
var s syscall.Handle
s, recvMsgFunc.err = syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
if recvMsgFunc.err != nil {
return
}
defer syscall.CloseHandle(s)
var n uint32
recvMsgFunc.err = syscall.WSAIoctl(s,
syscall.SIO_GET_EXTENSION_FUNCTION_POINTER,
(*byte)(unsafe.Pointer(&WSAID_WSARECVMSG)),
uint32(unsafe.Sizeof(WSAID_WSARECVMSG)),
(*byte)(unsafe.Pointer(&recvMsgFunc.addr)),
uint32(unsafe.Sizeof(recvMsgFunc.addr)),
&n, nil, 0)
})
return recvMsgFunc.err
}

func WSARecvMsg(fd syscall.Handle, msg *WSAMsg, bytesReceived *uint32, overlapped *syscall.Overlapped, croutine *byte) (err error) {
err = LoadWSARecvMsg()
if err != nil {
return
}
r1, _, e1 := syscall.Syscall6(recvMsgFunc.addr, 5, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(unsafe.Pointer(bytesReceived)), uintptr(unsafe.Pointer(overlapped)), uintptr(unsafe.Pointer(croutine)), 0)
if r1 == socket_error {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}

const (
ComputerNameNetBIOS = 0
@@ -48,6 +48,7 @@ var (
procMoveFileExW = modkernel32.NewProc("MoveFileExW")
procGetModuleFileNameW = modkernel32.NewProc("GetModuleFileNameW")
procWSASocketW = modws2_32.NewProc("WSASocketW")
procWSASendMsg = modws2_32.NewProc("WSASendMsg")
procGetACP = modkernel32.NewProc("GetACP")
procGetConsoleCP = modkernel32.NewProc("GetConsoleCP")
procMultiByteToWideChar = modkernel32.NewProc("MultiByteToWideChar")
@@ -123,6 +124,18 @@ func WSASocket(af int32, typ int32, protocol int32, protinfo *syscall.WSAProtoco
return
}

func WSASendMsg(fd syscall.Handle, msg *WSAMsg, flags uint32, bytesSent *uint32, overlapped *syscall.Overlapped, croutine *byte) (err error) {
r1, _, e1 := syscall.Syscall6(procWSASendMsg.Addr(), 6, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(flags), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), uintptr(unsafe.Pointer(croutine)))
if r1 == socket_error {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}

func GetACP() (acp uint32) {
r0, _, _ := syscall.Syscall(procGetACP.Addr(), 0, 0, 0, 0)
acp = uint32(r0)
@@ -223,17 +223,21 @@ func (fd *netFD) accept() (*netFD, error) {
return netfd, nil
}

func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
n, oobn, flags, sa, err = fd.pfd.ReadMsg(p, oob)
runtime.KeepAlive(fd)
return n, oobn, flags, sa, wrapSyscallError("wsarecvmsg", err)
}

func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsg(p, oob, sa)
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError("wsasendmsg", err)
}

// Unimplemented functions.

func (fd *netFD) dup() (*os.File, error) {
// TODO: Implement this
return nil, syscall.EWINDOWS
}

func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
return 0, 0, 0, nil, syscall.EWINDOWS
}

func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
return 0, 0, syscall.EWINDOWS
}
@@ -21,7 +21,7 @@ import (
// change the behavior of these methods; use Read or ReadMsgIP
// instead.

// BUG(mikio): On NaCl, Plan 9 and Windows, the ReadMsgIP and
// BUG(mikio): On NaCl and Plan 9, the ReadMsgIP and
// WriteMsgIP methods of IPConn are not implemented.

// BUG(mikio): On Windows, the File method of IPConn is not
@@ -9,7 +9,7 @@ import (
"syscall"
)

// BUG(mikio): On NaCl, Plan 9 and Windows, the ReadMsgUDP and
// BUG(mikio): On NaCl and Plan 9, the ReadMsgUDP and
// WriteMsgUDP methods of UDPConn are not implemented.

// BUG(mikio): On Windows, the File method of UDPConn is not
@@ -161,7 +161,7 @@ func testWriteToConn(t *testing.T, raddr string) {
}
_, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, nil)
switch runtime.GOOS {
case "nacl", "windows": // see golang.org/issue/9252
case "nacl": // see golang.org/issue/9252
t.Skipf("not implemented yet on %s", runtime.GOOS)
default:
if err != nil {
@@ -204,7 +204,7 @@ func testWriteToPacketConn(t *testing.T, raddr string) {
}
_, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, ra)
switch runtime.GOOS {
case "nacl", "windows": // see golang.org/issue/9252
case "nacl": // see golang.org/issue/9252
t.Skipf("not implemented yet on %s", runtime.GOOS)
default:
if err != nil {

0 comments on commit 774fc9c

Please sign in to comment.
You can’t perform that action at this time.