## Install dependencies (Colab only)

In [None]:
try:
    import google.colab
    GOOGLE_COLAB = True
except ImportError:
    GOOGLE_COLAB = False

In [None]:
if GOOGLE_COLAB:
    !sudo apt-get -yqq install libsnappy-dev

In [None]:
if GOOGLE_COLAB:
    !pip install -q \
        python-snappy bokeh \
        "git+https://github.com/ostrokach/beam.git@e2aa065f2717cfbf0490514cf164b69c0beb0fab#egg=apache_beam[gcp]&subdirectory=sdks/python"

In [None]:
if GOOGLE_COLAB:
    from google.colab import auth
    auth.authenticate_user()

In [None]:
# @title Google Cloud Project Info { display-mode: "form" }
import os
if GOOGLE_COLAB or "PUBSUB_EMULATOR_HOST" not in os.environ:
    os.environ["BEAM_PROJECT_ID"] = "strokach-playground"  # @param {type:"string"}
    os.environ["BEAM_TEMP_LOCATION"] = "gs://strokach/dataflow_temp"  # @param {type:"string"}

## Imports

In [None]:
import atexit
import contextlib
import gc
import itertools
import json
import logging
import math
import os
import sys
import tempfile
import threading
import time
import uuid
from contextlib import ExitStack
from datetime import datetime

import apache_beam as beam
import bokeh
import pytz
from apache_beam.io.gcp.pubsub import PubsubMessage
from apache_beam.options.pipeline_options import (GoogleCloudOptions,
                                                  PipelineOptions)
from apache_beam.runners.interactive.cache_manager import CacheManager
from apache_beam.runners.interactive.caching import streambasedcache
from apache_beam.runners.interactive.display import data_server
from bokeh import plotting
from google.cloud import pubsub_v1

## Boilerplate

In [None]:
sys.argv = sys.argv[:1]
logging.getLogger("werkzeug").setLevel(logging.WARNING)

## Parameters

### Configurable

In [None]:
NOTEBOOK_NAME = "streambasedcache"

In [None]:
LOCAL = True

In [None]:
project_id = os.getenv("BEAM_PROJECT_ID", "test-project")
try:
    temp_location = os.environ["BEAM_TEMP_LOCATION"]
except KeyError:
    _tempporary_directory = tempfile.TemporaryDirectory()
    temp_location = _tempporary_directory.name

### Derived

In [None]:
options = PipelineOptions(
    project=project_id, temp_location=temp_location, streaming=True,
)
options.display_data()

In [None]:
cache_manager = CacheManager(options)

In [None]:
if LOCAL:
    HOST_IP = "localhost"
else:
    HOST_IP = subprocess.check_output(["hostname", "-I"], universal_newlines=True).strip().split()[0]

In [None]:
try:
    %load_ext autoreload
    %autoreload 2
except Exception:
    print("No autoreload")

## Function definitions

### General

In [None]:
def current_time_milliseconds(timezone=pytz.UTC):
    current_time = datetime.utcnow().replace(tzinfo=pytz.UTC).astimezone(timezone).replace(tzinfo=None)
    unix_time = (current_time - datetime.utcfromtimestamp(0)).total_seconds()
    # ReadFromPubSub expects timestamps to be in milliseconds
    unix_time_milliseconds = int(unix_time * 1000)
    return unix_time_milliseconds


current_time_milliseconds(pytz.timezone("US/Pacific"))

In [None]:
def close_all_contexts():
    for obj in gc.get_objects():
        if isinstance(obj, ExitStack):
            print(obj)
            try:
                obj.__exit__(None, None, None)
            except Exception as e:
                print(e)
                
atexit.register(close_all_contexts)

### PubSub-specific

In [None]:
class EventPublisher(threading.Thread):
    def __init__(self, topic_path, time_between_events):
        """
        
        Args:
            time_between_events (float): Seconds
        """
        super(EventPublisher, self).__init__()
        self.topic_path = topic_path
        self.time_between_events = time_between_events
        self._stop_event = threading.Event()

    def run(self):
        pub_client = pubsub_v1.PublisherClient()
        while not self.stopped():
            timestamp = current_time_milliseconds(pytz.timezone("US/Pacific"))
            element = {"ts": timestamp}
            future = pub_client.publish(
                self.topic_path,
                json.dumps(element).encode("utf-8"),
                timestamp=str(timestamp),
            )
            time.sleep(self.time_between_events)

    def stop(self):
        self._stop_event.set()

    def stopped(self):
        return self._stop_event.is_set()

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, *args):
        self.stop()

