# StreamBasedCache

## Install dependencies

In [None]:
try:
    import google.colab
    GOOGLE_COLAB = True
except ImportError:
    GOOGLE_COLAB = False

In [None]:
if GOOGLE_COLAB:
    !sudo apt-get -yqq install libsnappy-dev

In [None]:
if GOOGLE_COLAB:
    !pip install -q python-snappy Faker pyproj
    !pip install -q -U bokeh

In [None]:
if GOOGLE_COLAB:
    !pip install "git+https://github.com/ostrokach/beam.git@feature/streambasedcache#egg=apache_beam[gcp]&subdirectory=sdks/python"

## Imports

In [None]:
from __future__ import print_function

import copy
import itertools
import json
import logging
import os
import pickle
import shutil
import tempfile
import time
import uuid
from datetime import datetime

import numpy as np
import pandas as pd
import pytz
import requests
import tqdm
from faker import Faker
from google.api_core import exceptions as gexc
from google.cloud import pubsub

import apache_beam as beam
import pyproj
from apache_beam.io.filesystems import FileSystems
from apache_beam.io.gcp.pubsub import PubsubMessage
from apache_beam.options.pipeline_options import GoogleCloudOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
from apache_beam.runners.interactive import caching
from apache_beam.transforms import combiners
from apache_beam.transforms import window
from apache_beam.transforms.ptransform import ptransform_fn
from bokeh.io import output_notebook
from bokeh.io import push_notebook
from bokeh.io import show
from bokeh.layouts import row
from bokeh.models import ColumnDataSource
from bokeh.plotting import figure

In [None]:
pd.set_option("max_columns", 100)

In [None]:
%matplotlib inline

## Parameters

In [None]:
if GOOGLE_COLAB:
    from google.colab import auth
    auth.authenticate_user()

In [None]:
#@title Google Cloud Project Info { display-mode: "form" }
project_id = "strokach-playground" #@param {type:"string"}
gcs_temp_location = "gs://strokach/dataflow_temp" #@param {type:"string"}

In [None]:
NOTEBOOK_NAME = "streambasedcache_chicago"
try:
    os.mkdir(NOTEBOOK_NAME)
except OSError:
    pass

In [None]:
options = PipelineOptions(
    temp_location=gcs_temp_location, streaming=True, project=project_id
)
options.display_data()

In [None]:
try:
    %load_ext autoreload
    %autoreload 2
except Exception:
    print("No autoreload")

## Dataset

### Load data

In [None]:
def download_file(url, outfile):
    local_filename = url.split('/')[-1]
    with requests.get(url, stream=True) as r:
        with open(local_filename, 'wb') as f:
            shutil.copyfileobj(r.raw, f)
    return local_filename

In [None]:
try:
    chicago_taxi_trips_2018_12 = pd.read_csv("chicago_taxi_trips_2018_12.csv.gz")
except OSError:
    local_filename = download_file("https://storage.googleapis.com/strokach/inputs/chicago_taxi_trips_2018_12.csv.gz")
    chicago_taxi_trips_2018_12 = pd.read_csv(local_filename)

In [None]:
try:
    chicago_taxi_trips_2019_01 = pd.read_csv("chicago_taxi_trips_2019_01.csv.gz")
except IOError:
    local_filename = download_file("https://storage.googleapis.com/strokach/inputs/chicago_taxi_trips_2019_01.csv.gz")
    chicago_taxi_trips_2019_01 = pd.read_csv(local_filename)

### Validate data

In [None]:
# Make sure that latilatitude and longitude columns have the same info as location columns
df1 = chicago_taxi_trips_2018_12[
    (~chicago_taxi_trips_2018_12[["pickup_location", "dropoff_location"]].isnull().any(axis=1))
]

df2 = chicago_taxi_trips_2018_12[
    (~chicago_taxi_trips_2018_12[["pickup_latitude", "pickup_longitude", "dropoff_latitude", "dropoff_longitude"]].isnull().any(axis=1))
]

assert (df1.index == df2.index).all()

