<a href="https://colab.research.google.com/github/pejmanrasti/Big_Data/blob/main/03_PySpark_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PySpark Tutorial.

## 0. Environment Setup

In [None]:
# Install PySpark only (no Hadoop, no Spark binary)
!pip install -q pyspark

from pyspark.sql import SparkSession

spark = (
    SparkSession.builder
    .appName("PySpark_Complete_Tutorial")
    .master("local[*]")  # local mode, all cores
    .getOrCreate()
)

spark

## 1. Load Dataset
We’ll use the classic 2015 flight summary dataset (origin, destination, number of flights), hosted in the Databricks “Spark: The Definitive Guide” repo.

In [None]:
# Download dataset from URL
!wget -q https://raw.githubusercontent.com/databricks/Spark-The-Definitive-Guide/master/data/flight-data/csv/2015-summary.csv -O flights.csv

# Load into a DataFrame
flights = spark.read.csv(
    "flights.csv",
    header=True,        # first row is header
    inferSchema=True    # infer column types
)

# Peek at the data
flights.show(10)
flights.printSchema()

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|    United States|            Romania|   15|
|    United States|            Croatia|    1|
|    United States|            Ireland|  344|
|            Egypt|      United States|   15|
|    United States|              India|   62|
|    United States|          Singapore|    1|
|    United States|            Grenada|   62|
|       Costa Rica|      United States|  588|
|          Senegal|      United States|   40|
|          Moldova|      United States|    1|
+-----------------+-------------------+-----+
only showing top 10 rows

root
 |-- DEST_COUNTRY_NAME: string (nullable = true)
 |-- ORIGIN_COUNTRY_NAME: string (nullable = true)
 |-- count: integer (nullable = true)



## 2. DataFrame Operations

**2.1 Transformations vs Actions**

•	Transformations: lazy (e.g., select, filter, withColumn, groupBy).

•	Actions: trigger execution (e.g., show, count, collect, take).

In [None]:
# Example of an action
total_rows = flights.count()
print("Number of rows:", total_rows)

# Example of a lazy transformation (not executed until an action)
only_us_dest = flights.filter(flights["DEST_COUNTRY_NAME"] == "United States")

# Action triggers execution
only_us_dest.show(5)

Number of rows: 256
+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|    United States|            Romania|   15|
|    United States|            Croatia|    1|
|    United States|            Ireland|  344|
|    United States|              India|   62|
|    United States|          Singapore|    1|
+-----------------+-------------------+-----+
only showing top 5 rows



**2.2 Basic column selection and renaming**

In [None]:
from pyspark.sql.functions import col

# Select specific columns
flights.select("DEST_COUNTRY_NAME", "ORIGIN_COUNTRY_NAME", "count").show(5)

# Using col() and alias()
flights.select(
    col("DEST_COUNTRY_NAME").alias("dest"),
    col("ORIGIN_COUNTRY_NAME").alias("origin"),
    col("count").alias("num_flights")
).show(5)

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|    United States|            Romania|   15|
|    United States|            Croatia|    1|
|    United States|            Ireland|  344|
|            Egypt|      United States|   15|
|    United States|              India|   62|
+-----------------+-------------------+-----+
only showing top 5 rows

+-------------+-------------+-----------+
|         dest|       origin|num_flights|
+-------------+-------------+-----------+
|United States|      Romania|         15|
|United States|      Croatia|          1|
|United States|      Ireland|        344|
|        Egypt|United States|         15|
|United States|        India|         62|
+-------------+-------------+-----------+
only showing top 5 rows



**2.3 Filtering rows (where / filter)**

In [None]:
# Simple equality filter
flights.filter(flights["DEST_COUNTRY_NAME"] == "United States").show(5)

# Multiple conditions with &, | (AND / OR)
flights.filter(
    (col("DEST_COUNTRY_NAME") == "United States") &
    (col("ORIGIN_COUNTRY_NAME") == "Canada")
).show(5)

# Using .where() – same as filter
flights.where(col("count") > 10000).show(5)

# IN, NOT IN style using isin()
flights.filter(col("DEST_COUNTRY_NAME").isin("United States", "Canada", "Mexico")).show(10)

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|    United States|            Romania|   15|
|    United States|            Croatia|    1|
|    United States|            Ireland|  344|
|    United States|              India|   62|
|    United States|          Singapore|    1|
+-----------------+-------------------+-----+
only showing top 5 rows

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|    United States|             Canada| 8483|
+-----------------+-------------------+-----+

