In [0]:
# Check Spark version and cluster info
print(f"Spark version: {spark.version}")
print(f"Cluster configured successfully!")

# Check available datasets
dbutils.fs.ls("/databricks-datasets/")

Spark version: 4.0.0
Cluster configured successfully!


[FileInfo(path='dbfs:/databricks-datasets/COVID/', name='COVID/', size=0, modificationTime=1762778067866),
 FileInfo(path='dbfs:/databricks-datasets/README.md', name='README.md', size=976, modificationTime=1596557781000),
 FileInfo(path='dbfs:/databricks-datasets/Rdatasets/', name='Rdatasets/', size=0, modificationTime=1762778067866),
 FileInfo(path='dbfs:/databricks-datasets/SPARK_README.md', name='SPARK_README.md', size=3359, modificationTime=1596557823000),
 FileInfo(path='dbfs:/databricks-datasets/adult/', name='adult/', size=0, modificationTime=1762778067866),
 FileInfo(path='dbfs:/databricks-datasets/airlines/', name='airlines/', size=0, modificationTime=1762778067866),
 FileInfo(path='dbfs:/databricks-datasets/amazon/', name='amazon/', size=0, modificationTime=1762778067866),
 FileInfo(path='dbfs:/databricks-datasets/asa/', name='asa/', size=0, modificationTime=1762778067866),
 FileInfo(path='dbfs:/databricks-datasets/atlas_higgs/', name='atlas_higgs/', size=0, modificationTime=

In [0]:
import time
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, avg, sum, max, min, to_date, desc, when

# 1. Data Processing Pipeline

## Load data using PySpark (Parquet, or any format)

In [0]:
display(dbutils.fs.ls("/databricks-datasets/nyctaxi/tripdata/yellow/"))

path,name,size,modificationTime
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-01.csv.gz,yellow_tripdata_2009-01.csv.gz,504262564,1596568279000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-02.csv.gz,yellow_tripdata_2009-02.csv.gz,480034681,1596568279000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-03.csv.gz,yellow_tripdata_2009-03.csv.gz,521102719,1596568279000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-04.csv.gz,yellow_tripdata_2009-04.csv.gz,515435466,1596568279000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-05.csv.gz,yellow_tripdata_2009-05.csv.gz,531133739,1596568279000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-06.csv.gz,yellow_tripdata_2009-06.csv.gz,508802995,1596568313000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-07.csv.gz,yellow_tripdata_2009-07.csv.gz,487731497,1596568318000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-08.csv.gz,yellow_tripdata_2009-08.csv.gz,490825210,1596568318000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-09.csv.gz,yellow_tripdata_2009-09.csv.gz,503121179,1596568318000
dbfs:/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2009-10.csv.gz,yellow_tripdata_2009-10.csv.gz,567109604,1596568319000


In [0]:
df = spark.read.csv(
    "/databricks-datasets/nyctaxi/tripdata/yellow/yellow_tripdata_2019-12.csv.gz",
    header=True,
    inferSchema=True
)

# Verify it works
print(f"Rows: {df.count()}")
df.show(5)

Rows: 6896317
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|       1| 2019-12-01 00:26:58|  2019-12-01 00:41:45|              1|          4.2|         1|                 N|         142|         116|           2|       14.5|  3.0|    0.5|       0.0|         0.0|       

In [0]:
print("Schema:")
df.printSchema()

Schema:
root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)



In [0]:
df.select("payment_type").distinct().show()
df.select("VendorID").distinct().show()

+------------+
|payment_type|
+------------+
|        NULL|
|           1|
|           4|
|           3|
|           2|
|           5|
+------------+

+--------+
|VendorID|
+--------+
|    NULL|
|       1|
|       2|
+--------+



## Apply transformations: 

### 2+ filter operations

In [0]:
# 1. Filter trips with distance > 10 miles
df_long_trips = df.filter(col("trip_distance") > 10)
df_long_trips.show(5)


+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|       2| 2019-12-01 00:43:02|  2019-12-01 01:11:18|              1|        13.07|         1|                 N|          41|          51|           2|       38.5|  0.5|    0.5|       0.0|         0.0|                  0.3

In [0]:
# 2. Filter trips with passengers between 3 and 6
df_3p_6p_trips = df.where((col("passenger_count") > 2) & (col("passenger_count") < 7))
df_3p_6p_trips.show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|       1| 2019-12-01 00:36:16|  2019-12-01 00:53:42|              3|          5.5|         1|                 N|          79|         226|           1|       18.0|  3.0|    0.5|      4.35|         0.0|                  0.3

In [0]:
# 3. Filter trips with payment_type = 5
df_payment_type_5 = df.filter(col("payment_type") == 5)
df_payment_type_5.show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+
|       1| 2019-12-06 12:50:55|  2019-12-06 13:05:31|              1|          1.3|         1|                 N|         246|         161|           5|       10.0|  2.5|    0.5|       2.0|         0.0|                  0.3

### 1+ join operation (if using multiple datasets) OR complex aggregation & 1+ groupBy with aggregations

In [0]:
# 1. Average fare and tip per day and payment type
df_daily_payment = (
    df.groupBy(to_date(col("tpep_pickup_datetime")).alias("date"), col("payment_type"))
      .agg(
          avg("fare_amount").alias("avg_fare"),
          avg("tip_amount").alias("avg_tip"),
          sum("tip_amount").alias("total_tip"),
          count("*").alias("num_trips")
      )
      .orderBy("date", "payment_type")
)
df_daily_payment.show(10)

+----------+------------+------------------+--------------------+------------------+---------+
|      date|payment_type|          avg_fare|             avg_tip|         total_tip|num_trips|
+----------+------------+------------------+--------------------+------------------+---------+
|2008-12-31|           1|              24.3|               7.586|             37.93|        5|
|2008-12-31|           2|             14.56|                 0.0|               0.0|       25|
|2009-01-01|           1|           45.6875|              12.705|            101.64|        8|
|2009-01-01|           2|21.928571428571427|                 0.0|               0.0|       21|
|2019-11-30|           1|            13.055|  3.1666999999999996|316.66999999999996|      100|
|2019-11-30|           2|12.658333333333333|                 0.0|               0.0|       60|
|2019-12-01|        NULL|33.344640957446835| 0.16043882978723403|120.64999999999999|      752|
|2019-12-01|           1|14.524030870220837|   3.2

In [0]:
# 2. Trip distance and total amount statistics per vendor
df_vendor_stats = (
    df.groupBy("VendorID")
      .agg(
          min("trip_distance").alias("min_distance"),
          max("trip_distance").alias("max_distance"),
          avg("trip_distance").alias("avg_distance"),
          avg("total_amount").alias("avg_total_amount"),
          sum("total_amount").alias("sum_total_amount")
      )
)
df_vendor_stats.show()

+--------+------------+------------+-----------------+------------------+-------------------+
|VendorID|min_distance|max_distance|     avg_distance|  avg_total_amount|   sum_total_amount|
+--------+------------+------------+-----------------+------------------+-------------------+
|    NULL|   -37264.53|       40.75|6.552147673370255| 38.10385452193327| 1943982.4499999916|
|       1|         0.0|       300.8|2.806885461733137|19.289672665686457|4.404221920964065E7|
|       2|      -23.01|    19130.18|3.016746226570609|19.611339948811544| 8.94688351464533E7|
+--------+------------+------------+-----------------+------------------+-------------------+



### Column transformations using withColumn

In [0]:
# Add new column
df = df.withColumn("half_total_amount", col("total_amount") / 2)

print("Schema:")
df.printSchema()

Schema:
root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- half_total_amount: double (nullable = true)



## 2+ SQL queries

In [0]:
# Create temporary view
df.createOrReplaceTempView("nyc_taxi")

 1. Top 5 days with most trips

In [0]:
%sql
SELECT DATE(tpep_pickup_datetime) AS pickup_date,
       COUNT(*) AS num_trips
FROM nyc_taxi
GROUP BY DATE(tpep_pickup_datetime)
ORDER BY num_trips DESC
LIMIT 5

pickup_date,num_trips
2019-12-19,286929
2019-12-12,278481
2019-12-13,278114
2019-12-06,274919
2019-12-20,274892


2. Total tips and trips per PULocationID

In [0]:
%sql
SELECT PULocationID,
       SUM(tip_amount) AS total_tips,
       COUNT(*) AS num_trips
FROM nyc_taxi
GROUP BY PULocationID
ORDER BY total_tips DESC

PULocationID,total_tips,num_trips
132,1287255.1300006406,229808
138,902563.5899999272,163829
161,572637.0300000147,265351
237,572564.7199999827,323953
162,544770.4500000151,246367
186,543600.3100000146,253599
236,528229.1900000226,287876
230,515870.1100000116,232984
142,440299.0500000266,223620
170,435853.3900000138,203205


3. Trips with more than 2 passengers and fare > $50

In [0]:
%sql
SELECT *
FROM nyc_taxi
WHERE passenger_count > 2 
  AND fare_amount > 50

VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,RatecodeID,store_and_fwd_flag,PULocationID,DOLocationID,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount,congestion_surcharge,half_total_amount
2,2019-12-01T00:37:17.000Z,2019-12-01T01:07:39.000Z,5,19.98,2,N,132,238,1,52.0,0.0,0.5,14.73,6.12,0.3,73.65,0.0,36.825
1,2019-12-01T00:03:14.000Z,2019-12-01T00:36:32.000Z,4,18.3,2,N,132,90,1,52.0,2.5,0.5,12.28,6.12,0.3,73.7,2.5,36.85
1,2019-12-01T00:03:31.000Z,2019-12-01T00:32:20.000Z,4,21.0,2,N,132,164,2,52.0,2.5,0.5,0.0,6.12,0.3,61.42,2.5,30.71
2,2019-12-01T00:55:33.000Z,2019-12-01T01:09:12.000Z,3,3.68,2,N,145,48,1,52.0,0.0,0.5,18.43,6.12,0.3,79.85,2.5,39.925
2,2019-12-01T00:43:35.000Z,2019-12-01T00:43:43.000Z,3,0.0,5,N,265,265,1,70.0,0.0,0.5,5.0,0.0,0.3,75.8,0.0,37.9
2,2019-12-01T00:21:56.000Z,2019-12-01T00:55:12.000Z,3,19.5,2,N,132,249,1,52.0,0.0,0.5,5.5,0.0,0.3,60.8,2.5,30.4
2,2019-12-01T00:53:05.000Z,2019-12-01T01:27:06.000Z,6,23.09,1,N,132,18,2,62.5,0.5,0.5,0.0,6.12,0.3,69.92,0.0,34.96
2,2019-12-01T00:16:50.000Z,2019-12-01T00:44:07.000Z,5,16.97,3,N,100,1,1,63.5,0.5,0.0,5.0,10.5,0.3,79.8,0.0,39.9
2,2019-12-01T00:39:23.000Z,2019-12-01T01:07:40.000Z,6,18.87,1,N,132,65,1,51.0,0.5,0.5,13.08,0.0,0.3,65.38,0.0,32.69
2,2019-12-01T00:48:45.000Z,2019-12-01T00:51:57.000Z,7,1.31,5,N,265,265,2,70.0,0.0,0.0,0.0,0.0,0.3,70.3,0.0,35.15


## Optimize your queries

In [0]:
# Filter out invalid or unnecessary rows before aggregations
df_filtered = df.filter(
    (col("passenger_count") > 0) &
    (col("fare_amount") > 0)
)

df_filtered.count()

6696315

In [0]:
df.count()

6896317

## Write results to a destination (Parquet files, database, etc.)

In [0]:
df.write.mode("overwrite").parquet("/tmp/yellow_tax_modified.parquet")

[0;31m---------------------------------------------------------------------------[0m
[0;31mUnsupportedOperationException[0m             Traceback (most recent call last)
File [0;32m<command-6728280540013920>, line 1[0m
[0;32m----> 1[0m df[38;5;241m.[39mwrite[38;5;241m.[39mmode([38;5;124m"[39m[38;5;124moverwrite[39m[38;5;124m"[39m)[38;5;241m.[39mparquet([38;5;124m"[39m[38;5;124m/tmp/yellow_tax_modified.parquet[39m[38;5;124m"[39m)

File [0;32m/databricks/python/lib/python3.12/site-packages/pyspark/sql/connect/readwriter.py:779[0m, in [0;36mDataFrameWriter.parquet[0;34m(self, path, mode, partitionBy, compression)[0m
[1;32m    777[0m     [38;5;28mself[39m[38;5;241m.[39mpartitionBy(partitionBy)
[1;32m    778[0m [38;5;28mself[39m[38;5;241m.[39m_set_opts(compression[38;5;241m=[39mcompression)
[0;32m--> 779[0m [38;5;28mself[39m[38;5;241m.[39mformat([38;5;124m"[39m[38;5;124mparquet[39m[38;5;124m"[39m)[38;5;241m.[39msave(path)

File [0;

# 2. Performance Analysis 

## Use .explain() to show the physical execution plan

In [0]:
print("\n4. GROUP BY - Daily Payment Stats:")
df_daily_payment.explain()


4. GROUP BY - Daily Payment Stats:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- == Initial Plan ==
   ColumnarToRow
   +- PhotonResultStage
      +- PhotonSort [date#11531 ASC NULLS FIRST, payment_type#11100 ASC NULLS FIRST]
         +- PhotonShuffleExchangeSource
            +- PhotonShuffleMapStage ENSURE_REQUIREMENTS, [id=#9869]
               +- PhotonShuffleExchangeSink rangepartitioning(date#11531 ASC NULLS FIRST, payment_type#11100 ASC NULLS FIRST, 1024)
                  +- PhotonGroupingAgg(keys=[_groupingexpression#12072, payment_type#11100], functions=[finalmerge_avg(merge sum#12077, count#12078L) AS avg(fare_amount)#12069, finalmerge_sum(merge sum#12080) AS sum(tip_amount)#12073, finalmerge_count(merge count#12082L) AS count(tip_amount)#12074L, finalmerge_count(merge count#12084L) AS count(1)#12068L])
                     +- PhotonShuffleExchangeSource
                        +- PhotonShuffleMapStage ENSURE_REQUIREMENTS, [id=#9863]
                           

## How you optimized the pipeline (e.g., filter ordering, column pruning)

In [0]:
df_optimized = df.filter(
    (col("fare_amount") > 0) & 
    (col("tip_amount") >= 0) &
    (col("payment_type").isin([1, 2, 3, 4, 5])) 
)

df_daily_payment_optimized = (
    df_optimized.groupBy(to_date(col("tpep_pickup_datetime")).alias("date"), 
                        col("payment_type"))
      .agg(
          avg("fare_amount").alias("avg_fare"),
          avg("tip_amount").alias("avg_tip"),
          sum("tip_amount").alias("total_tip"),
          count("*").alias("num_trips")
      )
      .orderBy("date", "payment_type")
)

print("OPTIMIZED VERSION:")
df_daily_payment_optimized.explain()

OPTIMIZED VERSION:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- == Initial Plan ==
   ColumnarToRow
   +- PhotonResultStage
      +- PhotonSort [date#12126 ASC NULLS FIRST, payment_type#11100 ASC NULLS FIRST]
         +- PhotonShuffleExchangeSource
            +- PhotonShuffleMapStage ENSURE_REQUIREMENTS, [id=#10044]
               +- PhotonShuffleExchangeSink rangepartitioning(date#12126 ASC NULLS FIRST, payment_type#11100 ASC NULLS FIRST, 1024)
                  +- PhotonGroupingAgg(keys=[_groupingexpression#12135, payment_type#11100], functions=[finalmerge_avg(merge sum#12140, count#12141L) AS avg(fare_amount)#12132, finalmerge_sum(merge sum#12143) AS sum(tip_amount)#12136, finalmerge_count(merge count#12145L) AS count(tip_amount)#12137L, finalmerge_count(merge count#12147L) AS count(1)#12131L])
                     +- PhotonShuffleExchangeSource
                        +- PhotonShuffleMapStage ENSURE_REQUIREMENTS, [id=#10038]
                           +- PhotonShuffl

## Cache

In [0]:
print("=== WITHOUT CACHING ===")

# First action - full computation
start_time = time.time()
df_long_trips = df.filter(col("trip_distance") > 10)
count1 = df_long_trips.count()
first_run = time.time() - start_time
print(f"First count: {count1} - Time: {first_run:.2f}s")

# Second action - recomputes everything
start_time = time.time()
count2 = df_long_trips.count()  
second_run = time.time() - start_time
print(f"Second count: {count2} - Time: {second_run:.2f}s")

print(f"Without caching - Second run was {first_run/second_run:.1f}x faster")


print("\n\n=== WITH CACHING ===")

# Cache the DataFrame
df_cached = df.filter(col("trip_distance") > 10).cache()

# First action - computation + caching
start_time = time.time()
count1 = df_cached.count()
first_run_cached = time.time() - start_time
print(f"First count (with cache): {count1} - Time: {first_run_cached:.2f}s")

# Second action - uses cache
start_time = time.time()
count2 = df_cached.count()
second_run_cached = time.time() - start_time
print(f"Second count (with cache): {count2} - Time: {second_run_cached:.2f}s")

print(f"With caching - Second run was {first_run_cached/second_run_cached:.1f}x faster")
print(f"Caching made subsequent runs {second_run/second_run_cached:.1f}x faster")

=== WITHOUT CACHING ===
First count: 420053 - Time: 7.79s
Second count: 420053 - Time: 7.77s
Without caching - Second run was 1.0x faster


=== WITH CACHING ===


[0;31m---------------------------------------------------------------------------[0m
[0;31mAnalysisException[0m                         Traceback (most recent call last)
File [0;32m<command-7544636203783356>, line 22[0m
[1;32m     19[0m [38;5;28mprint[39m([38;5;124m"[39m[38;5;130;01m\n[39;00m[38;5;130;01m\n[39;00m[38;5;124m=== WITH CACHING ===[39m[38;5;124m"[39m)
[1;32m     21[0m [38;5;66;03m# Cache the DataFrame[39;00m
[0;32m---> 22[0m df_cached [38;5;241m=[39m df[38;5;241m.[39mfilter(col([38;5;124m"[39m[38;5;124mtrip_distance[39m[38;5;124m"[39m) [38;5;241m>[39m [38;5;241m10[39m)[38;5;241m.[39mcache()
[1;32m     24[0m [38;5;66;03m# First action - computation + caching[39;00m
[1;32m     25[0m start_time [38;5;241m=[39m time[38;5;241m.[39mtime()

File [0;32m/databricks/python/lib/python3.12/site-packages/pyspark/sql/connect/dataframe.py:2093[0m, in [0;36mDataFrame.cache[0;34m(self)[0m
[1;32m   2092[0m [38;5;28;01mdef[39;00m [

# 3. Actions vs Transformations

## Difference between transformations (lazy) and actions (eager)

In [0]:
# 1. Transformations
print("Transformations (lazy):")
df_filtered = df.filter(col("trip_distance") > 5)
df_with_column = df_filtered.withColumn("double_distance", col("trip_distance") * 2)
df_grouped = df_with_column.groupBy("VendorID").agg(avg("trip_distance"))

print("Transformations created but NOT executed yet")
print("No work has been done by Spark!")
print("I've just built an execution plan")

# 2. Actions
print("\nActions (eager):")
print("About to trigger execution with count()")

start_time = time.time()
count_result = df_grouped.count()
execution_time = time.time() - start_time

print(f"Action completed! Count: {count_result}")
print(f"Execution time: {execution_time:.2f}s")
print("Now Spark executed the entire pipeline I built")

Transformations (lazy):
Transformations created but NOT executed yet
No work has been done by Spark!
I've just built an execution plan

Actions (eager):
About to trigger execution with count()
Action completed! Count: 3
Execution time: 8.13s
Now Spark executed the entire pipeline I built


# Machine Learning

In [0]:
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.regression import LinearRegression
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql.functions import *

In [0]:
print("Preprocessing")

df_ml = df.filter(
        (col("trip_distance") > 0) &
        (col("fare_amount") > 0) &
        (col("passenger_count") > 0) &
        (col("tpep_pickup_datetime").isNotNull()) &
        (col("tpep_dropoff_datetime").isNotNull())
    ).withColumn(
        "trip_duration_minutes", 
        (col("tpep_dropoff_datetime").cast("long") - col("tpep_pickup_datetime").cast("long")) / 60
    ).filter(
        col("trip_duration_minutes") > 0
    ).select(
        "fare_amount",           
        "trip_distance",         
        "trip_duration_minutes", 
        "passenger_count",       
        "PULocationID"          
    ).limit(10000)

print(f"Dataset size for ML: {df_ml.count()} rows")
df_ml.show(5)


Preprocessing
Dataset size for ML: 10000 rows
+-----------+-------------+---------------------+---------------+------------+
|fare_amount|trip_distance|trip_duration_minutes|passenger_count|PULocationID|
+-----------+-------------+---------------------+---------------+------------+
|       14.5|          4.2|   14.783333333333333|              1|         142|
|       28.5|          9.4|   21.266666666666666|              2|         138|
|        9.0|          1.6|   11.083333333333334|              2|         161|
|        6.5|          1.0|    9.766666666666667|              2|         161|
|        5.5|          0.9|                  4.5|              1|         148|
+-----------+-------------+---------------------+---------------+------------+
only showing top 5 rows


In [0]:
print("Feature Engineering")
location_indexer = StringIndexer(
    inputCol="PULocationID", 
    outputCol="location_index",
    handleInvalid="keep"
)

feature_assembler = VectorAssembler(
    inputCols=["trip_distance", "trip_duration_minutes", "passenger_count", "location_index"],
    outputCol="features"
)

Feature Engineering


In [0]:
print("Train-test split")
train_data, test_data = df_ml.randomSplit([0.8, 0.2], seed=2025)
print(f"Training data: {train_data.count()} rows")
print(f"Testing data: {test_data.count()} rows")

Train-test split
Training data: 7987 rows
Testing data: 2013 rows


In [0]:
print("Pipeline")

lr = LinearRegression(
    featuresCol="features",
    labelCol="fare_amount",
    predictionCol="prediction"
)

pipeline = Pipeline(stages=[location_indexer, feature_assembler, lr])

Pipeline


In [0]:
print("Training the model")
model = pipeline.fit(train_data)

print("\nMaking predictions")
predictions = model.transform(test_data)

print("\nPrediction results")
predictions.select("fare_amount", "prediction", "features").show(10)

Training the model

Making predictions

Prediction results
+-----------+------------------+--------------------+
|fare_amount|        prediction|            features|
+-----------+------------------+--------------------+
|        2.5| 3.992500482960982|[0.1,0.4666666666...|
|        2.5| 4.365149776991796|[0.15,0.566666666...|
|        2.5| 4.611382939456696| [0.16,0.8,1.0,43.0]|
|        2.5|  9.22153515092417|[1.9,0.1833333333...|
|        2.5|29.653544811505967|[9.4,0.8166666666...|
|        3.0| 4.277449081602778|[0.08,1.533333333...|
|        3.0|  4.48868927524401|[0.16,0.916666666...|
|        3.0|  4.16997480403122|[0.17,0.933333333...|
|        3.0| 4.403027026392697|[0.2,1.3666666666...|
|        3.0| 4.647223294582929|[0.21,0.933333333...|
+-----------+------------------+--------------------+
only showing top 10 rows


In [0]:
print("\nEvaluation")
evaluator = RegressionEvaluator(
    labelCol="fare_amount",
    predictionCol="prediction", 
    metricName="rmse"
)

rmse = evaluator.evaluate(predictions)
r2_evaluator = RegressionEvaluator(labelCol="fare_amount", predictionCol="prediction", metricName="r2")
r2 = r2_evaluator.evaluate(predictions)

print(f"Root Mean Squared Error (RMSE): {rmse:.2f}")
print(f"R² Score: {r2:.2f}")


Evaluation
Root Mean Squared Error (RMSE): 2.74
R² Score: 0.92
