In [1]:

from pyspark.sql import SparkSession
import pyspark.sql.types as st
import pyspark.sql.functions as sf




In [2]:

spark = (
    SparkSession
    .builder
    .appName("Pyspark tutorial")
    .master("local[*]")
    .getOrCreate()
)



In [3]:
data = [

(1, 'Alice', 20 ),
(2, "Bob", 25),
(3, "Charlie", 30),
(4, "David", 35)


]
columns = ["id", "name", "age"]

df = spark.createDataFrame(data, columns)

In [4]:
df.show()

+---+-------+---+
| id|   name|age|
+---+-------+---+
|  1|  Alice| 20|
|  2|    Bob| 25|
|  3|Charlie| 30|
|  4|  David| 35|
+---+-------+---+



In [5]:
df.printSchema()

root
 |-- id: long (nullable = true)
 |-- name: string (nullable = true)
 |-- age: long (nullable = true)



In [6]:
schema = st.StructType(
    [
        st.StructField("id", st.IntegerType(), True),
        st.StructField("name", st.StringType(), True),
        st.StructField("age", st.IntegerType(), True)
    ]
)

df = spark.createDataFrame(data, schema)

In [7]:
df.printSchema()

root
 |-- id: integer (nullable = true)
 |-- name: string (nullable = true)
 |-- age: integer (nullable = true)



In [8]:
df_read_csv = spark.read.csv("data/2010-summary.csv", header=True)
df_read_csv.show(10)

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|    United States|            Romania|    1|
|    United States|            Ireland|  264|
|    United States|              India|   69|
|            Egypt|      United States|   24|
|Equatorial Guinea|      United States|    1|
|    United States|          Singapore|   25|
|    United States|            Grenada|   54|
|       Costa Rica|      United States|  477|
|          Senegal|      United States|   29|
|    United States|   Marshall Islands|   44|
+-----------------+-------------------+-----+
only showing top 10 rows



In [16]:
df_read_csv = (

    spark.read.format("csv")
    .option("header", "true")
    .option("delimiter" , ",")
    .load("data/2010-summary.csv")
)

df_read_csv.show(2)

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|    United States|            Romania|    1|
|    United States|            Ireland|  264|
+-----------------+-------------------+-----+
only showing top 2 rows



In [17]:
df_read_parquet = spark.read.parquet("data/yellow_tripdata_2024-09.parquet")
df_read_parquet.show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:05:51|  2024-09-01 00:45:03|              1|          9.8|         1|                 N|         138|          48|           1|       47.8|10.25|    0.5|      13.

In [18]:
df_read_parquet.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)



#### Reading in data

In [19]:
df_read_parquet.count() 

3633030

#### Parquet

In [20]:
df_read_parquet.write.parquet("data/yellow_tripdata_2024-09_copy.parquet")

### Transformation

In [46]:
taxi_df = (

    spark
    .read
    .format("parquet")
    .load("data/yellow_tripdata_2024-09.parquet")
)

In [12]:
taxi_df.show(2)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:05:51|  2024-09-01 00:45:03|              1|          9.8|         1|                 N|         138|          48|           1|       47.8|10.25|    0.5|      13.

In [14]:
taxi_df.describe().show()

+-------+------------------+------------------+-----------------+------------------+------------------+-----------------+-----------------+------------------+------------------+------------------+-------------------+-----------------+------------------+---------------------+------------------+--------------------+-------------------+
|summary|          VendorID|   passenger_count|    trip_distance|        RatecodeID|store_and_fwd_flag|     PULocationID|     DOLocationID|      payment_type|       fare_amount|             extra|            mta_tax|       tip_amount|      tolls_amount|improvement_surcharge|      total_amount|congestion_surcharge|        Airport_fee|
+-------+------------------+------------------+-----------------+------------------+------------------+-----------------+-----------------+------------------+------------------+------------------+-------------------+-----------------+------------------+---------------------+------------------+--------------------+-------------

