In [1]:
import findspark
findspark.init()
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import IntegerType, FloatType

In [20]:
spark = SparkSession.builder\
      .master("local[2]")\
      .appName("SparkFirst")\
      .config("spark.executor.memory", "32g")\
      .config("spark.driver.memory", "32g")\
      .config("spark.executor.cores", 4)\
      .config("spark.dynamicAllocation.enabled", "true")\
      .config("spark.dynamicAllocation.maxExecutors", 4)\
      .config("spark.shuffle.service.enabled", "true")\
      .getOrCreate()
df = spark.read.option("header", True)\
               .option("inferSchema", "true")\
               .csv("input_data/yellow_tripdata_2020-01.csv")\
               .dropna()
df = df.filter((to_date(df.tpep_pickup_datetime) <= '2020-01-31') &
               (to_date(df.tpep_dropoff_datetime) >= '2020-01-01') &
               (df.total_amount > 0))
df = df.drop("VendorID", 
             "tpep_dropoff_datetime", 
             "RatecodeID", 
             "store_and_fwd_flag", 
             "PULocationID", 
             "DOLocationID", 
             "payment_type",
             "fare_amount",
             "extra",
             "mta_tax",
             "tip_amount",
             "tolls_amount",
             "improvement_surcharge",
             "congestion_surcharge")
@udf(returnType = IntegerType())
def case_s(n, x): return 1 if x == n else 0

@udf(returnType = IntegerType())
def case_s4(x): return 1 if x > 3 else 0

@udf(returnType = FloatType())
def case_m(n, n_pass, x): return x if n_pass == n else None

@udf(returnType = FloatType())
def case_m4(n_pass, x): return x if n_pass > 3 else None

df = df.select(to_date(df.tpep_pickup_datetime).alias("pickup_date"), df.passenger_count, df.total_amount)\
  .groupBy("pickup_date")\
  .agg(sum(case_s(lit(0), df.passenger_count)).alias("n_0"),
       sum(case_s(lit(1), df.passenger_count)).alias("n_1"),
       sum(case_s(lit(2), df.passenger_count)).alias("n_2"),
       sum(case_s(lit(3), df.passenger_count)).alias("n_3"),
       sum(case_s4(df.passenger_count)).alias("n_4"),
       min(case_m(lit(0), df.passenger_count, df.total_amount)).alias("min_0p"),
       max(case_m(lit(0), df.passenger_count, df.total_amount)).alias("max_0p"),
       min(case_m(lit(1), df.passenger_count, df.total_amount)).alias("min_1p"),
       max(case_m(lit(1), df.passenger_count, df.total_amount)).alias("max_1p"),
       min(case_m(lit(2), df.passenger_count, df.total_amount)).alias("min_2p"),
       max(case_m(lit(2), df.passenger_count, df.total_amount)).alias("max_2p"),
       min(case_m(lit(3), df.passenger_count, df.total_amount)).alias("min_3p"),
       max(case_m(lit(3), df.passenger_count, df.total_amount)).alias("max_3p"),
       min(case_m4(df.passenger_count, df.total_amount)).alias("min_4p"),
       max(case_m4(df.passenger_count, df.total_amount)).alias("max_4p"))
df = df.select(df.pickup_date, 
       round((df.n_0 * 100.0 / (df.n_0 + df.n_1 + df.n_2 + df.n_3 + df.n_4)), 2).alias("per_0p"),
       round((df.n_1 * 100.0 / (df.n_0 + df.n_1 + df.n_2 + df.n_3 + df.n_4)), 2).alias("per_1p"),
       round((df.n_2 * 100.0 / (df.n_0 + df.n_1 + df.n_2 + df.n_3 + df.n_4)), 2).alias("per_2p"),
       round((df.n_3 * 100.0 / (df.n_0 + df.n_1 + df.n_2 + df.n_3 + df.n_4)), 2).alias("per_3p"),
       round((df.n_4 * 100.0 / (df.n_0 + df.n_1 + df.n_2 + df.n_3 + df.n_4)), 2).alias("per_4p+"),
       df.min_0p, df.max_0p, df.min_1p, df.max_1p, df.min_2p, df.max_2p, df.min_3p, df.max_3p, 
       df.min_4p.alias("min_4p+"), df.max_4p.alias("max_4p+"))\
  .orderBy(df.pickup_date).cache()
df.show(40)

+-----------+------+------+------+------+-------+------+------+------+-------+------+------+------+------+-------+-------+
|pickup_date|per_0p|per_1p|per_2p|per_3p|per_4p+|min_0p|max_0p|min_1p| max_1p|min_2p|max_2p|min_3p|max_3p|min_4p+|max_4p+|
+-----------+------+------+------+------+-------+------+------+------+-------+------+------+------+------+-------+-------+
| 2019-12-31|   0.0| 57.14| 17.46|  4.76|  20.63|  NULL|  NULL|   5.8|  71.62|   9.8|  41.8|  16.0| 34.42|    6.8|  54.36|
| 2020-01-01|  1.52| 63.51| 19.48|  5.75|   9.74|   0.3|145.55|   0.3|  465.3|   0.3|281.42|   0.3|433.04|    0.3| 350.42|
| 2020-01-02|   1.7| 68.81| 16.01|  4.74|   8.74|   0.3|174.36|   0.3|  492.8|   0.3|328.04|   0.3|215.54|    3.3|  352.3|
| 2020-01-03|  1.77| 68.87| 16.31|  4.58|   8.47|   0.3|187.42|   0.3| 1242.3|   0.3| 370.3|  0.31|409.59|    0.3|  348.3|
| 2020-01-04|  1.69| 65.65| 18.38|  5.14|   9.14|   0.3|152.54|   0.3|  965.8|   0.3| 481.3|   0.3| 313.3|    0.3|  577.8|
| 2020-01-05|   

In [21]:
df.write.csv("output_data_spark/datamart_spark.csv", mode="overwrite", header=True)

In [22]:
df.write.parquet("output_data_spark/datamart_spark.parquet", mode="overwrite")

In [23]:
spark.catalog.clearCache()