In [1]:
from pyspark.sql import SparkSession
from pyspark import SparkContext
from pyspark.sql import functions as F
from pyspark.sql.functions import hour
import matplotlib.pyplot as plt
import requests
from bs4 import BeautifulSoup
import datetime
import time
from pyspark.sql.types import *

## 환경준비
- 디멘션 테이블을 팩트 테이블과 join하여 맨허튼 시의 특정 기간동안 count를 세보기

In [2]:
if SparkContext._active_spark_context:
    SparkContext._active_spark_context.stop()

spark = SparkSession.builder \
        .appName("TLC BroadCast + DPP Demo") \
        .config("spark.driver.bindAddress", "127.0.0.1") \
        .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/20 19:15:25 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# 팩트 데이터 로딩

# schema = StructType([
#     StructField("VendorID",               IntegerType(),  True),
#     StructField("tpep_pickup_datetime",   TimestampType(),True),
#     StructField("tpep_dropoff_datetime",  TimestampType(),True),
#     StructField("passenger_count",        IntegerType(),  True),
#     StructField("trip_distance",          DoubleType(),   True),
#     StructField("RatecodeID",             IntegerType(),  True),
#     StructField("store_and_fwd_flag",     StringType(),   True),
#     StructField("PULocationID",           LongType(),     True),   # ← INT32·INT64 섞임 대비
#     StructField("DOLocationID",           LongType(),     True),
#     StructField("payment_type",           IntegerType(),  True),
#     StructField("fare_amount",            DoubleType(),   True),
#     StructField("extra",                  DoubleType(),   True),
#     StructField("mta_tax",                DoubleType(),   True),
#     StructField("tip_amount",             DoubleType(),   True),
#     StructField("tolls_amount",           DoubleType(),   True),
#     StructField("improvement_surcharge",  DoubleType(),   True),
#     StructField("total_amount",           DoubleType(),   True),
#     StructField("congestion_surcharge",   DoubleType(),   True),
#     StructField("airport_fee",   DoubleType(),   True),
# ])
trips = spark.read.parquet("data/yellow_tripdata_2024-*.parquet")

                                                                                

In [4]:
# 디멘션 데이터 로딩
zone_path = "zone_data/taxi_zone_lookup.csv"
zones  = (
    spark.read.option("header", True).csv(zone_path)
         .selectExpr("LocationID", "Zone", "Borough")
         .cache()
)

In [5]:
print(trips.rdd.getNumPartitions())
print(trips.count())

11
37501349


In [6]:
trips.columns

['VendorID',
 'tpep_pickup_datetime',
 'tpep_dropoff_datetime',
 'passenger_count',
 'trip_distance',
 'RatecodeID',
 'store_and_fwd_flag',
 'PULocationID',
 'DOLocationID',
 'payment_type',
 'fare_amount',
 'extra',
 'mta_tax',
 'tip_amount',
 'tolls_amount',
 'improvement_surcharge',
 'total_amount',
 'congestion_surcharge',
 'Airport_fee']

## 최적화 OFF 버전
- Broadcast & DPP 비활성화
- 실험조건
  - 2024 택시 데이터
  - 디멘션 테이블과 Join 후에 맨허턴 Borugh에서의 택시 픽업 Zone 수 찾기
- 실험결과
  - all off : 5.04초(sort merge join)
  - AQE on / Broadcast join off / DPP off : 6.38초 (sort merge join 일어남)
  - AQE off / Broadcast join on / DPP on : 1.8초 (BroadcastHashJoin)
  - all on : 1.58초 (BroadcastHashJoin)
- 결론
  - DPP나 Broadcast join 기능을 꺼버리고 AQE만 실행해놓으면 조인전략개선 같은 기능을 사용할 수 없다. 그래서 위에 상황에서는 AQE 오버헤드만 늘어서 시간이 오히려 증가
  - AQE를 아무 상황에나 사용한다해서 나아지지 않는다. 파티션이 어느정도 많아야 적절한 파티션 병합기능이 활성화될 것이고, skew 혹은 join과정도 없다면 overhead만 늘어날 수 있다.
  - 데이터 쿼리가 아주 작고, 파티션 적으며, 플랜을 개발자가 알아야 할 때는 AQE 꺼도된다.

