# StreamBasedCache Demo - New York Taxi Rides

## Install dependencies

In [None]:
try:
    import google.colab
    GOOGLE_COLAB = True
except ImportError:
    GOOGLE_COLAB = False

In [None]:
if GOOGLE_COLAB:
    !sudo apt-get -yqq install libsnappy-dev

In [None]:
if GOOGLE_COLAB:
    !pip install -q python-snappy Faker pyproj
    !pip install -q -U bokeh

In [None]:
if GOOGLE_COLAB:
    !pip install "git+https://github.com/ostrokach/beam.git@feature/streambasedcache#egg=apache_beam[gcp]&subdirectory=sdks/python"

## Imports

In [None]:
from __future__ import print_function

import copy
import itertools
import json
import logging
import os
import pickle
import shutil
import tempfile
import time
import uuid
from collections import Counter, OrderedDict
from datetime import datetime

import apache_beam as beam
import bokeh
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytz
import requests
import tqdm
from apache_beam.io.filesystems import FileSystems
from apache_beam.io.gcp.pubsub import PubsubMessage
from apache_beam.options.pipeline_options import (GoogleCloudOptions,
                                                  PipelineOptions)
from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
from apache_beam.runners.interactive import caching
from apache_beam.runners.interactive.caching import streambasedcache
from apache_beam.transforms import combiners, window
from apache_beam.transforms.ptransform import ptransform_fn
from bokeh.core.properties import value
from bokeh.io import output_notebook, push_notebook, show
from bokeh.layouts import row
from bokeh.models import (ColumnDataSource, Label, LabelSet, Legend,
                          LegendItem, Range1d)
from bokeh.models.annotations import Title
from bokeh.plotting import figure, output_file, show
from bokeh.tile_providers import Vendors, get_provider
from bokeh.transform import factor_cmap, factor_mark
from faker import Faker
from google.api_core import exceptions as gexc
from google.cloud import pubsub

In [None]:
pd.set_option("max_columns", 100)

In [None]:
%matplotlib inline

## Parameters

In [None]:
if GOOGLE_COLAB:
    from google.colab import auth
    auth.authenticate_user()

In [None]:
#@title Google Cloud Project Info { display-mode: "form" }
project_id = "strokach-playground" #@param {type:"string"}
gcs_temp_location = "gs://strokach/dataflow_temp" #@param {type:"string"}

In [None]:
NOTEBOOK_NAME = "streambasedcache-new_york_taxirides_from_file"
try:
    os.mkdir(NOTEBOOK_NAME)
except OSError:
    pass

In [None]:
options = PipelineOptions(
    temp_location=gcs_temp_location, streaming=True, project=project_id
)
options.display_data()

In [None]:
try:
    %load_ext autoreload
    %autoreload 2
except Exception:
    print("No autoreload")

## Load data

In [None]:
input_file = "new_york_taxirides_to_events/new-york-taxi-events.parquet"

In [None]:
events_df = pq.read_table(input_file).to_pandas(integer_object_nulls=True)
events_df.head(2)

## Workflow

In [None]:
MERCATOR_X_RANGE = (-8240000, -8220000)
MERCATOR_Y_RANGE = (4950000, 5000000)

### Basic bokeh plot using pandas

In [None]:
seconds_per_batch = 15 * 60  # 15 minutes
events_df["group_id"] = (
    (events_df["timestamp_milliseconds"] / 1000)
    // seconds_per_batch
    * seconds_per_batch
).astype(np.int)

In [None]:
if GOOGLE_COLAB:
    print("Interactive plot does not work on colab yet!")


def create_map():
    # Colormap
    # cmap = bokeh.palettes.d3["Category10"][4]
    cmap = bokeh.palettes.d3["Category20b"][20]
    colors = [cmap[1], cmap[-2]]

    # Source of data
    source = ColumnDataSource(data=dict(x=[], y=[], event_type=[]))

    # Background map
    fg = figure(
        x_range=MERCATOR_X_RANGE,
        y_range=MERCATOR_Y_RANGE,
        x_axis_type="mercator",
        y_axis_type="mercator",
        title_location="above",
        plot_height=600,
    )
    fg.add_tile(get_provider(Vendors.CARTODBPOSITRON))

    # Scatterplot
    fg.circle(
        x="x",
        y="y",
        source=source,
        size=2,
        color=factor_cmap("event_type", colors, ["start", "stop"]),
        fill_alpha=0.8,
        #     legend=value("start", "stop"),
        #     legend=[value(x) for x in ["start", "stop"]],
    )

    # Legend
    start = fg.circle(x=[], y=[], color=colors[0])
    stop = fg.circle(x=[], y=[], color=colors[1])
    legend = Legend(items=[("start", [start]), ("stop", [stop])])
    fg.add_layout(legend)
    return fg, source

