# Schema setup for NY taxi dataset

1. Download green and yellow, years 2020-2021
    - `data/taxi_ingest_data/raw/<type>/<filename>`
    - use `download_taxi.sh` script
        - arg1: type
        - arg2: year
    - e.g. `data/taxi_ingest_data/raw/fhv/fhv_tripdata_2020-02.parquet`
    - gzip in place
1. Read raw data file from gzip and set schema
    - read from above
    - set schema
    - repartition
    - write to `data/taxi_ingest_data/staging/<type>/<year>/<month>`

I don't think casting parquet files are a good idea. If we need to recast, cast from `csv`. 

Alternatively read into pandas before writing back into parquet for partitioning

In [1]:
import pyspark
from pyspark.sql import SparkSession
import pandas as pd
from pathlib import Path
from pyspark.sql import types

In [2]:
spark = SparkSession.builder \
    .master('local[*]') \
    .appName('test') \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/02/21 05:20:13 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
hv_path = Path("../data/taxi_ingest_data/raw/fhvhv/fhvhv_tripdata_2021-01.parquet")

df = spark.read \
    .option('header','true') \
    .parquet(str(hv_path))

                                                                                

In [7]:
df.dtypes

[('hvfhs_license_num', 'string'),
 ('dispatching_base_num', 'string'),
 ('originating_base_num', 'string'),
 ('request_datetime', 'timestamp'),
 ('on_scene_datetime', 'timestamp'),
 ('pickup_datetime', 'timestamp'),
 ('dropoff_datetime', 'timestamp'),
 ('PULocationID', 'bigint'),
 ('DOLocationID', 'bigint'),
 ('trip_miles', 'double'),
 ('trip_time', 'bigint'),
 ('base_passenger_fare', 'double'),
 ('tolls', 'double'),
 ('bcf', 'double'),
 ('sales_tax', 'double'),
 ('congestion_surcharge', 'double'),
 ('airport_fee', 'double'),
 ('tips', 'double'),
 ('driver_pay', 'double'),
 ('shared_request_flag', 'string'),
 ('shared_match_flag', 'string'),
 ('access_a_ride_flag', 'string'),
 ('wav_request_flag', 'string'),
 ('wav_match_flag', 'string')]

In [9]:
df.select('pickup_datetime','dropoff_datetime','trip_time').show(5)



+-------------------+-------------------+---------+
|    pickup_datetime|   dropoff_datetime|trip_time|
+-------------------+-------------------+---------+
|2021-01-01 00:33:44|2021-01-01 00:49:07|      923|
|2021-01-01 00:55:19|2021-01-01 01:18:21|     1382|
|2021-01-01 00:23:56|2021-01-01 00:38:05|      849|
|2021-01-01 00:42:51|2021-01-01 00:45:50|      179|
|2021-01-01 00:48:14|2021-01-01 01:08:42|     1228|
+-------------------+-------------------+---------+
only showing top 5 rows



                                                                                

In [12]:
g_path = Path('../data/taxi_ingest_data/raw/green/green_tripdata_2020-01.parquet')
green = spark.read \
    .option('header', 'true') \
    .parquet(str(g_path))
green.dtypes


[('VendorID', 'bigint'),
 ('lpep_pickup_datetime', 'timestamp'),
 ('lpep_dropoff_datetime', 'timestamp'),
 ('store_and_fwd_flag', 'string'),
 ('RatecodeID', 'double'),
 ('PULocationID', 'bigint'),
 ('DOLocationID', 'bigint'),
 ('passenger_count', 'double'),
 ('trip_distance', 'double'),
 ('fare_amount', 'double'),
 ('extra', 'double'),
 ('mta_tax', 'double'),
 ('tip_amount', 'double'),
 ('tolls_amount', 'double'),
 ('ehail_fee', 'int'),
 ('improvement_surcharge', 'double'),
 ('total_amount', 'double'),
 ('payment_type', 'double'),
 ('trip_type', 'double'),
 ('congestion_surcharge', 'double')]

