# 3. Repartition

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as f
from pyspark.sql import DataFrame

In [2]:
spark = SparkSession \
    .builder \
    .appName("Repartition") \
    .master("local[4]") \
    .enableHiveSupport() \
    .getOrCreate()

sc = spark.sparkContext

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/11 20:34:13 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
def sdf_generator(num_rows: int, num_partitions: int = None) -> DataFrame:
    return (
        spark.range(num_rows, numPartitions=num_partitions)
        .withColumn("date", f.current_date())
        .withColumn("timestamp",f.current_timestamp())
        .withColumn("idstring", f.col("id").cast("string"))
        .withColumn("idfirst", f.col("idstring").substr(0,1))
        .withColumn("idlast", f.col("idstring").substr(-1,1))
        )

In [4]:
sdf_gen = sdf_generator(20)
sdf_gen.count()

20

In [5]:
sdf_gen.show()

+---+----------+--------------------+--------+-------+------+
| id|      date|           timestamp|idstring|idfirst|idlast|
+---+----------+--------------------+--------+-------+------+
|  0|2024-12-11|2024-12-11 20:34:...|       0|      0|     0|
|  1|2024-12-11|2024-12-11 20:34:...|       1|      1|     1|
|  2|2024-12-11|2024-12-11 20:34:...|       2|      2|     2|
|  3|2024-12-11|2024-12-11 20:34:...|       3|      3|     3|
|  4|2024-12-11|2024-12-11 20:34:...|       4|      4|     4|
|  5|2024-12-11|2024-12-11 20:34:...|       5|      5|     5|
|  6|2024-12-11|2024-12-11 20:34:...|       6|      6|     6|
|  7|2024-12-11|2024-12-11 20:34:...|       7|      7|     7|
|  8|2024-12-11|2024-12-11 20:34:...|       8|      8|     8|
|  9|2024-12-11|2024-12-11 20:34:...|       9|      9|     9|
| 10|2024-12-11|2024-12-11 20:34:...|      10|      1|     0|
| 11|2024-12-11|2024-12-11 20:34:...|      11|      1|     1|
| 12|2024-12-11|2024-12-11 20:34:...|      12|      1|     2|
| 13|202

In [6]:
def rows_per_partition(sdf: "DataFrame", num_rows: int) -> None:
    sdf_part = sdf.withColumn("partition_id", f.spark_partition_id())
    sdf_part_count = sdf_part.groupBy("partition_id").count()
    sdf_part_count = sdf_part_count.withColumn("count_perc", 100*f.col("count")/num_rows)
    sdf_part_count.orderBy("partition_id").show()

rows_per_partition(sdf_gen, 20)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           0|    5|      25.0|
|           1|    5|      25.0|
|           2|    5|      25.0|
|           3|    5|      25.0|
+------------+-----+----------+



In [7]:
def rows_per_partition_col(sdf: "DataFrame", num_rows: int, col: str) -> None:
    sdf_part = sdf.withColumn("partition_id", f.spark_partition_id())
    sdf_part_count = sdf_part.groupBy("partition_id", col).count()
    sdf_part_count = sdf_part_count.withColumn("count_perc", 100*f.col("count")/num_rows)
    sdf_part_count.orderBy("partition_id", col).show()

rows_per_partition_col(sdf_gen, 20, "idfirst")

+------------+-------+-----+----------+
|partition_id|idfirst|count|count_perc|
+------------+-------+-----+----------+
|           0|      0|    1|       5.0|
|           0|      1|    1|       5.0|
|           0|      2|    1|       5.0|
|           0|      3|    1|       5.0|
|           0|      4|    1|       5.0|
|           1|      5|    1|       5.0|
|           1|      6|    1|       5.0|
|           1|      7|    1|       5.0|
|           1|      8|    1|       5.0|
|           1|      9|    1|       5.0|
|           2|      1|    5|      25.0|
|           3|      1|    5|      25.0|
+------------+-------+-----+----------+