In [7]:
# Broadcast join, DPP 꺼버리기 
# 예상 풀스캔 후 Shuffle Hash join 발생할듯 하다.
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.enabled", "false")

# AQE도 꺼야 join 전략을 중간에 수정안할듯.
spark.conf.set("spark.sql.adaptive.enabled", "true")  

In [8]:
# 특정기간에서 맨허튼 픽업 건수 구하기
start = time.time()
result_no_option = (
    trips
        # .filter("tpep_pickup_datetime BETWEEN '2024-06-15' AND '2024-12-31'")
        .join(zones, trips.PULocationID == zones.LocationID, "inner")
        .filter(F.col("Borough") == "Manhattan")
        .groupBy("Zone")
        .count()
        .orderBy(F.col("count").desc())
)
print(f"[NO OPT] Rows = {result_no_option.count()}, elapsed = {time.time()-start:.2f}s")
# explain 함수는 Catalyst optimizer가 하는 실행 계획을 보여줌
result_no_option.explain(mode='formatted')

25/04/20 19:15:36 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:36 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:36 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:36 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:36 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:36 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:36 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:37 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:37 WARN RowBasedKeyValueBatch: Calling spill() on



25/04/20 19:15:38 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:38 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:38 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:38 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:38 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:38 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/04/20 19:15:39 WARN RowBasedKeyValueBatch: Calling spill() on

[NO OPT] Rows = 67, elapsed = 6.63s
== Physical Plan ==
AdaptiveSparkPlan (20)
+- Sort (19)
   +- Exchange (18)
      +- HashAggregate (17)
         +- Exchange (16)
            +- HashAggregate (15)
               +- Project (14)
                  +- SortMergeJoin Inner (13)
                     :- Sort (4)
                     :  +- Exchange (3)
                     :     +- Filter (2)
                     :        +- Scan parquet  (1)
                     +- Sort (12)
                        +- Exchange (11)
                           +- Project (10)
                              +- Filter (9)
                                 +- InMemoryTableScan (5)
                                       +- InMemoryRelation (6)
                                             +- * Project (8)
                                                +- Scan csv  (7)


(1) Scan parquet 
Output [1]: [PULocationID#7]
Batched: true
Location: InMemoryFileIndex [file:/Users/seonwoo/Documents/GitHub/data_engineering_co

                                                                                

In [10]:
result_no_option.explain(mode="extended") 

AttributeError: 'DataFrame' object has no attribute 'queryExecution'

## 최적화 ON 버전 

In [7]:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 100*1024*1024)
spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.enabled", "true")
spark.conf.set("spark.sql.adaptive.enabled", "true")  

In [9]:
# 특정기간에서 맨허튼 픽업 건수 구하기
start = time.time()
result_with_option = (
       trips
        # .filter("tpep_pickup_datetime BETWEEN '2024-06-15' AND '2024-12-31'")
        .join(zones, trips.PULocationID == zones.LocationID, "inner")
        .filter(F.col("Borough") == "Manhattan")
        .groupBy("Zone")
        .count()
        .orderBy(F.col("count").desc())
)
print(f"[WITH OPT] Rows = {result_with_option.count()}, elapsed = {time.time()-start:.2f}s")
# explain 함수는 Catalyst optimizer가 하는 실행 계획을 보여줌
result_with_option.explain(mode='formatted')

[WITH OPT] Rows = 67, elapsed = 0.84s
== Physical Plan ==
AdaptiveSparkPlan (17)
+- Sort (16)
   +- Exchange (15)
      +- HashAggregate (14)
         +- Exchange (13)
            +- HashAggregate (12)
               +- Project (11)
                  +- BroadcastHashJoin Inner BuildRight (10)
                     :- Filter (2)
                     :  +- Scan parquet  (1)
                     +- BroadcastExchange (9)
                        +- Project (8)
                           +- Filter (7)
                              +- InMemoryTableScan (3)
                                    +- InMemoryRelation (4)
                                          +- * Project (6)
                                             +- Scan csv  (5)


(1) Scan parquet 
Output [1]: [PULocationID#7]
Batched: true
Location: InMemoryFileIndex [file:/Users/seonwoo/Documents/GitHub/data_engineering_course_materials/missions/W4/M2/data/yellow_tripdata_2024-01.parquet, ... 10 entries]
PushedFilters: [IsNotNull(PULoca