In [1]:
from pyspark.sql import SparkSession 

# Create a SparkSession 
spark = SparkSession.builder \
    .appName("MyPySparkApp") \
    .master("local[3]") \
    .getOrCreate() 

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/01/18 18:40:28 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [7]:
#  Create simple dataframe
import string
string.ascii_letters
df = spark.createDataFrame([(i % 3, string.ascii_letters[i], i) for i in range(0, 20)], ['shard', 'letter', 'id'])
df.show(6)

+-----+------+---+
|shard|letter| id|
+-----+------+---+
|    0|     a|  0|
|    1|     b|  1|
|    2|     c|  2|
|    0|     d|  3|
|    1|     e|  4|
|    2|     f|  5|
+-----+------+---+
only showing top 6 rows



## What is a partition?
Each dataframe is split into partitions. Each partition allows for independent processing if no grouping is required. So when you call `df.withColumn('new_column', do_something)`, the execution might look like this:

```
Partition 0 - Calculated on first core:
+-----+------+---+
|shard|letter| id|
+-----+------+---+
|    0|     a|  0|
|    1|     b|  1|
|    2|     c|  2|
+-----+------+---+
Partition 1 Calculated on the second core:
+-----+------+---+
|shard|letter| id|
+-----+------+---+
|    0|     d|  3|
|    1|     e|  4|
|    2|     f|  5|
+-----+------+---+
...
```
### How many partitions should you have?

It depends. In general, one partition should be about 100-200MB to keep the processor busy when executing. Also, there should be not to many of them, to avoid scheduling overhead.
Also, a rule of thumb is that there should be 2-4x more partitions than execution cores. Otherwise we are wasting resources.

How many partitions do we have here?

In [13]:
# How many partitions are used in that dataframe?
print(f'Number of partitions: {df.rdd.getNumPartitions()}')

# Old way of checking how many records are in each rdd?
df.rdd.glom().map(len).collect()

Number of partitions: 3


[6, 6, 8]

In [None]:
from pyspark.sql.functions  import spark_partition_id
# spark_partition_id() gives the information on each partition a row is located
df.withColumn("partitionId", spark_partition_id()).groupBy("partitionId").count().show()

+-----------+-----+
|partitionId|count|
+-----------+-----+
|          0|    6|
|          1|    6|
|          2|    8|
+-----------+-----+



In [15]:
# We can repartition data so all the records with the same value will be located on the same partition
df.repartition(3, 'shard').withColumn("partitionId", spark_partition_id()).groupBy("partitionId", "shard").count().show()

+-----------+-----+-----+
|partitionId|shard|count|
+-----------+-----+-----+
|          1|    0|    7|
|          2|    1|    7|
|          2|    2|    6|
+-----------+-----+-----+



In [16]:
df.repartition(3, 'shard').rdd.getNumPartitions()

3

In [17]:
# A bigger data frame for even more interesting experiments
df_large = spark.createDataFrame(
    [(i % 5, string.ascii_letters[i %24], i) for i in list(range(0, 200)) + [20] * 1000 ], ['shard', 'letter', 'id']
)

In [23]:
# How the data was initially distributed:
df_large.withColumn("partitionId", spark_partition_id()).groupBy("shard", "partitionId").count().show()

+-----+-----------+-----+
|shard|partitionId|count|
+-----+-----------+-----+
|    3|          0|   40|
|    2|          0|   40|
|    0|          0|  240|
|    1|          0|   40|
|    4|          0|   40|
|    0|          1|  400|
|    0|          2|  400|
+-----+-----------+-----+



In [None]:
df_large.groupBy('shard').count().show()
df_large.withColumn("partitionId", spark_partition_id()).groupBy("partitionId").count().show()

+-----+-----+
|shard|count|
+-----+-----+
|    0| 1040|
|    1|   40|
|    3|   40|
|    2|   40|
|    4|   40|
+-----+-----+

+-----------+-----+
|partitionId|count|
+-----------+-----+
|          0|  400|
|          1|  400|
|          2|  400|
+-----------+-----+

+-----------+-----+
|partitionId|count|
+-----------+-----+
|          0|   40|
|          5| 1040|
|          7|   40|
|          8|   40|
|          9|   40|
+-----------+-----+



In [26]:
df_large.repartition(10, 'shard').withColumn("partitionId", spark_partition_id()).groupBy("partitionId", "shard").count().show()

