#### PySpark Configurations ####

In [1]:
# Import Libraries
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.storagelevel import StorageLevel

In [2]:
# Initialize Spark session
spark = SparkSession.builder \
            .master("spark://spark-master:7077") \
                .appName("Ansh-Lamba-Apache-Spark-Optimization") \
                    .config("spark.ui.port", "4040") \
                        .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/01/21 11:40:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# Adaptive Query Execution - AQE
spark.conf.set("spark.sql.adaptive.enabled", True)    # Enable/Disable AQE, enabled by default
print('Adaptive Query Execution (AQE) enabled:', spark.conf.get("spark.sql.adaptive.enabled"))   # Check if AQE is enabled

Adaptive Query Execution (AQE) enabled: true


In [4]:
# Dynamic Partition pruning
spark.conf.set("spark.sql.optimizer.dynamicPartitionPruning.enabled", True)    # Enable/Disable Dynamic Partition pruning, enabled by default
print('Dynamic Partition pruning enabled:', spark.conf.get("spark.sql.optimizer.dynamicPartitionPruning.enabled"))   # Check if Dynamic Partition pruning is enabled

Dynamic Partition pruning enabled: true


In [5]:
# Auto-Broadcast JOIN
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 5 * 1024 * 1024)    # Enable/Disable Auto-Broadcast JOIN, enabled by default -1 - disable
print('Auto-Broadcast JOIN enabled:', spark.conf.get("spark.sql.autoBroadcastJoinThreshold"))   # Check if Auto-Broadcast JOIN is enabled

Auto-Broadcast JOIN enabled: 5242880


#### Reading data from CSV file ####

In [6]:
# Create root directory
INPUT_DATA_ROOT = "/opt/spark-data/input/ansh-lamba"

In [7]:
# Read CSV file with Infered schema
df_big_mart_sales = spark.read.format("csv") \
                        .option('inferSchema',True) \
                            .option("header", True) \
                                .load(f"{INPUT_DATA_ROOT}/BigMart Sales - Copy.csv")

                                                                                

In [8]:
# Check first N records
df_big_mart_sales.limit(5).toPandas()

Unnamed: 0,Item_Identifier,Item_Weight,Item_Fat_Content,Item_Visibility,Item_Type,Item_MRP,Outlet_Identifier,Outlet_Establishment_Year,Outlet_Size,Outlet_Location_Type,Outlet_Type,Item_Outlet_Sales
0,DRA12,11.6,Low Fat,0.041178,Soft Drinks,140.3154,OUT017,2007,,Tier 2,Supermarket Type1,2552.6772
1,DRA12,11.6,Low Fat,0.0,Soft Drinks,141.6154,OUT045,2002,,Tier 2,Supermarket Type1,3829.0158
2,DRA12,11.6,Low Fat,0.040912,Soft Drinks,142.3154,OUT013,1987,High,Tier 3,Supermarket Type1,2552.6772
3,DRA12,11.6,LF,0.0,Soft Drinks,141.9154,OUT035,2004,Small,Tier 2,Supermarket Type1,992.7078
4,DRA12,11.6,Low Fat,0.041113,Soft Drinks,142.0154,OUT018,2009,Medium,Tier 3,Supermarket Type2,850.8924


In [9]:
# Check dataframe schema
df_big_mart_sales.printSchema()

root
 |-- Item_Identifier: string (nullable = true)
 |-- Item_Weight: double (nullable = true)
 |-- Item_Fat_Content: string (nullable = true)
 |-- Item_Visibility: double (nullable = true)
 |-- Item_Type: string (nullable = true)
 |-- Item_MRP: double (nullable = true)
 |-- Outlet_Identifier: string (nullable = true)
 |-- Outlet_Establishment_Year: integer (nullable = true)
 |-- Outlet_Size: string (nullable = true)
 |-- Outlet_Location_Type: string (nullable = true)
 |-- Outlet_Type: string (nullable = true)
 |-- Item_Outlet_Sales: double (nullable = true)



In [10]:
# Check total number of records in dataframe
print('Total records: {:,}'.format(df_big_mart_sales.count()))

[Stage 5:>                                                          (0 + 1) / 1]

Total records: 244


                                                                                

In [11]:
# Check number of Partitions in dataframe
print('Number of Partitions: {:,}'.format(df_big_mart_sales.rdd.getNumPartitions()))

Number of Partitions: 1