In [None]:
output_notebook()

fg, source = create_map()

# Updates
handle = show(fg, notebook_handle=True)

days_of_week = [
    "Monday",
    "Tuesday",
    "Wednesday",
    "Thursday",
    "Friday",
    "Saturday",
    "Sunday",
]
for index, gp in events_df.sample(frac=0.1).groupby(["group_id"]):
    dt = datetime.utcfromtimestamp(index).replace(
        tzinfo=pytz.UTC
    )  # .astimezone(pytz.timezone('US/Eastern'))
    dt_str = days_of_week[dt.weekday()] + " " + dt.strftime("%b %d %Y %I:%M:%S %p")
    fg.title.text = dt_str
    fg.title.align = "center"
    source.data = {
        "x": gp["utm_x"].values,
        "y": gp["utm_y"].values,
        "event_type": gp["event_type"].values,
    }
    push_notebook(handle=handle)
    time.sleep(0.2)
    break

### Write dataset to cache

In [None]:
events_df_sample = events_df.sample(frac=0.1).sort_values("timestamp_milliseconds", ascending=True).head(100)
events_df_sample.head(5)

In [None]:
def create_cache(location, cache_class, *args, **kwargs):
    for _ in range(3):
        full_location = "{}-{}".format(location, uuid.uuid4().hex)
        try:
            return cache_class(full_location, *args, **kwargs)
        except IOError as e:
            pass
    raise e

In [None]:
cache = create_cache(
    "projects/{}/topics/input".format(project_id),
    streambasedcache.PubSubBasedCache,
#     with_attributes=["timestamp_milliseconds"],
#     timestamp_attribute="timestamp_milliseconds",
)

cache.location

In [None]:
!cat /tmp/pipeline-gc-test3-00000-of-00001

In [None]:
cache.write((c for c in string.ascii_letters))

In [None]:
string.ascii_letters

In [None]:
options = PipelineOptions(
    temp_location=gcs_temp_location, streaming=True, project=project_id,
    runner="DirectRunner",
    runner="DataflowRunner",
    sdk_location=os.path.expanduser(
        "~/workspace/beam/sdks/python/dist/apache-beam-2.15.0.dev0.tar.gz"
    ),
    setup_file="../setup.py"
)
options.display_data()

In [None]:
options.display_data()

In [None]:
class FormatDoFn(beam.DoFn):
    def process(self, element, window=beam.DoFn.WindowParam):
        import pytz

#         ts_format = "%Y-%m-%dT%H:%M:%S.%f-04:00"
#         window_start = (
#             window.start.to_utc_datetime()
#             .replace(tzinfo=pytz.UTC)
#             .astimezone(pytz.timezone("US/Eastern"))
#             .strftime(ts_format)
#         )
#         window_end = (
#             window.end.to_utc_datetime()
#             .replace(tzinfo=pytz.UTC)
#             .astimezone(pytz.timezone("US/Eastern"))
#             .strftime(ts_format)
#         )
        yield {
            "events": element,
#             "window_start": window_start,
#             "window_end": window_end,
            "window_start_milliseconds": int(window.start.micros / 1000),
            "window_end_milliseconds": int(window.end.micros / 1000)
        }


# next(FormatDoFn().process({}))

In [None]:
print("hello")

In [None]:
!ls /tmp/pipeline-gc-test3*

In [None]:
!cat /tmp/pipeline-gc-test3-00000-of-00001

In [None]:
from apache_beam.utils.timestamp import Timestamp

In [None]:
class IndexAssigningStatefulDoFn(beam.DoFn):
    INDEX_STATE = beam.transforms.userstate.CombiningValueStateSpec("index", sum)

    def process(self, element, index=beam.DoFn.StateParam(INDEX_STATE)):
        unused_key, value = element
#         value = element
        current_index = index.read()
        yield (value, current_index)
        index.add(1)

In [None]:
import string

In [None]:
reversed_letters = string.ascii_letters[::-1]

In [None]:
datetime.now().strftime("%s")

In [None]:
with beam.Pipeline(options=options) as p:
    pcoll = (
        p
        | cache.reader()
        | beam.Map(lambda e: time.sleep(0.5) or beam.window.TimestampedValue(e, Timestamp(seconds=int(reversed_letters = string.ascii_letters[::-1]))))
#         | beam.Map(lambda e: beam.window.TimestampedValue(e, Timestamp(seconds=int(e))))
        | beam.Map(lambda e: (0, e))
        | beam.ParDo(IndexAssigningStatefulDoFn())
        | beam.WindowInto(window.FixedWindows(0.5))
        | beam.ParDo(FormatDoFn())
        | beam.Map(lambda e: print(e) or e)
#         | beam.ParDo(AddTimestampDoFn())
        | beam.io.WriteToText(os.path.join("/tmp", "pipeline-gc-test3"))
    )