+-----------------+-------------------+------+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME| count|
+-----------------+-------------------+------+
|    United States|      United States|370002|
+-----------------+-------------------+------+

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-

**2.4 Sorting and limiting**

In [None]:
# Top 10 routes by number of flights
flights.orderBy(col("count").desc()).show(10)

# Sort by multiple columns
flights.orderBy(col("DEST_COUNTRY_NAME").asc(), col("count").desc()).show(10)

# randomSplit for sampling / splitting sets
train_df, test_df = flights.randomSplit([0.7, 0.3], seed=42)
print("Train size:", train_df.count(), "Test size:", test_df.count())

+-----------------+-------------------+------+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME| count|
+-----------------+-------------------+------+
|    United States|      United States|370002|
|    United States|             Canada|  8483|
|           Canada|      United States|  8399|
|    United States|             Mexico|  7187|
|           Mexico|      United States|  7140|
|   United Kingdom|      United States|  2025|
|    United States|     United Kingdom|  1970|
|            Japan|      United States|  1548|
|    United States|              Japan|  1496|
|          Germany|      United States|  1468|
+-----------------+-------------------+------+
only showing top 10 rows

+-------------------+-------------------+-----+
|  DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-------------------+-------------------+-----+
|            Algeria|      United States|    4|
|             Angola|      United States|   15|
|           Anguilla|      United States|   41|
|Antigua and Barbuda|      U

**2.5 Derived columns: withColumn**

In [None]:
from pyspark.sql.functions import lit, expr

# Add a constant column
flights_with_year = flights.withColumn("year", lit(2015))
flights_with_year.show(5)

# Add a column with a simple expression
flights_double = flights.withColumn("count_x2", col("count") * 2)
flights_double.show(5)

# Conditional column using when/otherwise
from pyspark.sql.functions import when

flights_flag = flights.withColumn(
    "is_big_route",
    when(col("count") > 10000, lit(1)).otherwise(lit(0))
)
flights_flag.show(10)

# Using expr() for SQL-like expressions
flights_expr = flights.withColumn(
    "count_log",
    expr("log10(count + 1)")
)
flights_expr.select("DEST_COUNTRY_NAME", "count", "count_log").show(5)

+-----------------+-------------------+-----+----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|year|
+-----------------+-------------------+-----+----+
|    United States|            Romania|   15|2015|
|    United States|            Croatia|    1|2015|
|    United States|            Ireland|  344|2015|
|            Egypt|      United States|   15|2015|
|    United States|              India|   62|2015|
+-----------------+-------------------+-----+----+
only showing top 5 rows

+-----------------+-------------------+-----+--------+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|count_x2|
+-----------------+-------------------+-----+--------+
|    United States|            Romania|   15|      30|
|    United States|            Croatia|    1|       2|
|    United States|            Ireland|  344|     688|
|            Egypt|      United States|   15|      30|
|    United States|              India|   62|     124|
+-----------------+-------------------+-----+--------+
only showing top 5 ro

**2.6 Handling missing values**

This particular dataset is clean, but show the API:

In [None]:
# Show how to drop rows with any nulls
flights_no_nulls = flights.na.drop()

# Fill nulls (example only)
flights_filled = flights.na.fill({
    "DEST_COUNTRY_NAME": "Unknown",
    "count": 0
})

# Replace specific values
flights_replaced = flights.replace("United States", "USA", subset=["DEST_COUNTRY_NAME"])
flights_replaced.show(5)

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|              USA|            Romania|   15|
|              USA|            Croatia|    1|
|              USA|            Ireland|  344|
|            Egypt|      United States|   15|
|              USA|              India|   62|
+-----------------+-------------------+-----+
only showing top 5 rows



**2.7 Distinct and deduplication**

In [None]:
# Unique destination countries
dest_countries = flights.select("DEST_COUNTRY_NAME").distinct()
print("Number of destination countries:", dest_countries.count())
dest_countries.show(20)

# Drop duplicates on multiple columns
unique_routes = flights.dropDuplicates(["DEST_COUNTRY_NAME", "ORIGIN_COUNTRY_NAME"])
print("Unique routes:", unique_routes.count())