In [12]:
# Change default Partition size to 128KB
# spark.conf.set("spark.sql.files.maxPartitionBytes", 131072)

# Change back to default Partition size
spark.conf.set("spark.sql.files.maxPartitionBytes", 134217728)

In [13]:
# Repartition dataframe
df_big_mart_sales = df_big_mart_sales.repartition(10)

In [14]:
# Check number of Partitions in dataframe
print('Number of Partitions: {:,}'.format(df_big_mart_sales.rdd.getNumPartitions()))

[Stage 6:>                                                          (0 + 1) / 1]

Number of Partitions: 10


In [15]:
# Add Partition ID column to dataframe - showing which Partition a record is stored
df_big_mart_sales.withColumn('Partition_Id', spark_partition_id()).limit(5).toPandas()

                                                                                

Unnamed: 0,Item_Identifier,Item_Weight,Item_Fat_Content,Item_Visibility,Item_Type,Item_MRP,Outlet_Identifier,Outlet_Establishment_Year,Outlet_Size,Outlet_Location_Type,Outlet_Type,Item_Outlet_Sales,Partition_Id
0,DRE15,13.35,Low Fat,0.017858,Dairy,77.5012,OUT018,2009,Medium,Tier 3,Supermarket Type2,1518.024,0
1,DRB13,6.115,Regular,0.007084,Soft Drinks,191.153,OUT017,2007,,Tier 2,Supermarket Type1,3415.554,0
2,DRD37,9.8,Low Fat,0.013842,Soft Drinks,45.206,OUT046,1997,Small,Tier 1,Supermarket Type1,1211.756,0
3,DRF23,4.61,Low Fat,0.122629,Hard Drinks,175.4396,OUT035,2004,Small,Tier 2,Supermarket Type1,2616.594,0
4,DRD25,6.135,Low Fat,0.132183,Soft Drinks,115.086,OUT010,1998,,Tier 3,Grocery Store,452.744,0


#### Writing to Parquet file ####

In [16]:
# Create root directory
OUTPUT_DATA_ROOT = "/opt/spark-data/output/ansh-lamba/"

MODE = "APPEND"  # MODES = APPEND, OVERWRITE, ERROR, IGNORE

In [17]:
# Write dataframe to Parquet file - No partitions
df_big_mart_sales \
    .write \
        .format("parquet") \
            .mode(MODE) \
                .save(f"{OUTPUT_DATA_ROOT}/big-mart-sales.parquet")

                                                                                

#### Reading from Parquet file ####

In [18]:
# Read non-partitioned dataframe from Parquet file
df_big_mart_sales_no_partition = spark.read.format("parquet") \
                    .option('inferSchema',True) \
                        .option("header", True) \
                            .load(f"{OUTPUT_DATA_ROOT}/big-mart-sales.parquet")

In [19]:
# Check first N records
df_big_mart_sales_no_partition.limit(5).toPandas()

Unnamed: 0,Item_Identifier,Item_Weight,Item_Fat_Content,Item_Visibility,Item_Type,Item_MRP,Outlet_Identifier,Outlet_Establishment_Year,Outlet_Size,Outlet_Location_Type,Outlet_Type,Item_Outlet_Sales
0,DRD27,18.75,Low Fat,0.0,Dairy,97.9042,OUT045,2002,,Tier 2,Supermarket Type1,1686.4714
1,DRE13,6.28,Low Fat,0.0277,Soft Drinks,87.9198,OUT035,2004,Small,Tier 2,Supermarket Type1,1221.0772
2,DRE27,11.85,Low Fat,0.13256,Dairy,97.2726,OUT013,1987,High,Tier 3,Supermarket Type1,782.9808
3,DRB13,6.115,Regular,0.007043,Soft Drinks,190.353,OUT035,2004,Small,Tier 2,Supermarket Type1,569.259
4,DRC12,17.85,Low Fat,0.037886,Soft Drinks,190.4188,OUT049,1999,Medium,Tier 1,Supermarket Type1,952.094


In [20]:
# Check total number of records in dataframe
print('Total records: {:,}'.format(df_big_mart_sales_no_partition.count()))



Total records: 244


                                                                                

In [21]:
# Filter out Tier 1 locations
df_big_mart_sales_no_partition_filtered = df_big_mart_sales_no_partition \
    .filter(col("Outlet_Location_Type") == "Tier 1")

