Skip to content

Commit

Permalink
[3.13] gh-122133: Rework pure Python socketpair tests to avoid use of…
Browse files Browse the repository at this point in the history
… importlib.reload. (GH-122493) (#122504)

gh-122133: Rework pure Python socketpair tests to avoid use of importlib.reload. (GH-122493)

(cherry picked from commit f071f01)

Co-authored-by: Russell Keith-Magee <russell@keith-magee.com>
Co-authored-by: Gregory P. Smith <greg@krypto.org>
  • Loading branch information
3 people authored Jul 31, 2024
1 parent 4a6365c commit c21a361
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 77 deletions.
121 changes: 58 additions & 63 deletions Lib/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,16 +592,65 @@ def fromshare(info):
return socket(0, 0, 0, info)
__all__.append("fromshare")

if hasattr(_socket, "socketpair"):
# Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain.
# This is used if _socket doesn't natively provide socketpair. It's
# always defined so that it can be patched in for testing purposes.
def _fallback_socketpair(family=AF_INET, type=SOCK_STREAM, proto=0):
if family == AF_INET:
host = _LOCALHOST
elif family == AF_INET6:
host = _LOCALHOST_V6
else:
raise ValueError("Only AF_INET and AF_INET6 socket address families "
"are supported")
if type != SOCK_STREAM:
raise ValueError("Only SOCK_STREAM socket type is supported")
if proto != 0:
raise ValueError("Only protocol zero is supported")

# We create a connected TCP socket. Note the trick with
# setblocking(False) that prevents us from having to create a thread.
lsock = socket(family, type, proto)
try:
lsock.bind((host, 0))
lsock.listen()
# On IPv6, ignore flow_info and scope_id
addr, port = lsock.getsockname()[:2]
csock = socket(family, type, proto)
try:
csock.setblocking(False)
try:
csock.connect((addr, port))
except (BlockingIOError, InterruptedError):
pass
csock.setblocking(True)
ssock, _ = lsock.accept()
except:
csock.close()
raise
finally:
lsock.close()

def socketpair(family=None, type=SOCK_STREAM, proto=0):
"""socketpair([family[, type[, proto]]]) -> (socket object, socket object)
# Authenticating avoids using a connection from something else
# able to connect to {host}:{port} instead of us.
# We expect only AF_INET and AF_INET6 families.
try:
if (
ssock.getsockname() != csock.getpeername()
or csock.getsockname() != ssock.getpeername()
):
raise ConnectionError("Unexpected peer connection")
except:
# getsockname() and getpeername() can fail
# if either socket isn't connected.
ssock.close()
csock.close()
raise

Create a pair of socket objects from the sockets returned by the platform
socketpair() function.
The arguments are the same as for socket() except the default family is
AF_UNIX if defined on the platform; otherwise, the default is AF_INET.
"""
return (ssock, csock)

if hasattr(_socket, "socketpair"):
def socketpair(family=None, type=SOCK_STREAM, proto=0):
if family is None:
try:
family = AF_UNIX
Expand All @@ -613,61 +662,7 @@ def socketpair(family=None, type=SOCK_STREAM, proto=0):
return a, b

else:

# Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain.
def socketpair(family=AF_INET, type=SOCK_STREAM, proto=0):
if family == AF_INET:
host = _LOCALHOST
elif family == AF_INET6:
host = _LOCALHOST_V6
else:
raise ValueError("Only AF_INET and AF_INET6 socket address families "
"are supported")
if type != SOCK_STREAM:
raise ValueError("Only SOCK_STREAM socket type is supported")
if proto != 0:
raise ValueError("Only protocol zero is supported")

# We create a connected TCP socket. Note the trick with
# setblocking(False) that prevents us from having to create a thread.
lsock = socket(family, type, proto)
try:
lsock.bind((host, 0))
lsock.listen()
# On IPv6, ignore flow_info and scope_id
addr, port = lsock.getsockname()[:2]
csock = socket(family, type, proto)
try:
csock.setblocking(False)
try:
csock.connect((addr, port))
except (BlockingIOError, InterruptedError):
pass
csock.setblocking(True)
ssock, _ = lsock.accept()
except:
csock.close()
raise
finally:
lsock.close()

# Authenticating avoids using a connection from something else
# able to connect to {host}:{port} instead of us.
# We expect only AF_INET and AF_INET6 families.
try:
if (
ssock.getsockname() != csock.getpeername()
or csock.getsockname() != ssock.getpeername()
):
raise ConnectionError("Unexpected peer connection")
except:
# getsockname() and getpeername() can fail
# if either socket isn't connected.
ssock.close()
csock.close()
raise

return (ssock, csock)
socketpair = _fallback_socketpair
__all__.append("socketpair")

socketpair.__doc__ = """socketpair([family[, type[, proto]]]) -> (socket object, socket object)
Expand Down
20 changes: 6 additions & 14 deletions Lib/test/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -4861,7 +4861,6 @@ def _testSend(self):


class PurePythonSocketPairTest(SocketPairTest):

# Explicitly use socketpair AF_INET or AF_INET6 to ensure that is the
# code path we're using regardless platform is the pure python one where
# `_socket.socketpair` does not exist. (AF_INET does not work with
Expand All @@ -4876,28 +4875,21 @@ def socketpair(self):
# Local imports in this class make for easy security fix backporting.

def setUp(self):
import _socket
self._orig_sp = getattr(_socket, 'socketpair', None)
if self._orig_sp is not None:
if hasattr(_socket, "socketpair"):
self._orig_sp = socket.socketpair
# This forces the version using the non-OS provided socketpair
# emulation via an AF_INET socket in Lib/socket.py.
del _socket.socketpair
import importlib
global socket
socket = importlib.reload(socket)
socket.socketpair = socket._fallback_socketpair
else:
pass # This platform already uses the non-OS provided version.
# This platform already uses the non-OS provided version.
self._orig_sp = None
super().setUp()

def tearDown(self):
super().tearDown()
import _socket
if self._orig_sp is not None:
# Restore the default socket.socketpair definition.
_socket.socketpair = self._orig_sp
import importlib
global socket
socket = importlib.reload(socket)
socket.socketpair = self._orig_sp

def test_recv(self):
msg = self.serv.recv(1024)
Expand Down

0 comments on commit c21a361

Please sign in to comment.