In [None]:
@contextlib.contextmanager
def create_pubsub_topic(project_id, prefix):
    topic_path = "projects/{}/topics/{}-{}".format(project_id, prefix, uuid.uuid4().hex)
    pub_client = pubsub_v1.PublisherClient()
    pub_client.create_topic(topic_path)
    try:
        yield topic_path
    finally:
        pub_client.delete_topic(topic_path)

In [None]:
@contextlib.contextmanager
def create_pubsub_subscription(topic_path, suffix=""):
    subscription_path = topic_path.replace("/topics/", "/subscriptions/")
    if suffix:
        subscription_path += "-{}".format(suffix)
    sub_client = pubsub_v1.SubscriberClient()
    sub_client.create_subscription(subscription_path, topic_path)
    try:
        yield subscription_path
    finally:
        sub_client.delete_subscription(subscription_path)

In [None]:
@contextlib.contextmanager
def run_pipeline(pipeline):
    result = pipeline.run()
    try:
        yield result
    finally:
        result.cancel()

### Pipeline-specific

In [None]:
def decode_pubsub_message(message):
    data = json.loads(message.data.decode("utf-8"))
    return data

In [None]:
def milliseconds_to_iso(milliseconds, timezone=None):
    import pytz

    tzinfo = pytz.timezone(timezone) if timezone is not None else pytz.UTC
    dt = (
        datetime.utcfromtimestamp(milliseconds / 1000)
        .replace(tzinfo=pytz.UTC)
        .astimezone(tzinfo)
    )
    return dt.isoformat()


milliseconds_to_iso(12)

In [None]:
def custom_sin(x, period_degrees=360):
    import math
    return (math.sin(x / period_degrees * 2 * math.pi))

In [None]:
def custom_cos(x, period_degrees=360):
    import math
    return (math.cos(x / period_degrees * 2 * math.pi))

## Run pipeline

### Reset state

In [None]:
close_all_contexts()

### Start publisher

In [None]:
try:
    publisher_stack.__exit__(None, None, None)
except NameError:
    pass

publisher_stack = ExitStack()

input_topic = publisher_stack.enter_context(create_pubsub_topic(project_id, "event-stream"))
publisher = publisher_stack.enter_context(EventPublisher(input_topic, time_between_events=0.8))

### Read from topic

In [None]:
try:
    pipeline_stack.__exit__(None, None, None)
except NameError:
    pass

In [None]:
try:
    pipeline_stack.__exit__(None, None, None)
except NameError:
    pass

pipeline_stack = ExitStack()

input_subscription = pipeline_stack.enter_context(create_pubsub_subscription(input_topic, uuid.uuid4().hex[:8]))

input_cache = cache_manager.create_cache_from_defaults("input")

p = beam.Pipeline(options=options)
out_pcoll = (
    p
    | "Read" >> beam.io.ReadFromPubSub(subscription=input_subscription, with_attributes=True, timestamp_attribute="ts")
    | "Decode" >> beam.Map(lambda message: json.loads(message.data.decode("utf-8")))
    | "Add timestamp"
    >> beam.Map(lambda e: e.update({"ts_iso": milliseconds_to_iso(e["ts"], timezone="US/Pacific")}) or e)
    | "Write" >> input_cache.writer()
)

pr = pipeline_stack.enter_context(run_pipeline(p))

In [None]:
for element in itertools.islice(input_cache.read(seek_to_start=False, timeout=5), 10):
    print(element)

### Plot a sine wave 

In [None]:
# Start pipeline
try:
    sine_pipeline_stack.__exit__(None, None, None)
except NameError:
    pass

sine_pipeline_stack = ExitStack()

sine_cache = cache_manager.create_cache_from_defaults("sine")

In [None]:
p = beam.Pipeline(options=options)

_ = (
    p
    | "Read" >> input_cache.reader(seek_to_start=False)
    | "Add coords" >> beam.Map(lambda e: e.update({"x": e["ts"], "y": custom_sin((e["ts"]), (100000 / 2))}) or e)
    | "Write" >> sine_cache.writer()
)

pr = sine_pipeline_stack.enter_context(run_pipeline(p))

In [None]:
# Show top elements
for element in itertools.islice(sine_cache.read(seek_to_start=False, timeout=5), 5):
    print(element)

In [None]:
# Start data server
def parse_cache_data(messages):
    for timestamped_value in messages:
        yield (timestamped_value.value["x"], timestamped_value.value["y"])


try:
    sine_plot_stack.__exit__(None, None, None)
except NameError:
    pass

sine_plot_stack = ExitStack()

data_queue = sine_plot_stack.enter_context(sine_cache.read_to_queue(seek_to_start=False))