In [17]:
taxi_df.select("passenger_count", "VendorID").show(5)

+---------------+--------+
|passenger_count|VendorID|
+---------------+--------+
|              1|       1|
|              1|       1|
|              2|       2|
|              1|       2|
|              2|       2|
+---------------+--------+
only showing top 5 rows



In [18]:
taxi_df.select(taxi_df.passenger_count, taxi_df.VendorID).show(5)

+---------------+--------+
|passenger_count|VendorID|
+---------------+--------+
|              1|       1|
|              1|       1|
|              2|       2|
|              1|       2|
|              2|       2|
+---------------+--------+
only showing top 5 rows



In [27]:
taxi_df.select(sf.col("VendorID"), sf.col("passenger_count")).show(5)


+--------+---------------+
|VendorID|passenger_count|
+--------+---------------+
|       1|              1|
|       1|              1|
|       2|              2|
|       2|              1|
|       2|              2|
+--------+---------------+
only showing top 5 rows



##### Filtering

In [28]:
taxi_df.where(taxi_df.passenger_count > 2).show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:08:28|  2024-09-01 00:39:06|              4|          9.8|         1|                 N|          93|         161|           1|       44.3|  3.5|    0.5|      9.8

In [32]:
(taxi_df
 .where(
     (sf.col("passenger_count") > 2)
     &(sf.col("total_amount") > 100)
     )
     .show(5)
     )

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       2| 2024-09-01 00:27:46|  2024-09-01 01:20:29|              3|        27.07|         1|                 N|         132|         181|           1|      105.9|  1.0|    0.5|     21.6

In [42]:
taxi_df.where(~(sf.col("store_and_fwd_flag") == 'N')).show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:23:40|  2024-09-01 00:43:46|              1|          3.9|         1|                 Y|         142|          42|           1|       20.5|  3.5|    0.5|      6.3

##### withColumn

In [47]:
taxi_df = (
    taxi_df
    .withColumn("total_amount_with_all_tax"
                , sf.col("total_amount") + sf.col("congestion_surcharge")
                + sf.col("Airport_fee")
    )
)

(
    taxi_df
    .select(
        sf.col("total_amount_with_all_tax"),
        sf.col("congestion_surcharge"),
        sf.col("Airport_fee"),
        sf.col("total_amount"),
    ).show(5)
)

+-------------------------+--------------------+-----------+------------+
|total_amount_with_all_tax|congestion_surcharge|Airport_fee|total_amount|
+-------------------------+--------------------+-----------+------------+
|                    84.04|                 2.5|       1.75|       79.79|
|                     15.6|                 2.5|        0.0|        13.1|
|                     16.0|                 0.0|        0.0|        16.0|
|                    31.75|                 0.0|        0.0|       31.75|
|                     28.9|                 2.5|        0.0|        26.4|
+-------------------------+--------------------+-----------+------------+
only showing top 5 rows



##### withColumnRenamed

In [48]:
taxi_df = taxi_df.withColumnRenamed("total_amount_with_all_tax", "total_amount_with_taxes")

In [49]:
taxi_df.show(3)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-----------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|total_amount_with_taxes|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-----------------------+
|       1| 2024-09-01 00:05:51|  2024-09-01 00:45:03|              1|          9.8|         1|                 N|  

##### Drop

In [50]:
taxi_df = taxi_df.drop(sf.col("total_amount_with_taxes"))
taxi_df.show(3)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:05:51|  2024-09-01 00:45:03|              1|          9.8|         1|                 N|         138|          48|           1|       47.8|10.25|    0.5|      13.

##### Group By (order by)

In [55]:
(
    taxi_df
    .groupBy(sf.col("PULocationID"))
    .count()
    .sort(sf.col("count"), ascending = False)
    .show(10)
)

+------------+------+
|PULocationID| count|
+------------+------+
|         132|188147|
|         237|166033|
|         161|153988|
|         236|149284|
|         186|120393|
|         162|117163|
|         230|113531|
|         138|111112|
|         142|107537|
|          68|104585|
+------------+------+
only showing top 10 rows



