In [None]:
from pyspark.sql import SparkSession 

# Create a SparkSession 
spark = SparkSession.builder \
    .appName("MyPySparkApp") \
    .master("local[3]") \
    .getOrCreate() 

In [7]:


import string
string.ascii_letters
df = spark.createDataFrame([(i % 3, string.ascii_letters[i], i) for i in range(0, 20)], ['shard', 'letter', 'id'])

In [None]:
df.rdd.getNumPartitions()

In [None]:
df.rdd.glom().map(len).collect()

In [None]:
from pyspark.sql.functions  import spark_partition_id
df.withColumn("partitionId", spark_partition_id()).groupBy("partitionId").count().show()

In [None]:
df.repartition(3, 'shard').withColumn("partitionId", spark_partition_id()).groupBy("partitionId", "shard").count().show()

In [None]:
df.repartition(3, 'shard').rdd.getNumPartitions()

In [39]:
df_large = spark.createDataFrame(
    [(i % 5, string.ascii_letters[i %24], i) for i in list(range(0, 200)) + [20] * 1000 ], ['shard', 'letter', 'id']
)

In [None]:
df_large.groupBy('shard').count().show()
df_large.withColumn("partitionId", spark_partition_id()).groupBy("partitionId").count().show()
df_large.repartition(10, 'shard').withColumn("partitionId", spark_partition_id()).groupBy("partitionId").count().show()

In [None]:
import string
string.ascii_letters
df = spark.createDataFrame([(i % 3, string.ascii_letters[i], i) for i in range(0, 9)], ['shard', 'letter', 'id'])
df = df.withColumn("partitionId", spark_partition_id())
df.show()


In [None]:
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import col
from time import sleep
import logging


def string_length(name, partitionId):

    logging.basicConfig(
        level=logging.INFO,  # Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",  # Define the log format
        handlers=[
            logging.StreamHandler()  # Output logs to the console
        ]
    )
    logger = logging.getLogger('SPARK')
    logger.info(f'{name} - partition_id: {partitionId}')
    sleep(0.5)
    return len(name) if name else 0
string_length_udf = udf(string_length, IntegerType())

x = df.withColumn("name_length", string_length_udf(df["letter"], df['partitionId'])).collect()
sleep(2)
print("One partition:")
x = df.repartition(1).withColumn("partitionId", spark_partition_id()).withColumn("name_length", string_length_udf(df["letter"], col('partitionId'))).collect()
sleep(2)
print("Two partition:")
x = df.repartition(2).withColumn("partitionId", spark_partition_id()).withColumn("name_length", string_length_udf(df["letter"], col('partitionId'))).collect()

In [None]:
df.repartition(1).withColumn("partitionId", spark_partition_id())

In [None]:
df = spark.createDataFrame([(i % 3, string.ascii_letters[i], i) for i in list(range(0, 9)) + [0] *10], ['shard', 'letter', 'id'])
df = df.repartition(5,'shard').withColumn("partitionId", spark_partition_id())
x = df.withColumn("name_length", string_length_udf(df["letter"], df['partitionId'])).collect()

In [68]:
df = spark.createDataFrame([(i % 3, string.ascii_letters[i], i) for i in range(0, 9)], ['shard', 'letter', 'id'])
df = df.withColumn("partitionId", spark_partition_id())
df = df.withColumn("name_length", string_length_udf(df["letter"], df['partitionId']))

In [None]:
df2 = spark.createDataFrame([(string.ascii_letters[i], i *100) for i in range(0, 5)], ['letter', 'id100'])


In [None]:
df.join(df2, how='inner', on='letter').show()
df2.groupBy(['letter']).count().show()

In [None]:
df.groupBy(['letter']).count().show()
df.groupBy(['letter','name_length']).count().show()

In [None]:
df.cache()
df.join(df2, how='inner', on='letter').show()
df.groupBy(['letter','name_length']).count().show()

In [None]:
df.join(df2, how='inner', on='letter').explain()

In [None]:
df.groupBy(['letter','name_length']).count().explain()

In [94]:
df.repartition(3, 'shard').write.partitionBy('shard').mode('overwrite').parquet('tables/df')

In [None]:
df.rdd.getNumPartitions()

In [None]:

x = spark.read.parquet('tables/df').withColumn("partitionId", spark_partition_id()).withColumn("name_length", string_length_udf(col("letter"), col('partitionId'))).collect()
sleep(2)
print("Filter:")
x = spark.read.parquet('tables/df').filter(col('shard') == 1).withColumn("partitionId", spark_partition_id()).withColumn("name_length", string_length_udf(col("letter"), col('partitionId'))).collect()

print("Filter after:")
x = spark.read.parquet('tables/df').withColumn("partitionId", spark_partition_id()).withColumn("name_length", string_length_udf(col("letter"), col('partitionId'))).filter(col('shard') == 1).collect()


In [None]:
spark.read.parquet('tables/df').filter(col('shard') == 1).withColumn("partitionId", spark_partition_id()).withColumn("name_length", string_length_udf(col("letter"), col('partitionId'))).explain(mode="extended")


In [None]:
spark.read.parquet('tables/df').withColumn("partitionId", spark_partition_id()).withColumn("name_length", string_length_udf(col("letter"), col('partitionId'))).filter(col('shard') == 1).explain(mode="extended")