In [2]:
from pyspark.sql import SparkSession

In [3]:
MAX_MEMORY = "5g"
spark = SparkSession.builder.appName("taxi-fare-prediction")\
                    .config("spark.executor.memory", MAX_MEMORY)\
                    .config("spark.driver.memory", MAX_MEMORY)\
                    .getOrCreate()

22/07/31 22:42:28 WARN Utils: Your hostname, singyeongdeog-ui-Macmini.local resolves to a loopback address: 127.0.0.1; using 222.98.22.103 instead (on interface en0)
22/07/31 22:42:28 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/07/31 22:42:29 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
22/07/31 22:42:30 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [4]:
trip_files = "/Users/singyeongdeog/Documents/github_code/data-engineering/01-spark/data/yellow/*"

In [5]:
trips_df = spark.read.csv(f"file:///{trip_files}", inferSchema=True, header=True)

                                                                                

In [6]:
trips_df.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: double (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: double (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- airport_fee: double (nullable = true)



In [7]:
trips_df.createOrReplaceTempView("trips")

In [9]:
query = """
SELECT
    passenger_count,
    PULocationID as pickup_location_id,
    DOLocationID as dropoff_location_id,
    trip_distance,
    HOUR(tpep_pickup_datetime) as pickup_time,
    DATE_FORMAT(TO_DATE(tpep_pickup_datetime), 'EEEE') as day_of_week,
    total_amount
FROM
    trips
WHERE
    total_amount < 5000
    AND total_amount > 0
    AND trip_distance > 0
    AND trip_distance < 500
    AND passenger_count < 4
    AND TO_DATE(tpep_pickup_datetime) >= '2021-01-01'
    AND TO_DATE(tpep_dropoff_datetime) < '2021-08-01'
"""
data_df = spark.sql(query)
data_df.createOrReplaceTempView("data")

In [10]:
data_df.show()

+---------------+------------------+-------------------+-------------+-----------+-----------+------------+
|passenger_count|pickup_location_id|dropoff_location_id|trip_distance|pickup_time|day_of_week|total_amount|
+---------------+------------------+-------------------+-------------+-----------+-----------+------------+
|            1.0|               142|                 43|          2.1|          0|     Friday|        11.8|
|            1.0|               238|                151|          0.2|          0|     Friday|         4.3|
|            1.0|               132|                165|         14.7|          0|     Friday|       51.95|
|            0.0|               138|                132|         10.6|          0|     Friday|       36.35|
|            1.0|                68|                 33|         4.94|          0|     Friday|       24.36|
|            1.0|               224|                 68|          1.6|          0|     Friday|       14.15|
|            1.0|           

In [11]:
data_df.printSchema()

root
 |-- passenger_count: double (nullable = true)
 |-- pickup_location_id: integer (nullable = true)
 |-- dropoff_location_id: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- pickup_time: integer (nullable = true)
 |-- day_of_week: string (nullable = true)
 |-- total_amount: double (nullable = true)



In [13]:
train_df, test_df = data_df.randomSplit([0.8,0.2],seed=1)

In [14]:
data_dir = "/Users/singyeongdeog/Documents/github_code/data-engineering/01-spark/data/data/"

In [15]:
train_df.write.format("parquet").save(f"{data_dir}/train/")
test_df.write.format("parquet").save(f"{data_dir}/test/")

                                                                                

In [16]:
train_df = spark.read.parquet(f"{data_dir}/train/")
test_df = spark.read.parquet(f"{data_dir}/test/")

In [17]:
train_df.printSchema()

root
 |-- passenger_count: double (nullable = true)
 |-- pickup_location_id: integer (nullable = true)
 |-- dropoff_location_id: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- pickup_time: integer (nullable = true)
 |-- day_of_week: string (nullable = true)
 |-- total_amount: double (nullable = true)



In [22]:
# Wednesday -> 3 -> [0,0,0,1,0,0] one-hot-encoding
from pyspark.ml.feature import OneHotEncoder, StringIndexer

cat_feats = [
    "pickup_location_id",
    "dropoff_location_id",
    "day_of_week"
]

stages = []

for c in cat_feats:
    cat_indexer = StringIndexer(inputCol=c, outputCol= c + "_idx").setHandleInvalid("keep")
    onehot_encoder = OneHotEncoder(inputCols=[cat_indexer.getOutputCol()], outputCols=[c + "_onehot"])
    stages += [cat_indexer, onehot_encoder]

In [23]:
stages

[StringIndexer_bec00fe4b4d4,
 OneHotEncoder_0ea42f81beb4,
 StringIndexer_6acb76e33543,
 OneHotEncoder_6b772b0e9170,
 StringIndexer_8f133e623a4c,
 OneHotEncoder_ec83ab4fee65]

In [24]:
from pyspark.ml.feature import VectorAssembler, StandardScaler

num_feats = [
    "passenger_count",
    "trip_distance",
    "pickup_time"
]

for n in num_feats:
    num_assembler = VectorAssembler(inputCols=[n], outputCol= n + "_vector")
    num_scaler = StandardScaler(inputCol=num_assembler.getOutputCol(), outputCol= n + "_scaled")
    stages += [num_assembler, num_scaler]

In [25]:
stages

[StringIndexer_bec00fe4b4d4,
 OneHotEncoder_0ea42f81beb4,
 StringIndexer_6acb76e33543,
 OneHotEncoder_6b772b0e9170,
 StringIndexer_8f133e623a4c,
 OneHotEncoder_ec83ab4fee65,
 VectorAssembler_2c3d71333502,
 StandardScaler_59c60a5b8b59,
 VectorAssembler_de553c584d35,
 StandardScaler_9495c34dda86,
 VectorAssembler_4661affdfa2a,
 StandardScaler_28c509e3fe39]

In [29]:
assembler_inputs = [c + "_onehot" for c in cat_feats] + [n + "_scaled" for n in num_feats]
print(assembler_inputs)

['pickup_location_id_onehot', 'dropoff_location_id_onehot', 'day_of_week_onehot', 'passenger_count_scaled', 'trip_distance_scaled', 'pickup_time_scaled']


In [28]:
assembler = VectorAssembler(inputCols=assembler_inputs, outputCol="feature_vector")

In [54]:
stages += [assembler]
print(stages)

[StringIndexer_bec00fe4b4d4, OneHotEncoder_0ea42f81beb4, StringIndexer_6acb76e33543, OneHotEncoder_6b772b0e9170, StringIndexer_8f133e623a4c, OneHotEncoder_ec83ab4fee65, VectorAssembler_2c3d71333502, StandardScaler_59c60a5b8b59, VectorAssembler_de553c584d35, StandardScaler_9495c34dda86, VectorAssembler_4661affdfa2a, StandardScaler_28c509e3fe39, VectorAssembler_1f80931a6d4c, VectorAssembler_1f80931a6d4c]


In [38]:
from pyspark.ml import Pipeline

transform_stages = stages
pipeline = Pipeline(stages = transform_stages)
fitted_transformer = pipeline.fit(train_df)

                                                                                

In [39]:
vtrain_df = fitted_transformer.transform(train_df)

In [40]:
from pyspark.ml.regression import LinearRegression

lr = LinearRegression(
    maxIter = 5,
    solver = "normal",
    labelCol = "total_amount",
    featuresCol = "feature_vector"
)

In [41]:
model = lr.fit(vtrain_df)

22/07/31 23:04:31 WARN Instrumentation: [25f1fcb6] regParam is zero, which might cause numerical instability and overfitting.


[Stage 43:>                                                         (0 + 8) / 9]

22/07/31 23:04:44 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
22/07/31 23:04:44 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.ForeignLinkerBLAS


                                                                                

22/07/31 23:04:54 WARN InstanceBuilder$NativeLAPACK: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK


                                                                                

In [44]:
vtest_df = fitted_transformer.transform(test_df)

In [45]:
predictions = model.transform(vtest_df)

In [46]:
predictions.cache()

DataFrame[passenger_count: double, pickup_location_id: int, dropoff_location_id: int, trip_distance: double, pickup_time: int, day_of_week: string, total_amount: double, pickup_location_id_idx: double, pickup_location_id_onehot: vector, dropoff_location_id_idx: double, dropoff_location_id_onehot: vector, day_of_week_idx: double, day_of_week_onehot: vector, passenger_count_vector: vector, passenger_count_scaled: vector, trip_distance_vector: vector, trip_distance_scaled: vector, pickup_time_vector: vector, pickup_time_scaled: vector, feature_vector: vector, prediction: double]

In [51]:
predictions.select(["trip_distance","day_of_week","total_amount","prediction"]).show()

+-------------+-----------+------------+------------------+
|trip_distance|day_of_week|total_amount|        prediction|
+-------------+-----------+------------+------------------+
|          0.7|   Saturday|       12.35|12.689527507405728|
|          1.5|     Friday|        11.8| 14.50426259211963|
|          2.9|     Sunday|        15.8| 16.30302554462105|
|          2.1|   Saturday|       15.35|  16.9425655126106|
|          1.7|   Saturday|        13.3|14.493378049111925|
|          0.4|   Thursday|         4.8|  9.53814848314505|
|          1.4|     Friday|         8.3|11.998519125627354|
|          2.2|    Tuesday|        13.3|13.397050101542845|
|          3.8|    Tuesday|       27.25|17.712843503279146|
|          1.7|    Tuesday|        11.8|12.840609883249272|
|          4.5|  Wednesday|       27.65|19.465887664817274|
|         13.4|     Monday|       66.35|62.293303116515744|
|         16.2|     Monday|       82.37| 68.89146993128816|
|          7.2|  Wednesday|       32.75|

In [52]:
model.summary.rootMeanSquaredError

5.676293357333079

In [53]:
model.summary.r2 # 정확도

0.8064767086817035