Skip to content
This repository was archived by the owner on Nov 23, 2017. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def _sock_sendall(self, fut, registered, sock, data):
data = data[n:]
self.add_writer(fd, self._sock_sendall, fut, True, sock, data)

@coroutine
def sock_connect(self, sock, address):
"""Connect to a remote socket at address.

Expand All @@ -390,24 +391,16 @@ def sock_connect(self, sock, address):
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")

fut = self.create_future()
if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX:
self._sock_connect(fut, sock, address)
else:
if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX:
resolved = base_events._ensure_resolved(
address, family=sock.family, proto=sock.proto, loop=self)
resolved.add_done_callback(
lambda resolved: self._on_resolved(fut, sock, resolved))

return fut

def _on_resolved(self, fut, sock, resolved):
try:
if not resolved.done():
yield from resolved
_, _, _, _, address = resolved.result()[0]
except Exception as exc:
fut.set_exception(exc)
else:
self._sock_connect(fut, sock, address)

fut = self.create_future()
self._sock_connect(fut, sock, address)
return (yield from fut)

def _sock_connect(self, fut, sock, address):
fd = sock.fileno()
Expand All @@ -418,8 +411,8 @@ def _sock_connect(self, fut, sock, address):
# connection runs in background. We have to wait until the socket
# becomes writable to be notified when the connection succeed or
# fails.
fut.add_done_callback(functools.partial(self._sock_connect_done,
fd))
fut.add_done_callback(
functools.partial(self._sock_connect_done, fd))
self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
except Exception as exc:
fut.set_exception(exc)
Expand Down
124 changes: 101 additions & 23 deletions tests/test_selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import errno
import socket
import threading
import time
import unittest
from unittest import mock
try:
Expand Down Expand Up @@ -337,18 +339,6 @@ def test__sock_sendall_none(self):
(10, self.loop._sock_sendall, f, True, sock, b'data'),
self.loop.add_writer.call_args[0])

def test_sock_connect(self):
sock = test_utils.mock_nonblocking_socket()
self.loop._sock_connect = mock.Mock()

f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
self.assertIsInstance(f, asyncio.Future)
self.loop._run_once()
future_in, sock_in, address_in = self.loop._sock_connect.call_args[0]
self.assertEqual(future_in, f)
self.assertEqual(sock_in, sock)
self.assertEqual(address_in, ('127.0.0.1', 8080))

def test_sock_connect_timeout(self):
# asyncio issue #205: sock_connect() must unregister the socket on
# timeout error
Expand All @@ -360,29 +350,34 @@ def test_sock_connect_timeout(self):
sock.connect.side_effect = BlockingIOError

# first call to sock_connect() registers the socket
fut = self.loop.sock_connect(sock, ('127.0.0.1', 80))
fut = self.loop.create_task(
self.loop.sock_connect(sock, ('127.0.0.1', 80)))
self.loop._run_once()
self.assertTrue(sock.connect.called)
self.assertTrue(self.loop.add_writer.called)
self.assertEqual(len(fut._callbacks), 1)

# on timeout, the socket must be unregistered
sock.connect.reset_mock()
fut.set_exception(asyncio.TimeoutError)
with self.assertRaises(asyncio.TimeoutError):
fut.cancel()
with self.assertRaises(asyncio.CancelledError):
self.loop.run_until_complete(fut)
self.assertTrue(self.loop.remove_writer.called)

def test_sock_connect_resolve_using_socket_params(self):
@mock.patch('socket.getaddrinfo')
def test_sock_connect_resolve_using_socket_params(self, m_gai):
addr = ('need-resolution.com', 8080)
sock = test_utils.mock_nonblocking_socket()
self.loop.getaddrinfo = mock.Mock()
self.loop.sock_connect(sock, addr)
while not self.loop.getaddrinfo.called:
m_gai.side_effect = (None, None, None, None, ('127.0.0.1', 0))
m_gai._is_coroutine = False
con = self.loop.create_task(self.loop.sock_connect(sock, addr))
while not m_gai.called:
self.loop._run_once()
self.loop.getaddrinfo.assert_called_with(
*addr, type=sock.type, family=sock.family, proto=sock.proto,
flags=0)
m_gai.assert_called_with(
addr[0], addr[1], sock.family, sock.type, sock.proto, 0)

con.cancel()
with self.assertRaises(asyncio.CancelledError):
self.loop.run_until_complete(con)

def test__sock_connect(self):
f = asyncio.Future(loop=self.loop)
Expand Down Expand Up @@ -1778,5 +1773,88 @@ def test_fatal_error_connected(self, m_exc):
exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY))


class SelectorLoopFunctionalTests(unittest.TestCase):

def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()

@asyncio.coroutine
def recv_all(self, sock, nbytes):
buf = b''
while len(buf) < nbytes:
buf += yield from self.loop.sock_recv(sock, nbytes - len(buf))
return buf

def test_sock_connect_sock_write_race(self):
TIMEOUT = 3.0
PAYLOAD = b'DATA' * 1024 * 1024

class Server(threading.Thread):
def __init__(self, *args, srv_sock, **kwargs):
super().__init__(*args, **kwargs)
self.srv_sock = srv_sock

def run(self):
with self.srv_sock:
srv_sock.listen(100)

sock, addr = self.srv_sock.accept()
sock.settimeout(TIMEOUT)

with sock:
sock.sendall(b'helo')

buf = bytearray()
while len(buf) < len(PAYLOAD):
pack = sock.recv(1024 * 65)
if not pack:
break
buf.extend(pack)

@asyncio.coroutine
def client(addr):
sock = socket.socket()
with sock:
sock.setblocking(False)

started = time.monotonic()
while True:
if time.monotonic() - started > TIMEOUT:
self.fail('unable to connect to the socket')
return
try:
yield from self.loop.sock_connect(sock, addr)
except OSError:
yield from asyncio.sleep(0.05, loop=self.loop)
else:
break

# Give 'Server' thread a chance to accept and send b'helo'
time.sleep(0.1)

data = yield from self.recv_all(sock, 4)
self.assertEqual(data, b'helo')
yield from self.loop.sock_sendall(sock, PAYLOAD)

srv_sock = socket.socket()
srv_sock.settimeout(TIMEOUT)
srv_sock.bind(('127.0.0.1', 0))
srv_addr = srv_sock.getsockname()

srv = Server(srv_sock=srv_sock, daemon=True)
srv.start()

try:
self.loop.run_until_complete(
asyncio.wait_for(client(srv_addr), loop=self.loop,
timeout=TIMEOUT))
finally:
srv.join()


if __name__ == '__main__':
unittest.main()