In [None]:
list(cache.read(timeout=1))

In [None]:
print("Asdf")

In [None]:
cache.write((tup._asdict() for tup in events_df_sample.itertuples()))

In [None]:
cache_source = cache.read(return_timestamp=True)
time.sleep(5)
[m for m in itertools.islice(cache_source, 5) if not time.sleep(1)]
del cache_source

### Process data from subscription to cache

In [None]:
class Limit(beam.PTransform):
    def __init__(self, num_elements=1000):
        self.num_elements = num_elements

    def expand(self, pcoll):
        from apache_beam import transforms
        from apache_beam.transforms import combiners

        return (
            pcoll
            | combiners.Sample.FixedSizeGlobally(self.num_elements)
            | transforms.FlatMap(lambda lst: [e for e in lst])
        )

In [None]:
class ToList(beam.PTransform):
    """A global CombineFn that condenses a PCollection into a single list."""

    def __init__(self, label="ToList"):  # pylint: disable=useless-super-delegation
        super(ToList, self).__init__(label)

    def expand(self, pcoll):
        return (
            pcoll
            | self.label
            >> beam.CombineGlobally(combiners.ToListCombineFn()).without_defaults()
        )

In [None]:
class BuildRecordFn(beam.DoFn):
    def __init__(self):
        super(BuildRecordFn, self).__init__()

    def process(self, elements, window=beam.DoFn.WindowParam):
        # window_start = window.start.to_utc_datetime()
        window_end = window.end.to_utc_datetime()
        return [(window_end, elements)]

In [None]:
class FormatMessage(beam.DoFn):
    def process(self, element):
        from apache_beam.utils import timestamp
        from apache_beam.transforms import window

        data = element.data
        ts = timestamp.Timestamp(micros=int(data["timestamp_milliseconds"]) * 1000)
        yield beam.window.TimestampedValue(data, ts)

In [None]:
class AddTimestampAttribute(beam.DoFn):
    def process(self, element):
        from apache_beam.io.gcp.pubsub import PubsubMessage

        message = PubsubMessage(
            data=element,
            attributes={"timestamp_milliseconds": str(element[0]["timestamp_milliseconds"])},
        )
        yield message

In [None]:
class FormatDoFn(beam.DoFn):
    def process(self, element, window=beam.DoFn.WindowParam):
        import pytz

        ts_format = "%Y-%m-%dT%H:%M:%S.%f-04:00"
        window_start = (
            window.start.to_utc_datetime()
            .replace(tzinfo=pytz.UTC)
            .astimezone(pytz.timezone("US/Eastern"))
            .strftime(ts_format)
        )
        window_end = (
            window.end.to_utc_datetime()
            .replace(tzinfo=pytz.UTC)
            .astimezone(pytz.timezone("US/Eastern"))
            .strftime(ts_format)
        )
        yield {
            "events": element,
            "window_start": window_start,
            "window_end": window_end,
            "timestamp_milliseconds": element[0]["timestamp_milliseconds"],
            "window_end_milliseconds": int(window.end.micros / 1000)
        }


# next(FormatDoFn().process({}))

In [None]:
class AddWindowRange(beam.DoFn):
    def process(self, element, window=beam.DoFn.WindowParam):
        import pytz

        ts_format = "%Y-%m-%dT%H:%M:%S.%f-04:00"
        element["window_start_est"] = (
            window.start.to_utc_datetime()
            .replace(tzinfo=pytz.UTC)
            .astimezone(pytz.timezone("US/Eastern"))
            .strftime(ts_format)
        )
        element["window_end_est"] = (
            window.end.to_utc_datetime()
            .replace(tzinfo=pytz.UTC)
            .astimezone(pytz.timezone("US/Eastern"))
            .strftime(ts_format)
        )
        element["window_start_milliseconds"] = int(window.start.micros / 1000)
        element["window_end_milliseconds"] = int(window.end.micros / 1000)
        yield element

In [None]:
try:
    pr.cancel()
except NameError:
    pass

temp = create_cache(
    "projects/{}/topics/temp".format(project_id),
    streambasedcache.PubSubBasedCache,
    with_attributes=["timestamp_milliseconds"],
    timestamp_attribute="timestamp_milliseconds",
)

p = beam.Pipeline(runner=BundleBasedDirectRunner(), options=options)

head = ()
input = cache

