Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not monkey-patch socket, ensure socketpair returns non-blocking sockets #918

Merged
merged 3 commits into from
Jan 11, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions pika/adapters/select_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import pika.compat

from pika.compat import dictkeys
from pika.compat import SOCKET_ERROR
from pika.adapters.base_connection import BaseConnection

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -535,7 +533,7 @@ def stop(self):
# Send byte to interrupt the poll loop, use send() instead of
# os.write for Windows compatibility
self._w_interrupt.send(b'X')
except SOCKET_ERROR as err:
except pika.compat.SOCKET_ERROR as err:
if err.errno != errno.EWOULDBLOCK:
raise
except Exception as err:
Expand Down Expand Up @@ -614,7 +612,7 @@ def _dispatch_fd_events(self, fd_event_map):

self._processing_fd_event_map = fd_event_map

for fileno in dictkeys(fd_event_map):
for fileno in pika.compat.dictkeys(fd_event_map):
if fileno not in fd_event_map:
# the fileno has been removed from the map under our feet.
continue
Expand All @@ -635,7 +633,7 @@ def _get_interrupt_pair():
so use a pair of simple TCP sockets instead. The sockets will be
closed and garbage collected by python when the ioloop itself is.
"""
return socket.socketpair()
return pika.compat.socketpair()

def _read_interrupt(self, interrupt_fd, events): # pylint: disable=W0613
""" Read the interrupt byte(s). We ignore the event mask as we can ony
Expand All @@ -647,7 +645,7 @@ def _read_interrupt(self, interrupt_fd, events): # pylint: disable=W0613
try:
# NOTE Use recv instead of os.read for windows compatibility
self._r_interrupt.recv(512)
except SOCKET_ERROR as err:
except pika.compat.SOCKET_ERROR as err:
if err.errno != errno.EAGAIN:
raise

Expand Down
78 changes: 31 additions & 47 deletions pika/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,51 +167,35 @@ def get_linux_version(release_str):
_LOCALHOST = '127.0.0.1'
_LOCALHOST_V6 = '::1'

if not hasattr(socket, 'socketpair'):
# Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain.
def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
if family == socket.AF_INET:
host = _LOCALHOST
elif family == socket.AF_INET6:
host = _LOCALHOST_V6
else:
raise ValueError(
'Only AF_INET and AF_INET6 socket address families '
'are supported')
if type != socket.SOCK_STREAM:
raise ValueError('Only SOCK_STREAM socket type is supported')
if proto != 0:
raise ValueError('Only protocol zero is supported')

# We create a connected TCP socket. Note the trick with
# setblocking(False) that prevents us from having to create a thread.
lsock = socket.socket(family, type, proto)
def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a function comment that explains the reasons for having this method - Windows doesn't provide socket.socketpair.

if family == socket.AF_INET:
host = _LOCALHOST
elif family == socket.AF_INET6:
host = _LOCALHOST_V6
else:
raise ValueError(
'Only AF_INET and AF_INET6 socket address families '
'are supported')
if type != socket.SOCK_STREAM:
raise ValueError('Only SOCK_STREAM socket type is supported')
if proto != 0:
raise ValueError('Only protocol zero is supported')

lsock = socket.socket(family, type, proto)
try:
lsock.bind((host, 0))
lsock.listen(min(socket.SOMAXCONN, 128))
# On IPv6, ignore flow_info and scope_id
addr, port = lsock.getsockname()[:2]
csock = socket.socket(family, type, proto)
try:
lsock.bind((host, 0))
lsock.listen(min(socket.SOMAXCONN, 128))
# On IPv6, ignore flow_info and scope_id
addr, port = lsock.getsockname()[:2]
csock = socket.socket(family, type, proto)
try:
csock.setblocking(False)
if _sys.version_info >= (3, 0):
try:
csock.connect((addr, port))
except (BlockingIOError, InterruptedError):
pass
else:
try:
csock.connect((addr, port))
except SOCKET_ERROR as e:
if e.errno != errno.EWOULDBLOCK:
raise
csock.setblocking(True)
ssock, _ = lsock.accept()
except Exception:
csock.close()
raise
finally:
lsock.close()
return (ssock, csock)

socket.socketpair = socketpair
csock.connect((addr, port))
ssock, _ = lsock.accept()
except Exception:
csock.close()
raise
finally:
lsock.close()
csock.setblocking(False)
ssock.setblocking(False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My 2 cents: this function's name is exactly like the real deal in socket module which returns blocking sockets; if some other part of pika or pika's user use it, they might expect that it behaves just like the one in socket. For this reason, I would add the logic that makes the sockets non-blocking in the specialized method _get_interrupt_pair augmented with a comment, such as "make sockets non-blocking to avoid potential deadlock".

return (ssock, csock)
15 changes: 7 additions & 8 deletions tests/acceptance/forward_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
import threading
import traceback

from pika.compat import PY3
from pika.compat import SOCKET_ERROR
import pika.compat

if PY3:
if pika.compat.PY3:

def buffer(object, offset, size): # pylint: disable=W0622
"""array etc. have the buffer protocol"""
Expand Down Expand Up @@ -376,7 +375,7 @@ def handle(self): # pylint: disable=R0912
remote_dest_sock.getpeername())
else:
# Echo set-up
remote_dest_sock, remote_src_sock = socket.socketpair()
remote_dest_sock, remote_src_sock = pika.compat.socketpair()

try:
local_forwarder = threading.Thread(
Expand Down Expand Up @@ -421,7 +420,7 @@ def _forward(self, src_sock, dest_sock): # pylint: disable=R0912
while True:
try:
nbytes = src_sock.recv_into(rx_buf)
except SOCKET_ERROR as exc:
except pika.compat.SOCKET_ERROR as exc:
if exc.errno == errno.EINTR:
continue
elif exc.errno == errno.ECONNRESET:
Expand All @@ -442,7 +441,7 @@ def _forward(self, src_sock, dest_sock): # pylint: disable=R0912

try:
dest_sock.sendall(buffer(rx_buf, 0, nbytes))
except SOCKET_ERROR as exc:
except pika.compat.SOCKET_ERROR as exc:
if exc.errno == errno.EPIPE:
# Destination peer closed its end of the connection
_trace("%s Destination peer %s closed its end of "
Expand Down Expand Up @@ -499,7 +498,7 @@ def echo(port=0):
while True:
try:
data = sock.recv(4 * 1024) # pylint: disable=E1101
except SOCKET_ERROR as exc:
except pika.compat.SOCKET_ERROR as exc:
if exc.errno == errno.EINTR:
continue
else:
Expand All @@ -521,6 +520,6 @@ def _safe_shutdown_socket(sock, how=socket.SHUT_RDWR):
"""
try:
sock.shutdown(how)
except SOCKET_ERROR as exc:
except pika.compat.SOCKET_ERROR as exc:
if exc.errno != errno.ENOTCONN:
raise