# StreamBasedCache Demo - New York Taxi Rides

## Install dependencies (Colab only)

In [None]:
try:
    import google.colab
    GOOGLE_COLAB = True
except ImportError:
    GOOGLE_COLAB = False

In [None]:
if GOOGLE_COLAB:
    !sudo apt-get -yqq install libsnappy-dev

In [None]:
if GOOGLE_COLAB:
    !pip install -q \
        python-snappy bokeh \
        "git+https://github.com/ostrokach/beam.git@e2aa065f2717cfbf0490514cf164b69c0beb0fab#egg=apache_beam[gcp]&subdirectory=sdks/python"

In [None]:
if GOOGLE_COLAB:
    from google.colab import auth
    auth.authenticate_user()

In [None]:
# @title Google Cloud Project Info { display-mode: "form" }
import os
if GOOGLE_COLAB or "PUBSUB_EMULATOR_HOST" not in os.environ:
    os.environ["BEAM_PROJECT_ID"] = "strokach-beam-dev"  # @param {type:"string"}
    os.environ["BEAM_TEMP_LOCATION"] = "gs://strokach-beam-dev/dataflow-temp"  # @param {type:"string"}

## Imports

In [None]:
import atexit
import contextlib
import gc
import itertools
import json
import logging
import os
import sys
import time
import uuid
from contextlib import ExitStack
from datetime import datetime
from pathlib import Path

import apache_beam as beam
import bokeh
import dateutil
import pytz
from apache_beam import transforms
from apache_beam.io.gcp.pubsub import PubsubMessage
from apache_beam.options.pipeline_options import GoogleCloudOptions, PipelineOptions
from apache_beam.runners.direct.direct_runner import DirectRunner
from apache_beam.runners.interactive.cache_manager import CacheManager
from apache_beam.runners.interactive.caching import pubsub_utils, streambasedcache
from apache_beam.transforms import combiners, window
from bokeh.core.properties import value
from bokeh.io import output_notebook, push_notebook, show
from bokeh.layouts import row
from bokeh.models import ColumnDataSource, Label, LabelSet, Legend, LegendItem, Range1d
from bokeh.models.annotations import Title
from bokeh.plotting import figure, show
from bokeh.tile_providers import Vendors, get_provider
from bokeh.transform import factor_cmap
from google.api_core import exceptions as gexc
from google.cloud import pubsub_v1

In [None]:
output_notebook()

In [None]:
try:
    %load_ext autoreload
    %autoreload 2
except Exception:
    print("No autoreload")

## Boilerplate

In [None]:
sys.argv = sys.argv[:1]
logging.getLogger("werkzeug").setLevel(logging.WARNING)

## Parameters

### Configurable

In [None]:
NOTEBOOK_NAME = "streambasedcache-new_york_taxirides"
NOTEBOOK_PATH = Path(NOTEBOOK_NAME)
NOTEBOOK_PATH.mkdir(exist_ok=True)

In [None]:
project_id = os.getenv("BEAM_PROJECT_ID", "test-project")
try:
    temp_location = os.environ["BEAM_TEMP_LOCATION"]
except KeyError:
    _tempporary_directory = tempfile.TemporaryDirectory()
    temp_location = _tempporary_directory.name

In [None]:
taxirides_topic_path = "projects/pubsub-public-data/topics/taxirides-realtime"

### Derived

In [None]:
def download_url(url, folder):
    filename = url.rsplit("/", 1)[-1]
    folder = Path(folder)
    folder.mkdir(exist_ok=True)

    if url.startswith("file://") or url.startswith("/"):
        shutil.copy(url.replace("file://", ""), folder)
    else:
        chunk_size = 16 * 1024
        response = urlopen(url)
        with (folder / filename).open("wb") as f:
            while True:
                chunk = response.read(chunk_size)
                if not chunk:
                    break
                f.write(chunk)

