Skip to content

Commit

Permalink
improvement[javalib]: JVM behavior parity for unresolved addresses (#…
Browse files Browse the repository at this point in the history
…3803)

On JVM, bind and connect throw exception when the socket address
parameter is unresolved
Adapt scala-native code to match behaviour
  • Loading branch information
RustedBones committed Mar 5, 2024
1 parent fdb0f08 commit 62d4b78
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 180 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,48 +99,22 @@ private[net] abstract class AbstractPlainDatagramSocketImpl
portOpt.map(inet.ntohs(_).toInt)
}

private def prepareSockaddrIn4(
inetAddress: InetAddress,
port: Int,
hints: Ptr[addrinfo],
ret: Ptr[Ptr[addrinfo]]
): Unit = {
hints.ai_family = posix.sys.socket.AF_UNSPEC
hints.ai_flags = AI_NUMERICHOST
hints.ai_socktype = posix.sys.socket.SOCK_DGRAM

Zone.acquire { implicit z =>
val cIP = toCString(inetAddress.getHostAddress())
if (getaddrinfo(cIP, toCString(port.toString), hints, ret) != 0) {
throw new BindException(
"Couldn't resolve address: " + inetAddress.getHostAddress()
)
}
}
}

private def bind4(addr: InetAddress, port: Int): Unit = {
val hints = stackalloc[addrinfo]()
val sa4Ptr = stackalloc[Ptr[addrinfo]]()

prepareSockaddrIn4(addr, port, hints, sa4Ptr)
val sa4 = (!sa4Ptr).ai_addr
val sa4Len = (!sa4Ptr).ai_addrlen
val sa4Family = (!sa4Ptr).ai_family
val sa4 = stackalloc[in.sockaddr_in]()
val sa4Len = sizeof[in.sockaddr_in].toUInt
SocketHelpers.prepareSockaddrIn4(addr, port, sa4)

val bindRes = posix.sys.socket.bind(
fd.fd,
sa4,
sa4.asInstanceOf[Ptr[posix.sys.socket.sockaddr]],
sa4Len
)

freeaddrinfo(!sa4Ptr)

if (bindRes < 0) {
throwCannotBind(addr)
}

this.localport = fetchLocalPort(sa4Family).getOrElse {
this.localport = fetchLocalPort(posix.sys.socket.AF_INET).getOrElse {
throwCannotBind(addr)
}
}
Expand Down Expand Up @@ -178,11 +152,9 @@ private[net] abstract class AbstractPlainDatagramSocketImpl

private def send4(p: DatagramPacket): Unit = {
val insAddr = p.getSocketAddress().asInstanceOf[InetSocketAddress]
val hints = stackalloc[addrinfo]()
val sa4Ptr = stackalloc[Ptr[addrinfo]]()
prepareSockaddrIn4(insAddr.getAddress, insAddr.getPort, hints, sa4Ptr)
val sa4 = (!sa4Ptr).ai_addr
val sa4Len = (!sa4Ptr).ai_addrlen
val sa4 = stackalloc[in.sockaddr_in]()
val sa4Len = sizeof[in.sockaddr_in].toUInt
SocketHelpers.prepareSockaddrIn4(insAddr.getAddress, insAddr.getPort, sa4)

val buffer = p.getData()
val cArr = buffer.at(p.getOffset())
Expand All @@ -192,12 +164,10 @@ private[net] abstract class AbstractPlainDatagramSocketImpl
cArr,
len.toUInt,
posix.sys.socket.MSG_NOSIGNAL,
sa4,
sa4.asInstanceOf[Ptr[posix.sys.socket.sockaddr]],
sa4Len
)

freeaddrinfo(!sa4Ptr)

if (ret < 0) {
throw new IOException("Could not send the datagram packet to the client")
}
Expand Down Expand Up @@ -234,13 +204,15 @@ private[net] abstract class AbstractPlainDatagramSocketImpl
}

