In [23]:
import pyspark
import pandas as pd

from pyspark.sql import SparkSession
from pyspark.sql import types as T
from pyspark.sql import functions as F

In [3]:
spark = SparkSession.builder \
    .master('local[*]') \
    .appName('test') \
    .config('spark.driver.memory', '2g') \
    .getOrCreate()

23/03/06 15:56:05 WARN Utils: Your hostname, Zambo-ROG resolves to a loopback address: 127.0.1.1; using 172.30.104.214 instead (on interface eth0)
23/03/06 15:56:05 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
23/03/06 15:56:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [4]:
spark

In [5]:
# !wget https://github.com/DataTalksClub/nyc-tlc-data/releases/download/fhvhv/fhvhv_tripdata_2021-01.csv.gz

In [6]:
df = spark.read \
    .option('header', 'true') \
    .csv('fhvhv_tripdata_2021-01.csv.gz') \

In [7]:
!gunzip -c fhvhv_tripdata_2021-01.csv.gz | head -n 1001 > head.csv


gzip: stdout: Broken pipe


In [8]:
df_pandas = pd.read_csv('head.csv', parse_dates=[2, 3])

In [9]:
df_pandas.dtypes

hvfhs_license_num               object
dispatching_base_num            object
pickup_datetime         datetime64[ns]
dropoff_datetime        datetime64[ns]
PULocationID                     int64
DOLocationID                     int64
SR_Flag                        float64
dtype: object

In [11]:
spark.createDataFrame(df_pandas).schema

StructType(List(StructField(hvfhs_license_num,StringType,true),StructField(dispatching_base_num,StringType,true),StructField(pickup_datetime,TimestampType,true),StructField(dropoff_datetime,TimestampType,true),StructField(PULocationID,LongType,true),StructField(DOLocationID,LongType,true),StructField(SR_Flag,DoubleType,true)))

Integer - 4 bytes<br>
Long - 8 bytes

In [12]:
schema = T.StructType([
    T.StructField('hvfhs_license_num', T.StringType(), True),
    T.StructField('dispatching_base_num', T.StringType(), True),
    T.StructField('pickup_datetime', T.TimestampType(), True),
    T.StructField('dropoff_datetime', T.TimestampType(), True),
    T.StructField('PULocationID', T.IntegerType(), True),
    T.StructField('DOLocationID', T.IntegerType(), True),
    T.StructField('SR_Flag', T.StringType(), True)
])

In [13]:
df = spark.read \
    .schema(schema) \
    .option('header', 'true') \
    .csv('fhvhv_tripdata_2021-01.csv') \

In [14]:
df = df.repartition(24)

In [15]:
df.write \
    .mode('overwrite') \
    .parquet('fhvhv/2021/01') \

23/03/06 16:04:37 WARN MemoryManager: Total allocation exceeds 95.00% (2,040,109,440 bytes) of heap memory
Scaling row group sizes to 95.00% for 16 writers
                                                                                

## Next steps

In [16]:
df = spark.read.parquet('fhvhv/2021/01')

In [20]:
df.show(5)

+-----------------+--------------------+-------------------+-------------------+------------+------------+-------+
|hvfhs_license_num|dispatching_base_num|    pickup_datetime|   dropoff_datetime|PULocationID|DOLocationID|SR_Flag|
+-----------------+--------------------+-------------------+-------------------+------------+------------+-------+
|           HV0005|              B02510|2021-01-01 00:20:57|2021-01-01 00:41:24|          69|          41|   null|
|           HV0005|              B02510|2021-01-01 02:17:52|2021-01-01 02:37:16|         243|         116|   null|
|           HV0003|              B02878|2021-01-01 21:18:35|2021-01-01 21:32:32|          18|         259|   null|
|           HV0003|              B02872|2021-01-01 23:25:49|2021-01-01 23:36:12|          94|          32|   null|
|           HV0003|              B02764|2021-01-02 19:51:39|2021-01-02 20:07:16|         236|         143|   null|
+-----------------+--------------------+-------------------+-------------------+

In [29]:
df.select('pickup_datetime', 'dropoff_datetime', 'PULocationID', 'DOLocationID') \
    .filter(df.hvfhs_license_num == 'HV0003') \
    .show(5)

+-------------------+-------------------+------------+------------+
|    pickup_datetime|   dropoff_datetime|PULocationID|DOLocationID|
+-------------------+-------------------+------------+------------+
|2021-01-01 21:18:35|2021-01-01 21:32:32|          18|         259|
|2021-01-01 23:25:49|2021-01-01 23:36:12|          94|          32|
|2021-01-02 19:51:39|2021-01-02 20:07:16|         236|         143|
|2021-01-02 00:02:37|2021-01-02 00:09:32|         216|         216|
|2021-01-01 16:34:07|2021-01-01 17:02:33|         181|          75|
+-------------------+-------------------+------------+------------+
only showing top 5 rows



In [30]:
df \
    .withColumn('pickup_date', F.to_date(df.pickup_datetime)) \
    .withColumn('dropoff_date', F.to_date(df.dropoff_datetime)) \
    .select('pickup_date', 'dropoff_date', 'PULocationID', 'DOLocationID') \
    .show(5)

+-----------+------------+------------+------------+
|pickup_date|dropoff_date|PULocationID|DOLocationID|
+-----------+------------+------------+------------+
| 2021-01-01|  2021-01-01|          69|          41|
| 2021-01-01|  2021-01-01|         243|         116|
| 2021-01-01|  2021-01-01|          18|         259|
| 2021-01-01|  2021-01-01|          94|          32|
| 2021-01-02|  2021-01-02|         236|         143|
+-----------+------------+------------+------------+
only showing top 5 rows



In [97]:
def crazy_stuff(base_num):
    num = int(base_num[1:])
    if num % 7 == 0:
        return f's/{num:03x}'
    elif num % 3 == 0:
        return f'a/{num:03x}'
    else:
        return f'e/{num:03x}'

In [98]:
crazy_stuff_udf = F.udf(crazy_stuff, returnType=T.StringType())

In [100]:
df \
    .withColumn('pickup_date', F.to_date(df.pickup_datetime)) \
    .withColumn('dropoff_date', F.to_date(df.dropoff_datetime)) \
    .withColumn('base_id', crazy_stuff_udf(df.dispatching_base_num)) \
    .select('base_id', 'pickup_date', 'dropoff_date', 'PULocationID', 'DOLocationID') \
    .show(5)

+-------+-----------+------------+------------+------------+
|base_id|pickup_date|dropoff_date|PULocationID|DOLocationID|
+-------+-----------+------------+------------+------------+
|  e/9ce| 2021-01-01|  2021-01-01|          69|          41|
|  e/9ce| 2021-01-01|  2021-01-01|         243|         116|
|  e/b3e| 2021-01-01|  2021-01-01|          18|         259|
|  e/b38| 2021-01-01|  2021-01-01|          94|          32|
|  e/acc| 2021-01-02|  2021-01-02|         236|         143|
+-------+-----------+------------+------------+------------+
only showing top 5 rows