In [None]:
sdk_location = Path("../dist/apache-beam-2.16.0.dev0.tar.gz").resolve()
sdk_location.parent.mkdir(exist_ok=True)

sdk_url = "https://raw.githubusercontent.com/ostrokach/beam-notebooks/1bd1de8eb3b9dc59f76272819e79a07bb42944f2/dist/apache-beam-2.16.0.dev0.tar.gz"
if not sdk_location.is_file():
    download_url(sdk_url, sdk_location.parent)
assert sdk_location.is_file()

In [None]:
options = PipelineOptions(
    project=project_id,
    temp_location=temp_location,
    streaming=True,
    #     runner="DirectRunner",
    runner="DataflowRunner",
    sdk_location=sdk_location.as_posix(),
    setup_file="../setup.py",
    job_name="test-demo",
)
options.display_data()

In [None]:
cache_manager = CacheManager(options)

## Function definitions

### Pipeline-specific

In [None]:
class AddMercatorCoords(beam.DoFn):
    
    def process(self, element):
        import numpy as np

        r_major = 6378137.000

        element["utm_x"] = (r_major * 2 * np.pi / 360) * element["longitude"]
        try:
            scale = element["utm_x"] / element["longitude"]
        except ZeroDivisionError:
            scale = 0
        element["utm_y"] = (
            180.0 / np.pi * np.log(np.tan((np.pi / 4.0) + element["latitude"] * (np.pi / 180.0 / 2.0))) * scale
        )
        yield element
#         events_df["utm_y"] = events_df["utm_y"].fillna(0)
        
        
next(AddMercatorCoords().process({"longitude": 0, "latitude": 0}))

In [None]:
class FilterByRegion(beam.DoFn):
    def __init__(self, utm_x_range, utm_y_range):
        self.utm_x_range = utm_x_range
        self.utm_y_range = utm_y_range

    def process(self, element):
        if (self.utm_x_range[0] <= element["utm_x"] < self.utm_x_range[1]) and (
            self.utm_y_range[0] <= element["utm_y"] < self.utm_y_range[1]
        ):
            yield element

In [None]:
class FilterRideStatus(beam.DoFn):
    def __init__(self, ride_status):
        self._ride_status = ride_status
        super(FilterRideStatus, self).__init__()

    def process(self, element):
        if element["ride_status"] in self._ride_status:
            yield element

In [None]:
class AddWindowRange(beam.DoFn):
    def process(
        self, element, window=beam.DoFn.WindowParam, timestamp=beam.DoFn.TimestampParam
    ):
        import pytz

        element = {"events": element}

        ts_format = "%Y-%m-%dT%H:%M:%S.%f-04:00"
        element["window_start_est"] = (
            window.start.to_utc_datetime()
            .replace(tzinfo=pytz.UTC)
            .astimezone(pytz.timezone("US/Eastern"))
            .strftime(ts_format)
        )
        element["window_end_est"] = (
            window.end.to_utc_datetime()
            .replace(tzinfo=pytz.UTC)
            .astimezone(pytz.timezone("US/Eastern"))
            .strftime(ts_format)
        )
        element["window_start_micros"] = window.start.micros
        element["window_end_micros"] = window.end.micros
        element["timestamp_micros"] = timestamp.micros
        yield element

In [None]:
class DumpPubsubMessage(beam.DoFn):
    def __init__(self):
        pass

    def process(
        self, element, window=beam.DoFn.WindowParam, timestamp=beam.DoFn.TimestampParam
    ):
        from apache_beam.io.gcp.pubsub import PubsubMessage

        data = element
        attributes = {"ts": str(int(timestamp.micros / 1000.0))}

        message = PubsubMessage(json.dumps(data).encode("utf-8"), attributes)
        yield message

### Plotting

