### Demo task

Given a sequence of numbers ($0, 1, 2, ... N$) arriving in real time. You need to calculate the sum of this sequence.
Use sateful approach to store and update the current result.

The output should be one number.

**Example**
* Input sequence: `0, 1, 2, 3`.
* Output: `6`


In [1]:
import os
from time import sleep
from pyspark import SparkContext
from pyspark.streaming import StreamingContext

sc = SparkContext(master='local[4]')

NUM_BATCHES = 10  # the amount of numbers in sequence
batches = [sc.parallelize([num]) for num in range(NUM_BATCHES)]

BATCH_TIMEOUT = 5 # Timeout between batch generation
ssc = StreamingContext(sc, BATCH_TIMEOUT)
dstream = ssc.queueStream(rdds=batches)

In [2]:
finished = False
printed = False

def set_ending_flag(rdd):
    global finished
    if rdd.isEmpty():
        finished = True

def print_only_at_the_end(rdd):
    global printed
    if finished and not printed:
        print(rdd.collect()[0])
        printed = True

# If we have received empty rdd, the stream is finished.
# So print the result and stop the context.

dstream.foreachRDD(set_ending_flag)

In [3]:
def aggregator(values, old):
    return (old or 0) + sum(values)

# `updateStateByKey` needs key-value structue so you need to specify fictive key "res"
# and then remove it after aggregation

dstream.map(lambda num: ('res', num))\
    .updateStateByKey(aggregator)\
    .map(lambda x: x[1])\
    .foreachRDD(print_only_at_the_end)

In [4]:
ssc.checkpoint('./checkpoint')  # checkpoint for storing current state  
ssc.start()
while not printed:
    pass
ssc.stop()
sc.stop()

45
