From 6ed18a93ccc988470ed72b600b91e091ff0597a3 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Fri, 22 Mar 2019 13:01:34 -0400 Subject: [PATCH 1/3] Add ConsumerFactory for batched kafka Inspired by #226, but in this case, consumers are thread-local, because they should not be called simultaneously from separate threads, and this should also work across processes. --- streamz/sources.py | 79 +++++++++++++++++++++++++++---------- streamz/tests/test_kafka.py | 21 ++++++++++ 2 files changed, 79 insertions(+), 21 deletions(-) diff --git a/streamz/sources.py b/streamz/sources.py index 2cd47545..cb2fe203 100644 --- a/streamz/sources.py +++ b/streamz/sources.py @@ -1,5 +1,6 @@ from glob import glob import os +import threading import time import tornado.ioloop @@ -228,16 +229,52 @@ def _close_consumer(self): self.stopped = True +class ConsumerFactory: + """Keeps active consumer for a given thread + + This exists so that each tasks in from_kafka_batched does not need to + recreate a kafka consumer, which takes some time. In the Dask case, these + consumers are created in the workers and will also be reused between tasks + and threads in the same worker. + """ + consumers = threading.local() + + def __init__(self, consumer_params, topic): + consumer_params['enable.auto.commit'] = 'false' + self.consumer_params = consumer_params + self.topic = topic + self.key = str(sorted(consumer_params.items())) + self._init() + + def _init(self): + import confluent_kafka as ck + if not hasattr(self.consumers, 'map'): + # race unlikely here + self.consumers.map = {} + consumers = self.consumers.map + if self.key not in consumers: + consumers[self.key] = ck.Consumer(self.consumer_params) + + def get_consumer(self, partition, low): + import confluent_kafka as ck + consumer = self.consumers.map[self.key] + topic_partition = ck.TopicPartition(self.topic, partition, low) + consumer.assign([topic_partition]) + return consumer + + class FromKafkaBatched(Stream): """Base class for both local and cluster-based batched kafka processing""" def __init__(self, topic, consumer_params, poll_interval='1s', npartitions=1, **kwargs): self.consumer_params = consumer_params + self.factory = ConsumerFactory(self.consumer_params, topic) self.topic = topic self.npartitions = npartitions self.positions = [0] * npartitions self.poll_interval = convert_interval(poll_interval) self.stopped = True + self.consumer = None super(FromKafkaBatched, self).__init__(ensure_io_loop=True, **kwargs) @@ -258,8 +295,7 @@ def poll_kafka(self): current_position = self.positions[partition] lowest = max(current_position, low) if high > lowest: - out.append((self.consumer_params, self.topic, partition, - lowest, high - 1)) + out.append((self.factory, partition, lowest, high - 1)) self.positions[partition] = high for part in out: @@ -270,6 +306,7 @@ def poll_kafka(self): finally: self.consumer.unsubscribe() self.consumer.close() + self.consumer = None def start(self): import confluent_kafka as ck @@ -336,29 +373,29 @@ def from_kafka_batched(topic, consumer_params, poll_interval='1s', return source.starmap(get_message_batch) -def get_message_batch(kafka_params, topic, partition, low, high, timeout=None): +def get_message_batch(factory, partition, low, high, timeout=None): """Fetch a batch of kafka messages in given topic/partition This will block until messages are available, or timeout is reached. """ - import confluent_kafka as ck + factory._init() t0 = time.time() - consumer = ck.Consumer(kafka_params) - tp = ck.TopicPartition(topic, partition, low) - consumer.assign([tp]) + consumer = factory.get_consumer(partition, low) out = [] - try: - while True: - msg = consumer.poll(0) - if msg and msg.value() and msg.error() is None: - if high >= msg.offset(): - out.append(msg.value()) - if high <= msg.offset(): - break - else: - time.sleep(0.1) - if timeout is not None and time.time() - t0 > timeout: - break - finally: - consumer.close() + while True: + msg = consumer.poll(0) + if msg and msg.value() and msg.error() is None: + if high >= msg.offset(): + out.append(msg.value()) + if high <= msg.offset(): + break + else: + time.sleep(0.1) + if timeout is not None and time.time() - t0 > timeout: + break + if low <= high: + try: + consumer.commit(asynchronous=False) + except factory.ck.KafkaError: + pass return out diff --git a/streamz/tests/test_kafka.py b/streamz/tests/test_kafka.py index d66fabad..1cedc46a 100644 --- a/streamz/tests/test_kafka.py +++ b/streamz/tests/test_kafka.py @@ -219,3 +219,24 @@ def test_kafka_dask_batch(c, s, w1, w2): yield await_for(lambda: any(out), 10, period=0.2) assert b'value-1' in out[0] stream.upstream.stopped = True + + +def test_consumer_factory(): + from streamz.sources import ConsumerFactory + import pickle + j = random.randint(0, 10000) + ARGS = {'bootstrap.servers': 'localhost:9092', + 'group.id': 'streamz-test%i' % j} + with kafka_service() as kafka: + kafka, TOPIC = kafka + factory1 = ConsumerFactory(ARGS, TOPIC) + consumer1 = factory1.get_consumer(0, 0) + factory2 = ConsumerFactory(ARGS, TOPIC) + assert factory1.consumers is factory2.consumers + factory3 = pickle.loads(pickle.dumps(factory1)) + assert factory3 is not factory1 + assert factory1.consumers is factory3.consumers + consumer2 = factory1.get_consumer(0, 0) + assert consumer1 is consumer2 + consumer3 = factory1.get_consumer(0, 0) + assert consumer1 is consumer3 From fec0f364e5c418d940646d2ea0c9f079c2ddbb80 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Sun, 28 Apr 2019 17:45:58 -0400 Subject: [PATCH 2/3] cov --- streamz/compatibility.py | 2 +- streamz/sources.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/streamz/compatibility.py b/streamz/compatibility.py index 8763ebb4..b68e3d03 100644 --- a/streamz/compatibility.py +++ b/streamz/compatibility.py @@ -1,6 +1,6 @@ import sys -if sys.version_info[0] == 2: +if sys.version_info[0] == 2: # pragma: no cover from thread import get_ident as get_thread_identity import __builtin__ as builtins else: diff --git a/streamz/sources.py b/streamz/sources.py index 1856a58f..541a167a 100644 --- a/streamz/sources.py +++ b/streamz/sources.py @@ -540,6 +540,6 @@ def get_message_batch(factory, partition, low, high, timeout=None): if low <= high: try: consumer.commit(asynchronous=False) - except factory.ck.KafkaError: + except factory.ck.KafkaError: # pragma: no cover pass return out From 5bdd358875631a2771c0447d6f79a4967595e320 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Mon, 29 Apr 2019 19:47:46 -0400 Subject: [PATCH 3/3] Use different tcp port To guard against previous server taking time to quit --- streamz/tests/test_sources.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streamz/tests/test_sources.py b/streamz/tests/test_sources.py index f51dc38a..0ff79c6d 100644 --- a/streamz/tests/test_sources.py +++ b/streamz/tests/test_sources.py @@ -34,7 +34,7 @@ def test_tcp(): @gen_test(timeout=60) def test_tcp_async(): - port = 9876 + port = 9877 s = Source.from_tcp(port) out = s.sink_to_list() s.start()