diff --git a/distributed/batched.py b/distributed/batched.py index 2aeabf6005..04bb8bc51f 100644 --- a/distributed/batched.py +++ b/distributed/batched.py @@ -1,16 +1,17 @@ from __future__ import print_function, division, absolute_import from datetime import timedelta +from functools import partial import logging from timeit import default_timer -from tornado import gen +from tornado import gen, locks from tornado.queues import Queue from tornado.iostream import StreamClosedError from tornado.ioloop import PeriodicCallback, IOLoop from .core import read, write -from .utils import log_errors +from .utils import ignoring, log_errors logger = logging.getLogger(__name__) @@ -38,18 +39,19 @@ class BatchedSend(object): def __init__(self, interval, loop=None): self.loop = loop or IOLoop.current() self.interval = interval / 1000. - self.last_transmission = 0 + + self.waker = locks.Event() + self.stopped = locks.Event() + self.please_stop = False self.buffer = [] self.stream = None - self.last_payload = [] - self.last_send = gen.sleep(0) self.message_count = 0 self.batch_count = 0 + self.next_deadline = None def start(self, stream): self.stream = stream - if self.buffer: - self.send_next() + self.loop.add_callback(self._background_send) def __str__(self): return '' % len(self.buffer) @@ -57,68 +59,56 @@ def __str__(self): __repr__ = __str__ @gen.coroutine - def send_next(self, wait=True): - try: - now = default_timer() - if wait: - wait_time = min(self.last_transmission + self.interval - now, - self.interval) - yield gen.sleep(wait_time) - yield self.last_send - self.buffer, payload = [], self.buffer - self.last_payload = payload - self.last_transmission = now + def _background_send(self): + while not self.please_stop: + with ignoring(gen.TimeoutError): + yield self.waker.wait(self.next_deadline) + self.waker.clear() + if not self.buffer: + # Nothing to send + self.next_deadline = None + continue + if (self.next_deadline is not None and + self.loop.time() < self.next_deadline): + # Send interval not expired yet + continue + payload, self.buffer = self.buffer, [] self.batch_count += 1 - self.last_send = write(self.stream, payload) - except Exception as e: - logger.exception(e) - raise + try: + yield write(self.stream, payload) + except Exception: + logger.exception("Error in batched write") + break + self.next_deadline = self.loop.time() + self.interval - @gen.coroutine - def _write(self, payload): - yield gen.sleep(0) - yield write(self.stream, payload) + self.stopped.set() def send(self, msg): - """ Send a message to the other side + """ Schedule a message for sending to the other side This completes quickly and synchronously """ - try: - self.message_count += 1 - if self.stream is None: # not yet started - self.buffer.append(msg) - return - - if self.stream._closed: - raise StreamClosedError() + if self.stream is not None and self.stream._closed: + raise StreamClosedError() - if self.buffer: - self.buffer.append(msg) - return - - # If we're new and early, - now = default_timer() - if (now < self.last_transmission + self.interval - or not self.last_send._done): - self.buffer.append(msg) - self.loop.add_callback(self.send_next) - return - - self.buffer.append(msg) - self.loop.add_callback(self.send_next, wait=False) - except StreamClosedError: - raise - except Exception as e: - logger.exception(e) + self.message_count += 1 + self.buffer.append(msg) + self.waker.set() @gen.coroutine def close(self, ignore_closed=False): """ Flush existing messages and then close stream """ + if self.stream is None: + return + self.please_stop = True + self.waker.set() + yield self.stopped.wait() try: - if self.stream._write_buffer: - yield self.last_send if self.buffer: + if self.next_deadline is not None: + delay = self.next_deadline - self.loop.time() + if delay > 0: + yield gen.sleep(delay) self.buffer, payload = [], self.buffer yield write(self.stream, payload) except StreamClosedError: diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index e7059ffb53..ba49e0fbfc 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -37,7 +37,7 @@ def handle_stream(self, stream, address): self.count += 1 yield write(stream, msg) except StreamClosedError as e: - pass + return def listen(self, port=0): while True: @@ -113,7 +113,6 @@ def test_BatchedSend(): assert str(len(b.buffer)) in str(b) assert str(len(b.buffer)) in repr(b) b.start(stream) - yield b.last_send yield gen.sleep(0.020) @@ -135,41 +134,30 @@ def test_send_before_start(): stream = yield client.connect('127.0.0.1', e.port) b = BatchedSend(interval=10) - yield b.last_send b.send('hello') - b.send('hello') + b.send('world') b.start(stream) - result = yield read(stream); assert result == ['hello', 'hello'] + result = yield read(stream); assert result == ['hello', 'world'] @gen_test() -def test_send_after_stream_start_before_stream_finish(): +def test_send_after_stream_start(): with echo_server() as e: client = TCPClient() stream = yield client.connect('127.0.0.1', e.port) b = BatchedSend(interval=10) - yield b.last_send b.start(stream) b.send('hello') - result = yield read(stream); assert result == ['hello'] - - -@gen_test() -def test_send_after_stream_finish(): - with echo_server() as e: - client = TCPClient() - stream = yield client.connect('127.0.0.1', e.port) - - b = BatchedSend(interval=10) - b.start(stream) - yield b.last_send + b.send('world') + result = yield read(stream) + if len(result) < 2: + result += yield read(stream) + assert result == ['hello', 'world'] - b.send('hello') - result = yield read(stream); assert result == ['hello'] @gen_test() def test_send_before_close(): @@ -179,7 +167,6 @@ def test_send_before_close(): b = BatchedSend(interval=10) b.start(stream) - yield b.last_send cnt = int(e.count) b.send('hello') @@ -203,7 +190,6 @@ def test_close_closed(): b = BatchedSend(interval=10) b.start(stream) - yield b.last_send b.send(123) stream.close() # external closing @@ -211,6 +197,24 @@ def test_close_closed(): yield b.close(ignore_closed=True) +@gen_test() +def test_close_not_started(): + b = BatchedSend(interval=10) + yield b.close() + + +@gen_test() +def test_close_twice(): + with echo_server() as e: + client = TCPClient() + stream = yield client.connect('127.0.0.1', e.port) + + b = BatchedSend(interval=10) + b.start(stream) + yield b.close() + yield b.close() + + @slow @gen_test(timeout=50) def test_stress(): @@ -253,14 +257,12 @@ def test_sending_traffic_jam(): b = BatchedSend(interval=0.01) b.start(stream) - yield b.last_send n = 50 msg = {'x': to_serialize(data)} for i in range(n): b.send(assoc(msg, 'i', i)) - print(len(b.buffer)) yield gen.sleep(0.001) results = []