Number of destination countries: 132
+--------------------+
|   DEST_COUNTRY_NAME|
+--------------------+
|            Anguilla|
|              Russia|
|            Paraguay|
|             Senegal|
|              Sweden|
|            Kiribati|
|              Guyana|
|         Philippines|
|            Djibouti|
|            Malaysia|
|           Singapore|
|                Fiji|
|              Turkey|
|                Iraq|
|             Germany|
|              Jordan|
|               Palau|
|Turks and Caicos ...|
|              France|
|              Greece|
+--------------------+
only showing top 20 rows

Unique routes: 256


## 3. Aggregations

**3.1 Simple aggregations**

In [None]:
from pyspark.sql.functions import sum, avg, min, max, countDistinct, approx_count_distinct

# Total flights in dataset
flights.select(sum("count").alias("total_flights")).show()

# Average flights per route
flights.select(avg("count").alias("avg_flights_per_route")).show()

# Min and max number of flights for a route
flights.select(
    min("count").alias("min_routes"),
    max("count").alias("max_routes")
).show()

+-------------+
|total_flights|
+-------------+
|       453316|
+-------------+

+---------------------+
|avg_flights_per_route|
+---------------------+
|          1770.765625|
+---------------------+

+----------+----------+
|min_routes|max_routes|
+----------+----------+
|         1|    370002|
+----------+----------+



**3.2 GroupBy aggregations**

In [None]:
# Total flights per destination country
flights.groupBy("DEST_COUNTRY_NAME") \
       .agg(sum("count").alias("total_dest_flights")) \
       .orderBy(col("total_dest_flights").desc()) \
       .show(20)

# Number of routes and total flights per origin country
flights.groupBy("ORIGIN_COUNTRY_NAME") \
       .agg(
           countDistinct("DEST_COUNTRY_NAME").alias("num_destinations"),
           sum("count").alias("total_outgoing_flights")
       ) \
       .orderBy(col("total_outgoing_flights").desc()) \
       .show(20)

+------------------+------------------+
| DEST_COUNTRY_NAME|total_dest_flights|
+------------------+------------------+
|     United States|            411352|
|            Canada|              8399|
|            Mexico|              7140|
|    United Kingdom|              2025|
|             Japan|              1548|
|           Germany|              1468|
|Dominican Republic|              1353|
|       South Korea|              1048|
|       The Bahamas|               955|
|            France|               935|
|          Colombia|               873|
|            Brazil|               853|
|       Netherlands|               776|
|             China|               772|
|           Jamaica|               666|
|        Costa Rica|               588|
|       El Salvador|               561|
|            Panama|               510|
|              Cuba|               466|
|             Spain|               420|
+------------------+------------------+
only showing top 20 rows

+-------------

**3.3 More complex aggregations: rollup and cube (optional)**

In [None]:
from pyspark.sql.functions import sum

# Total flights per origin, and overall total
flights_rollup = flights.rollup("ORIGIN_COUNTRY_NAME").agg(
    sum("count").alias("total_flights")
).orderBy("ORIGIN_COUNTRY_NAME")

flights_rollup.show(50)

+--------------------+-------------+
| ORIGIN_COUNTRY_NAME|total_flights|
+--------------------+-------------+
|                NULL|       453316|
|              Angola|           13|
|            Anguilla|           38|
| Antigua and Barbuda|          117|
|           Argentina|          141|
|               Aruba|          342|
|           Australia|          258|
|             Austria|           63|
|          Azerbaijan|           21|
|             Bahrain|            1|
|            Barbados|          130|
|             Belgium|          228|
|              Belize|          193|
|             Bermuda|          193|
|             Bolivia|           13|
|Bonaire, Sint Eus...|           59|
|              Brazil|          619|
|British Virgin Is...|           80|
|            Bulgaria|            1|
|              Canada|         8483|
|          Cape Verde|           14|
|      Cayman Islands|          310|
|               Chile|          185|
|               China|          920|
|

# 4. Spark SQL – A More “Relational” View

**4.1 Create a temporary view**

In [None]:
flights.createOrReplaceTempView("flights_table")

**4.2 Simple SQL queries**

In [None]:
# Top 10 routes by flights
spark.sql("""
SELECT
  DEST_COUNTRY_NAME,
  ORIGIN_COUNTRY_NAME,
  count
FROM flights_table
ORDER BY count DESC
LIMIT 10
""").show()