In [None]:
def create_map():
    # Colormap
    # cmap = bokeh.palettes.d3["Category10"][4]
    cmap = bokeh.palettes.d3["Category20b"][20]
    colors = [cmap[1], cmap[9], cmap[-2]]

    # Source of data
    source = ColumnDataSource(data=dict(x=[], y=[], ride_status=[]))

    # Background map
    fg = figure(
        x_range=MERCATOR_X_RANGE,
        y_range=MERCATOR_Y_RANGE,
        x_axis_type="mercator",
        y_axis_type="mercator",
        title_location="above",
        plot_height=600,
    )
    fg.add_tile(get_provider(Vendors.CARTODBPOSITRON))

    # Scatterplot
    fg.circle(
        x="x",
        y="y",
        source=source,
        size=2,
        color=factor_cmap("ride_status", colors, ["pickup", "enroute", "dropoff"]),
        fill_alpha=0.8,
        #     legend=value("start", "stop"),
        #     legend=[value(x) for x in ["start", "stop"]],
    )

    # Legend
    pickup = fg.circle(x=[], y=[], color=colors[0])
    enroute = fg.circle(x=[], y=[], color=colors[1])
    dropoff = fg.circle(x=[], y=[], color=colors[2])
    legend = Legend(
        items=[("pickup", [pickup]), ("enroute", [enroute]), ("dropoff", [dropoff])]
    )
    fg.add_layout(legend)
    return fg, source

## Pipeline

In [None]:
MERCATOR_X_RANGE = (-8240000, -8220000)
MERCATOR_Y_RANGE = (4950000, 5000000)

In [None]:
input_subscription = pubsub_utils.TemporaryPubsubSubscription(project_id, taxirides_topic_path)

In [None]:
temp_cache = cache_manager.create_default_cache("temp")

In [None]:
try:
    pr.cancel()
except NameError:
    pass

In [None]:
p = beam.Pipeline(options=options)

out = (
    p
    | "Read"
    >> beam.io.ReadFromPubSub(
        subscription=input_subscription.name, with_attributes=True, timestamp_attribute="ts"
    )
    | beam.Map(lambda message: json.loads(message.data.decode()))
    | "Add Mercator coords" >> beam.ParDo(AddMercatorCoords())
    | "Filter to New York" >> beam.ParDo(FilterByRegion(MERCATOR_X_RANGE, MERCATOR_Y_RANGE))
    | "Subsample" >> beam.Filter(lambda e: e["ride_id"][0] == "a")
    | "Window" >> beam.WindowInto(window.FixedWindows(1))
    | "Combine" >> beam.CombineGlobally(combiners.ToListCombineFn()).without_defaults()
    | "Add window info" >> beam.ParDo(AddWindowRange())
    #     | beam.Map(lambda e: print(e) or e)
    | "Write" >> temp_cache.writer()
)

pr = p.run()

In [None]:
for element in itertools.islice(temp_cache.read(seek_to_start=False, timeout=10), 5):
    print(element.value)

In [None]:
output_notebook()

fg, source = create_map()
fg.title.text = "-"
fg.title.align = "center"

# Updates
handle = show(fg, notebook_handle=True)

days_of_week = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]

for timestamped_value in temp_cache.read(burnin=10, seek_to_start=True, timeout=100):
    timestamp = timestamped_value.timestamp
    data = timestamped_value.value["events"]
    dt = (
        datetime.utcfromtimestamp(timestamp.micros / 1000.0 / 1000.0 / 1000.0 + 0.001)
        .replace(tzinfo=pytz.UTC)
        .astimezone(pytz.timezone("US/Eastern"))
    )
    dt_str = days_of_week[dt.weekday()] + " " + dt.strftime("%b %d %Y %I:%M:%S %f")
    fg.title.text = dt_str
    fg.title.align = "center"
    source.stream(
        {
            "x": [d["utm_x"] for d in data],
            "y": [d["utm_y"] for d in data],
            "ride_status": [d["ride_status"] for d in data],
        },
        rollover=100,
    )
    push_notebook(handle=handle)

    time.sleep(0.05)