Skip to content
This repository was archived by the owner on Nov 23, 2017. It is now read-only.

Fix ordering issues in UNIX read/write pipe transport constructors #370

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,6 @@ def create_connection(self, protocol_factory, host=None, port=None, *,
raise ValueError(
'host and port was not specified and no sock specified')

sock.setblocking(False)

transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
if self._debug:
Expand All @@ -721,14 +719,17 @@ def create_connection(self, protocol_factory, host=None, port=None, *,

@coroutine
def _create_connection_transport(self, sock, protocol_factory, ssl,
server_hostname):
server_hostname, server_side=False):

sock.setblocking(False)

protocol = protocol_factory()
waiter = self.create_future()
if ssl:
sslcontext = None if isinstance(ssl, bool) else ssl
transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter,
server_side=False, server_hostname=server_hostname)
server_side=server_side, server_hostname=server_hostname)
else:
transport = self._make_socket_transport(sock, protocol, waiter)

Expand Down Expand Up @@ -979,6 +980,25 @@ def create_server(self, protocol_factory, host=None, port=None,
logger.info("%r is serving", server)
return server

@coroutine
def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None):
"""Handle an accepted connection.

This is used by servers that accept connections outside of
asyncio but that use asyncio to handle connections.

This method is a coroutine. When completed, the coroutine
returns a (transport, protocol) pair.
"""
transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
sock = transport.get_extra_info('socket')
logger.debug("%r handled: (%r, %r)", sock, transport, protocol)
return transport, protocol

