Skip to content

Commit

Permalink
DStream.updateStateByKey() (#56)
Browse files Browse the repository at this point in the history
* implement updateStateByKey()

* rewrite streaming test cases without StreamingTestCase class
  • Loading branch information
svenkreiss committed Apr 30, 2017
1 parent 347c633 commit 4595a7c
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 118 deletions.
69 changes: 69 additions & 0 deletions pysparkling/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,50 @@ def transform(self, func):

return TransformedDStream(self, func)

def updateStateByKey(self, func):
"""Process with state.
:param func: Evaluated per key. Takes list of input_values and a state.
:rtype: DStream
This example shows how to return the latest value per key:
>>> import pysparkling
>>> sc = pysparkling.Context()
>>> ssc = pysparkling.streaming.StreamingContext(sc, 0.2)
>>> (
... ssc
... .queueStream([[('a', 1), ('b', 3)], [('a', 2), ('c', 4)]])
... .updateStateByKey(lambda input_values, state:
... state
... if not input_values
... else input_values[-1])
... .foreachRDD(lambda rdd: print(sorted(rdd.collect())))
... )
>>> ssc.start()
>>> ssc.awaitTermination(0.5)
[('a', 1), ('b', 3)]
[('a', 2), ('b', 3), ('c', 4)]
This example counts values per key:
>>> sc = pysparkling.Context()
>>> ssc = pysparkling.streaming.StreamingContext(sc, 0.2)
>>> (
... ssc
... .queueStream([[('a', 1)], [('a', 2), ('b', 4), ('b', 3)]])
... .updateStateByKey(lambda input_values, state:
... (state if state is not None else 0) +
... sum(input_values))
... .foreachRDD(lambda rdd: print(sorted(rdd.collect())))
... )
>>> ssc.start()
>>> ssc.awaitTermination(0.5)
[('a', 1)]
[('a', 3), ('b', 7)]
"""
return StatefulDStream(self, func)

def window(self, windowDuration, slideDuration=None):
"""Windowed RDD.
Expand Down Expand Up @@ -577,3 +621,28 @@ def _step(self, time_):
self._current_time = time_
self._current_rdd = getattr(self._prev1._current_rdd, self._op)(
self._prev2._current_rdd, self._num_partitions)


class StatefulDStream(DStream):
def __init__(self, prev, state_update_fn):
super(StatefulDStream, self).__init__(prev._stream, prev._context)
self._prev = prev
self._func = state_update_fn
self._state_rdd = EmptyRDD(self._context._context)

def convert_fn(self, joined):
input_values, state_list = joined
state = state_list[-1] if len(state_list) > 0 else None

return self._func(input_values, state)

def _step(self, time_):
if time_ <= self._current_time:
return

self._prev._step(time_)
self._current_time = time_

combined = self._prev._current_rdd.cogroup(self._state_rdd)
self._state_rdd = combined.mapValues(self.convert_fn)
self._current_rdd = self._state_rdd
19 changes: 14 additions & 5 deletions scripts/pyspark_streaming.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Explore PySpark API.
Run with `spark-submit scripts/pyspark_streaming.py`.
"""

from __future__ import print_function

import pyspark
Expand Down Expand Up @@ -40,11 +45,15 @@ def union(ssc):


def updateStateByKey(ssc):
def processStateUpdateByKey(input_stream, state):
print('i', input_stream)
print('s', state)
return state if not input_stream else input_stream[-1]

ssc.checkpoint('checkpoints/')
(ssc
.queueStream([[('a', 1), ('b', 3)], [('a', 2), ('c', 4)]])
.updateStateByKey(lambda input_stream, state:
state if not input_stream else input_stream[-1])
.queueStream([[('a', 1), ('b', 3)], [('a', 2), ('a', 5), ('c', 4)]])
.updateStateByKey(processStateUpdateByKey)
.pprint()
)

Expand All @@ -63,8 +72,8 @@ def stream_log(ssc):
# save_text(ssc)
# window(ssc)
# union(ssc)
# updateStateByKey(ssc)
stream_log(ssc)
updateStateByKey(ssc)
# stream_log(ssc)

ssc.start()
time.sleep(3.0)
Expand Down
22 changes: 0 additions & 22 deletions tests/streaming_test_case.py

This file was deleted.

64 changes: 42 additions & 22 deletions tests/test_streaming_files.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,73 @@
from .streaming_test_case import StreamingTestCase
import pysparkling
import tornado.testing


class TextFile(StreamingTestCase):
class TextFile(tornado.testing.AsyncTestCase):

def test_connect(self):
self.result = 0
self.expect = 22
sc = pysparkling.Context()
ssc = pysparkling.streaming.StreamingContext(sc, 0.1)

result = []
(
self.stream_c.textFileStream('LICENS*', process_all=True)
ssc.textFileStream('LICENS*', process_all=True)
.count()
.foreachRDD(lambda rdd: self.incr_result(rdd.collect()[0]))
.foreachRDD(lambda rdd: result.append(rdd.collect()[0]))
)

ssc.start()
ssc.awaitTermination(timeout=0.3)
self.assertEqual(sum(result), 22)

def test_save(self):
self.result = 0
self.expect = 0
sc = pysparkling.Context()
ssc = pysparkling.streaming.StreamingContext(sc, 0.1)

(
self.stream_c.textFileStream('LICENS*')
ssc.textFileStream('LICENS*')
.count()
.saveAsTextFiles('tests/textout/')
)

def test_save_gz(self):
self.result = 0
self.expect = 0
sc = pysparkling.Context()
ssc = pysparkling.streaming.StreamingContext(sc, 0.1)

(
self.stream_c.textFileStream('LICENS*')
ssc.textFileStream('LICENS*')
.count()
.saveAsTextFiles('tests/textout/', suffix='.gz')
)


class BinaryFile(StreamingTestCase):
class BinaryFile(tornado.testing.AsyncTestCase):

def test_read_file(self):
self.result = 0
self.expect = 1
sc = pysparkling.Context()
ssc = pysparkling.streaming.StreamingContext(sc, 0.1)

result = []
(
self.stream_c.fileBinaryStream('LICENS*', process_all=True)
ssc.fileBinaryStream('LICENS*', process_all=True)
.count()
.foreachRDD(lambda rdd: self.incr_result(rdd.collect()[0]))
.foreachRDD(lambda rdd: result.append(rdd.collect()[0]))
)

ssc.start()
ssc.awaitTermination(timeout=0.3)
self.assertEqual(sum(result), 1)

def test_read_chunks(self):
self.result = 0
self.expect = 28
sc = pysparkling.Context()
ssc = pysparkling.streaming.StreamingContext(sc, 0.1)

result = []
(
self.stream_c.fileBinaryStream('LICENS*', recordLength=40,
process_all=True)
ssc.fileBinaryStream('LICENS*', recordLength=40, process_all=True)
.count()
.foreachRDD(lambda rdd: self.incr_result(rdd.collect()[0]))
.foreachRDD(lambda rdd: result.append(rdd.collect()[0]))
)

ssc.start()
ssc.awaitTermination(timeout=0.3)
self.assertEqual(sum(result), 28)
50 changes: 35 additions & 15 deletions tests/test_streaming_queue.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,54 @@
from __future__ import print_function

from .streaming_test_case import StreamingTestCase
import pysparkling
import tornado.testing


class TestCount(StreamingTestCase):
class TestCount(tornado.testing.AsyncTestCase):

def test_count(self):
self.result = 0
self.expect = 23
sc = pysparkling.Context()
ssc = pysparkling.streaming.StreamingContext(sc, 0.1)

result = []
(
self.stream_c.queueStream([range(20), ['a', 'b'], ['c']])
ssc.queueStream([range(20), ['a', 'b'], ['c']])
.count()
.foreachRDD(lambda rdd: self.incr_result(rdd.collect()[0]))
.foreachRDD(lambda rdd: result.append(rdd.collect()[0]))
)

ssc.start()
ssc.awaitTermination(timeout=0.3)
self.assertEqual(sum(result), 23)

def test_groupByKey(self):
self.result = []
self.expect = [[('a', [2, 5]), ('b', [8])], [('a', [2]), ('b', [3])]]
sc = pysparkling.Context()
ssc = pysparkling.streaming.StreamingContext(sc, 0.1)

result = []
(
self.stream_c.queueStream([[('a', 5), ('b', 8), ('a', 2)],
[('a', 2), ('b', 3)]])
ssc.queueStream([[('a', 5), ('b', 8), ('a', 2)],
[('a', 2), ('b', 3)]])
.groupByKey().mapPartitions(sorted).mapValues(sorted)
.foreachRDD(lambda rdd: self.append_result(rdd.collect()))
.foreachRDD(lambda rdd: result.append(rdd.collect()))
)

ssc.start()
ssc.awaitTermination(timeout=0.25)
self.assertEqual(
result, [[('a', [2, 5]), ('b', [8])], [('a', [2]), ('b', [3])]])

def test_mapValues(self):
self.result = []
self.expect = [[('a', [2, 5, 8]), ('b', [3, 6, 8])]]
sc = pysparkling.Context()
ssc = pysparkling.streaming.StreamingContext(sc, 0.1)

result = []
(
self.stream_c.queueStream([[('a', [5, 8, 2]), ('b', [6, 3, 8])]])
ssc.queueStream([[('a', [5, 8, 2]), ('b', [6, 3, 8])]])
.mapValues(sorted)
.foreachRDD(lambda rdd: self.append_result(rdd.collect()))
.foreachRDD(lambda rdd: result.append(rdd.collect()))
)

ssc.start()
ssc.awaitTermination(timeout=0.15)
self.assertEqual(result, [[('a', [2, 5, 8]), ('b', [3, 6, 8])]])

0 comments on commit 4595a7c

Please sign in to comment.