Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not monkey-patch socket, ensure socketpair returns non-blocking sockets #918

Merged
merged 3 commits into from Jan 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 4 additions & 6 deletions pika/adapters/select_connection.py
Expand Up @@ -15,8 +15,6 @@

import pika.compat

from pika.compat import dictkeys
from pika.compat import SOCKET_ERROR
from pika.adapters.base_connection import BaseConnection

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -535,7 +533,7 @@ def stop(self):
# Send byte to interrupt the poll loop, use send() instead of
# os.write for Windows compatibility
self._w_interrupt.send(b'X')
except SOCKET_ERROR as err:
except pika.compat.SOCKET_ERROR as err:
if err.errno != errno.EWOULDBLOCK:
raise
except Exception as err:
Expand Down Expand Up @@ -614,7 +612,7 @@ def _dispatch_fd_events(self, fd_event_map):

self._processing_fd_event_map = fd_event_map

for fileno in dictkeys(fd_event_map):
for fileno in pika.compat.dictkeys(fd_event_map):
if fileno not in fd_event_map:
# the fileno has been removed from the map under our feet.
continue
Expand All @@ -635,7 +633,7 @@ def _get_interrupt_pair():
so use a pair of simple TCP sockets instead. The sockets will be
closed and garbage collected by python when the ioloop itself is.
"""
return socket.socketpair()
return pika.compat._nonblocking_socketpair()

def _read_interrupt(self, interrupt_fd, events): # pylint: disable=W0613
""" Read the interrupt byte(s). We ignore the event mask as we can ony
Expand All @@ -647,7 +645,7 @@ def _read_interrupt(self, interrupt_fd, events): # pylint: disable=W0613
try:
# NOTE Use recv instead of os.read for windows compatibility
self._r_interrupt.recv(512)
except SOCKET_ERROR as err:
except pika.compat.SOCKET_ERROR as err:
if err.errno != errno.EAGAIN:
raise

Expand Down
87 changes: 40 additions & 47 deletions pika/compat/__init__.py
Expand Up @@ -167,51 +167,44 @@ def get_linux_version(release_str):
_LOCALHOST = '127.0.0.1'
_LOCALHOST_V6 = '::1'

if not hasattr(socket, 'socketpair'):
# Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain.
def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
if family == socket.AF_INET:
host = _LOCALHOST
elif family == socket.AF_INET6:
host = _LOCALHOST_V6
else:
raise ValueError(
'Only AF_INET and AF_INET6 socket address families '
'are supported')
if type != socket.SOCK_STREAM:
raise ValueError('Only SOCK_STREAM socket type is supported')
if proto != 0:
raise ValueError('Only protocol zero is supported')

# We create a connected TCP socket. Note the trick with
# setblocking(False) that prevents us from having to create a thread.
lsock = socket.socket(family, type, proto)
def _nonblocking_socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

"""
Returns a pair of sockets in the manner of socketpair with the additional
feature that they will be non-blocking. Prior to Python 3.5, socketpair
did not exist on Windows at all.
"""
if family == socket.AF_INET:
host = _LOCALHOST
elif family == socket.AF_INET6:
host = _LOCALHOST_V6
else:
raise ValueError(
'Only AF_INET and AF_INET6 socket address families '
'are supported')
if type != socket.SOCK_STREAM:
raise ValueError('Only SOCK_STREAM socket type is supported')
if proto != 0:
raise ValueError('Only protocol zero is supported')

lsock = socket.socket(family, type, proto)
try:
lsock.bind((host, 0))
lsock.listen(min(socket.SOMAXCONN, 128))
# On IPv6, ignore flow_info and scope_id
addr, port = lsock.getsockname()[:2]
csock = socket.socket(family, type, proto)
try:
lsock.bind((host, 0))
lsock.listen(min(socket.SOMAXCONN, 128))
# On IPv6, ignore flow_info and scope_id
addr, port = lsock.getsockname()[:2]
csock = socket.socket(family, type, proto)
try:
csock.setblocking(False)
if _sys.version_info >= (3, 0):
try:
csock.connect((addr, port))
except (BlockingIOError, InterruptedError):
pass
else:
try:
csock.connect((addr, port))
except SOCKET_ERROR as e:
if e.errno != errno.EWOULDBLOCK:
raise
csock.setblocking(True)
ssock, _ = lsock.accept()
except Exception:
csock.close()
raise
finally:
lsock.close()
return (ssock, csock)

socket.socketpair = socketpair
csock.connect((addr, port))
ssock, _ = lsock.accept()
except Exception:
csock.close()
raise
finally:
lsock.close()

# Make sockets non-blocking to prevent deadlocks
# See https://github.com/pika/pika/issues/917
csock.setblocking(False)
ssock.setblocking(False)

return (ssock, csock)
16 changes: 8 additions & 8 deletions tests/acceptance/forward_server.py
Expand Up @@ -15,10 +15,9 @@
import threading
import traceback

from pika.compat import PY3
from pika.compat import SOCKET_ERROR
import pika.compat

if PY3:
if pika.compat.PY3:

def buffer(object, offset, size): # pylint: disable=W0622
"""array etc. have the buffer protocol"""
Expand Down Expand Up @@ -376,7 +375,8 @@ def handle(self): # pylint: disable=R0912
remote_dest_sock.getpeername())
else:
# Echo set-up
remote_dest_sock, remote_src_sock = socket.socketpair()
remote_dest_sock, remote_src_sock = \
pika.compat._nonblocking_socketpair()

try:
local_forwarder = threading.Thread(
Expand Down Expand Up @@ -421,7 +421,7 @@ def _forward(self, src_sock, dest_sock): # pylint: disable=R0912
while True:
try:
nbytes = src_sock.recv_into(rx_buf)
except SOCKET_ERROR as exc:
except pika.compat.SOCKET_ERROR as exc:
if exc.errno == errno.EINTR:
continue
elif exc.errno == errno.ECONNRESET:
Expand All @@ -442,7 +442,7 @@ def _forward(self, src_sock, dest_sock): # pylint: disable=R0912

try:
dest_sock.sendall(buffer(rx_buf, 0, nbytes))
except SOCKET_ERROR as exc:
except pika.compat.SOCKET_ERROR as exc:
if exc.errno == errno.EPIPE:
# Destination peer closed its end of the connection
_trace("%s Destination peer %s closed its end of "
Expand Down Expand Up @@ -499,7 +499,7 @@ def echo(port=0):
while True:
try:
data = sock.recv(4 * 1024) # pylint: disable=E1101
except SOCKET_ERROR as exc:
except pika.compat.SOCKET_ERROR as exc:
if exc.errno == errno.EINTR:
continue
else:
Expand All @@ -521,6 +521,6 @@ def _safe_shutdown_socket(sock, how=socket.SHUT_RDWR):
"""
try:
sock.shutdown(how)
except SOCKET_ERROR as exc:
except pika.compat.SOCKET_ERROR as exc:
if exc.errno != errno.ENOTCONN:
raise