In [8]:
num_rows = 20000

In [9]:
sdf1 = sdf_generator(num_rows=num_rows, num_partitions=4)
sdf1.rdd.getNumPartitions()

4

In [10]:
row_count = sdf1.count()
print(row_count)

20000


In [11]:
rows_per_partition(sdf1, num_rows)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           0| 5000|      25.0|
|           1| 5000|      25.0|
|           2| 5000|      25.0|
|           3| 5000|      25.0|
+------------+-----+----------+



In [12]:
rows_per_partition_col(sdf1, num_rows, "idfirst")

+------------+-------+-----+----------+
|partition_id|idfirst|count|count_perc|
+------------+-------+-----+----------+
|           0|      0|    1|     0.005|
|           0|      1| 1111|     5.555|
|           0|      2| 1111|     5.555|
|           0|      3| 1111|     5.555|
|           0|      4| 1111|     5.555|
|           0|      5|  111|     0.555|
|           0|      6|  111|     0.555|
|           0|      7|  111|     0.555|
|           0|      8|  111|     0.555|
|           0|      9|  111|     0.555|
|           1|      5| 1000|       5.0|
|           1|      6| 1000|       5.0|
|           1|      7| 1000|       5.0|
|           1|      8| 1000|       5.0|
|           1|      9| 1000|       5.0|
|           2|      1| 5000|      25.0|
|           3|      1| 5000|      25.0|
+------------+-------+-----+----------+



In [13]:
sc.setJobDescription("Baseline 4 partitions")
sdf1.write.format("noop").mode("overwrite").save()
sc.setJobDescription("None")

In [14]:
sdf_3 = sdf1.repartition(3)
sdf_3.rdd.getNumPartitions()

3

In [15]:
rows_per_partition(sdf_3, num_rows)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           0| 6667|    33.335|
|           1| 6667|    33.335|
|           2| 6666|     33.33|
+------------+-----+----------+



In [16]:
sdf_12 = sdf1.repartition(12)
sdf_12.rdd.getNumPartitions()

12

In [17]:
rows_per_partition(sdf_12, num_rows)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           0| 1667|     8.335|
|           1| 1666|      8.33|
|           2| 1666|      8.33|
|           3| 1666|      8.33|
|           4| 1667|     8.335|
|           5| 1667|     8.335|
|           6| 1667|     8.335|
|           7| 1667|     8.335|
|           8| 1666|      8.33|
|           9| 1667|     8.335|
|          10| 1667|     8.335|
|          11| 1667|     8.335|
+------------+-----+----------+



In [18]:
spark.conf.set("spark.sql.adaptive.enabled", "False")

In [19]:
# spark.conf.set("spark.sql.shuffle.partitions", 200)
sdf_col_200 = sdf1.repartition("idfirst")
sdf_col_200.rdd.getNumPartitions()

200

In [20]:
rows_per_partition(sdf_col_200, num_rows)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           3| 1111|     5.555|
|          18| 1111|     5.555|
|          26| 1111|     5.555|
|          35|    1|     0.005|
|          49| 1111|     5.555|
|          75| 1111|     5.555|
|         139| 1111|     5.555|
|         144|11111|    55.555|
|         166| 1111|     5.555|
|         189| 1111|     5.555|
+------------+-----+----------+



In [21]:
rows_per_partition_col(sdf_col_200, num_rows, "idfirst")


+------------+-------+-----+----------+
|partition_id|idfirst|count|count_perc|
+------------+-------+-----+----------+
|           3|      7| 1111|     5.555|
|          18|      3| 1111|     5.555|
|          26|      8| 1111|     5.555|
|          35|      0|    1|     0.005|
|          49|      5| 1111|     5.555|
|          75|      6| 1111|     5.555|
|         139|      9| 1111|     5.555|
|         144|      1|11111|    55.555|
|         166|      4| 1111|     5.555|
|         189|      2| 1111|     5.555|
+------------+-------+-----+----------+