In [22]:
# Check first N records
df_big_mart_sales_no_partition_filtered.limit(5).toPandas()

Unnamed: 0,Item_Identifier,Item_Weight,Item_Fat_Content,Item_Visibility,Item_Type,Item_MRP,Outlet_Identifier,Outlet_Establishment_Year,Outlet_Size,Outlet_Location_Type,Outlet_Type,Item_Outlet_Sales
0,DRC12,17.85,Low Fat,0.037886,Soft Drinks,190.4188,OUT049,1999,Medium,Tier 1,Supermarket Type1,952.094
1,DRE49,20.75,Low Fat,0.021283,Soft Drinks,153.5024,OUT049,1999,Medium,Tier 1,Supermarket Type1,2428.8384
2,DRC49,8.67,Low Fat,0.065437,Soft Drinks,142.9128,OUT046,1997,Small,Tier 1,Supermarket Type1,2013.3792
3,DRD25,6.135,Low Fat,0.079095,Soft Drinks,114.386,OUT049,1999,Medium,Tier 1,Supermarket Type1,1018.674
4,DRE12,4.59,Low Fat,0.070781,Soft Drinks,114.586,OUT046,1997,Small,Tier 1,Supermarket Type1,1245.046


In [23]:
# Check total number of records in dataframe
print('Total records: {:,}'.format(df_big_mart_sales_no_partition_filtered.count()))

[Stage 19:>                                                         (0 + 5) / 5]

Total records: 68


                                                                                

#### Sanning Optimization ####

In [24]:
# Write dataframe to Parquet file - Partitions
partition_by_columns = ["Item_Identifier"]

df_big_mart_sales \
    .write \
        .format("parquet") \
            .partitionBy(*partition_by_columns) \
                .mode(MODE) \
                    .save(f"{OUTPUT_DATA_ROOT}/big-mart-sales-partitions.parquet")

                                                                                

In [25]:
# Read data frame from disk - With Partitions
"""
df_big_mart_sales_partition = spark.read.format("parquet") \
    .option('inferSchema',True) \
        .option("header", True) \
            .load(f"{OUTPUT_DATA_ROOT}/big-mart-sales-partitions.parquet") \
                .filter(col("Outlet_Location_Type") == "Tier 1")
"""

df_big_mart_sales_partition = spark.read.format("parquet") \
    .option('inferSchema',True) \
        .option("header", True) \
            .load(f"{OUTPUT_DATA_ROOT}/big-mart-sales-partitions.parquet")

                                                                                

In [26]:
# Check first N records
df_big_mart_sales_partition.limit(5).toPandas()

                                                                                

Unnamed: 0,Item_Weight,Item_Fat_Content,Item_Visibility,Item_Type,Item_MRP,Outlet_Identifier,Outlet_Establishment_Year,Outlet_Size,Outlet_Location_Type,Outlet_Type,Item_Outlet_Sales,Item_Identifier
0,12.85,Low Fat,0.03322,Fruits and Vegetables,196.6768,OUT046,1997,Small,Tier 1,Supermarket Type1,2759.0752,FDF20
1,12.85,Low Fat,0.033193,Fruits and Vegetables,199.0768,OUT013,1987,High,Tier 3,Supermarket Type1,3153.2288,FDF20
2,12.85,Low Fat,0.033272,Fruits and Vegetables,196.5768,OUT049,1999,Medium,Tier 1,Supermarket Type1,5715.2272,FDF20
3,13.5,Regular,0.128792,Starchy Foods,95.4068,OUT035,2004,Small,Tier 2,Supermarket Type1,1944.136,FDH47
4,13.5,Regular,0.129016,Starchy Foods,98.4068,OUT049,1999,Medium,Tier 1,Supermarket Type1,1846.9292,FDH47


#### Optimized Joins ####

In [27]:
# Create Siblings & Countries datasets
siblings = [
            (1, 'Kwaku Jude', 40, 'M', 10_090.50, 840), \
            (2, 'Yaw David', 36, 'M', 9_001.10, 288), \
            (3, 'Kofi Baffuor', 34, 'M', 8_200.99, 288), \
            (4, 'Abena Salo', 32, 'F', 7_905.00, 288), \
            (5, 'Abena Pat', 30, 'F', 7_005.19, 288)
            ]

countries = [
            (840, 'USA'), \
            (288, 'GHANA')
            ]