In [None]:
# Make sure that latilatitude and longitude columns have the same info as location columns
df1 = chicago_taxi_trips_2019_01[
    (~chicago_taxi_trips_2019_01[["pickup_location", "dropoff_location"]].isnull().any(axis=1))
]

df2 = chicago_taxi_trips_2019_01[
    (~chicago_taxi_trips_2019_01[["pickup_latitude", "pickup_longitude", "dropoff_latitude", "dropoff_longitude"]].isnull().any(axis=1))
]

assert (df1.index == df2.index).all()

In [None]:
df1.head(10)

### Create `events_df`

In [None]:
nonull_columns = [
    "pickup_latitude",
    "pickup_longitude",
    "dropoff_latitude",
    "dropoff_longitude",
]

events = []

for i, row in enumerate(
    itertools.chain(
        chicago_taxi_trips_2018_12[
            chicago_taxi_trips_2018_12[nonull_columns].notnull().all(axis=1)
        ].itertuples(),
        chicago_taxi_trips_2019_01[
            chicago_taxi_trips_2019_01[nonull_columns].notnull().all(axis=1)
        ].itertuples(),
    )
):
    start_event = {
        "index": i,
        "event_type": "start",
        "unique_key": row.unique_key,
        "taxi_id": row.taxi_id,
        "timestamp": row.trip_start_timestamp,
        "latitude": row.pickup_latitude,
        "longitude": row.pickup_longitude,
    }

    stop_event = {
        "index": i,
        "event_type": "stop",
        "unique_key": row.unique_key,
        "taxi_id": row.taxi_id,
        "timestamp": row.trip_end_timestamp,
        "latitude": row.dropoff_latitude,
        "longitude": row.dropoff_longitude,
        "trip_seconds": row.trip_seconds,
        "trip_miles": row.trip_miles,
        "trip_total": row.trip_total,
    }

    events.extend([start_event, stop_event])


events_columns = [
    "index",
    "event_type",
    "unique_key",
    "taxi_id",
    "timestamp",
    "latitude",
    "longitude",
    "trip_seconds",
    "trip_miles",
    "trip_total",
]
events_df = pd.DataFrame(events, columns=events_columns)

In [None]:
events_df.head()

In [None]:
from collections import Counter

c = Counter([tuple(ll) for ll in events_df[["latitude", "longitude"]].values])

In [None]:
counts = pd.DataFrame([key + (value,) for key, value in c.items()], columns=["latitude", "longitude", "count"]).sort_values("count")

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.hist(np.clip(counts["count"], 0, 100), range=(0, 100), bins=50)
plt.xlabel("Number of pickups / drop-offs in location\n(Capped at 100)")
plt.ylabel("Number of locations")
plt.title("Chicago - December 2018 / January 2019")

### Add `timestamp_seconds` column

In [None]:
def timestamp_to_seconds(timestamp_str):
    from datetime import datetime
    import pytz

    dt = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S UTC")
    dt = dt.replace(tzinfo=pytz.UTC)  # .astimezone(pytz.timezone('America/Chicago'))
    unix_dt = datetime.utcfromtimestamp(0).replace(tzinfo=pytz.UTC)
    dt_delta = (dt - unix_dt).total_seconds()
    return dt_delta

timestamp_to_seconds("2018-12-06 00:00:00 UTC")

In [None]:
events_df["timestamp_seconds"] = events_df["timestamp"].apply(timestamp_to_seconds)

In [None]:
events_df["utm_x"], events_df["utm_y"] = list(
    zip(
        *[
            geographic_to_utm(*ll)
            for ll in tqdm.tqdm_notebook(
                events_df[["longitude", "latitude"]].values, total=len(events_df)
            )
        ]
    )
)

### Sort events

In [None]:
events_df = events_df.sort_values("timestamp_seconds", ascending=True)

## Functions

In [None]:
def expand_timestamp(timestamp_str):
    from datetime import datetime

    dt = datetime.strptime(timestamp_str, "%Y-%m-%d %H:%M:%S UTC")
    dt = dt.replace(tzinfo=pytz.UTC).astimezone(pytz.timezone('America/Chicago'))
    return dt

expand_timestamp("2018-12-06 00:00:00 UTC")