# Total flights per destination
spark.sql("""
SELECT
  DEST_COUNTRY_NAME,
  SUM(count) AS total_flights
FROM flights_table
GROUP BY DEST_COUNTRY_NAME
ORDER BY total_flights DESC
LIMIT 20
""").show()

+-----------------+-------------------+------+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME| count|
+-----------------+-------------------+------+
|    United States|      United States|370002|
|    United States|             Canada|  8483|
|           Canada|      United States|  8399|
|    United States|             Mexico|  7187|
|           Mexico|      United States|  7140|
|   United Kingdom|      United States|  2025|
|    United States|     United Kingdom|  1970|
|            Japan|      United States|  1548|
|    United States|              Japan|  1496|
|          Germany|      United States|  1468|
+-----------------+-------------------+------+

+------------------+-------------+
| DEST_COUNTRY_NAME|total_flights|
+------------------+-------------+
|     United States|       411352|
|            Canada|         8399|
|            Mexico|         7140|
|    United Kingdom|         2025|
|             Japan|         1548|
|           Germany|         1468|
|Dominican Republic|      

**4.3 Filters and CASE WHEN logic**

In [None]:
spark.sql("""
SELECT
  DEST_COUNTRY_NAME,
  ORIGIN_COUNTRY_NAME,
  count,
  CASE
    WHEN count > 10000 THEN 'HIGH'
    WHEN count BETWEEN 3000 AND 10000 THEN 'MEDIUM'
    ELSE 'LOW'
  END AS traffic_category
FROM flights_table
ORDER BY count DESC
LIMIT 20
""").show()

+------------------+-------------------+------+----------------+
| DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME| count|traffic_category|
+------------------+-------------------+------+----------------+
|     United States|      United States|370002|            HIGH|
|     United States|             Canada|  8483|          MEDIUM|
|            Canada|      United States|  8399|          MEDIUM|
|     United States|             Mexico|  7187|          MEDIUM|
|            Mexico|      United States|  7140|          MEDIUM|
|    United Kingdom|      United States|  2025|             LOW|
|     United States|     United Kingdom|  1970|             LOW|
|             Japan|      United States|  1548|             LOW|
|     United States|              Japan|  1496|             LOW|
|           Germany|      United States|  1468|             LOW|
|     United States| Dominican Republic|  1420|             LOW|
|Dominican Republic|      United States|  1353|             LOW|
|     United States|     

**4.4 Window functions in SQL**

In [None]:
spark.sql("""
SELECT
  DEST_COUNTRY_NAME,
  ORIGIN_COUNTRY_NAME,
  count,
  ROW_NUMBER() OVER (PARTITION BY DEST_COUNTRY_NAME ORDER BY count DESC) AS rn
FROM flights_table
""").createOrReplaceTempView("flights_ranked")

# Top 3 routes per destination
spark.sql("""
SELECT *
FROM flights_ranked
WHERE rn <= 3
ORDER BY DEST_COUNTRY_NAME, count DESC
""").show(50)

+--------------------+-------------------+-----+---+
|   DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count| rn|
+--------------------+-------------------+-----+---+
|             Algeria|      United States|    4|  1|
|              Angola|      United States|   15|  1|
|            Anguilla|      United States|   41|  1|
| Antigua and Barbuda|      United States|  126|  1|
|           Argentina|      United States|  180|  1|
|               Aruba|      United States|  346|  1|
|           Australia|      United States|  329|  1|
|             Austria|      United States|   62|  1|
|          Azerbaijan|      United States|   21|  1|
|             Bahrain|      United States|   19|  1|
|            Barbados|      United States|  154|  1|
|             Belgium|      United States|  259|  1|
|              Belize|      United States|  188|  1|
|             Bermuda|      United States|  183|  1|
|             Bolivia|      United States|   30|  1|
|Bonaire, Sint Eus...|      United States|   5

## 5. Joins and Multi-Dataset Analytics

We’ll create a small countries lookup DataFrame to join with.

In [None]:
from pyspark.sql import Row

countries_data = [
    Row(country="United States", iso_code="USA", region="North America"),
    Row(country="Canada", iso_code="CAN", region="North America"),
    Row(country="Mexico", iso_code="MEX", region="North America"),
    Row(country="France", iso_code="FRA", region="Europe"),
    Row(country="United Kingdom", iso_code="GBR", region="Europe"),
    Row(country="China", iso_code="CHN", region="Asia"),
    Row(country="Japan", iso_code="JPN", region="Asia"),
]