In [28]:
# Schema definitions
siblings_schema = 'Id INT, Name STRING, Age INT, Gender STRING, Salary DOUBLE, CountryId INT'
countries_schema = 'CountryId INT, CountryName STRING'

In [29]:
# Create Spark dataframes
df_siblings = spark.createDataFrame(data=siblings, schema=siblings_schema)
df_countries = spark.createDataFrame(data=countries, schema=countries_schema)

In [30]:
# Check first N records
df_siblings.limit(5).toPandas()

                                                                                

Unnamed: 0,Id,Name,Age,Gender,Salary,CountryId
0,1,Kwaku Jude,40,M,10090.5,840
1,2,Yaw David,36,M,9001.1,288
2,3,Kofi Baffuor,34,M,8200.99,288
3,4,Abena Salo,32,F,7905.0,288
4,5,Abena Pat,30,F,7005.19,288


In [31]:
# Check first N records
df_countries.limit(5).toPandas()

Unnamed: 0,CountryId,CountryName
0,840,USA
1,288,GHANA


In [32]:
# Join Siblings & Countries dataframes - Merge JOIN
dfs_siblings = df_siblings.join(df_countries, df_siblings["CountryId"] == df_countries["CountryId"], "inner")

In [33]:
# Check first N records
dfs_siblings.limit(5).toPandas()

Unnamed: 0,Id,Name,Age,Gender,Salary,CountryId,CountryId.1,CountryName
0,2,Yaw David,36,M,9001.1,288,288,GHANA
1,3,Kofi Baffuor,34,M,8200.99,288,288,GHANA
2,4,Abena Salo,32,F,7905.0,288,288,GHANA
3,5,Abena Pat,30,F,7005.19,288,288,GHANA
4,1,Kwaku Jude,40,M,10090.5,840,840,USA


In [34]:
# Join Siblings & Countries dataframes - Broadcast JOIN
dfs_siblings_optimized = df_siblings.join(broadcast(df_countries), df_siblings["CountryId"] == df_countries["CountryId"], "inner")

In [35]:
# Check first N records
dfs_siblings_optimized.limit(5).toPandas()

Unnamed: 0,Id,Name,Age,Gender,Salary,CountryId,CountryId.1,CountryName
0,1,Kwaku Jude,40,M,10090.5,840,840,USA
1,2,Yaw David,36,M,9001.1,288,288,GHANA
2,3,Kofi Baffuor,34,M,8200.99,288,288,GHANA
3,4,Abena Salo,32,F,7905.0,288,288,GHANA
4,5,Abena Pat,30,F,7005.19,288,288,GHANA


#### Spark SQL Hints ####

In [36]:
# Create table/view using Siblings & Countries dataframes
df_siblings.createOrReplaceTempView("tbl_siblings")
df_countries.createOrReplaceTempView("tbl_countries")

In [37]:
# Join Siblings & Countries dataframes - Merge JOIN SQL
sql_query = spark.sql("""
                      SELECT Id, Name, Age, Gender, Salary, CountryName 
                      FROM tbl_siblings sb 
                      INNER JOIN tbl_countries cs 
                      ON sb.CountryId = cs.CountryId 
                      """)

# Show data
sql_query.limit(5).toPandas()

Unnamed: 0,Id,Name,Age,Gender,Salary,CountryName
0,2,Yaw David,36,M,9001.1,GHANA
1,3,Kofi Baffuor,34,M,8200.99,GHANA
2,4,Abena Salo,32,F,7905.0,GHANA
3,5,Abena Pat,30,F,7005.19,GHANA
4,1,Kwaku Jude,40,M,10090.5,USA


In [38]:
# Join Siblings & Countries dataframes - Merge JOIN SQL
sql_query_optimized = spark.sql("""
                                SELECT Id, Name, Age, Gender, Salary, CountryName /* broadcast(cs) */ 
                                FROM tbl_siblings sb 
                                INNER JOIN tbl_countries cs 
                                ON sb.CountryId = cs.CountryId 
                                 """)

# Show data
sql_query_optimized.limit(5).toPandas()

                                                                                

Unnamed: 0,Id,Name,Age,Gender,Salary,CountryName
0,2,Yaw David,36,M,9001.1,GHANA
1,3,Kofi Baffuor,34,M,8200.99,GHANA
2,4,Abena Salo,32,F,7905.0,GHANA
3,5,Abena Pat,30,F,7005.19,GHANA
4,1,Kwaku Jude,40,M,10090.5,USA


