Skip to content

Commit

Permalink
Adapt sockets to work in multithreading mode (#3128)
Browse files Browse the repository at this point in the history
Both implementations now have handling for possible interruption of syscall (important for usage with Boehm GC)
  • Loading branch information
WojciechMazur committed Feb 2, 2023
1 parent d8fa21a commit 4c81860
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 60 deletions.
76 changes: 44 additions & 32 deletions javalib/src/main/scala/java/net/UnixPlainSocketImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ package java.net

import scala.scalanative.unsigned._
import scala.scalanative.unsafe._
import scala.scalanative.libc._
import scala.scalanative.posix.errno._
import scala.scalanative.posix.fcntl._
import scala.scalanative.posix.poll._
import scala.scalanative.posix.pollEvents._
import scala.scalanative.posix.pollOps._
import scala.scalanative.posix.sys.socket

import java.io.{FileDescriptor, IOException}
import scala.annotation.tailrec

private[net] class UnixPlainSocketImpl extends AbstractPlainSocketImpl {

Expand All @@ -33,41 +34,54 @@ private[net] class UnixPlainSocketImpl extends AbstractPlainSocketImpl {
fd = new FileDescriptor(sock)
}

protected def tryPollOnConnect(timeout: Int): Unit = {
final protected def tryPollOnConnect(timeout: Int): Unit = {
val hasTimeout = timeout > 0
val deadline = if (hasTimeout) System.currentTimeMillis() + timeout else 0L
val nAlloc = 1.toUInt
val pollFd: Ptr[struct_pollfd] = stackalloc[struct_pollfd](nAlloc)

pollFd.fd = fd.fd
pollFd.revents = 0
pollFd.events = (POLLIN | POLLOUT).toShort

val pollRes = poll(pollFd, nAlloc, timeout)
val revents = pollFd.revents

setSocketFdBlocking(fd, blocking = true)

pollRes match {
case err if err < 0 =>
throw new SocketException(s"connect failed, poll errno: ${errno.errno}")
def failWithTimeout() = throw new SocketTimeoutException(
s"connect timed out, SO_TIMEOUT: ${timeout}"
)

@tailrec def loop(remainingTimeout: Int): Unit = {
val pollRes = poll(pollFd, nAlloc, remainingTimeout)
val revents = pollFd.revents

pollRes match {
case err if err < 0 =>
val errCode = errno
if (errCode == EINTR && hasTimeout) {
val remaining = deadline - System.currentTimeMillis()
if (remaining > 0) loop(remaining.toInt)
else failWithTimeout()
} else
throw new SocketException(s"connect failed, poll errno: $errCode")

case 0 => failWithTimeout()

case _ =>
if ((revents & POLLNVAL) != 0) {
val msg = s"connect failed, invalid poll request: ${revents}"
throw new ConnectException(msg)
} else if ((revents & (POLLIN | POLLHUP)) != 0) {
// Not enough information at this point to report remote host:port.
val msg = "Connection refused"
throw new ConnectException(msg)
} else if ((revents & POLLERR) != 0) { // an error was recognized.
val msg = s"connect failed, poll POLLERR: ${revents}"
throw new ConnectException(msg)
} // else should be POLLOUT - Open for Business, ignore XSI bits if set
}
}

case 0 =>
throw new SocketTimeoutException(
s"connect timed out, SO_TIMEOUT: ${timeout}"
)
try loop(timeout)
finally setSocketFdBlocking(fd, blocking = true)

case _ =>
if ((revents & POLLNVAL) != 0) {
val msg = s"connect failed, invalid poll request: ${revents}"
throw new ConnectException(msg)
} else if ((revents & (POLLIN | POLLHUP)) != 0) {
// Not enough information at this point to report remote host:port.
val msg = "Connection refused"
throw new ConnectException(msg)
} else if ((revents & POLLERR) != 0) { // an error was recognized.
val msg = s"connect failed, poll POLLERR: ${revents}"
throw new ConnectException(msg)
} // else should be POLLOUT - Open for Business, ignore XSI bits if set
}
}

protected def tryPollOnAccept(): Unit = {
Expand All @@ -83,7 +97,7 @@ private[net] class UnixPlainSocketImpl extends AbstractPlainSocketImpl {

pollRes match {
case err if err < 0 =>
throw new SocketException(s"accept failed, poll errno: ${errno.errno}")
throw new SocketException(s"accept failed, poll errno: $errno")

case 0 =>
throw new SocketTimeoutException(
Expand Down Expand Up @@ -123,8 +137,7 @@ private[net] class UnixPlainSocketImpl extends AbstractPlainSocketImpl {

if (opts == -1) {
throw new ConnectException(
"connect failed, fcntl F_GETFL" +
s", errno: ${errno.errno}"
s"connect failed, fcntl F_GETFL, errno: $errno"
)
}

Expand All @@ -138,8 +151,7 @@ private[net] class UnixPlainSocketImpl extends AbstractPlainSocketImpl {
if (ret == -1) {
throw new ConnectException(
"connect failed, " +
s"fcntl F_SETFL for opts: ${opts}" +
s", errno: ${errno.errno}"
s"fcntl F_SETFL for opts: $opts, errno: $errno"
)
}
}
Expand Down
69 changes: 41 additions & 28 deletions javalib/src/main/scala/java/net/WindowsPlainSocketImpl.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package java.net

import java.io.{FileDescriptor, IOException}
import scala.scalanative.libc._
import scala.scalanative.posix.sys.{socket => unixSocket}
import scala.scalanative.unsafe._
import scala.scalanative.unsigned._
import scala.scalanative.windows._
import scala.annotation.tailrec

private[net] class WindowsPlainSocketImpl extends AbstractPlainSocketImpl {
import WinSocketApi._
Expand All @@ -31,39 +31,51 @@ private[net] class WindowsPlainSocketImpl extends AbstractPlainSocketImpl {
)
}

protected def tryPollOnConnect(timeout: Int): Unit = {
final protected def tryPollOnConnect(timeout: Int): Unit = {
val hasTimeout = timeout > 0
val deadline = if (hasTimeout) System.currentTimeMillis() + timeout else 0L
val nAlloc = 1.toUInt
val pollFd: Ptr[WSAPollFd] = stackalloc[WSAPollFd](nAlloc)

pollFd.socket = fd.handle
pollFd.revents = 0.toShort
pollFd.events = (POLLIN | POLLOUT).toShort

val pollRes = WSAPoll(pollFd, nAlloc, timeout)
val revents = pollFd.revents

setSocketFdBlocking(fd, blocking = true)

pollRes match {
case err if err < 0 =>
throw new SocketException(s"connect failed, poll errno: ${errno.errno}")

case 0 =>
throw new SocketTimeoutException(
s"connect timed out, SO_TIMEOUT: ${timeout}"
)
def failWithTimeout() = throw new SocketTimeoutException(
s"connect timed out, SO_TIMEOUT: ${timeout}"
)

case _ =>
if ((revents & POLLNVAL) != 0) {
throw new ConnectException(
s"connect failed, invalid poll request: ${revents}"
)
} else if ((revents & (POLLERR | POLLHUP)) != 0) {
throw new ConnectException(
s"connect failed, POLLERR or POLLHUP set: ${revents}"
)
}
@tailrec def loop(remainingTimeout: Int): Unit = {
val pollRes = WSAPoll(pollFd, nAlloc, remainingTimeout)
val revents = pollFd.revents

pollRes match {
case err if err < 0 =>
val errCode = WSAGetLastError()
if (errCode == WSAEINTR && hasTimeout) {
val remaining = deadline - System.currentTimeMillis()
if (remaining > 0) loop(remaining.toInt)
else failWithTimeout()
} else
throw new SocketException(s"connect failed, poll errno: ${errCode}")

case 0 => failWithTimeout()

case _ =>
if ((revents & POLLNVAL) != 0) {
throw new ConnectException(
s"connect failed, invalid poll request: ${revents}"
)
} else if ((revents & (POLLERR | POLLHUP)) != 0) {
throw new ConnectException(
s"connect failed, POLLERR or POLLHUP set: ${revents}"
)
}
}
}

try loop(timeout)
finally setSocketFdBlocking(fd, blocking = true)
}

protected def tryPollOnAccept(): Unit = {
Expand All @@ -79,7 +91,9 @@ private[net] class WindowsPlainSocketImpl extends AbstractPlainSocketImpl {

pollRes match {
case err if err < 0 =>
throw new SocketException(s"accept failed, poll errno: ${errno.errno}")
throw new SocketException(
s"accept failed, poll errno: ${WSAGetLastError()}"
)

case 0 =>
throw new SocketTimeoutException(
Expand All @@ -97,8 +111,7 @@ private[net] class WindowsPlainSocketImpl extends AbstractPlainSocketImpl {
)
} else if (((revents & POLLIN) | (revents & POLLOUT)) == 0) {
throw new SocketException(
"accept failed, neither POLLIN nor POLLOUT set, " +
s"revents, ${revents}"
s"accept failed, neither POLLIN nor POLLOUT set, revents, ${revents}"
)
}
}
Expand Down

0 comments on commit 4c81860

Please sign in to comment.