app = data_server.create_data_publisher_app(data_queue, processors=[parse_cache_data], timeout=5)
sine_data_endpoint = sine_plot_stack.enter_context(
    data_server.ServerThread(app, host=("localhost" if LOCAL else "0.0.0.0"), port=0, threaded=True)
)

In [None]:
# Configure plotting
def generate_plot(data_url):
    from bokeh.models import DatetimeTickFormatter

    adapter = bokeh.models.CustomJS(
        code="""
        const result = {x: [], y: []};
        const pts = cb_data.response;
        for (i=0; i<pts.length; i++) {
            result.x.push(pts[i][0])
            result.y.push(pts[i][1])
        }
        return result;
    """
    )

    source = bokeh.models.AjaxDataSource(
        data_url=data_url, polling_interval=500, adapter=adapter, mode="append"
    )

    p = plotting.figure(
        plot_height=300,
        plot_width=800,
        background_fill_color="lightgrey",
        title="",
        y_range=(-1.1, 1.1),
    )
    p.circle("x", "y", source=source)

    p.x_range.follow = "end"
    p.x_range.follow_interval = 100000

    p.xaxis.major_label_orientation = math.pi / 4
    p.xaxis.formatter = DatetimeTickFormatter(
        milliseconds=["%H:%M:%S"],
        seconds=["%H:%M:%S"],
        minsec=["%H:%M:%S"],
        minutes=["%H:%M:%S"],
    )

    return p

In [None]:
# Create plot
bokeh.io.reset_output()
# bokeh.io.output_file("sines.html")
bokeh.io.output_notebook(hide_banner=True)

data_url = "http://{}:{}/data".format(HOST_IP, sine_data_endpoint.server.port)
plot = generate_plot(data_url)
bokeh.io.show(plot)

In [None]:
# sine_data_endpoint.stop()

### Plot a cosine wave

In [None]:
# Start pipeline
try:
    cosine_pipeline_stack.__exit__(None, None, None)
except NameError:
    pass

cosine_pipeline_stack = ExitStack()

cosine_cache = cache_manager.create_cache_from_defaults("cosine")

In [None]:
p = beam.Pipeline(options=options)

_ = (
    p
    | "Read" >> input_cache.reader(seek_to_start=False)
    | "Add coords" >> beam.Map(lambda e: e.update({"x": e["ts"], "y": custom_cos((e["ts"]), (100000 / 2))}) or e)
    | "Write" >> cosine_cache.writer()
)

pr = cosine_pipeline_stack.enter_context(run_pipeline(p))

In [None]:
# Start data server
try:
    cosine_plot_stack.__exit__(None, None, None)
except NameError:
    pass

cosine_plot_stack = ExitStack()

data_queue = cosine_plot_stack.enter_context(cosine_cache.read_to_queue(seek_to_start=False))

app = data_server.create_data_publisher_app(data_queue, processors=[parse_cache_data], timeout=5)
while True:
    try:
        cosine_data_endpoint = cosine_plot_stack.enter_context(
            data_server.ServerThread(
                app, host=("localhost" if LOCAL else "0.0.0.0"), port=0, threaded=True
            )
        )
        break
    except ValueError as e:
        print(e)

In [None]:
# Configure plotting
def generate_plot(data_url):
    from bokeh.models import DatetimeTickFormatter

    adapter = bokeh.models.CustomJS(
        code="""
        const result = {x: [], y: []};
        const pts = cb_data.response;
        for (i=0; i<pts.length; i++) {
            result.x.push(pts[i][0])
            result.y.push(pts[i][1])
        }
        return result;
    """
    )

    source = bokeh.models.AjaxDataSource(
        data_url=data_url, polling_interval=500, adapter=adapter, mode="append"
    )

    p = plotting.figure(
        plot_height=300,
        plot_width=800,
        background_fill_color="lightgrey",
        title="",
        y_range=(-2.2, 2.2),
    )
    p.circle("x", "y", source=source, color="red")

    p.x_range.follow = "end"
    p.x_range.follow_interval = 100000

    p.xaxis.major_label_orientation = math.pi / 4
    p.xaxis.formatter = DatetimeTickFormatter(
        milliseconds=["%H:%M:%S"],
        seconds=["%H:%M:%S"],
        minsec=["%H:%M:%S"],
        minutes=["%H:%M:%S"],
    )

    return p

In [None]:
# Create plot
bokeh.io.reset_output()
# bokeh.io.output_file("sines.html")
bokeh.io.output_notebook(hide_banner=True)

data_url = "http://{}:{}/data".format(HOST_IP, cosine_data_endpoint.server.port)
plot = generate_plot(data_url)
bokeh.io.show(plot)