countries_df = spark.createDataFrame(countries_data)
countries_df.show()

+--------------+--------+-------------+
|       country|iso_code|       region|
+--------------+--------+-------------+
| United States|     USA|North America|
|        Canada|     CAN|North America|
|        Mexico|     MEX|North America|
|        France|     FRA|       Europe|
|United Kingdom|     GBR|       Europe|
|         China|     CHN|         Asia|
|         Japan|     JPN|         Asia|
+--------------+--------+-------------+



**5.1 Left join: annotate destination with region**

In [None]:
joined_dest = flights.join(
    countries_df,
    flights["DEST_COUNTRY_NAME"] == countries_df["country"],
    how="left"
)

joined_dest.select(
    "DEST_COUNTRY_NAME", "ORIGIN_COUNTRY_NAME", "count", "region"
).show(20)

+--------------------+-------------------+-----+-------------+
|   DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|       region|
+--------------------+-------------------+-----+-------------+
|            Anguilla|      United States|   41|         NULL|
|             Senegal|      United States|   40|         NULL|
|              Guyana|      United States|   64|         NULL|
|Turks and Caicos ...|      United States|  230|         NULL|
|             Algeria|      United States|    4|         NULL|
|       United States|            Romania|   15|North America|
|       United States|            Croatia|    1|North America|
|       United States|            Ireland|  344|North America|
|       United States|              India|   62|North America|
|       United States|          Singapore|    1|North America|
|       United States|            Grenada|   62|North America|
|       United States|       Sint Maarten|  325|North America|
|       United States|   Marshall Islands|   39|North A

**5.2 Inner join: only routes where destination is in countries_df**

In [None]:
inner_dest = flights.join(
    countries_df,
    flights["DEST_COUNTRY_NAME"] == countries_df["country"],
    how="inner"
)

inner_dest.select(
    "DEST_COUNTRY_NAME", "ORIGIN_COUNTRY_NAME", "count", "region"
).orderBy(col("count").desc()).show(20)

+-----------------+-------------------+------+-------------+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME| count|       region|
+-----------------+-------------------+------+-------------+
|    United States|      United States|370002|North America|
|    United States|             Canada|  8483|North America|
|           Canada|      United States|  8399|North America|
|    United States|             Mexico|  7187|North America|
|           Mexico|      United States|  7140|North America|
|   United Kingdom|      United States|  2025|       Europe|
|    United States|     United Kingdom|  1970|North America|
|            Japan|      United States|  1548|         Asia|
|    United States|              Japan|  1496|North America|
|    United States| Dominican Republic|  1420|North America|
|    United States|            Germany|  1336|North America|
|    United States|        The Bahamas|   986|North America|
|    United States|             France|   952|North America|
|           France|     

**5.3 Left anti join: routes with destinations NOT in lookup**

In [None]:
not_matched_dest = flights.join(
    countries_df,
    flights["DEST_COUNTRY_NAME"] == countries_df["country"],
    how="left_anti"
)

print("Routes with destination not in countries_df:", not_matched_dest.count())
not_matched_dest.show(20)

Routes with destination not in countries_df: 125
+--------------------+-------------------+-----+
|   DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+--------------------+-------------------+-----+
|            Anguilla|      United States|   41|
|              Russia|      United States|  176|
|            Paraguay|      United States|   60|
|             Senegal|      United States|   40|
|              Sweden|      United States|  118|
|            Kiribati|      United States|   26|
|              Guyana|      United States|   64|
|         Philippines|      United States|  134|
|            Djibouti|      United States|    1|
|            Malaysia|      United States|    2|
|           Singapore|      United States|    3|
|                Fiji|      United States|   24|
|              Turkey|      United States|  138|
|                Iraq|      United States|    1|
|             Germany|      United States| 1468|
|              Jordan|      United States|   44|
|               Pala

**5.4 Aggregation with join: total flights per region**

In [None]:
from pyspark.sql.functions import sum

region_stats = flights.join(
    countries_df,
    flights["DEST_COUNTRY_NAME"] == countries_df["country"],
    how="inner"
).groupBy("region") \
 .agg(sum("count").alias("total_flights_to_region")) \
 .orderBy(col("total_flights_to_region").desc())

region_stats.show()