In [60]:
(
    taxi_df
    .groupBy(sf.col("payment_type"))
    .agg(
        sf.sum("total_amount").alias("total_amount"),
        sf.avg("total_amount").alias("avg_amount"),
    )
    .sort(sf.col("total_amount"), ascending=False)
    # .withColumn("total_amount", sf.format_number(sf.col("total_amount"), 2))
    # .withColumn("avg_amount", sf.format_number(sf.col("avg_amount"), 2))
    .withColumn("total_amount", sf.round(sf.col("total_amount"), 2))
    .withColumn("avg_amount", sf.round(sf.col("avg_amount"), 2))
    .show(5)
)


+------------+-------------+----------+
|payment_type| total_amount|avg_amount|
+------------+-------------+----------+
|           1|8.109027601E7|     31.13|
|           0|1.157718759E7|     23.93|
|           2|1.068214365E7|     23.99|
|           3|    198527.68|      7.84|
|           4|    135263.93|      1.83|
+------------+-------------+----------+



In [63]:
# DataFrame 1: employees
employees_data = [
    (1, "Alice", 1),
    (2, "Bob", 2),
    (3, "Charlie", 4),
    (4, "David", 5),
    (5, "Eve", 5),
    (6, "Frank", 3),
    (7, "George", 7),
    (8, "Hannah", 9),
    (9, "Ivan", 5),
    (10, "John", 4),
]
employess_columns = ["employee_id", "name", "department_id"]
employees_df = spark.createDataFrame(employees_data, employess_columns)

# DataFrame 2: departments
departments_data = [
    (1, "IT"),
    (2, "HR"),
    (3, "Sales"),
    (4, "Finance"),
    (5, "Marketing"),
    (6, "Legal"),
    (7, "Operations"),
]
departments_columns = ["department_id", "department_name"]
departments_df = spark.createDataFrame(departments_data, departments_columns)

print(employees_df.show())
print(departments_df.show())

+-----------+-------+-------------+
|employee_id|   name|department_id|
+-----------+-------+-------------+
|          1|  Alice|            1|
|          2|    Bob|            2|
|          3|Charlie|            4|
|          4|  David|            5|
|          5|    Eve|            5|
|          6|  Frank|            3|
|          7| George|            7|
|          8| Hannah|            9|
|          9|   Ivan|            5|
|         10|   John|            4|
+-----------+-------+-------------+

None
+-------------+---------------+
|department_id|department_name|
+-------------+---------------+
|            1|             IT|
|            2|             HR|
|            3|          Sales|
|            4|        Finance|
|            5|      Marketing|
|            6|          Legal|
|            7|     Operations|
+-------------+---------------+

None


In [64]:
inner_join_df = (
    employees_df
    .join(departments_df, "department_id")
    .sort(sf.col("department_id"))

)

inner_join_df.show()

+-------------+-----------+-------+---------------+
|department_id|employee_id|   name|department_name|
+-------------+-----------+-------+---------------+
|            1|          1|  Alice|             IT|
|            2|          2|    Bob|             HR|
|            3|          6|  Frank|          Sales|
|            4|          3|Charlie|        Finance|
|            4|         10|   John|        Finance|
|            5|          4|  David|      Marketing|
|            5|          5|    Eve|      Marketing|
|            5|          9|   Ivan|      Marketing|
|            7|          7| George|     Operations|
+-------------+-----------+-------+---------------+



In [65]:
left_join_df = (
    employees_df
    .join(departments_df, "department_id", how="left")
    .sort(sf.col("department_id"))

)

left_join_df.show()