private def connect4(address: InetAddress, port: Int): Unit = {
val hints = stackalloc[addrinfo]()
val sa4Ptr = stackalloc[Ptr[addrinfo]]()
prepareSockaddrIn4(address, port, hints, sa4Ptr)
val sa4 = (!sa4Ptr).ai_addr
val sa4Len = (!sa4Ptr).ai_addrlen
val sa4 = stackalloc[in.sockaddr_in]()
val sa4Len = sizeof[in.sockaddr_in].toUInt
SocketHelpers.prepareSockaddrIn4(address, port, sa4)

val connectRet = posix.sys.socket.connect(fd.fd, sa4, sa4Len)
val connectRet = posix.sys.socket.connect(
fd.fd,
sa4.asInstanceOf[Ptr[posix.sys.socket.sockaddr]],
sa4Len
)

if (connectRet < 0) {
throw new ConnectException(
Expand Down
132 changes: 52 additions & 80 deletions javalib/src/main/scala/java/net/AbstractPlainSocketImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,45 +94,35 @@ private[net] abstract class AbstractPlainSocketImpl extends SocketImpl {
}

private def bind4(addr: InetAddress, port: Int): Unit = {
val hints = stackalloc[addrinfo]()
val ret = stackalloc[Ptr[addrinfo]]()
hints.ai_family = socket.AF_UNSPEC
hints.ai_flags = AI_NUMERICHOST
hints.ai_socktype = socket.SOCK_STREAM

Zone.acquire { implicit z =>
val cIP = toCString(addr.getHostAddress())
if (getaddrinfo(cIP, toCString(port.toString), hints, ret) != 0) {
throw new BindException(
"Couldn't resolve address: " + addr.getHostAddress()
)
}
}
val sa4 = stackalloc[in.sockaddr_in]()
val sa4Len = sizeof[in.sockaddr_in].toUInt
SocketHelpers.prepareSockaddrIn4(addr, port, sa4)

val bindRes = socket.bind(fd.fd, (!ret).ai_addr, (!ret).ai_addrlen)

val family = (!ret).ai_family
freeaddrinfo(!ret)
val bindRes = socket.bind(
fd.fd,
sa4.asInstanceOf[Ptr[socket.sockaddr]],
sa4Len
)

if (bindRes < 0) {
if (bindRes < 0)
throwCannotBind(addr)
}

this.localport = fetchLocalPort(family).getOrElse {
this.localport = fetchLocalPort(socket.AF_INET).getOrElse {
throwCannotBind(addr)
}
}

private def bind6(addr: InetAddress, port: Int): Unit = {
val sa6 = stackalloc[in.sockaddr_in6]()
val sa6Len = sizeof[in.sockaddr_in6].toUInt

// By contract, all the bytes in sa6 are zero going in.
SocketHelpers.prepareSockaddrIn6(addr, port, sa6)

val bindRes = socket.bind(
fd.fd,
sa6.asInstanceOf[Ptr[socket.sockaddr]],
sizeof[in.sockaddr_in6].toUInt
sa6Len
)

if (bindRes < 0)
Expand Down Expand Up @@ -183,46 +173,19 @@ private[net] abstract class AbstractPlainSocketImpl extends SocketImpl {
s.fd = new FileDescriptor(newFd)
}

override def connect(host: String, port: Int): Unit = {
val addr = InetAddress.getByName(host)
connect(addr, port)
}

override def connect(address: InetAddress, port: Int): Unit = {
connect(new InetSocketAddress(address, port), 0)
}

private def connect4(address: SocketAddress, timeout: Int): Unit = {
val inetAddr = address.asInstanceOf[InetSocketAddress]
val hints = stackalloc[addrinfo]()
val ret = stackalloc[Ptr[addrinfo]]()
hints.ai_family = socket.AF_UNSPEC
hints.ai_flags = AI_NUMERICHOST | AI_NUMERICSERV
hints.ai_socktype = socket.SOCK_STREAM
val remoteAddress = inetAddr.getAddress.getHostAddress()

Zone.acquire { implicit z =>
val cIP = toCString(remoteAddress)
val cPort = toCString(inetAddr.getPort.toString)

val retCode = getaddrinfo(cIP, cPort, hints, ret)
private def connect4(addr: InetAddress, port: Int, timeout: Int): Unit = {
val sa4 = stackalloc[in.sockaddr_in]()
val sa4Len = sizeof[in.sockaddr_in].toUInt
SocketHelpers.prepareSockaddrIn4(addr, port, sa4)

if (retCode != 0) {
throw new ConnectException(
s"Could not resolve address: ${remoteAddress}"
+ s" on port: ${inetAddr.getPort}"
+ s" return code: ${retCode}"
)
}
}

val family = (!ret).ai_family
if (timeout != 0)
setSocketFdBlocking(fd, blocking = false)

val connectRet = socket.connect(fd.fd, (!ret).ai_addr, (!ret).ai_addrlen)

freeaddrinfo(!ret) // Must be after last use of ai_addr.
val connectRet = socket.connect(
fd.fd,
sa4.asInstanceOf[Ptr[socket.sockaddr]],
sa4Len
)

if (connectRet < 0) {
def inProgress = mapLastError(
Expand All @@ -236,37 +199,34 @@ private[net] abstract class AbstractPlainSocketImpl extends SocketImpl {
tryPollOnConnect(timeout)
} else {
throw new ConnectException(
s"Could not connect to address: ${remoteAddress}"
+ s" on port: ${inetAddr.getPort}"
+ s", errno: ${lastError()}"
s"Could not connect to address: $addr on port: $port, errno: ${lastError()}"
)
}
}

this.address = inetAddr.getAddress
this.port = inetAddr.getPort
this.localport = fetchLocalPort(family).getOrElse {
this.address = addr
this.port = port
this.localport = fetchLocalPort(socket.AF_INET).getOrElse {
throw new ConnectException(
"Could not resolve a local port when connecting"
)
}
}

private def connect6(address: SocketAddress, timeout: Int): Unit = {
val insAddr = address.asInstanceOf[InetSocketAddress]

private def connect6(addr: InetAddress, port: Int, timeout: Int): Unit = {
val sa6 = stackalloc[in.sockaddr_in6]()
val sa6Len = sizeof[in.sockaddr_in6].toUInt

// By contract, all the bytes in sa6 are zero going in.
SocketHelpers.prepareSockaddrIn6(insAddr.getAddress, insAddr.getPort, sa6)
SocketHelpers.prepareSockaddrIn6(addr, port, sa6)

if (timeout != 0)
setSocketFdBlocking(fd, blocking = false)

val connectRet = socket.connect(
fd.fd,
sa6.asInstanceOf[Ptr[socket.sockaddr]],
sizeof[in.sockaddr_in6].toUInt
sa6Len
)

if (connectRet < 0) {
Expand All @@ -281,17 +241,14 @@ private[net] abstract class AbstractPlainSocketImpl extends SocketImpl {
if (timeout > 0 && inProgress) {
tryPollOnConnect(timeout)
} else {
val ra = insAddr.getAddress.getHostAddress()
throw new ConnectException(
s"Could not connect to address: ${ra}"
+ s" on port: ${insAddr.getPort}"
+ s", errno: ${lastError()}"
s"Could not connect to address: $addr on port: $port, errno: ${lastError()}"
)
}
}

this.address = insAddr.getAddress
this.port = insAddr.getPort
this.address = addr
this.port = port
this.localport = fetchLocalPort(sa6.sin6_family.toInt).getOrElse {
throw new ConnectException(
"Could not resolve a local port when connecting"
Expand All @@ -300,13 +257,28 @@ private[net] abstract class AbstractPlainSocketImpl extends SocketImpl {
}

private lazy val connectFunc =
if (useIPv4Only) connect4(_: SocketAddress, _: Int)
else connect6(_: SocketAddress, _: Int)
if (useIPv4Only) connect4(_: InetAddress, _: Int, _: Int)
else connect6(_: InetAddress, _: Int, _: Int)

override def connect(address: SocketAddress, timeout: Int): Unit = {
throwIfClosed("connect") // Do not send negative fd.fd to poll()
override def connect(host: String, port: Int): Unit = {
throwIfClosed("connect")
val addr = InetAddress.getByName(host)
connectFunc(addr, port, 0)
}
override def connect(address: InetAddress, port: Int): Unit = {
throwIfClosed("connect")
connectFunc(address, port, 0)
}

connectFunc(address, timeout)
override def connect(address: SocketAddress, timeout: Int): Unit = {
throwIfClosed("connect")
val insAddr = address match {
case insAddr: InetSocketAddress => insAddr
case _ => throw new IllegalArgumentException("Unsupported address type")
}
val addr = insAddr.getAddress
val port = insAddr.getPort
connectFunc(addr, port, timeout)
}

override def close(): Unit = {
Expand Down
54 changes: 31 additions & 23 deletions javalib/src/main/scala/java/net/DatagramSocket.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,32 +89,35 @@ class DatagramSocket protected (
throw new SocketException("already bound")
}

if (addr != null && !addr.isInstanceOf[InetSocketAddress]) {
throw new IllegalArgumentException(
"Endpoint is of unsupported SocketAddress subclass"
)
val insAddr = addr match {
case null =>
new InetSocketAddress(SocketHelpers.getWildcardAddressForBind(), 0)
case insAddr: InetSocketAddress =>
insAddr
case _ =>
throw new IllegalArgumentException(
"Endpoint is of unsupported SocketAddress subclass"
)
}

val inetAddr =
if (addr == null ||
addr.asInstanceOf[InetSocketAddress].getAddress == null)
new InetSocketAddress(SocketHelpers.getWildcardAddressForBind(), 0)
else {
addr.asInstanceOf[InetSocketAddress]
}
if (insAddr.isUnresolved)
throw new SocketException("Unresolved address")

checkClosedAndCreate()

impl.bind(inetAddr.getPort, inetAddr.getAddress)
this.localAddr = inetAddr.getAddress
impl.bind(insAddr.getPort, insAddr.getAddress)
this.localAddr = insAddr.getAddress
this.localPort = impl.localport
bound = true
}

private[net] def checkAddress(addr: InetAddress, op: String) = addr match {
case null =>
case null =>
throw new IllegalArgumentException(op + ": null address")
case _: Inet4Address | _: Inet6Address =>
case _ => new IllegalArgumentException(op + ": invalid address type")
()
case _ =>
throw new IllegalArgumentException(op + ": invalid address type")
}

def connect(address: InetAddress, port: Int): Unit = {
Expand All @@ -136,14 +139,19 @@ class DatagramSocket protected (
}
}

def connect(address: SocketAddress): Unit = address match {
case iaddr: InetSocketAddress =>
connect(iaddr.getAddress, iaddr.getPort)
case _ =>
throw new IllegalArgumentException(
"Invalid address argument to connect - " +
"either of unsupported SocketAddress subclass or null"
)
def connect(address: SocketAddress): Unit = {
if (address == null)
throw new IllegalArgumentException("Address can't be null")

val inetAddr = address match {
case inetAddr: InetSocketAddress => inetAddr
case _ => throw new IllegalArgumentException("Unsupported address type")
}

if (inetAddr.isUnresolved)
throw new SocketException("Unresolved address")

connect(inetAddr.getAddress, inetAddr.getPort)
}

def disconnect(): Unit = {
Expand Down

0 comments on commit 62d4b78

Please sign in to comment.