+-------------+-----------------------+
|       region|total_flights_to_region|
+-------------+-----------------------+
|North America|                 426891|
|       Europe|                   2960|
|         Asia|                   2320|
+-------------+-----------------------+



## 6. MLlib – Simple Binary Classification

We’ll build a simple model to classify whether a route is “high traffic” (count > 1000).

**6.1 Prepare labeled data**

In [None]:
from pyspark.sql.functions import when

ml_data = flights.withColumn(
    "label",
    when(col("count") > 1000, 1).otherwise(0)
)

ml_data.select("DEST_COUNTRY_NAME", "ORIGIN_COUNTRY_NAME", "count", "label").show(10)

+-----------------+-------------------+-----+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|label|
+-----------------+-------------------+-----+-----+
|    United States|            Romania|   15|    0|
|    United States|            Croatia|    1|    0|
|    United States|            Ireland|  344|    0|
|            Egypt|      United States|   15|    0|
|    United States|              India|   62|    0|
|    United States|          Singapore|    1|    0|
|    United States|            Grenada|   62|    0|
|       Costa Rica|      United States|  588|    0|
|          Senegal|      United States|   40|    0|
|          Moldova|      United States|    1|    0|
+-----------------+-------------------+-----+-----+
only showing top 10 rows



**6.2 Feature engineering**

We’ll use count and log(count+1) as features.

In [None]:
from pyspark.ml.feature import VectorAssembler

ml_data = ml_data.withColumn("count_log", expr("log10(count + 1)"))

assembler = VectorAssembler(
    inputCols=["count", "count_log"],
    outputCol="features"
)

data_ml = assembler.transform(ml_data).select("features", "label")
data_ml.show(10, truncate=False)

+--------------------------+-----+
|features                  |label|
+--------------------------+-----+
|[15.0,1.2041199826559248] |0    |
|[1.0,0.3010299956639812]  |0    |
|[344.0,2.537819095073274] |0    |
|[15.0,1.2041199826559248] |0    |
|[62.0,1.7993405494535817] |0    |
|[1.0,0.3010299956639812]  |0    |
|[62.0,1.7993405494535817] |0    |
|[588.0,2.7701152947871015]|0    |
|[40.0,1.6127838567197355] |0    |
|[1.0,0.3010299956639812]  |0    |
+--------------------------+-----+
only showing top 10 rows



**6.3 Train/test split and model training**

In [None]:
train_df, test_df = data_ml.randomSplit([0.7, 0.3], seed=42)
print("Train size:", train_df.count(), "Test size:", test_df.count())

from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression(featuresCol="features", labelCol="label")
lr_model = lr.fit(train_df)

print("Coefficients:", lr_model.coefficients)
print("Intercept:", lr_model.intercept)

Train size: 190 Test size: 66
Coefficients: [0.010942358438595649,878.5610282504983]
Intercept: -2653.4227379957615


**6.4 Evaluation**

In [None]:
predictions = lr_model.transform(test_df)
predictions.select("features", "label", "prediction", "probability").show(20, truncate=False)

+-------------------------+-----+----------+-----------+
|features                 |label|prediction|probability|
+-------------------------+-----+----------+-----------+
|[1.0,0.3010299956639812] |0    |0.0       |[1.0,0.0]  |
|[1.0,0.3010299956639812] |0    |0.0       |[1.0,0.0]  |
|[1.0,0.3010299956639812] |0    |0.0       |[1.0,0.0]  |
|[1.0,0.3010299956639812] |0    |0.0       |[1.0,0.0]  |
|[1.0,0.3010299956639812] |0    |0.0       |[1.0,0.0]  |
|[1.0,0.3010299956639812] |0    |0.0       |[1.0,0.0]  |
|[1.0,0.3010299956639812] |0    |0.0       |[1.0,0.0]  |
|[1.0,0.3010299956639812] |0    |0.0       |[1.0,0.0]  |
|[1.0,0.3010299956639812] |0    |0.0       |[1.0,0.0]  |
|[1.0,0.3010299956639812] |0    |0.0       |[1.0,0.0]  |
|[1.0,0.3010299956639812] |0    |0.0       |[1.0,0.0]  |
|[2.0,0.47712125471966244]|0    |0.0       |[1.0,0.0]  |
|[2.0,0.47712125471966244]|0    |0.0       |[1.0,0.0]  |
|[2.0,0.47712125471966244]|0    |0.0       |[1.0,0.0]  |
|[2.0,0.47712125471966244]|0   