In [None]:
def geographic_to_utm(longitude, latitude, _cache={}):
    if (longitude, latitude) in _cache:
        return _cache[(longitude, latitude)]

    from pyproj import Proj, transform

    x, y = transform(
        Proj(init="epsg:4326"), Proj(init="epsg:3857"), longitude, latitude
    )

    _cache[(longitude, latitude)] = (x, y)
    return x, y


geographic_to_utm(-87.632746, 41.880994)

## Workflow

In [None]:
from bokeh.io import push_notebook, show, output_notebook
from bokeh.layouts import row
from bokeh.models import ColumnDataSource
from bokeh.plotting import figure, show
from bokeh.tile_providers import Vendors, get_provider
from bokeh.models.annotations import Title

In [None]:
LONGITUDE_RANGE = events_df["longitude"].min(), events_df["longitude"].max()
LATITUDE_RANGE = events_df["latitude"].min(), events_df["latitude"].max()

x_min, y_min = geographic_to_utm(longitude=LONGITUDE_RANGE[0], latitude=LATITUDE_RANGE[0])
x_max, y_max = geographic_to_utm(longitude=LONGITUDE_RANGE[1], latitude=LATITUDE_RANGE[1])

MERCATOR_X_RANGE = (x_min, x_max)
MERCATOR_Y_RANGE = (y_min, y_max)

In [None]:
from bokeh.io import output_notebook
from bokeh.io import push_notebook
from bokeh.io import show
from bokeh.layouts import row
from bokeh.models import ColumnDataSource
from bokeh.models import Label
from bokeh.models import LabelSet
from bokeh.models import Range1d
from bokeh.models.annotations import Title
from bokeh.plotting import figure
from bokeh.plotting import output_file
from bokeh.plotting import show
from bokeh.tile_providers import Vendors
from bokeh.tile_providers import get_provider

output_notebook()

if GOOGLE_COLAB:
    print("Interactive plot does not work on colab yet!")

# range bounds supplied in web mercator coordinates
fig = figure(
    x_range=MERCATOR_X_RANGE,
    y_range=MERCATOR_Y_RANGE,
    x_axis_type="mercator",
    y_axis_type="mercator",
    #     title_location="left"
    #     plot_height=600,
)
fig.add_tile(get_provider(Vendors.CARTODBPOSITRON))

source = ColumnDataSource(data=dict(x=[], y=[]))

fig.circle(x="x", y="y", size=2, fill_color="blue", fill_alpha=0.8, source=source)

handle = show(fig, notebook_handle=True)

for index, gp in events_df.sample(frac=0.1).groupby(["timestamp"]):
    print(index)
    print('0')
    gp = gp.copy()
    print('1')
    gp["utm_x"], gp["utm_y"] = list(
        zip(*[geographic_to_utm(*ll) for ll in gp[["longitude", "latitude"]].values])
    )
    print('a')
    fig.title.text = index
    fig.title.align = "center"
    print('b')
    source.data = {"x": gp["utm_x"].values, "y": gp["utm_y"].values}
    print('c')
    push_notebook(handle=handle)
#     time.sleep(1)
#     break

# while True:
#     for i, row in enumerate(temp.read()):
#         title = Title()
#         title.text = row[0].strftime("%Y-%m-%d %H:%M:%S")
#         fig.title = title
#         source.stream({"x": [e["x"] for e in row[1]], "y": [e["y"] for e in row[1]]})
#         push_notebook(handle=handle)

### Write dataset to cache

In [None]:
sub_client = pubsub.SubscriberClient()

In [None]:
subscription_name = "projects/{}/subscriptions/taxirides-realtime-sub".format(project_id)

try:
    sub_client.create_subscription(
        subscription_name,
        "projects/pubsub-public-data/topics/taxirides-realtime",
    )
except gexc.AlreadyExists:
    sub_client.delete_subscription(subscription_name)
    sub_client.create_subscription(
        subscription_name,
        "projects/pubsub-public-data/topics/taxirides-realtime",
    )

### Process data from subscription to cache

