Skip to content
This repository has been archived by the owner on Nov 23, 2017. It is now read-only.

Pause server #448

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,15 @@ def _run_until_complete_cb(fut):

class Server(events.AbstractServer):

def __init__(self, loop, sockets):
def __init__(self, loop, sockets, protocol_factory, ssl, backlog, *,
max_connections=None):
self._loop = loop
self.sockets = sockets
self._protocol_factory = protocol_factory
self._ssl = ssl
self._backlog = backlog
self._max_connections = max_connections
self._paused = False
self._active_count = 0
self._waiters = []

Expand All @@ -188,14 +194,37 @@ def __repr__(self):
def _attach(self):
assert self.sockets is not None
self._active_count += 1
if self._max_connections is not None and \
not self._paused and \
self._active_count >= self._max_connections:
self.pause()

def _detach(self):
assert self._active_count > 0
self._active_count -= 1
if self._active_count == 0 and self.sockets is None:
self._wakeup()
elif self._paused and self._max_connections is not None and \
self._active_count < self._max_connections:
self.resume()

def pause(self):
"""Pause future calls to accept()."""
assert not self._paused
self._paused = True
for sock in self.sockets:
self._loop.remove_reader(sock.fileno())

def resume(self):
"""Resume use of accept() on listening socket(s)."""
assert self._paused
self._paused = False
for sock in self.sockets:
self._loop._start_serving(self._protocol_factory, sock, self._ssl,
self, self._backlog)

def close(self):
self._protocol_factory = None
sockets = self.sockets
if sockets is None:
return
Expand Down Expand Up @@ -943,7 +972,8 @@ def create_server(self, protocol_factory, host=None, port=None,
backlog=100,
ssl=None,
reuse_address=None,
reuse_port=None):
reuse_port=None,
max_connections=None):
"""Create a TCP server.

The host parameter can be a string, in that case the TCP server is bound
Expand Down Expand Up @@ -1026,7 +1056,8 @@ def create_server(self, protocol_factory, host=None, port=None,
raise ValueError('Neither host/port nor sock were specified')
sockets = [sock]

server = Server(self, sockets)
server = Server(self, sockets, protocol_factory, ssl, backlog,
max_connections=max_connections)
for sock in sockets:
sock.listen(backlog)
sock.setblocking(False)
Expand Down
6 changes: 4 additions & 2 deletions asyncio/unix_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def create_unix_connection(self, protocol_factory, path, *,

@coroutine
def create_unix_server(self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None):
sock=None, backlog=100, ssl=None,
max_connections=None):
if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None')

Expand Down Expand Up @@ -294,7 +295,8 @@ def create_unix_server(self, protocol_factory, path=None, *,
'A UNIX Domain Stream Socket was expected, got {!r}'
.format(sock))

server = base_events.Server(self, [sock])
server = base_events.Server(self, [sock], protocol_factory, ssl,
backlog, max_connections=max_connections)
sock.listen(backlog)
sock.setblocking(False)
self._start_serving(protocol_factory, sock, ssl, server)
Expand Down
109 changes: 109 additions & 0 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,109 @@ def connection_made(self, transport):

server.close()

def test_create_server_max_connections(self):
protos = []
on_data = asyncio.Event(loop=self.loop)

class MaxConnTestProto(MyBaseProto):
def connection_made(self, transport):
super().connection_made(transport)
protos.append(self)
def data_received(self, data):
super().data_received(data)
on_data.set()

f = self.loop.create_server(lambda: MaxConnTestProto(loop=self.loop),
'0.0.0.0', 0, max_connections=2)
server = self.loop.run_until_complete(f)
port = server.sockets[0].getsockname()[1]
self._test_create_server_max_connections(server, socket.socket,
('127.0.0.1', port),
protos, on_data)

@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_unix_server_max_connections(self):
protos = []
on_data = asyncio.Event(loop=self.loop)

class MaxConnTestProto(MyBaseProto):
def connection_made(self, transport):
super().connection_made(transport)
protos.append(self)
def data_received(self, data):
super().data_received(data)
on_data.set()

factory = lambda: MaxConnTestProto(loop=self.loop)
server, path = self._make_unix_server(factory, max_connections=2)
socket_factory = lambda: socket.socket(socket.AF_UNIX)
self._test_create_server_max_connections(server, socket_factory, path,
protos, on_data)

def _test_create_server_max_connections(self, server, socket_factory,
connect_to, protos, on_data):
sock_fd = server.sockets[0].fileno()

# Low water..
c1 = socket_factory()
c1.connect(connect_to)
c1.sendall(b'x')
self.loop.run_until_complete(on_data.wait())
on_data.clear()
self.assertFalse(server._paused)
self.loop._selector.get_key(sock_fd) # has reader

# High water..
c2 = socket_factory()
c2.connect(connect_to)
c2.sendall(b'x')
self.loop.run_until_complete(on_data.wait())
on_data.clear()
self.assertEqual(server._active_count, 2)
self.assertTrue(server._paused)
self.assertRaises(KeyError, self.loop._selector.get_key, sock_fd)

# Low water again..
p = protos.pop(0)
p.transport.close()
self.loop.run_until_complete(p.done)
self.assertFalse(server._paused)
self.loop._selector.get_key(sock_fd) # has reader

# cleanup
p = protos.pop(0)
p.transport.close()
self.loop.run_until_complete(p.done)
c1.close()
c2.close()
server.close()
self.assertFalse(protos)

def test_create_server_pause_resume(self):
f = self.loop.create_server(lambda: None, '0.0.0.0', 0)
server = self.loop.run_until_complete(f)
sock_fd = server.sockets[0].fileno()
self._test_create_server_pause_resume(server, sock_fd)

@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
def test_create_unix_server_pause_resume(self):
server, path = self._make_unix_server(lambda: None)
sock_fd = server.sockets[0].fileno()
self._test_create_server_pause_resume(server, sock_fd)

def _test_create_server_pause_resume(self, server, sock_fd):
server.pause()
self.assertTrue(server._paused)
self.assertRaises(KeyError, self.loop._selector.get_key, sock_fd)
self.assertRaises(AssertionError, server.pause)

server.resume()
self.assertFalse(server._paused)
self.loop._selector.get_key(sock_fd) # has reader
self.assertRaises(AssertionError, server.resume)

server.close()

def test_server_close(self):
f = self.loop.create_server(MyProto, '0.0.0.0', 0)
server = self.loop.run_until_complete(f)
Expand Down Expand Up @@ -2162,6 +2265,12 @@ def test_create_datagram_endpoint(self):

def test_remove_fds_after_closing(self):
raise unittest.SkipTest("IocpEventLoop does not have add_reader()")

def test_create_server_max_connections(self):
raise unittest.SkipTest("IocpEventLoop incompatible with max_connections")

def test_create_server_pause_resume(self):
raise unittest.SkipTest("IocpEventLoop incompatible with Server pause")
else:
from asyncio import selectors

Expand Down