In [0]:
from pyspark.ml.feature import Bucketizer
import pyspark.sql.functions as sf
import numpy as np
from pyspark.sql.window import Window
import math
from pyspark.sql.types import *
from pyspark.sql import DataFrame

In [0]:
def aggregate_by_zip_codes(
    trips_df: DataFrame,
    bin_size: int
) -> DataFrame:

    min_max_pu_zip_df = (
        trips_df
        .select(
            sf.min(sf.col("pickup_zip")).alias("min_pu_zip"),
            sf.max(sf.col("pickup_zip")).alias("max_pu_zip")
        )
    )

    # Get the minimum pickup_zip from the table into a number variable, and floor it to the nearest thousand.
    min_pu_zip = [
        i[0] for i in 
        min_max_pu_zip_df.select("min_pu_zip").collect()
    ][0]

    min_pu_zip = math.floor(min_pu_zip / 1000) * 1000

    # Get the maximum pickup_zip from the table into a number variable, and round it up to the nearest thousand.
    max_pu_zip = [
        i[0] for i in 
        min_max_pu_zip_df.select("max_pu_zip").collect()
    ][0]

    max_pu_zip = math.ceil(max_pu_zip / 1000) * 1000

    split_list = [float(i) for i in np.arange(min_pu_zip, max_pu_zip + bin_size, bin_size)]

    # The "buckets" column gives the bucket rank, not the actual bucket value(range).
    # Use dictionary to match bucket rank and bucket value.
    bucket_names = dict(zip([float(i) for i in range(len(split_list[1:]))], split_list[1:]))
    # User defined function to update the data frame with the bucket value
    udf_foo = udf(lambda x: bucket_names[x], DoubleType())

    bucketizer = Bucketizer(
        splits = split_list, 
        inputCol = "pickup_zip", 
        outputCol = "pu_zip_ranks"
    )

    pu_ranks_df = (
        bucketizer
        .setHandleInvalid("keep")
        .transform(trips_df.select("pickup_zip").dropna())
        .dropDuplicates()
        .withColumn("zip_bins", udf_foo("pu_zip_ranks"))
    )

    trips_df_zip_ranks = (
        trips_df.alias("lt")
        .join(
            pu_ranks_df.alias("rt"),
            [sf.col("lt.pickup_zip") == sf.col("rt.pickup_zip")],
            "leftouter"
        )
        .select(
            sf.col("lt.*"), sf.col("rt.pu_zip_ranks"), sf.col("rt.zip_bins")
        )
    )

    win = Window.partitionBy(sf.col("pu_zip_ranks"))

    trips_df_agg = (
        trips_df_zip_ranks
        .withColumn("trip_count", sf.count(sf.col("pu_zip_ranks")).over(win))
        .withColumn(
            "avg_duration_mins", 
            sf.round(sf.avg(sf.col("trip_duration_mins")).over(win))
        )
        .withColumn(
            "zip_bin_lower_bound",
            sf.col("zip_bins") - bin_size
        )
        .withColumn("zip_bin_upper_bound", sf.col("zip_bins"))
        .drop("zip_bins")
    )

    return trips_df_agg