In [None]:
## Import local library
import os
from datetime import datetime

## Import GeoPandas
import geopandas as gpd
import matplotlib.pyplot as plt

## Import PySpark
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import col, expr, broadcast, udf, lit, struct

## Import Apache Sedona
from sedona.register import SedonaRegistrator
from sedona.utils import SedonaKryoRegistrator, KryoSerializer
from sedona.core.formatMapper.shapefileParser import ShapefileReader
from sedona.utils.adapter import Adapter

# Define spark session if not defined yet
No need to define spark if run in an external cloud

In [None]:
try:
    spark
except NameError:
    spark = SparkSession. \
    builder. \
    appName('appName'). \
    master('local[*]'). \
    config("spark.serializer", KryoSerializer.getName). \
    config("spark.kryo.registrator", SedonaKryoRegistrator.getName). \
    config("fs.s3a.aws.credentials.provider", "org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider"). \
    config("spark.hadoop.fs.s3a.aws.credentials.provider", "org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider"). \
    getOrCreate()

In [None]:
SedonaRegistrator.registerAll(spark)
sc = spark.sparkContext
sc.setSystemProperty("sedona.global.charset", "utf8")
sc._jsc.hadoopConfiguration().set("fs.s3a.aws.credentials.provider", "org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider")

def delete_path(sc, path):
    fs = (sc._jvm.org
          .apache.hadoop
          .fs.FileSystem
          .get(sc._jsc.hadoopConfiguration())
          )
    fs.delete(sc._jvm.org.apache.hadoop.fs.Path(path), True)

# Use the prefix in all your EMR path

If you use EMR, EMR requires that all paths must be relative. You can use the variable below as the prefix for all paths.

In [None]:
from pathlib import Path

PATH_PREFIX= str(Path.home()) + '/' if os.environ.get('ENV_WB', 'false') == 'true' else ''

print(PATH_PREFIX)

# Load taxi pickup records to Sedona

In [None]:
taxidf = spark.read.format('csv').option("header","true").option("delimiter", ",").load("s3a://wherobots-examples/data/nyc-taxi-data.csv")
taxidf = taxidf.selectExpr('ST_Point(CAST(Start_Lon AS Decimal(24,20)), CAST(Start_Lat AS Decimal(24,20))) AS pickup', 'Trip_Pickup_DateTime', 'Payment_Type', 'Fare_Amt')
taxidf = taxidf.filter(col("pickup").isNotNull())
taxidf.show()
taxidf.createOrReplaceTempView('taxiDf')
taxiRdd = Adapter.toSpatialRdd(taxidf, "pickup")
import shutil
shutil.rmtree(PATH_PREFIX + "taxi-pickup.geojson", ignore_errors=True)
delete_path(sc, PATH_PREFIX + "taxi-pickup.geojson")
taxiRdd.saveAsGeoJSON(PATH_PREFIX + "taxi-pickup.geojson")

# Load Zones to Sedona

In [None]:
zoneDf = spark.read.format('csv').option("delimiter", ",").load("s3a://wherobots-examples/data/TIGER2018_ZCTA5.csv")
zoneDf = zoneDf.selectExpr('ST_GeomFromWKT(_c0) as zone', '_c1 as zipcode')
zoneDf.show()
zoneDf.createOrReplaceTempView('zoneDf')

# Visualize Sedona Dataframes on maps

In [None]:
zoneGpd = gpd.GeoDataFrame(zoneDf.toPandas(), geometry="zone")
taxiGpd = gpd.GeoDataFrame(taxidf.toPandas(), geometry="pickup")

In [None]:
zone = zoneGpd.plot(color='yellow', edgecolor='black', zorder=1)
zone.set_xlabel('Longitude (degrees)')
zone.set_ylabel('Latitude (degrees)')

# Local view
zone.set_xlim(-74.1, -73.8)
zone.set_ylim(40.65, 40.9)

taxi = taxiGpd.plot(ax=zone, alpha=0.01, color='red', zorder=3)

In [None]:
%matplot plt

# Find taxis in each zone

In [None]:
taxiVsZone = spark.sql('SELECT zone, zipcode, pickup, Fare_Amt FROM zoneDf, taxiDf WHERE ST_Contains(zone, pickup)')
taxiVsZone.show()
taxiVsZone.createOrReplaceTempView("taxiVsZone")
taxiVsZone = taxiVsZone.cache()

# Count taxis per zone

In [None]:
taxiPerZone = spark.sql("SELECT zone, zipcode, count(*) as count, avg(Fare_Amt) as avg_fare FROM taxiVsZone c GROUP BY zone, zipcode")
taxiPerZone.show()
taxiPerZoneRdd = Adapter.toSpatialRdd(taxiPerZone, "zone")
import shutil
shutil.rmtree(PATH_PREFIX + "taxi-per-zone.geojson", ignore_errors=True)
delete_path(sc, PATH_PREFIX + "taxi-per-zone.geojson")
taxiPerZoneRdd.saveAsGeoJSON(PATH_PREFIX + "taxi-per-zone.geojson")

# Visualize the result on a map

In [None]:
gdf = gpd.GeoDataFrame(taxiPerZone.toPandas(), geometry="zone")

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
fig, ax = plt.subplots(1, 1)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)

result = gdf.plot(
    column="count",
    legend=True,
    cmap='OrRd',
    cax=cax,
    ax=ax
)

# Local view
result.set_xlim(-74.1, -73.8)
result.set_ylim(40.65, 40.9)

In [None]:
%matplot plt