In [22]:
spark.conf.set("spark.sql.shuffle.partitions", 20)
sdf_col_20 = sdf1.repartition("idfirst")
sdf_col_20.rdd.getNumPartitions()

20

In [23]:
rows_per_partition(sdf_col_20, num_rows)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           3| 1111|     5.555|
|           4|11111|    55.555|
|           6| 2222|     11.11|
|           9| 2222|     11.11|
|          15| 1112|      5.56|
|          18| 1111|     5.555|
|          19| 1111|     5.555|
+------------+-----+----------+



In [24]:
rows_per_partition_col(sdf_col_20, num_rows, "idfirst")

+------------+-------+-----+----------+
|partition_id|idfirst|count|count_perc|
+------------+-------+-----+----------+
|           3|      7| 1111|     5.555|
|           4|      1|11111|    55.555|
|           6|      4| 1111|     5.555|
|           6|      8| 1111|     5.555|
|           9|      2| 1111|     5.555|
|           9|      5| 1111|     5.555|
|          15|      0|    1|     0.005|
|          15|      6| 1111|     5.555|
|          18|      3| 1111|     5.555|
|          19|      9| 1111|     5.555|
+------------+-------+-----+----------+



In [25]:
sdf_col_10 = sdf1.repartition(10, "idfirst")
sdf_col_10.rdd.getNumPartitions()

10

In [26]:
rows_per_partition(sdf_col_10, num_rows)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           3| 1111|     5.555|
|           4|11111|    55.555|
|           5| 1112|      5.56|
|           6| 2222|     11.11|
|           8| 1111|     5.555|
|           9| 3333|    16.665|
+------------+-----+----------+



In [27]:
rows_per_partition_col(sdf_col_10, num_rows, "idfirst")

+------------+-------+-----+----------+
|partition_id|idfirst|count|count_perc|
+------------+-------+-----+----------+
|           3|      7| 1111|     5.555|
|           4|      1|11111|    55.555|
|           5|      0|    1|     0.005|
|           5|      6| 1111|     5.555|
|           6|      4| 1111|     5.555|
|           6|      8| 1111|     5.555|
|           8|      3| 1111|     5.555|
|           9|      2| 1111|     5.555|
|           9|      5| 1111|     5.555|
|           9|      9| 1111|     5.555|
+------------+-------+-----+----------+



In [28]:
sdf_col_5 = sdf1.repartition(5, "idfirst")
sdf_col_5.rdd.getNumPartitions()

5

In [29]:
rows_per_partition(sdf_col_5, num_rows)

+------------+-----+----------+
|partition_id|count|count_perc|
+------------+-----+----------+
|           0| 1112|      5.56|
|           1| 2222|     11.11|
|           3| 2222|     11.11|
|           4|14444|     72.22|
+------------+-----+----------+



In [30]:
rows_per_partition_col(sdf_col_5, num_rows, "idfirst")

+------------+-------+-----+----------+
|partition_id|idfirst|count|count_perc|
+------------+-------+-----+----------+
|           0|      0|    1|     0.005|
|           0|      6| 1111|     5.555|
|           1|      4| 1111|     5.555|
|           1|      8| 1111|     5.555|
|           3|      3| 1111|     5.555|
|           3|      7| 1111|     5.555|
|           4|      1|11111|    55.555|
|           4|      2| 1111|     5.555|
|           4|      5| 1111|     5.555|
|           4|      9| 1111|     5.555|
+------------+-------+-----+----------+



In [31]:
sc.setJobDescription("Repartition from 4 to 3")
sdf_3.write.format("noop").mode("overwrite").save()
sc.setJobDescription("None")

In [32]:
sc.setJobDescription("Repartition from 4 to 12")
sdf_12.write.format("noop").mode("overwrite").save()
sc.setJobDescription("None")

In [33]:
sc.setJobDescription("Repartition from 4 to 5 with col")
sdf_col_5.write.format("noop").mode("overwrite").save()
sc.setJobDescription("None")