# StreamBasedCache

## Install dependencies

In [1]:
try:
    import google.colab
    GOOGLE_COLAB = True
except ImportError:
    GOOGLE_COLAB = False

In [2]:
if GOOGLE_COLAB:
    !sudo apt-get -yqq install libsnappy-dev

In [3]:
if GOOGLE_COLAB:
    !pip install -q python-snappy Faker

In [4]:
if GOOGLE_COLAB:
    !pip install "git+https://github.com/ostrokach/beam.git@feature/streambasedcache#egg=apache_beam[gcp]&subdirectory=sdks/python"

## Imports

In [5]:
import copy
import itertools
import json
import logging
import os
import pickle
import tempfile
import uuid

import apache_beam as beam
import numpy as np
import pandas as pd
import tqdm
from apache_beam.io.filesystems import FileSystems
from apache_beam.io.gcp.pubsub import PubsubMessage
from apache_beam.options.pipeline_options import GoogleCloudOptions, PipelineOptions
from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
from apache_beam.runners.interactive import caching
from apache_beam.transforms.ptransform import ptransform_fn
from faker import Faker
from google.api_core import exceptions as gexc
from google.cloud import pubsub

import pyproj
from bokeh.io import output_notebook, push_notebook, show
from bokeh.layouts import row
from bokeh.models import ColumnDataSource
from bokeh.plotting import figure

## Parameters

In [6]:
if GOOGLE_COLAB:
    from google.colab import auth
    auth.authenticate_user()

In [7]:
#@title Google Cloud Project Info { display-mode: "form" }
project_id = "strokach-playground" #@param {type:"string"}
gcs_temp_location = "gs://strokach/dataflow_temp" #@param {type:"string"}

In [8]:
options = PipelineOptions(
    temp_location=gcs_temp_location, streaming=True, project=project_id
)
options.display_data()



{'project': 'strokach-playground',
 'streaming': True,
 'temp_location': 'gs://strokach/dataflow_temp'}

In [9]:
try:
    %load_ext autoreload
    %autoreload 2
except Exception:
    print("No autoreload")

## Functions

### Helper functions

In [10]:
class AverageFn(beam.CombineFn):
    def create_accumulator(self):
        return (0.0, 0.0, 0)

    def add_input(self, sum_count, input):
        from datetime import datetime

        (passenger_count_sum, timepoint_sum, count) = sum_count

        try:
            timestamp = datetime.strptime(
                input["timestamp"], "%Y-%m-%dT%H:%M:%S.%f-04:00"
            )
        except ValueError:
            timestamp = datetime.strptime(input["timestamp"], "%Y-%m-%dT%H:%M:%S-04:00")

        passenger_count_sum += input["passenger_count"]
        timepoint_sum += float(timestamp.strftime("%s"))
        count += 1
        return passenger_count_sum, timepoint_sum, count

    def merge_accumulators(self, accumulators):
        passenger_count_sums, timepoint_sums, counts = zip(*accumulators)
        return sum(passenger_count_sums), sum(timepoint_sums), sum(counts)

    def extract_output(self, sum_count):
        (passenger_count_sum, timepoint_sum, count) = sum_count
        passenger_count_avg = passenger_count_sum / count if count else float("NaN")
        timepoint_avg = timepoint_sum / count if count else float("NaN")
        return passenger_count_avg, timepoint_avg

In [11]:
def update_plot(values):
    x_lst = []
    y_lst = []
    count = 0
    for element in values:
        try:
            x_lst.append(element["x"])
            y_lst.append(element["y"])
            count += 1
            if count > 10:
                break
        except (KeyError, TypeError):
            output.append(element)
    source.stream({"x": x_lst, "y": y_lst})
    push_notebook(handle=t)

In [12]:
def increment_counter(element):
    global counter
    counter += 1
    return element

In [13]:
def tee_to_output(element):
    output.append(element)
    return element

In [14]:
def load_json(element):
#     from past.builtins import unicode

    ju = json.loads(element)
    js = {}
    for k, v in ju.items():
        if isinstance(k, unicode):
            k = str(k)
        if isinstance(v, unicode):
            v = str(v)
        js[k] = v
    return js


assert load_json(u'{"a": 10, "b": "20"}') == {"a": 10, "b": "20"}

In [15]:
def dump_json(element):
    element_str = json.dumps(element).encode("utf-8")
    return element_str


assert dump_json({"a": 10, "b": "20"}) == u'{"a": 10, "b": "20"}'

In [16]:
def geographic_to_utm(longitude, latitude):
    from pyproj import Proj, transform

    x, y = transform(
        Proj(init='epsg:4326'),
        Proj(init='epsg:3857'),
        longitude,
        latitude,
    )

    return x, y

### DoFns

In [17]:
class DecodeTaxiMessage(beam.DoFn):

    def process(self, message):
        from datetime import datetime
        import dateutil
        import pytz
        from apache_beam.utils.timestamp import Timestamp

        timestamp_str = message.attributes["ts"]
        dt = dateutil.parser.parse("2019-06-27T20:36:35.4972-04:00")
        dt_delta = dt.astimezone(pytz.UTC) - datetime.utcfromtimestamp(0).replace(
            tzinfo=pytz.UTC
        )
        timestampt = Timestamp(seconds=dt_delta.total_seconds())
        
        element = message.data
        yield beam.window.TimestampedValue(element, timestampt)


