In [1]:
pip install pyspark

Note: you may need to restart the kernel to use updated packages.


In [2]:
import pyspark
import os



from pyspark.sql import SparkSession
from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark.sql.types import *
import sys

os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable

In [3]:
spark = SparkSession.builder.master("local[*]") \
    .appName("prajwal") \
    .config("spark.driver.extraClassPath", "C:\\my_sql_jar\\mysql-connector-java-8.0.26.jar") \
    .getOrCreate()
print(spark)

<pyspark.sql.session.SparkSession object at 0x00000156D368C280>


In [5]:
flight_df = spark.read.format("csv")\
            .option("header", "true")\
            .option("inferschema", "true")\
            .option("mode", "FAILFAST")\
            .load("flight_data.csv")
flight_df.show(5)

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|    United States|            Romania|    1|
|    United States|            Ireland|  264|
|    United States|              India|   69|
|            Egypt|      United States|   24|
|Equatorial Guinea|      United States|    1|
+-----------------+-------------------+-----+
only showing top 5 rows



## checking count

In [6]:
flight_df.count()

255

In [8]:
# checking the default no. of partitions
flight_df.rdd.getNumPartitions()

1

In [9]:
partitioned_flight_df = flight_df.repartition(4)

In [14]:
# as flight_df has only one partition all the 255 records was in single partition
flight_df.withColumn("partitionId", spark_partition_id()).groupBy("partitionId").count().show()

+-----------+-----+
|partitionId|count|
+-----------+-----+
|          0|  255|
+-----------+-----+



In [19]:
partitioned_flight_df.rdd.getNumPartitions()

4

In [13]:
# post repartitioning the flight_df to 4 partitions now all the records are equally distributed 
partitioned_flight_df.withColumn("partitionId", spark_partition_id()).groupBy("partitionId").count().show()

+-----------+-----+
|partitionId|count|
+-----------+-----+
|          0|   63|
|          1|   64|
|          2|   64|
|          3|   64|
+-----------+-----+



In [15]:
partitioned_col = flight_df.repartition(300, "ORIGIN_COUNTRY_NAME")

In [18]:
partitioned_col.rdd.getNumPartitions()

300

In [21]:
partitioned_col.withColumn("partitionId", spark_partition_id()).groupBy("partitionId").count().show()

+-----------+-----+
|partitionId|count|
+-----------+-----+
|          0|    1|
|          2|    1|
|          7|    1|
|         10|    1|
|         13|    1|
|         15|    2|
|         16|    2|
|         19|    1|
|         21|    1|
|         22|    1|
|         28|    1|
|         31|    1|
|         39|    1|
|         42|    1|
|         43|    1|
|         44|    1|
|         45|    2|
|         48|    1|
|         53|    1|
|         54|    1|
+-----------+-----+
only showing top 20 rows



In [22]:
coalesce_df = flight_df.repartition(8)

In [23]:
coalesce_df.rdd.getNumPartitions()

8

In [24]:
coalesce_df.withColumn("partitionId", spark_partition_id()).groupBy("partitionId").count().show()

+-----------+-----+
|partitionId|count|
+-----------+-----+
|          0|   32|
|          1|   31|
|          2|   32|
|          3|   32|
|          4|   32|
|          5|   32|
|          6|   32|
|          7|   32|
+-----------+-----+



In [25]:
three_coal_df = coalesce_df.coalesce(3)

In [26]:
three_coal_df.rdd.getNumPartitions()

3

In [27]:
three_coal_df.withColumn("partitionId", spark_partition_id()).groupBy("partitionId").count().show()

+-----------+-----+
|partitionId|count|
+-----------+-----+
|          0|   64|
|          1|   95|
|          2|   96|
+-----------+-----+



In [28]:
# trying to increase the no. of partitions using coalesce.
# we had 8 partition in coalesce_df, now trying to increase it to 10
coalesce_df.coalesce(10).rdd.getNumPartitions()

8