In [None]:
LONGITUDE_RANGE = (-74.747, -73.969)  # (-74.07, -73.90)
LATITUDE_RANGE = (40.699, 40.720)  # (40.73, 40.77)

In [None]:
x_min, y_min = geographic_to_utm(longitude=LONGITUDE_RANGE[0], latitude=LATITUDE_RANGE[0])
x_max, y_max = geographic_to_utm(longitude=LONGITUDE_RANGE[1], latitude=LATITUDE_RANGE[1])

MERCATOR_X_RANGE = (x_min, x_max)
MERCATOR_Y_RANGE = (y_min, y_max)

In [None]:
raise Exception

In [None]:
try:
    p_result.cancel()
except NameError:
    pass

In [None]:
temp = caching.PubSubBasedCache(
    "projects/{}/topics/temp-2".format(project_id), mode="overwrite"
)

In [None]:
class ToList(beam.PTransform):
  """A global CombineFn that condenses a PCollection into a single list."""

  def __init__(self, label='ToList'):  # pylint: disable=useless-super-delegation
    super(ToList, self).__init__(label)

  def expand(self, pcoll):
    return pcoll | self.label >> beam.CombineGlobally(combiners.ToListCombineFn()).without_defaults()


In [None]:
class BuildRecordFn(beam.DoFn):
    def __init__(self):
        super(BuildRecordFn, self).__init__()

    def process(self, elements, window=beam.DoFn.WindowParam):
        # window_start = window.start.to_utc_datetime()
        window_end = window.end.to_utc_datetime()
        return [(window_end, elements)]

In [None]:
p = beam.Pipeline(runner=BundleBasedDirectRunner(), options=options)

out = (
    p
    | "Read" >> beam.io.ReadFromPubSub(subscription=subscription_name, with_attributes=True)
#     | "echo" >> beam.Map(lambda e: print(e) or e)
    | "Decode PubSub message" >> beam.ParDo(DecodeTaxiMessage())
    | "Load JSON" >> beam.Map(load_json)
    | "Filter coords" >> beam.ParDo(SelectWithinGeographicRange(LONGITUDE_RANGE, LATITUDE_RANGE))
    | "Add UTM coords" >> beam.Map(add_mercator_coords)
    | "Window" >> beam.WindowInto(window.FixedWindows(2 * 60))
    | "Combine" >> ToList()
    | 'AddWindowEndTimestamp' >> beam.ParDo(BuildRecordFn())
#     | "echo" >> beam.Map(lambda e: print(e) or e)
    | "Write" >> temp.writer()
)

p_result = p.run()

In [None]:
for row in itertools.islice(temp.read(), 2):
    print(row)

In [None]:
row[0].strftime("%Y-%m-%d %H:%M:%S")

### Interactive dashboard

In [None]:
import logging

logging.getLogger("google.auth._default").setLevel(logging.CRITICAL)

In [None]:
from bokeh.io import push_notebook, show, output_notebook
from bokeh.layouts import row
from bokeh.models import ColumnDataSource
from bokeh.plotting import figure, show
from bokeh.tile_providers import Vendors, get_provider
from bokeh.models.annotations import Title

In [None]:
bokeh.plotting.

In [None]:
output_notebook()

if GOOGLE_COLAB:
    print("Interactive plot does not work on colab yet!")

# range bounds supplied in web mercator coordinates
fig = figure(
    x_range=MERCATOR_X_RANGE,
    y_range=MERCATOR_Y_RANGE,
    x_axis_type="mercator",
    y_axis_type="mercator",
#     plot_height=600,
)
fig.add_tile(get_provider(Vendors.CARTODBPOSITRON))

source = ColumnDataSource(data=dict(x=[], y=[]))

fig.circle(x="x", y="y", size=2, fill_color="blue", fill_alpha=0.8, source=source)

handle = show(fig, notebook_handle=True)

while True:
    for i, row in enumerate(temp.read()):
        title = Title()
        title.text = row[0].strftime("%Y-%m-%d %H:%M:%S")
        fig.title = title
        source.stream({"x": [e["x"] for e in row[1]], "y": [e["y"] for e in row[1]]})
        push_notebook(handle=handle)