In [47]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructField, StructType, IntegerType, FloatType, DoubleType, StringType
from sedona.spark import *
from pyspark.sql.functions import col, row_number, desc
from pyspark.sql.window import Window
import time

# Initialize spark session
spark = SparkSession \
    .builder \
    .appName("Dataframe query 1 execution (no UDF)") \
    .getOrCreate()

# Initialize sedone context
sedona = SedonaContext.create(spark)

# Load police stations
police_station_schema = StructType([
    StructField("X", DoubleType()),
    StructField("Y", DoubleType()),
    StructField("FID", IntegerType()),
    StructField("Division", StringType()),
    StructField("Location", StringType()),
    StructField("PREC", IntegerType()),
])

police_stations_df = spark.read.csv("s3://initial-notebook-data-bucket-dblab-905418150721/project_data/LA_Police_Stations.csv", \
    header=True, \
    schema=police_station_schema)

# Create geometry row for coords
police_stations_df = police_stations_df.withColumn("station_geom", ST_Point("X", "Y"))

# Load crime data
crime_schema = StructType([
    StructField("DR_NO", IntegerType()),
    StructField("Date Rptd", StringType()),
    StructField("DATE OCC", StringType()),
    StructField("TIME OCC", IntegerType()),
    StructField("AREA", IntegerType()),
    StructField("AREA NAME", StringType()),
    StructField("Rpt Dist No", IntegerType()),
    StructField("Part 1-2", IntegerType()),
    StructField("Crm Cd", IntegerType()),
    StructField("Crm Cd Desc", StringType()),
    StructField("Mocodes", StringType()),
    StructField("Vict Age", IntegerType()),
    StructField("Vict Sex", StringType()),
    StructField("Vict Descent", StringType()),
    StructField("Premis Cd", IntegerType()),
    StructField("Premis Desc", StringType()),
    StructField("Weapon Used Cd", IntegerType()),
    StructField("Weapon Desc", StringType()),
    StructField("Status", StringType()),
    StructField("Status Descent", StringType()),
    StructField("Crm Cd 1", IntegerType()),
    StructField("Crm Cd 2", IntegerType()),
    StructField("Crm Cd 3", IntegerType()),
    StructField("Crm Cd 4", IntegerType()),
    StructField("LOCATION", StringType()),
    StructField("Cross Street", StringType()),
    StructField("LAT", DoubleType()),
    StructField("LON", DoubleType()),
])

data1 = spark.read.csv("s3://initial-notebook-data-bucket-dblab-905418150721/project_data/LA_Crime_Data/LA_Crime_Data_2010_2019.csv", \
    header=True, \
    schema=crime_schema)

data2 = spark.read.csv("s3://initial-notebook-data-bucket-dblab-905418150721/project_data/LA_Crime_Data/LA_Crime_Data_2020_2025.csv", \
    header=True, \
    schema=crime_schema)

crime_df = data1.union(data2)

# Create geometry row for coords
crime_df = crime_df.withColumn("crime_geom", ST_Point("LON", "LAT"))

# Start timer
start = time.time()

# Get the cartesian product of crimes x police stations
# then compute the distance of each crime from each station
cartesian_prod = crime_df.crossJoin(police_stations_df).select("DR_NO","Division","crime_geom","station_geom")
cartesian_prod_with_dist = cartesian_prod.withColumn("Distance",ST_DistanceSphere("crime_geom","station_geom")).drop("station_geom")

# Partition cartesian product over crimes to find the nearest station fro each
window = Window.partitionBy("DR_NO").orderBy(col("Distance"))
closest_station_df = cartesian_prod_with_dist.withColumn("min_dist",row_number().over(window)).filter(col("min_dist") <= 1)
closest_station_df = closest_station_df.drop("min_dist").drop("crime_geom")

# Get the average distance and total number of crimes per station
avg_crimes_per_station = closest_station_df.groupBy(col("Division")).avg("Distance")
tot_crimes_per_station = closest_station_df.groupBy(col("Division")).count()

# Join the above results to get the final one
results = avg_crimes_per_station.join(tot_crimes_per_station, avg_crimes_per_station.Division==tot_crimes_per_station.Division)\
    .orderBy(desc("count"))
results.show(21)

end = time.time()
print("Elapsed time: ",end-start)
results.explain(mode="formatted")

# SQL Implementation
# Create views
start = time.time()
police_stations_df.createOrReplaceTempView("stations")
crime_df.createOrReplaceTempView("crimes")

# First query: calculate distance of each crime from each station
# then partition over crimes and find the nearest station
query1 = "SELECT Division, Distance FROM \
    (SELECT *, ROW_NUMBER() OVER (PARTITION BY DR_NO ORDER BY Distance) AS row_num FROM \
        (SELECT DR_NO, Division,\
        ST_DistanceSphere(ST_Point(LON, LAT),ST_Point(X, Y)) AS Distance \
        FROM crimes,stations) AS res1 \
    ) WHERE row_num=1"

intermediate_res = spark.sql(query1)
intermediate_res.createOrReplaceTempView("closest_stations")

# Second query: get average distance and total count of crimes per station
query2 = "SELECT Division, ROUND(AVG(Distance),2) AS average_distance, COUNT(*) AS count \
    FROM closest_stations GROUP BY Division ORDER BY count DESC"
res = spark.sql(query2)
res.show(21)

end = time.time()
print("Elapsed time: ",end-start)
res.explain(mode="formatted")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),â€¦

+----------------+------------------+----------------+------+
|        Division|     avg(Distance)|        Division| count|
+----------------+------------------+----------------+------+
|       HOLLYWOOD| 2073.575473968219|       HOLLYWOOD|224124|
|        VAN NUYS| 2939.256360612656|        VAN NUYS|208129|
|       SOUTHWEST|2191.4685251016463|       SOUTHWEST|189119|
|        WILSHIRE| 2593.298374272885|        WILSHIRE|186383|
|     77TH STREET| 1717.114916334095|     77TH STREET|170620|
| NORTH HOLLYWOOD| 2642.499207526158| NORTH HOLLYWOOD|168096|
|         OLYMPIC|1728.9319104161402|         OLYMPIC|162805|
|         PACIFIC| 3853.497411898416|         PACIFIC|162027|
|         CENTRAL| 993.3242673630771|         CENTRAL|154689|
|         RAMPART|1534.2201910926965|         RAMPART|153204|
|       SOUTHEAST| 2443.914918878561|       SOUTHEAST|143803|
|     WEST VALLEY|3021.5716977222737|     WEST VALLEY|136622|
|        FOOTHILL|4260.0997597283695|        FOOTHILL|132482|
|       