In [2]:
from pyspark.sql import SparkSession 
spark= SparkSession.builder.appName("taxi_price").getOrCreate()

# 1. 데이터 준비 

In [3]:
# 파일경로 설정 
import os
trip_files= '/trip/*'
zone_file= 'taxi+_zone_lookup.csv'
directory = os.path.join(os.getcwd(), 'data')

In [4]:
# 데이터 로드 
trips_df = spark.read.csv(f'file:///{directory}/{trip_files}', inferSchema=True, header=True)
zone_df= spark.read.csv(f'file:///{directory}/{zone_file}', inferSchema=True, header=True)

                                                                                

In [5]:
trips_df.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: string (nullable = true)
 |-- tpep_dropoff_datetime: string (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)



In [6]:
trips_df.show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|       2| 2021-03-01 00:22:02|  2021-03-01 00:23:22|              1|          0.0|         1|                 N|         264|         264|           2|        3.0|  0.5|    0.5|       0.0|         0.0|                  0.3

# 2. 데이터 전처리 

In [7]:
# 컬럼 선택
selected_columns = ["passenger_count", "trip_distance", "RatecodeID", "PULocationID", 
    "DOLocationID", "fare_amount", "extra", "mta_tax", "tip_amount", 
    "tolls_amount", "improvement_surcharge", "congestion_surcharge", "total_amount"
]
trips_df= trips_df.select(selected_columns)

In [8]:
trips_df= trips_df.select(selected_columns)

In [9]:
trips_df = trips_df.fillna(0)

In [10]:
label_column = "total_amount"

# 피처 

In [11]:
from pyspark.ml.feature import VectorAssembler 

#독립 변수 (피처) 컬럼 정의 

feature_columns = [
    "passenger_count", "trip_distance", "RatecodeID", "PULocationID", 
    "DOLocationID", "fare_amount", "extra", "mta_tax", "tip_amount", 
    "tolls_amount", "improvement_surcharge", "congestion_surcharge"
]

assembler = VectorAssembler(inputCols=feature_columns, outputCol= "features")
trips_df = assembler.transform(trips_df)

In [12]:
trips_df.select("features", label_column).show(5, truncate=False)

+----------------------------------------------------------+------------+
|features                                                  |total_amount|
+----------------------------------------------------------+------------+
|[1.0,0.0,1.0,264.0,264.0,3.0,0.5,0.5,0.0,0.0,0.3,0.0]     |4.3         |
|[1.0,0.0,1.0,152.0,152.0,2.5,0.5,0.5,0.0,0.0,0.3,0.0]     |3.8         |
|[1.0,0.0,1.0,152.0,152.0,3.5,0.5,0.5,0.0,0.0,0.3,0.0]     |4.8         |
|[0.0,16.5,4.0,138.0,265.0,51.0,0.5,0.5,11.65,6.12,0.3,0.0]|70.07       |
|[1.0,1.13,1.0,68.0,264.0,5.5,0.5,0.5,1.86,0.0,0.3,2.5]    |11.16       |
+----------------------------------------------------------+------------+
only showing top 5 rows



In [13]:
train_df, test_df = trips_df.randomSplit([0.8, 0.2], seed=42)

In [14]:
from pyspark.ml.regression import LinearRegression

# LinearRegression 모델 생성
lr = LinearRegression(featuresCol="features", labelCol=label_column, predictionCol="prediction")

# 모델 학습
lr_model = lr.fit(train_df)

24/12/13 17:22:46 WARN Instrumentation: [465a7462] regParam is zero, which might cause numerical instability and overfitting.
24/12/13 17:23:02 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
24/12/13 17:23:02 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
24/12/13 17:24:12 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeSystemLAPACK
24/12/13 17:24:12 WARN LAPACK: Failed to load implementation from: com.github.fommil.netlib.NativeRefLAPACK
                                                                                

In [15]:
from pyspark.ml.evaluation import RegressionEvaluator

# 테스트 데이터에서 예측 수행
predictions = lr_model.transform(test_df)

# RMSE 평가
evaluator = RegressionEvaluator(
    metricName="rmse", labelCol=label_column, predictionCol="prediction"
)
rmse = evaluator.evaluate(predictions)
print(f"Root Mean Square Error (RMSE): {rmse}")

# 모델 요약 정보 출력
print(f"Coefficients: {lr_model.coefficients}")
print(f"Intercept: {lr_model.intercept}")



Root Mean Square Error (RMSE): 0.5400656154366438
Coefficients: [0.012160979142100167,6.506610272354602e-06,-0.03454818574468199,-0.0003065896187344202,2.4886780039517477e-05,1.0000271713377529,0.22626341393251398,2.2498439879669663,1.0141778236708519,1.0551089219411338,1.5397333669115845,0.7283117087374027]
Intercept: 0.006045128786987593


                                                                                

In [16]:
predictions.select("total_amount", "prediction", "features").show(5, truncate=False)

[Stage 12:>                                                         (0 + 1) / 1]

+------------+-----------------+---------------------------------------------------+
|total_amount|prediction       |features                                           |
+------------+-----------------+---------------------------------------------------+
|3.3         |4.051646007314969|(12,[2,3,4,5,7,10],[1.0,24.0,24.0,2.5,0.5,0.3])    |
|4.3         |4.273120472989669|[0.0,0.0,1.0,41.0,41.0,2.5,1.0,0.5,0.0,0.0,0.3,0.0]|
|3.3         |4.046575356218461|(12,[2,3,4,5,7,10],[1.0,42.0,42.0,2.5,0.5,0.3])    |
|4.3         |4.272838770150974|[0.0,0.0,1.0,42.0,42.0,2.5,1.0,0.5,0.0,0.0,0.3,0.0]|
|6.8         |6.657586359793597|[0.0,0.0,1.0,48.0,48.0,2.5,3.5,0.5,0.0,0.0,0.3,2.5]|
+------------+-----------------+---------------------------------------------------+
only showing top 5 rows



                                                                                