Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion streamz/compatibility.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys

if sys.version_info[0] == 2:
if sys.version_info[0] == 2: # pragma: no cover
from thread import get_ident as get_thread_identity
import __builtin__ as builtins
else:
Expand Down
79 changes: 58 additions & 21 deletions streamz/sources.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from glob import glob
import os
import threading

import time
import tornado.ioloop
Expand Down Expand Up @@ -372,16 +373,52 @@ def _close_consumer(self):
self.stopped = True


class ConsumerFactory:
"""Keeps active consumer for a given thread

This exists so that each tasks in from_kafka_batched does not need to
recreate a kafka consumer, which takes some time. In the Dask case, these
consumers are created in the workers and will also be reused between tasks
and threads in the same worker.
"""
consumers = threading.local()

def __init__(self, consumer_params, topic):
consumer_params['enable.auto.commit'] = 'false'
self.consumer_params = consumer_params
self.topic = topic
self.key = str(sorted(consumer_params.items()))
self._init()

def _init(self):
import confluent_kafka as ck
if not hasattr(self.consumers, 'map'):
# race unlikely here
self.consumers.map = {}
consumers = self.consumers.map
if self.key not in consumers:
consumers[self.key] = ck.Consumer(self.consumer_params)

def get_consumer(self, partition, low):
import confluent_kafka as ck
consumer = self.consumers.map[self.key]
topic_partition = ck.TopicPartition(self.topic, partition, low)
consumer.assign([topic_partition])
return consumer


class FromKafkaBatched(Stream):
"""Base class for both local and cluster-based batched kafka processing"""
def __init__(self, topic, consumer_params, poll_interval='1s',
npartitions=1, **kwargs):
self.consumer_params = consumer_params
self.factory = ConsumerFactory(self.consumer_params, topic)
self.topic = topic
self.npartitions = npartitions
self.positions = [0] * npartitions
self.poll_interval = convert_interval(poll_interval)
self.stopped = True
self.consumer = None

super(FromKafkaBatched, self).__init__(ensure_io_loop=True, **kwargs)

Expand All @@ -402,8 +439,7 @@ def poll_kafka(self):
current_position = self.positions[partition]
lowest = max(current_position, low)
if high > lowest:
out.append((self.consumer_params, self.topic, partition,
lowest, high - 1))
out.append((self.factory, partition, lowest, high - 1))
self.positions[partition] = high

for part in out:
Expand All @@ -414,6 +450,7 @@ def poll_kafka(self):
finally:
self.consumer.unsubscribe()
self.consumer.close()
self.consumer = None

def start(self):
import confluent_kafka as ck
Expand Down Expand Up @@ -480,29 +517,29 @@ def from_kafka_batched(topic, consumer_params, poll_interval='1s',
return source.starmap(get_message_batch)


def get_message_batch(kafka_params, topic, partition, low, high, timeout=None):
def get_message_batch(factory, partition, low, high, timeout=None):
"""Fetch a batch of kafka messages in given topic/partition

This will block until messages are available, or timeout is reached.
"""
import confluent_kafka as ck
factory._init()
t0 = time.time()
consumer = ck.Consumer(kafka_params)
tp = ck.TopicPartition(topic, partition, low)
consumer.assign([tp])
consumer = factory.get_consumer(partition, low)
out = []
try:
while True:
msg = consumer.poll(0)
if msg and msg.value() and msg.error() is None:
if high >= msg.offset():
out.append(msg.value())
if high <= msg.offset():
break
else:
time.sleep(0.1)
if timeout is not None and time.time() - t0 > timeout:
break
finally:
consumer.close()
while True:
msg = consumer.poll(0)
if msg and msg.value() and msg.error() is None:
if high >= msg.offset():
out.append(msg.value())
if high <= msg.offset():
break
else:
time.sleep(0.1)
if timeout is not None and time.time() - t0 > timeout:
break
if low <= high:
try:
consumer.commit(asynchronous=False)
except factory.ck.KafkaError: # pragma: no cover
pass
return out
21 changes: 21 additions & 0 deletions streamz/tests/test_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,24 @@ def test_kafka_dask_batch(c, s, w1, w2):
yield await_for(lambda: any(out), 10, period=0.2)
assert b'value-1' in out[0]
stream.upstream.stopped = True


def test_consumer_factory():
from streamz.sources import ConsumerFactory
import pickle
j = random.randint(0, 10000)
ARGS = {'bootstrap.servers': 'localhost:9092',
'group.id': 'streamz-test%i' % j}
with kafka_service() as kafka:
kafka, TOPIC = kafka
factory1 = ConsumerFactory(ARGS, TOPIC)
consumer1 = factory1.get_consumer(0, 0)
factory2 = ConsumerFactory(ARGS, TOPIC)
assert factory1.consumers is factory2.consumers
factory3 = pickle.loads(pickle.dumps(factory1))
assert factory3 is not factory1
assert factory1.consumers is factory3.consumers
consumer2 = factory1.get_consumer(0, 0)
assert consumer1 is consumer2
consumer3 = factory1.get_consumer(0, 0)
assert consumer1 is consumer3
2 changes: 1 addition & 1 deletion streamz/tests/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_tcp():

@gen_test(timeout=60)
def test_tcp_async():
port = 9876
port = 9877
s = Source.from_tcp(port)
out = s.sink_to_list()
s.start()
Expand Down