#### Caching & Persistence ####

In [39]:
# Cache dataframe
df_big_mart_sales.cache()

DataFrame[Item_Identifier: string, Item_Weight: double, Item_Fat_Content: string, Item_Visibility: double, Item_Type: string, Item_MRP: double, Outlet_Identifier: string, Outlet_Establishment_Year: int, Outlet_Size: string, Outlet_Location_Type: string, Outlet_Type: string, Item_Outlet_Sales: double]

In [40]:
# Filter out Tier 1 locations
df_big_mart_sales_tier_1 = df_big_mart_sales \
                                .filter(col("Outlet_Location_Type") == "Tier 1")

In [41]:
# Filter out Tier 2 locations
df_big_mart_sales_tier_2 = df_big_mart_sales \
                                .filter(col("Outlet_Location_Type") == "Tier 2")

In [42]:
# Filter out Tier 3 locations
df_big_mart_sales_tier_3 = df_big_mart_sales \
                                .filter(col("Outlet_Location_Type") == "Tier 3")

In [43]:
# Check first N records
df_big_mart_sales_tier_1.limit(5).toPandas()

Unnamed: 0,Item_Identifier,Item_Weight,Item_Fat_Content,Item_Visibility,Item_Type,Item_MRP,Outlet_Identifier,Outlet_Establishment_Year,Outlet_Size,Outlet_Location_Type,Outlet_Type,Item_Outlet_Sales
0,DRD37,9.8,Low Fat,0.013842,Soft Drinks,45.206,OUT046,1997,Small,Tier 1,Supermarket Type1,1211.756
1,DRD12,6.96,Low Fat,0.077194,Soft Drinks,89.9146,OUT046,1997,Small,Tier 1,Supermarket Type1,1277.0044
2,DRD60,15.7,Low Fat,0.037232,Soft Drinks,183.1634,OUT046,1997,Small,Tier 1,Supermarket Type1,5634.6654
3,DRE03,19.6,Low Fat,0.024227,Dairy,48.7718,OUT046,1997,Small,Tier 1,Supermarket Type1,236.359
4,DRE49,20.75,LF,0.02125,Soft Drinks,150.5024,OUT046,1997,Small,Tier 1,Supermarket Type1,2580.6408


In [44]:
# Uncache dataframe
df_big_mart_sales.unpersist()

DataFrame[Item_Identifier: string, Item_Weight: double, Item_Fat_Content: string, Item_Visibility: double, Item_Type: string, Item_MRP: double, Outlet_Identifier: string, Outlet_Establishment_Year: int, Outlet_Size: string, Outlet_Location_Type: string, Outlet_Type: string, Item_Outlet_Sales: double]

In [45]:
# Persist data frame - MEMORY_ONLY
df_big_mart_sales.persist(StorageLevel.MEMORY_ONLY)

DataFrame[Item_Identifier: string, Item_Weight: double, Item_Fat_Content: string, Item_Visibility: double, Item_Type: string, Item_MRP: double, Outlet_Identifier: string, Outlet_Establishment_Year: int, Outlet_Size: string, Outlet_Location_Type: string, Outlet_Type: string, Item_Outlet_Sales: double]

#### Adaptive Query Execution - AQE ####

In [46]:
# Group dataframe by Item_Fat_Content colum
df_items_fat_content = df_big_mart_sales \
                            .groupby("Item_Fat_Content") \
                                .count()

In [47]:
# Check first N records
df_items_fat_content.limit(5).toPandas()

Unnamed: 0,Item_Fat_Content,count
0,Low Fat,181
1,LF,14
2,Regular,46
3,reg,1
4,low fat,2


#### Dynamic Partition pruning ####

In [52]:
# Join Partitioned & Non-Partitioned dataframes
dfs_join_optimized = df_big_mart_sales_partition \
    .join(df_big_mart_sales_no_partition.filter(col("Item_Identifier") == "DRB01"), df_big_mart_sales_partition["Item_Identifier"] == df_big_mart_sales_no_partition["Item_Identifier"], "inner")

In [53]:
# Check first N records
dfs_join_optimized.limit(5).toPandas()