message = PubsubMessage(
    data=b"hello", attributes={"ts": "2019-06-27T20:36:35.4972-04:00"}
)
assert(next(DecodeTaxiMessage().process(message)).value == b"hello")
assert(next(DecodeTaxiMessage().process(message)).timestamp.micros == 1561682195497200)

In [18]:
class SelectWithinGeographicRange(beam.DoFn):
    def __init__(self, longitude_range, latitude_range):
        self.longitude_range = longitude_range
        self.latitude_range = latitude_range

    def process(self, element):
        if (
            self.longitude_range[0] <= element["longitude"] <= self.longitude_range[1]
        ) and (self.latitude_range[0] <= element["latitude"] <= self.latitude_range[1]):
            return [element]
        else:
            return []


el = {"longitude": 0, "latitude": 0}
assert SelectWithinGeographicRange((0, 1), (0, 1)).process(el) == [el]

el = {"longitude": 0, "latitude": 1}
assert SelectWithinGeographicRange((0, 1), (0, 1)).process(el) == [el]

el = {"longitude": 0, "latitude": -0.1}
assert SelectWithinGeographicRange((0, 1), (0, 1)).process(el) == []

In [19]:
class Limit(beam.CombineFn):
    
    def __init__(self, limit=1000):
        self.limit = 1000
        
    def create_accumulator(self):
        lst = []
        return lst
    
    def add_input(self, lst, input):
        if len(lst) < self.limit:
            lst.append(input)
        return lst

    def merge_accumulators(self, accumulators):
        lst = [l for lst in accumulators for l in lst]
        lst = lst[:self.limit]
        return lst

    def extract_output(self, lst):
        return lst

In [20]:
def add_mercator_coords(element):
    def geographic_to_utm(longitude, latitude):
        from pyproj import Proj, transform

        x, y = transform(
            Proj(init="epsg:4326"), Proj(init="epsg:3857"), longitude, latitude
        )
        return x, y

    element["x"], element["y"] = geographic_to_utm(
        element["longitude"], element["latitude"]
    )
    return element

## Workflow

In [21]:
LONGITUDE_RANGE = (-74.07, -73.90)
LATITUDE_RANGE = (40.74, 40.76)

In [22]:
x_min, y_min = geographic_to_utm(longitude=LONGITUDE_RANGE[0], latitude=LATITUDE_RANGE[0])
x_max, y_max = geographic_to_utm(longitude=LONGITUDE_RANGE[1], latitude=LATITUDE_RANGE[1])

MERCATOR_X_RANGE = (x_min, x_max)
MERCATOR_Y_RANGE = (y_min, y_max)

### Create subscription to public `taxirides-realtime-sub` topic

In [23]:
sub_client = pubsub.SubscriberClient()

In [24]:
subscription_name = "projects/{}/subscriptions/taxirides-realtime-sub".format(project_id)

try:
    sub_client.create_subscription(
        subscription_name,
        "projects/pubsub-public-data/topics/taxirides-realtime",
    )
except gexc.AlreadyExists:
    sub_client.delete_subscription(subscription_name)
    sub_client.create_subscription(
        subscription_name,
        "projects/pubsub-public-data/topics/taxirides-realtime",
    )

### Process data from subscription to cache

In [25]:
temp = caching.PubSubBasedCache(
    "projects/{}/topics/temp".format(project_id), mode="overwrite"
)

In [26]:
p = beam.Pipeline(runner=BundleBasedDirectRunner(), options=options)

out = (
    p
    | "Read" >> beam.io.ReadFromPubSub(subscription=subscription_name, with_attributes=True)
    | "Decode PubSub message" >> beam.ParDo(DecodeTaxiMessage())
    | "Load JSON" >> beam.Map(load_json)
    | "Filter coords" >> beam.ParDo(SelectWithinGeographicRange(LONGITUDE_RANGE, LATITUDE_RANGE))
    | "Add UTM coords" >> beam.Map(add_mercator_coords)
    | "Write" >> temp.writer()
)

p_result = p.run()

In [None]:
next(temp.read())

### Interactive dashboard

In [28]:
from bokeh.io import push_notebook, show, output_notebook
from bokeh.layouts import row
from bokeh.models import ColumnDataSource
from bokeh.plotting import figure, show
from bokeh.tile_providers import Vendors, get_provider

In [None]:
output_notebook()

# range bounds supplied in web mercator coordinates
p = figure(
    x_range=MERCATOR_X_RANGE,
    y_range=MERCATOR_Y_RANGE,
    x_axis_type="mercator",
    y_axis_type="mercator",
#     plot_height=600,
)
p.add_tile(get_provider(Vendors.CARTODBPOSITRON))

source = ColumnDataSource(data=dict(x=[], y=[]))

p.circle(x="x", y="y", size=2, fill_color="blue", fill_alpha=0.8, source=source)

t = show(p, notebook_handle=True)

while True:
    for i, element in enumerate(temp.read()):
        source.stream({"x": [element["x"]], "y": [element["y"]]})

KeyboardInterrupt: 

Exception KeyboardInterrupt in 'grpc._cython.cygrpc.ReceiveInitialMetadataOperation.un_c' ignored


In [None]:
print("hello")