+-----------+-----+-----+
|partitionId|shard|count|
+-----------+-----+-----+
|          0|    4|   40|
|          5|    0| 1040|
|          7|    3|   40|
|          8|    2|   40|
|          9|    1|   40|
+-----------+-----+-----+



As we can se above - all data associated with shard `0` was placed on one partition. Because of that we have a huge imbalance, which is called a skewed data. It's not good, because when doing processing, most probably other partitions will be ready way sooner that this one partition. Usually we try to avoid such cases. But this is a topic for another presentation. Just be careful.

## But what is RDD?

It's a data structure used by spark to allow recovering from executor failure. In fact, because of RDD Spark is lazy. The operations are not executed unitl the very last moment.

In [None]:
# This code does no spark processing
import string
string.ascii_letters
df = spark.createDataFrame([(i % 3, string.ascii_letters[i], i) for i in range(0, 9)], ['shard', 'letter', 'id'])
df = df.withColumn("partitionId", spark_partition_id())


In [38]:

df.show()

+-----+------+---+-----------+
|shard|letter| id|partitionId|
+-----+------+---+-----------+
|    0|     a|  0|          0|
|    1|     b|  1|          0|
|    2|     c|  2|          0|
|    0|     d|  3|          1|
|    1|     e|  4|          1|
|    2|     f|  5|          1|
|    0|     g|  6|          2|
|    1|     h|  7|          2|
|    2|     i|  8|          2|
+-----+------+---+-----------+



In [30]:
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import col
from time import sleep
import logging

def string_length(name, partitionId):

    logging.basicConfig(
        level=logging.INFO,  # Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",  # Define the log format
        handlers=[
            logging.StreamHandler()  # Output logs to the console
        ]
    )
    logger = logging.getLogger('SPARK')
    logger.info(f'{name} - partition_id: {partitionId}')
    sleep(0.5)
    return len(name) if name else 0
string_length_udf = udf(string_length, IntegerType())

x = df.withColumn("name_length", string_length_udf(df["letter"], df['partitionId'])).collect()
sleep(2)
print("One partition:")
x = df.repartition(1).withColumn("partitionId", spark_partition_id()).withColumn("name_length", string_length_udf(df["letter"], col('partitionId'))).collect()
sleep(2)
print("Two partition:")
x = df.repartition(2).withColumn("partitionId", spark_partition_id()).withColumn("name_length", string_length_udf(df["letter"], col('partitionId'))).collect()

2025-01-18 19:11:24,202 - SPARK - INFO - d - partition_id: 1        (0 + 3) / 3]
2025-01-18 19:11:24,208 - SPARK - INFO - a - partition_id: 0
2025-01-18 19:11:24,224 - SPARK - INFO - g - partition_id: 2
2025-01-18 19:11:24,703 - SPARK - INFO - e - partition_id: 1
2025-01-18 19:11:24,708 - SPARK - INFO - b - partition_id: 0
2025-01-18 19:11:24,725 - SPARK - INFO - h - partition_id: 2
2025-01-18 19:11:25,204 - SPARK - INFO - f - partition_id: 1
2025-01-18 19:11:25,209 - SPARK - INFO - c - partition_id: 0
2025-01-18 19:11:25,225 - SPARK - INFO - i - partition_id: 2
                                                                                

One partition:


2025-01-18 19:11:27,883 - SPARK - INFO - a - partition_id: 0
2025-01-18 19:11:28,384 - SPARK - INFO - b - partition_id: 0
2025-01-18 19:11:28,884 - SPARK - INFO - c - partition_id: 0        (0 + 1) / 1]
2025-01-18 19:11:29,385 - SPARK - INFO - d - partition_id: 0
2025-01-18 19:11:29,886 - SPARK - INFO - e - partition_id: 0
2025-01-18 19:11:30,386 - SPARK - INFO - f - partition_id: 0
2025-01-18 19:11:30,887 - SPARK - INFO - g - partition_id: 0
2025-01-18 19:11:31,387 - SPARK - INFO - h - partition_id: 0
2025-01-18 19:11:31,888 - SPARK - INFO - i - partition_id: 0
                                                                                

Two partition:


2025-01-18 19:11:34,559 - SPARK - INFO - b - partition_id: 1
2025-01-18 19:11:34,561 - SPARK - INFO - a - partition_id: 0
2025-01-18 19:11:35,060 - SPARK - INFO - c - partition_id: 1
2025-01-18 19:11:35,061 - SPARK - INFO - f - partition_id: 0
2025-01-18 19:11:35,561 - SPARK - INFO - e - partition_id: 1        (0 + 2) / 2]
2025-01-18 19:11:35,562 - SPARK - INFO - i - partition_id: 0
2025-01-18 19:11:36,062 - SPARK - INFO - d - partition_id: 1
2025-01-18 19:11:36,062 - SPARK - INFO - g - partition_id: 0
2025-01-18 19:11:36,562 - SPARK - INFO - h - partition_id: 1
                                                                                