+-------------+-----------+-------+---------------+
|department_id|employee_id|   name|department_name|
+-------------+-----------+-------+---------------+
|            1|          1|  Alice|             IT|
|            2|          2|    Bob|             HR|
|            3|          6|  Frank|          Sales|
|            4|          3|Charlie|        Finance|
|            4|         10|   John|        Finance|
|            5|          4|  David|      Marketing|
|            5|          5|    Eve|      Marketing|
|            5|          9|   Ivan|      Marketing|
|            7|          7| George|     Operations|
|            9|          8| Hannah|           NULL|
+-------------+-----------+-------+---------------+



In [66]:
outer_join_df = (
    employees_df
    .join(departments_df, "department_id", how="outer")
    .sort(sf.col("department_id"))

)

outer_join_df.show()

+-------------+-----------+-------+---------------+
|department_id|employee_id|   name|department_name|
+-------------+-----------+-------+---------------+
|            1|          1|  Alice|             IT|
|            2|          2|    Bob|             HR|
|            3|          6|  Frank|          Sales|
|            4|          3|Charlie|        Finance|
|            4|         10|   John|        Finance|
|            5|          4|  David|      Marketing|
|            5|          5|    Eve|      Marketing|
|            5|          9|   Ivan|      Marketing|
|            6|       NULL|   NULL|          Legal|
|            7|          7| George|     Operations|
|            9|          8| Hannah|           NULL|
+-------------+-----------+-------+---------------+



##### SparkSql

In [69]:
taxi_df = (
    spark
    .read
    .format("parquet")
    .load("data/yellow_tripdata_2024-09.parquet")
)

In [70]:
taxi_df.createOrReplaceTempView("taxi_temp_view")

In [71]:
spark.sql("select * from taxi_temp_view limit 10").show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:05:51|  2024-09-01 00:45:03|              1|          9.8|         1|                 N|         138|          48|           1|       47.8|10.25|    0.5|      13.

In [72]:
spark.sql(
    """
    SELECT
        *
    FROM
        taxi_temp_view
    WHERE
        trip_distance > 10
    Limit 10
    """
).show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:57:30|  2024-09-01 01:40:20|              2|         14.5|         1|                 N|         132|          91|           1|       64.6| 2.75|    0.5|     13.7

In [73]:
avg_amount_by_payment_type = spark.sql(
    """
    SELECT
        payment_type,
        avg(total_amount) as avg_amount
    FROM
        taxi_temp_view
    GROUP BY
        payment_type
    ORDER BY
        avg_amount DESC
    """
)

avg_amount_by_payment_type.show(10)

+------------+------------------+
|payment_type|        avg_amount|
+------------+------------------+
|           1|31.132636716912728|
|           2|23.992416616129827|
|           0|23.933110737169443|
|           3| 7.835484864032835|
|           4|1.8263854120252212|
+------------+------------------+



##### Caching

In [74]:
import time

In [77]:
data = [(i, i * 2) for i in range(1, 1000000)]
df = spark.createDataFrame(data, ["number", "doubled"])
df.show(5)

start_time = time.time()

df_filtered = df.filter(df["number"] % 2 == 0)
print(df_filtered.count())
print(df_filtered.count())

end_time = time.time()
print("Execution time: ", end_time - start_time)


+------+-------+
|number|doubled|
+------+-------+
|     1|      2|
|     2|      4|
|     3|      6|
|     4|      8|
|     5|     10|
+------+-------+
only showing top 5 rows

499999
499999
Execution time:  15.572580575942993


In [78]:
data = [(i, i * 2) for i in range(1, 1000000)]
df_2 = spark.createDataFrame(data, ["number", "doubled"])
df_2.show(5)

start_time = time.time()

df_filtered = df_2.filter(df_2["number"] % 2 == 0).cache()
print(df_filtered.count())
print(df_filtered.count())

end_time = time.time()
print("Execution time: ", end_time - start_time)

+------+-------+
|number|doubled|
+------+-------+
|     1|      2|
|     2|      4|
|     3|      6|
|     4|      8|
|     5|     10|
+------+-------+
only showing top 5 rows

499999
499999
Execution time:  7.967309951782227