In [16]:
green.select('VendorID','ehail_fee','PULocationID').summary().show()



+-------+-------------------+---------+------------------+
|summary|           VendorID|ehail_fee|      PULocationID|
+-------+-------------------+---------+------------------+
|  count|             447770|        0|            447770|
|   mean| 1.8742948388681688|     null|108.12123634901847|
| stddev|0.33151714743469723|     null|  71.1659562223991|
|    min|                  1|     null|                 1|
|    25%|                  2|     null|                52|
|    50%|                  2|     null|                82|
|    75%|                  2|     null|               166|
|    max|                  2|     null|               265|
+-------+-------------------+---------+------------------+



                                                                                

In [51]:
def get_col_types(df) -> dict:
    """
    Returns dict of types containing column names
    Types detected:
        int
        string
        float
        timestamp
    """
    dtypes = ['int', 'string', 'float', 'timestamp']
    col_all = df.columns
    col_ints = [col for col in col_all if "_time" in col or "ID" in col or "_type" in col or "_count" in col]
    col_str = [col for col in col_all if "_num" in col or "_flag" in col]
    col_timestamp = [col for col in col_all if "_datetime" in col]
    col_flt = [col for col in col_all if col not in col_ints and col not in col_str and col not in col_timestamp]
    return dict(zip(dtypes, [col_ints, col_str, col_flt, col_timestamp]))

dtypes = get_col_types(green)
                
    

In [22]:
for k in dtypes:
    print(f'{k}: {dtypes[k]}')

int: ['VendorID', 'RatecodeID', 'PULocationID', 'DOLocationID', 'payment_type', 'trip_type']
string: ['store_and_fwd_flag']
float: ['passenger_count', 'trip_distance', 'fare_amount', 'extra', 'mta_tax', 'tip_amount', 'tolls_amount', 'ehail_fee', 'improvement_surcharge', 'total_amount', 'congestion_surcharge']
timestamp: ['lpep_pickup_datetime', 'lpep_dropoff_datetime']


In [23]:
fhvhv_dtypes = get_col_types(df)
for k in fhvhv_dtypes:
    print(f'{k}: {fhvhv_dtypes[k]}')

int: ['PULocationID', 'DOLocationID', 'trip_time']
string: ['hvfhs_license_num', 'dispatching_base_num', 'originating_base_num', 'shared_request_flag', 'shared_match_flag', 'access_a_ride_flag', 'wav_request_flag', 'wav_match_flag']
float: ['trip_miles', 'base_passenger_fare', 'tolls', 'bcf', 'sales_tax', 'congestion_surcharge', 'airport_fee', 'tips', 'driver_pay']
timestamp: ['request_datetime', 'on_scene_datetime', 'pickup_datetime', 'dropoff_datetime']


## Schema

- `*_datetime`: timestamp
- `*_num` or `*_flag`: string
- `*_time` or `*ID`: Integer
- all others: Float

In [28]:
from pyspark.sql import functions as F

In [29]:
df.withColumn('PULocationID', F.col('PULocationID').cast(types.IntegerType()))

DataFrame[hvfhs_license_num: string, dispatching_base_num: string, originating_base_num: string, request_datetime: timestamp, on_scene_datetime: timestamp, pickup_datetime: timestamp, dropoff_datetime: timestamp, PULocationID: int, DOLocationID: bigint, trip_miles: double, trip_time: bigint, base_passenger_fare: double, tolls: double, bcf: double, sales_tax: double, congestion_surcharge: double, airport_fee: double, tips: double, driver_pay: double, shared_request_flag: string, shared_match_flag: string, access_a_ride_flag: string, wav_request_flag: string, wav_match_flag: string]

In [30]:
# Casting multiple columns
df.withColumns({
    'PULocationID': F.col('PULocationID').cast(types.ShortType()),
    'DOLocationID': F.col('DOLocationID').cast(types.ShortType())
})