out_pcoll = (
    p
    | "Read" >> input.reader()  # Ideally, we could limit input a the reader level
    #     | "Limit" >> Limit(100)  # Does not work?!!!
    | "Extract data" >> beam.ParDo(FormatMessage())
    | "Window" >> beam.WindowInto(window.FixedWindows(15 * 60))
    | "Add window info" >> beam.ParDo(AddWindowRange())
    | "Pair with end of window" >> beam.Map(lambda e: (e["window_end_milliseconds"], e))
    | "Group by end of window" >> beam.GroupByKey()
    | "Reduce" >> beam.Map(lambda e: {"timestamp_milliseconds": e[0], "events": e[1]})
    | "Write" >> temp.writer()
)
# Explicitly specify dependencies so that cache is not automatically garbage collected
# out._deps = [cache]

pr = p.run()

In [None]:
# data_source  = temp.read(timeout=10, return_timestamp=True)
# out = list(itertools.islice(data_source, 3))
# out

### Interactive dashboard

In [None]:
fg, source = create_map()
fg.title.text = "-"
fg.title.align = "center"

# Updates
handle = show(fg, notebook_handle=True)

days_of_week = [
    "Monday",
    "Tuesday",
    "Wednesday",
    "Thursday",
    "Friday",
    "Saturday",
    "Sunday",
]

data_source = temp.read(timeout=5, with_attributes=["timestamp_milliseconds"], return_timestamp=True)
time.sleep(2)

for timestamp, message in data_source:
    data = message.data["events"]
    dt = datetime.utcfromtimestamp(timestamp).replace(
        tzinfo=pytz.UTC
    )  # .astimezone(pytz.timezone('US/Eastern'))
    dt_str = days_of_week[dt.weekday()] + " " + dt.strftime("%b %d %Y %I:%M:%S %f")
    fg.title.text = dt_str
    fg.title.align = "center"
    source.data = {
        "x": [d["utm_x"] for d in data],
        "y": [d["utm_y"] for d in data],
        "event_type": [d["event_type"] for d in data],
    }
    push_notebook(handle=handle)
    time.sleep(0.2)

In [None]:
fg, source = create_map()
fg.title.text = "-"
fg.title.align = "center"

# Updates
handle = show(fg, notebook_handle=True)

days_of_week = [
    "Monday",
    "Tuesday",
    "Wednesday",
    "Thursday",
    "Friday",
    "Saturday",
    "Sunday",
]

data_source = temp.read(timeout=5, return_timestamp=True)
time.sleep(5)

current_batch = {"timestamp": None, "utm_x": [], "utm_y": [], "event_type": []}
for timestamp, message in data_source:
    data = message.data["events"]
    if current_batch["timestamp"] is None or current_batch["timestamp"] == timestamp:
        current_batch["timestamp"] = timestamp
        current_batch["utm_x"] += [d["utm_x"] for d in data]
        current_batch["utm_y"] += [d["utm_y"] for d in data]
        current_batch["event_type"] += [d["event_type"] for d in data]
        continue
    elif current_batch["timestamp"] > timestamp:
        continue
    else:
        dt = datetime.utcfromtimestamp(timestamp).replace(
            tzinfo=pytz.UTC
        )  # .astimezone(pytz.timezone('US/Eastern'))
        dt_str = days_of_week[dt.weekday()] + " " + dt.strftime("%b %d %Y %I:%M:%S %f")
        fg.title.text = dt_str
        fg.title.align = "center"
        source.data = {
            "x": current_batch["utm_x"],
            "y": current_batch["utm_y"],
            "event_type": current_batch["event_type"],
        }
        push_notebook(handle=handle)
        time.sleep(0.2)
        current_batch = {
            "timestamp": timestamp,
            "utm_x": [d["utm_x"] for d in data],
            "utm_y": [d["utm_y"] for d in data],
            "event_type": [d["event_type"] for d in data],
        }
    
    

In [None]:
output_notebook()

fg, source = create_map()

# Updates
handle = show(fg, notebook_handle=True)

days_of_week = [
    "Monday",
    "Tuesday",
    "Wednesday",
    "Thursday",
    "Friday",
    "Saturday",
    "Sunday",
]

data_source = temp.read(return_timestamp=False)
time.sleep(5)
for message in itertools.islice(data_source, 100):
    data = message.data.data
    index = data[0]["group_id"]
    dt = datetime.utcfromtimestamp(index).replace(
        tzinfo=pytz.UTC
    )  # .astimezone(pytz.timezone('US/Eastern'))
    dt_str = days_of_week[dt.weekday()] + " " + dt.strftime("%b %d %Y %I:%M:%S %p")
    fg.title.text = dt_str
    fg.title.align = "center"
    source.data = {
        "x": [d["utm_x"] for d in data],
        "y": [d["utm_y"] for d in data],
        "event_type": [d["event_type"] for d in data],
    }
    push_notebook(handle=handle)
    time.sleep(0.2)