Unnamed: 0,Item_Weight,Item_Fat_Content,Item_Visibility,Item_Type,Item_MRP,Outlet_Identifier,Outlet_Establishment_Year,Outlet_Size,Outlet_Location_Type,Outlet_Type,...,Item_Fat_Content.1,Item_Visibility.1,Item_Type.1,Item_MRP.1,Outlet_Identifier.1,Outlet_Establishment_Year.1,Outlet_Size.1,Outlet_Location_Type.1,Outlet_Type.1,Item_Outlet_Sales
0,7.39,Low Fat,0.082367,Soft Drinks,187.753,OUT049,1999,Medium,Tier 1,Supermarket Type1,...,Low Fat,0.082171,Soft Drinks,190.953,OUT013,1987,High,Tier 3,Supermarket Type1,2466.789
1,7.39,Low Fat,0.082367,Soft Drinks,187.753,OUT049,1999,Medium,Tier 1,Supermarket Type1,...,Low Fat,0.081841,Soft Drinks,190.053,OUT027,1985,Medium,Tier 3,Supermarket Type3,569.259
2,7.39,Low Fat,0.082367,Soft Drinks,187.753,OUT049,1999,Medium,Tier 1,Supermarket Type1,...,Low Fat,0.082367,Soft Drinks,187.753,OUT049,1999,Medium,Tier 1,Supermarket Type1,1518.024
3,7.39,Low Fat,0.082171,Soft Drinks,190.953,OUT013,1987,High,Tier 3,Supermarket Type1,...,Low Fat,0.082171,Soft Drinks,190.953,OUT013,1987,High,Tier 3,Supermarket Type1,2466.789
4,7.39,Low Fat,0.082171,Soft Drinks,190.953,OUT013,1987,High,Tier 3,Supermarket Type1,...,Low Fat,0.081841,Soft Drinks,190.053,OUT027,1985,Medium,Tier 3,Supermarket Type3,569.259


#### Broadcast Variable ####

In [54]:
# Create a countries dictionary
countries_dict = {
    840:"USA",
    288:"GHANA"
}

In [55]:
# Broadcast the dictionary variable
broadcast_var = spark.sparkContext.broadcast(countries_dict)

In [56]:
# Check values of Broadcast variable
broadcast_var.value

{840: 'USA', 288: 'GHANA'}

In [57]:
# Get value of specific Broadcast variable
broadcast_var.value.get(840)

'USA'

In [59]:
# Mapper function
def mapper(x) -> str():

    return broadcast_var.value.get(x)

In [60]:
# Convert function to UDF
mapper_udf = udf(mapper)

In [63]:
# Implement Mapper function
df_siblings_map = df_siblings.withColumn("CountryName", mapper_udf("CountryId"))

In [64]:
# Check first N records
df_siblings_map.limit(5).toPandas()

Unnamed: 0,Id,Name,Age,Gender,Salary,CountryId,CountryName
0,1,Kwaku Jude,40,M,10090.5,840,USA
1,2,Yaw David,36,M,9001.1,288,GHANA
2,3,Kofi Baffuor,34,M,8200.99,288,GHANA
3,4,Abena Salo,32,F,7905.0,288,GHANA
4,5,Abena Pat,30,F,7005.19,288,GHANA


#### Salting ####

In [74]:
# Create dummy data to demonstrate SALTING
salt_data = [ \
    ("A", 100), ("A", 200), ("A", 300), ("B", 100), ("C", 200), ("A", 100), ("B", 200), ("C", 100), ("B", 400), ("A", 100), ("C", 200), ("A", 500) \
]

In [75]:
# Create Spark dataframe
df_salt = spark.createDataFrame(salt_data, ["product_id","amount"])

In [76]:
# Add Salt column
df_salt = df_salt.withColumn("salt_value", floor(rand() * 3))

In [77]:
# Concatenate product_id & salt_value columns
df_salt = df_salt.withColumn("product_id_salt", concat(col("product_id"), lit(" - "), col("salt_value")))

In [78]:
# Check first N records
df_salt.limit(5).toPandas()

Unnamed: 0,product_id,amount,salt_value,product_id_salt
0,A,100,0,A - 0
1,A,200,1,A - 1
2,A,300,1,A - 1
3,B,100,0,B - 0
4,C,200,2,C - 2


In [80]:
# Group dataframe by product_id_salt
df_salt \
    .groupBy("product_id_salt") \
        .agg(sum("amount").alias("total_sum")) \
            .limit(5) \
                .toPandas()

Unnamed: 0,product_id_salt,total_sum
0,A - 0,100
1,A - 1,500
2,B - 0,100
3,C - 2,500
4,A - 2,700