DataFrame[hvfhs_license_num: string, dispatching_base_num: string, originating_base_num: string, request_datetime: timestamp, on_scene_datetime: timestamp, pickup_datetime: timestamp, dropoff_datetime: timestamp, PULocationID: smallint, DOLocationID: smallint, trip_miles: double, trip_time: bigint, base_passenger_fare: double, tolls: double, bcf: double, sales_tax: double, congestion_surcharge: double, airport_fee: double, tips: double, driver_pay: double, shared_request_flag: string, shared_match_flag: string, access_a_ride_flag: string, wav_request_flag: string, wav_match_flag: string]

In [33]:
df.select('PULocationID','DOLocationID').show(5)

+------------+------------+
|PULocationID|DOLocationID|
+------------+------------+
|         230|         166|
|         152|         167|
|         233|         142|
|         142|         143|
|         143|          78|
+------------+------------+
only showing top 5 rows



In [34]:
green_schema = types.StructType([
    types.StructField("VendorID", types.LongType(), True),
    types.StructField("lpep_pickup_datetime", types.TimestampType(), True),
    types.StructField("lpep_dropoff_datetime", types.TimestampType(), True),
    types.StructField("store_and_fwd_flag", types.StringType(), True),
    types.StructField("RatecodeID", types.DoubleType(), True),
    types.StructField("PULocationID", types.LongType(), True),
    types.StructField("DOLocationID", types.LongType(), True),
    types.StructField("passenger_count", types.LongType(), True),
    types.StructField("trip_distance", types.DoubleType(), True),
    types.StructField("fare_amount", types.DoubleType(), True),
    types.StructField("extra", types.DoubleType(), True),
    types.StructField("mta_tax", types.DoubleType(), True),
    types.StructField("tip_amount", types.DoubleType(), True),
    types.StructField("tolls_amount", types.DoubleType(), True),
    types.StructField("ehail_fee", types.DoubleType(), True),
    types.StructField("improvement_surcharge", types.DoubleType(), True),
    types.StructField("total_amount", types.DoubleType(), True),
    types.StructField("payment_type", types.LongType(), True),
    types.StructField("trip_type", types.LongType(), True),
    types.StructField("congestion_surcharge", types.DoubleType(), True)
])

yellow_schema = types.StructType([
    types.StructField("VendorID", types.LongType(), True),
    types.StructField("tpep_pickup_datetime", types.TimestampType(), True),
    types.StructField("tpep_dropoff_datetime", types.TimestampType(), True),
    types.StructField("passenger_count", types.LongType(), True),
    types.StructField("trip_distance", types.DoubleType(), True),
    types.StructField("RatecodeID", types.DoubleType(), True),
    types.StructField("store_and_fwd_flag", types.StringType(), True),
    types.StructField("PULocationID", types.LongType(), True),
    types.StructField("DOLocationID", types.LongType(), True),
    types.StructField("payment_type", types.LongType(), True),
    types.StructField("fare_amount", types.DoubleType(), True),
    types.StructField("extra", types.DoubleType(), True),
    types.StructField("mta_tax", types.DoubleType(), True),
    types.StructField("tip_amount", types.DoubleType(), True),
    types.StructField("tolls_amount", types.DoubleType(), True),
    types.StructField("improvement_surcharge", types.DoubleType(), True),
    types.StructField("total_amount", types.DoubleType(), True),
    types.StructField("congestion_surcharge", types.DoubleType(), True)
])