Compute AUC:

In [None]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction")
auc = evaluator.evaluate(predictions)
print("AUC:", auc)

AUC: 1.0


**6.5 Simple confusion matrix**

In [None]:
# Count TP, TN, FP, FN
confusion = predictions.select("label", "prediction")

tp = confusion.filter("label = 1 AND prediction = 1").count()
tn = confusion.filter("label = 0 AND prediction = 0").count()
fp = confusion.filter("label = 0 AND prediction = 1").count()
fn = confusion.filter("label = 1 AND prediction = 0").count()

print("TP:", tp, "TN:", tn, "FP:", fp, "FN:", fn)

accuracy = (tp + tn) / (tp + tn + fp + fn)
print("Accuracy:", accuracy)

TP: 1 TN: 65 FP: 0 FN: 0
Accuracy: 1.0


## 7. Performance & Execution Concepts

**7.1 Number of partitions**

In [None]:
num_partitions = flights.rdd.getNumPartitions()
print("Default partitions:", num_partitions)

Default partitions: 1


**7.2 Repartition and coalesce**

In [None]:
flights_8 = flights.repartition(8)
print("After repartition(8):", flights_8.rdd.getNumPartitions())

flights_2 = flights_8.coalesce(2)
print("After coalesce(2):", flights_2.rdd.getNumPartitions())

After repartition(8): 8
After coalesce(2): 2


**7.3 Caching**

In [None]:
from pyspark.sql.functions import sum

# Cache a sub-DataFrame for repeated use
flights_cached = flights.repartition(4).cache()
flights_cached.count()  # materialize cache

# Reuse cached DF in a heavy aggregation
agg1 = flights_cached.groupBy("DEST_COUNTRY_NAME").agg(sum("count").alias("total_flights"))
agg2 = flights_cached.groupBy("ORIGIN_COUNTRY_NAME").agg(sum("count").alias("total_flights"))

agg1.show(5)
agg2.show(5)

+-----------------+-------------+
|DEST_COUNTRY_NAME|total_flights|
+-----------------+-------------+
|         Anguilla|           41|
|          Senegal|           40|
|           Sweden|          118|
|      Philippines|          134|
|           Jordan|           44|
+-----------------+-------------+
only showing top 5 rows

+-------------------+-------------+
|ORIGIN_COUNTRY_NAME|total_flights|
+-------------------+-------------+
|           Anguilla|           38|
|           Kiribati|           35|
|        Philippines|          126|
|           Malaysia|            3|
|             Turkey|          129|
+-------------------+-------------+
only showing top 5 rows



**7.4 Explain plans (logical vs physical)**

In [None]:
# Example query
query_df = flights.groupBy("DEST_COUNTRY_NAME").agg(sum("count").alias("total_flights"))

# Show the execution plan
query_df.explain(True)

== Parsed Logical Plan ==
'Aggregate ['DEST_COUNTRY_NAME], ['DEST_COUNTRY_NAME, sum('count) AS total_flights#2639]
+- Relation [DEST_COUNTRY_NAME#1340,ORIGIN_COUNTRY_NAME#1341,count#1342] csv

== Analyzed Logical Plan ==
DEST_COUNTRY_NAME: string, total_flights: bigint
Aggregate [DEST_COUNTRY_NAME#1340], [DEST_COUNTRY_NAME#1340, sum(count#1342) AS total_flights#2639L]
+- Relation [DEST_COUNTRY_NAME#1340,ORIGIN_COUNTRY_NAME#1341,count#1342] csv

== Optimized Logical Plan ==
Aggregate [DEST_COUNTRY_NAME#1340], [DEST_COUNTRY_NAME#1340, sum(count#1342) AS total_flights#2639L]
+- Project [DEST_COUNTRY_NAME#1340, count#1342]
   +- Relation [DEST_COUNTRY_NAME#1340,ORIGIN_COUNTRY_NAME#1341,count#1342] csv

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[DEST_COUNTRY_NAME#1340], functions=[sum(count#1342)], output=[DEST_COUNTRY_NAME#1340, total_flights#2639L])
   +- Exchange hashpartitioning(DEST_COUNTRY_NAME#1340, 200), ENSURE_REQUIREMENTS, [plan_id=5288]
      +

## 8. Stop Spark

In [None]:
spark.stop()