### SparkSession

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql import types

In [None]:
spark = SparkSession.builder \
    .master("local[*]") \
    .appName('schema') \
    .getOrCreate()

### Green taxi

In [None]:
# Loop over each year
# with each year, read each raw parquet file for each month, modify schema and write results out
years = [2020, 2021]

for year in years:
    for month in range(1, 13):

        print(f'Processing data for {year}/{month}')

        input_path = f'data/parquet_raw/green/{year}/{month:02d}/'
        output_path = f'data/parquet/green/{year}/{month:02d}/'

        df_green = spark.read.parquet(input_path)

        df_green = df_green \
            .withColumn('VendorID', df_green["VendorID"].cast(types.IntegerType())) \
            .withColumn('RatecodeID', df_green["RatecodeID"].cast(types.IntegerType())) \
            .withColumn('PULocationID', df_green["PULocationID"].cast(types.IntegerType())) \
            .withColumn('DOLocationID', df_green["DOLocationID"].cast(types.IntegerType())) \
            .withColumn('passenger_count', df_green["passenger_count"].cast(types.IntegerType())) \
            .withColumn('ehail_fee', df_green["ehail_fee"].cast(types.DoubleType())) \
            .withColumn('payment_type', df_green["payment_type"].cast(types.IntegerType())) \
            .withColumn('trip_type', df_green["trip_type"].cast(types.IntegerType()))

        df_green \
            .repartition(4) \
            .write.mode("overwrite").parquet(output_path)

### Yellow taxi

In [None]:
# Similar to green
years = [2020, 2021]

for year in years:
    for month in range(1, 13):

        print(f'Processing data for {year}/{month}')

        input_path = f'data/parquet_raw/yellow/{year}/{month:02d}/'
        output_path = f'data/parquet/yellow/{year}/{month:02d}/'

        df_yellow = spark.read.parquet(input_path)

        df_yellow = df_yellow \
            .withColumn('VendorID', df_yellow["VendorID"].cast(types.IntegerType())) \
            .withColumn('passenger_count', df_yellow["passenger_count"].cast(types.IntegerType())) \
            .withColumn('RatecodeID', df_yellow["RatecodeID"].cast(types.IntegerType())) \
            .withColumn('PULocationID', df_yellow["PULocationID"].cast(types.IntegerType())) \
            .withColumn('DOLocationID', df_yellow["DOLocationID"].cast(types.IntegerType())) \
            .withColumn('payment_type', df_yellow["payment_type"].cast(types.IntegerType()))

        df_yellow \
            .repartition(4) \
            .write.mode("overwrite").parquet(output_path)