hvfhv_schema = types.StructType([
    types.StructField('hvfhs_license_num', types.StringType(), True),
    types.StructField('dispatching_base_num', types.StringType(), True),
    types.StructField('originating_base_num', types.StringType(), True),
    types.StructField('request_datetime', types.TimestampType(), True),
    types.StructField('on_scene_datetime', types.TimestampType(), True),
    types.StructField('pickup_datetime', types.TimestampType(), True),
    types.StructField('dropoff_datetime', types.TimestampType(), True),
    types.StructField('PULocationID', types.LongType(), True),
    types.StructField('DOLocationID', types.LongType(), True),
    types.StructField('trip_miles', types.DoubleType(), True),
    types.StructField('trip_time', types.LongType(), True),
    types.StructField('base_passenger_fare', types.DoubleType(), True),
    types.StructField('tolls', types.DoubleType(), True),
    types.StructField('bcf', types.DoubleType(), True),
    types.StructField('sales_tax', types.DoubleType(), True),
    types.StructField('congestion_surcharge', types.DoubleType(), True),
    types.StructField('airport_fee', types.DoubleType(), True),
    types.StructField('tips', types.DoubleType(), True),
    types.StructField('driver_pay', types.DoubleType(), True),
    types.StructField('shared_request_flag', types.StringType(), True),
    types.StructField('shared_match_flag', types.StringType(), True),
    types.StructField('access_a_ride_flag', types.StringType(), True),
    types.StructField('wav_request_flag', types.StringType(), True),
    types.StructField('wav_match_flag', types.StringType(), True)
    ]
)

In [45]:
def cast_schema(spark_client, raw_path: Path, parts_dir: Path, num_parts: int = 4):
    """
    Reads parquet from raw_path, casts the input schema onto it,
    then partitions into output folder
    """
    # read as-is
    df = spark_client.read \
        .option('header', 'true') \
        .parquet(str(raw_path))
    # get schema
    schema = get_col_types(df)
    # cast schema
    for dtype in schema:
        match dtype:
            case 'int':
                spark_type = types.IntegerType()
            case 'string':
                spark_type = types.StringType()
            case 'float':
                spark_type = types.FloatType()
            case 'timestamp':
                spark_type = types.TimestampType()
        col_map = {col: F.col(col).cast(spark_type) for col in schema[dtype]}
        # withColumns() returns a new dataframe
        df = df.withColumns(col_map)
    # repartition and write
    df \
        .repartition(num_parts) \
        .write.parquet(
            str(parts_dir),
            mode='overwrite',
            compression='gzip')
    return df


In [49]:
out_dir = Path("../data/taxi_ingest_data/parts/fhvhv/2021/01/")
# if not out_dir.exists():
#     out_dir.mkdir(parents=True)
hv_path = Path("../data/taxi_ingest_data/raw/fhvhv/fhvhv_tripdata_2021-01.parquet")
hv = cast_schema(spark, hv_path, out_dir)
hv.printSchema()