@coroutine
def connect_read_pipe(self, protocol_factory, pipe):
protocol = protocol_factory()
Expand Down
2 changes: 1 addition & 1 deletion asyncio/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def acquire(self):
This method blocks until the lock is unlocked, then sets it to
locked and returns True.
"""
if not self._waiters and not self._locked:
if not self._locked and all(w.cancelled() for w in self._waiters):
self._locked = True
return True

Expand Down
31 changes: 21 additions & 10 deletions asyncio/unix_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,14 +305,20 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None):
self._loop = loop
self._pipe = pipe
self._fileno = pipe.fileno()
self._protocol = protocol
self._closing = False

mode = os.fstat(self._fileno).st_mode
if not (stat.S_ISFIFO(mode) or
stat.S_ISSOCK(mode) or
stat.S_ISCHR(mode)):
self._pipe = None
self._fileno = None
self._protocol = None
raise ValueError("Pipe transport is for pipes/sockets only.")

_set_nonblocking(self._fileno)
self._protocol = protocol
self._closing = False

self._loop.call_soon(self._protocol.connection_made, self)
# only start reading when connection_made() has been called
self._loop.call_soon(self._loop.add_reader,
Expand Down Expand Up @@ -421,25 +427,30 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None):
self._extra['pipe'] = pipe
self._pipe = pipe
self._fileno = pipe.fileno()
self._protocol = protocol
self._buffer = []
self._conn_lost = 0
self._closing = False # Set when close() or write_eof() called.

mode = os.fstat(self._fileno).st_mode
is_char = stat.S_ISCHR(mode)
is_fifo = stat.S_ISFIFO(mode)
is_socket = stat.S_ISSOCK(mode)
if not (is_socket or
stat.S_ISFIFO(mode) or
stat.S_ISCHR(mode)):
if not (is_char or is_fifo or is_socket):
self._pipe = None
self._fileno = None
self._protocol = None
raise ValueError("Pipe transport is only for "
"pipes, sockets and character devices")

_set_nonblocking(self._fileno)
self._protocol = protocol
self._buffer = []
self._conn_lost = 0
self._closing = False # Set when close() or write_eof() called.

self._loop.call_soon(self._protocol.connection_made, self)

# On AIX, the reader trick (to be notified when the read end of the
# socket is closed) only works for sockets. On other platforms it
# works for pipes and sockets. (Exception: OS X 10.4? Issue #19294.)
if is_socket or not sys.platform.startswith("aix"):
if is_socket or (is_fifo and not sys.platform.startswith("aix")):
# only start reading when connection_made() has been called
self._loop.call_soon(self._loop.add_reader,
self._fileno, self._read_ready)
Expand Down
11 changes: 6 additions & 5 deletions examples/crawl.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,11 +594,12 @@ def report(self, stats, file=None):
stats.add('html')
size = len(self.body or b'')
stats.add('html_bytes', size)
print(self.url, self.response.status,
self.ctype, self.encoding,
size,
'%d/%d' % (len(self.new_urls or ()), len(self.urls or ())),
file=file)
if self.log.level:
print(self.url, self.response.status,
self.ctype, self.encoding,
size,
'%d/%d' % (len(self.new_urls or ()), len(self.urls or ())),
file=file)
elif self.response is None:
print(self.url, 'no response object')
else:
Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
# - hg ci && hg push

import os
import sys
try:
from setuptools import setup, Extension
except ImportError:
# Use distutils.core as a fallback.
# We won't be able to build the Wheel file on Windows.
from distutils.core import setup, Extension

if sys.version_info < (3, 3, 0):
raise RuntimeError("asyncio requires Python 3.3.0+")

extensions = []
if os.name == 'nt':
ext = Extension(
Expand Down
154 changes: 154 additions & 0 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from unittest import mock
import weakref

if sys.platform != 'win32':
import tty

import asyncio
from asyncio import proactor_events
Expand Down Expand Up @@ -744,6 +746,85 @@ def test_create_connection_local_addr_in_use(self):
self.assertEqual(cm.exception.errno, errno.EADDRINUSE)
self.assertIn(str(httpd.address), cm.exception.strerror)

def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
loop = self.loop

class MyProto(MyBaseProto):

def connection_lost(self, exc):
super().connection_lost(exc)
loop.call_soon(loop.stop)

def data_received(self, data):
super().data_received(data)
self.transport.write(expected_response)

lsock = socket.socket()
lsock.bind(('127.0.0.1', 0))
lsock.listen(1)
addr = lsock.getsockname()

message = b'test data'
reponse = None
expected_response = b'roger'

def client():
global response
try:
csock = socket.socket()
if client_ssl is not None:
csock = client_ssl.wrap_socket(csock)
csock.connect(addr)
csock.sendall(message)
response = csock.recv(99)
csock.close()
except Exception as exc:
print(
"Failure in client thread in test_connect_accepted_socket",
exc)

thread = threading.Thread(target=client, daemon=True)
thread.start()

conn, _ = lsock.accept()
proto = MyProto(loop=loop)
proto.loop = loop
f = loop.create_task(
loop.connect_accepted_socket(
(lambda : proto), conn, ssl=server_ssl))
loop.run_forever()
conn.close()
lsock.close()

thread.join(1)
self.assertFalse(thread.is_alive())
self.assertEqual(proto.state, 'CLOSED')
self.assertEqual(proto.nbytes, len(message))
self.assertEqual(response, expected_response)

@unittest.skipIf(ssl is None, 'No ssl module')
def test_ssl_connect_accepted_socket(self):
if (sys.platform == 'win32' and
sys.version_info < (3, 5) and
isinstance(self.loop, proactor_events.BaseProactorEventLoop)
):
raise unittest.SkipTest(
'SSL not supported with proactor event loops before Python 3.5'
)

server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
server_context.load_cert_chain(ONLYCERT, ONLYKEY)
if hasattr(server_context, 'check_hostname'):
server_context.check_hostname = False
server_context.verify_mode = ssl.CERT_NONE

client_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
if hasattr(server_context, 'check_hostname'):
client_context.check_hostname = False
client_context.verify_mode = ssl.CERT_NONE

self.test_connect_accepted_socket(server_context, client_context)

@mock.patch('asyncio.base_events.socket')
def create_server_multiple_hosts(self, family, hosts, mock_sock):
@asyncio.coroutine
Expand Down Expand Up @@ -1547,6 +1628,79 @@ def reader(data):
self.loop.run_until_complete(proto.done)
self.assertEqual('CLOSED', proto.state)

@unittest.skipUnless(sys.platform != 'win32',
"Don't support pipes for Windows")
# select, poll and kqueue don't support character devices (PTY) on Mac OS X
# older than 10.6 (Snow Leopard)
@support.requires_mac_ver(10, 6)
def test_bidirectional_pty(self):
master, read_slave = os.openpty()
write_slave = os.dup(read_slave)
tty.setraw(read_slave)

slave_read_obj = io.open(read_slave, 'rb', 0)
read_proto = MyReadPipeProto(loop=self.loop)
read_connect = self.loop.connect_read_pipe(lambda: read_proto,
slave_read_obj)
read_transport, p = self.loop.run_until_complete(read_connect)
self.assertIs(p, read_proto)
self.assertIs(read_transport, read_proto.transport)
self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
self.assertEqual(0, read_proto.nbytes)


slave_write_obj = io.open(write_slave, 'wb', 0)
write_proto = MyWritePipeProto(loop=self.loop)
write_connect = self.loop.connect_write_pipe(lambda: write_proto,
slave_write_obj)
write_transport, p = self.loop.run_until_complete(write_connect)
self.assertIs(p, write_proto)
self.assertIs(write_transport, write_proto.transport)
self.assertEqual('CONNECTED', write_proto.state)

data = bytearray()
def reader(data):
chunk = os.read(master, 1024)
data += chunk
return len(data)

write_transport.write(b'1')
test_utils.run_until(self.loop, lambda: reader(data) >= 1, timeout=10)
self.assertEqual(b'1', data)
self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
self.assertEqual('CONNECTED', write_proto.state)

os.write(master, b'a')
test_utils.run_until(self.loop, lambda: read_proto.nbytes >= 1,
timeout=10)
self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
self.assertEqual(1, read_proto.nbytes)
self.assertEqual('CONNECTED', write_proto.state)

write_transport.write(b'2345')
test_utils.run_until(self.loop, lambda: reader(data) >= 5, timeout=10)
self.assertEqual(b'12345', data)
self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
self.assertEqual('CONNECTED', write_proto.state)

os.write(master, b'bcde')
test_utils.run_until(self.loop, lambda: read_proto.nbytes >= 5,
timeout=10)
self.assertEqual(['INITIAL', 'CONNECTED'], read_proto.state)
self.assertEqual(5, read_proto.nbytes)
self.assertEqual('CONNECTED', write_proto.state)

os.close(master)

read_transport.close()
self.loop.run_until_complete(read_proto.done)
self.assertEqual(
['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], read_proto.state)

write_transport.close()
self.loop.run_until_complete(write_proto.done)
self.assertEqual('CLOSED', write_proto.state)

def test_prompt_cancellation(self):
r, w = test_utils.socketpair()
r.setblocking(False)
Expand Down