In [34]:
# Skew data in action
df = spark.createDataFrame([(i % 3, string.ascii_letters[i], i) for i in list(range(0, 9)) + [0] * 15], ['shard', 'letter', 'id'])
df = df.repartition(5,'shard').withColumn("partitionId", spark_partition_id())
x = df.withColumn("name_length", string_length_udf(df["letter"], df['partitionId'])).collect()

2025-01-18 19:13:05,575 - SPARK - INFO - c - partition_id: 3
2025-01-18 19:13:05,576 - SPARK - INFO - a - partition_id: 0
2025-01-18 19:13:05,615 - SPARK - INFO - b - partition_id: 4
2025-01-18 19:13:06,076 - SPARK - INFO - f - partition_id: 3
2025-01-18 19:13:06,076 - SPARK - INFO - d - partition_id: 0
2025-01-18 19:13:06,116 - SPARK - INFO - e - partition_id: 4
2025-01-18 19:13:06,577 - SPARK - INFO - i - partition_id: 3        (0 + 3) / 5]
2025-01-18 19:13:06,577 - SPARK - INFO - g - partition_id: 0
2025-01-18 19:13:06,616 - SPARK - INFO - h - partition_id: 4
2025-01-18 19:13:07,078 - SPARK - INFO - a - partition_id: 0
2025-01-18 19:13:07,578 - SPARK - INFO - a - partition_id: 0        (4 + 1) / 5]
2025-01-18 19:13:08,079 - SPARK - INFO - a - partition_id: 0
2025-01-18 19:13:08,579 - SPARK - INFO - a - partition_id: 0
2025-01-18 19:13:09,080 - SPARK - INFO - a - partition_id: 0
2025-01-18 19:13:09,581 - SPARK - INFO - a - partition_id: 0
2025-01-18 19:13:10,081 - SPARK - INFO - a - 

In [36]:
df = spark.createDataFrame([(i % 3, string.ascii_letters[i], i) for i in range(0, 9)], ['shard', 'letter', 'id'])
df = df.withColumn("partitionId", spark_partition_id())
df = df.withColumn("name_length", string_length_udf(df["letter"], df['partitionId']))

In [None]:
df2 = spark.createDataFrame([(string.ascii_letters[i], i *100) for i in range(0, 5)], ['letter', 'id100'])

In [None]:
df.join(df2, how='inner', on='letter').show()
df2.groupBy(['letter']).count().show()

In [None]:
df.groupBy(['letter']).count().show()
df.groupBy(['letter','name_length']).count().show()

In [None]:
df.cache()
df.join(df2, how='inner', on='letter').show()
df.groupBy(['letter','name_length']).count().show()

In [None]:
df.join(df2, how='inner', on='letter').explain()

In [None]:
df.groupBy(['letter','name_length']).count().explain()

In [94]:
df.repartition(3, 'shard').write.partitionBy('shard').mode('overwrite').parquet('tables/df')

In [None]:
df.rdd.getNumPartitions()

In [None]:

x = spark.read.parquet('tables/df').withColumn("partitionId", spark_partition_id()).withColumn("name_length", string_length_udf(col("letter"), col('partitionId'))).collect()
sleep(2)
print("Filter:")
x = spark.read.parquet('tables/df').filter(col('shard') == 1).withColumn("partitionId", spark_partition_id()).withColumn("name_length", string_length_udf(col("letter"), col('partitionId'))).collect()

print("Filter after:")
x = spark.read.parquet('tables/df').withColumn("partitionId", spark_partition_id()).withColumn("name_length", string_length_udf(col("letter"), col('partitionId'))).filter(col('shard') == 1).collect()


In [None]:
spark.read.parquet('tables/df').filter(col('shard') == 1).withColumn("partitionId", spark_partition_id()).withColumn("name_length", string_length_udf(col("letter"), col('partitionId'))).explain(mode="extended")


In [None]:
spark.read.parquet('tables/df').withColumn("partitionId", spark_partition_id()).withColumn("name_length", string_length_udf(col("letter"), col('partitionId'))).filter(col('shard') == 1).explain(mode="extended")