root
 |-- hvfhs_license_num: string (nullable = true)
 |-- dispatching_base_num: string (nullable = true)
 |-- originating_base_num: string (nullable = true)
 |-- request_datetime: timestamp (nullable = true)
 |-- on_scene_datetime: timestamp (nullable = true)
 |-- pickup_datetime: timestamp (nullable = true)
 |-- dropoff_datetime: timestamp (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- trip_miles: float (nullable = true)
 |-- trip_time: integer (nullable = true)
 |-- base_passenger_fare: float (nullable = true)
 |-- tolls: float (nullable = true)
 |-- bcf: float (nullable = true)
 |-- sales_tax: float (nullable = true)
 |-- congestion_surcharge: float (nullable = true)
 |-- airport_fee: float (nullable = true)
 |-- tips: float (nullable = true)
 |-- driver_pay: float (nullable = true)
 |-- shared_request_flag: string (nullable = true)
 |-- shared_match_flag: string (nullable = true)
 |-- access_a_ride_flag: string (nul

                                                                                

In [50]:
df_read = spark.read \
    .option('header', 'true') \
    .parquet(str(out_dir))
df_read.printSchema()

root
 |-- hvfhs_license_num: string (nullable = true)
 |-- dispatching_base_num: string (nullable = true)
 |-- originating_base_num: string (nullable = true)
 |-- request_datetime: timestamp (nullable = true)
 |-- on_scene_datetime: timestamp (nullable = true)
 |-- pickup_datetime: timestamp (nullable = true)
 |-- dropoff_datetime: timestamp (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- trip_miles: float (nullable = true)
 |-- trip_time: integer (nullable = true)
 |-- base_passenger_fare: float (nullable = true)
 |-- tolls: float (nullable = true)
 |-- bcf: float (nullable = true)
 |-- sales_tax: float (nullable = true)
 |-- congestion_surcharge: float (nullable = true)
 |-- airport_fee: float (nullable = true)
 |-- tips: float (nullable = true)
 |-- driver_pay: float (nullable = true)
 |-- shared_request_flag: string (nullable = true)
 |-- shared_match_flag: string (nullable = true)
 |-- access_a_ride_flag: string (nul

In [55]:
# hvfhv 2021-02
fmonth = "02"
out_dir = Path(f"../data/taxi_ingest_data/parts/fhvhv/2021/{fmonth}/")
# if not out_dir.exists():
#     out_dir.mkdir(parents=True)
hv_path = Path(f"../data/taxi_ingest_data/raw/fhvhv/fhvhv_tripdata_2021-{fmonth}.parquet")
cast_schema(spark, hv_path, out_dir)

                                                                                

DataFrame[hvfhs_license_num: string, dispatching_base_num: string, originating_base_num: string, request_datetime: timestamp, on_scene_datetime: timestamp, pickup_datetime: timestamp, dropoff_datetime: timestamp, PULocationID: int, DOLocationID: int, trip_miles: float, trip_time: int, base_passenger_fare: float, tolls: float, bcf: float, sales_tax: float, congestion_surcharge: float, airport_fee: float, tips: float, driver_pay: float, shared_request_flag: string, shared_match_flag: string, access_a_ride_flag: string, wav_request_flag: string, wav_match_flag: string]

In [52]:
taxi_type = 'green'
year = 2020
month = 1
raw_path = Path(f"../data/taxi_ingest_data/raw/{taxi_type}/{taxi_type}_tripdata_{year}-{month:02d}.parquet")
out_dir = Path(f"../data/taxi_ingest_data/parts/{taxi_type}/{year}/{month:02d}/")
g = cast_schema(spark, raw_path, out_dir)

                                                                                

In [53]:
g.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- lpep_pickup_datetime: timestamp (nullable = true)
 |-- lpep_dropoff_datetime: timestamp (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: float (nullable = true)
 |-- fare_amount: float (nullable = true)
 |-- extra: float (nullable = true)
 |-- mta_tax: float (nullable = true)
 |-- tip_amount: float (nullable = true)
 |-- tolls_amount: float (nullable = true)
 |-- ehail_fee: float (nullable = true)
 |-- improvement_surcharge: float (nullable = true)
 |-- total_amount: float (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- trip_type: integer (nullable = true)
 |-- congestion_surcharge: float (nullable = true)



In [54]:
g = spark.read.option('header','true').parquet(str(out_dir))
g.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- lpep_pickup_datetime: timestamp (nullable = true)
 |-- lpep_dropoff_datetime: timestamp (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: float (nullable = true)
 |-- fare_amount: float (nullable = true)
 |-- extra: float (nullable = true)
 |-- mta_tax: float (nullable = true)
 |-- tip_amount: float (nullable = true)
 |-- tolls_amount: float (nullable = true)
 |-- ehail_fee: float (nullable = true)
 |-- improvement_surcharge: float (nullable = true)
 |-- total_amount: float (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- trip_type: integer (nullable = true)
 |-- congestion_surcharge: float (nullable = true)



In [56]:
taxi_types = ['green', 'yellow']
years = list(range(2020, 2022))
for taxi_type, year in zip(taxi_types, years):
    for month in range(1, 13):
        raw_path = Path(f"../data/taxi_ingest_data/raw/{taxi_type}/{taxi_type}_tripdata_{year}-{month:02d}.parquet")
        out_dir = Path(f"../data/taxi_ingest_data/parts/{taxi_type}/{year}/{month:02d}/")
        cast_schema(spark, raw_path=raw_path, parts_dir=out_dir)


                                                                                