In [2]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import time

In [3]:
import os
os.environ["PYSPARK_PYTHON"] = "python"
os.environ["PYSPARK_DRIVER_PYTHON"] = "python"

In [4]:
from pyspark.sql.types import IntegerType
import pyspark.sql.functions as F
from pyspark.sql import SparkSession

In [5]:
#locally use 4 cores with 1 driver pod only

spark = SparkSession.builder.master("local[4]").getOrCreate()
sc = spark.sparkContext
sc.setLogLevel("ERROR")

In [8]:
spark.stop()

In [9]:
#use driver and worker pods

spark = (
    SparkSession
    .builder
    .appName("Testing Salting")
    .master("spark://192.168.1.15:7077")        #ip address of master
    .config("spark.hadoop.hadoop.native.io", "false")
    .getOrCreate()
)

In [10]:
# spark.conf.set("spark.sql.shuffle.partitions", "3")
spark.conf.set("spark.sql.adaptive.enabled", "false")

# Join Skews

In [None]:
transactions_file = "../../data/transactions.parquet"
customers_file = "../../data/customers.parquet"

df_transactions = spark.read.parquet(transactions_file)
df_customers = spark.read.parquet(customers_file)

In [12]:
df_transactions.printSchema()
#df_transactions.show(5, False)

df_customers.printSchema()
#df_customers.show(5, False)

root
 |-- cust_id: string (nullable = true)
 |-- start_date: string (nullable = true)
 |-- end_date: string (nullable = true)
 |-- txn_id: string (nullable = true)
 |-- date: string (nullable = true)
 |-- year: string (nullable = true)
 |-- month: string (nullable = true)
 |-- day: string (nullable = true)
 |-- expense_type: string (nullable = true)
 |-- amt: string (nullable = true)
 |-- city: string (nullable = true)

root
 |-- cust_id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- age: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- birthday: string (nullable = true)
 |-- zip: string (nullable = true)
 |-- city: string (nullable = true)



In [13]:
(
    df_transactions
    .groupBy("cust_id")
    .agg(F.countDistinct("txn_id").alias("ct"))
    .orderBy(F.desc("ct"))
    .show(5, False)
)

+----------+--------+
|cust_id   |ct      |
+----------+--------+
|C0YDPQWPBJ|17539732|
|C89FCEGPJP|7999    |
|CBW3FMEAU7|7999    |
|C3KUDEN3KO|7999    |
|CHNFNR89ZV|7998    |
+----------+--------+
only showing top 5 rows


In [14]:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)  #disable broadcast join

In [15]:
df_txn_details = (
    df_transactions.join(
        df_customers,
        on="cust_id",
        how="inner"
    )
)

In [16]:
start_time = time.time()
df_txn_details.count()
print(f"time taken: {time.time() - start_time}")

39790092

time taken: 21.429769277572632


### SALTING ON SKEWED DATA

In [17]:
SALT_NUM = 5

In [18]:
salt_transactions = df_transactions.withColumn("salt", (F.rand() * SALT_NUM).cast("int"))

In [19]:
#salt_transactions.show(10, False)

salt_transactions.select("cust_id", "salt").show(5, False)

+----------+----+
|cust_id   |salt|
+----------+----+
|C0YDPQWPBJ|2   |
|C0YDPQWPBJ|2   |
|C0YDPQWPBJ|0   |
|C0YDPQWPBJ|3   |
|C0YDPQWPBJ|1   |
+----------+----+
only showing top 5 rows


In [20]:
salt_transactions = salt_transactions.select("cust_id", "txn_id", "expense_type", "amt", "city","salt")

salt_transactions.printSchema()

root
 |-- cust_id: string (nullable = true)
 |-- txn_id: string (nullable = true)
 |-- expense_type: string (nullable = true)
 |-- amt: string (nullable = true)
 |-- city: string (nullable = true)
 |-- salt: integer (nullable = true)



In [21]:
df_customers.printSchema()

root
 |-- cust_id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- age: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- birthday: string (nullable = true)
 |-- zip: string (nullable = true)
 |-- city: string (nullable = true)



In [22]:
salt_customers = (
    df_customers
    .withColumn("salt_values", F.array([F.lit(i) for i in range(SALT_NUM)]))
    .withColumn("salt", F.explode(F.col("salt_values")))
    .select("cust_id", "salt", "salt_values")
)

In [23]:
salt_customers.printSchema()

root
 |-- cust_id: string (nullable = true)
 |-- salt: integer (nullable = false)
 |-- salt_values: array (nullable = false)
 |    |-- element: integer (containsNull = false)



In [24]:
salt_customers.show(8, False)

+----------+----+---------------+
|cust_id   |salt|salt_values    |
+----------+----+---------------+
|C007YEYTX9|0   |[0, 1, 2, 3, 4]|
|C007YEYTX9|1   |[0, 1, 2, 3, 4]|
|C007YEYTX9|2   |[0, 1, 2, 3, 4]|
|C007YEYTX9|3   |[0, 1, 2, 3, 4]|
|C007YEYTX9|4   |[0, 1, 2, 3, 4]|
|C00B971T1J|0   |[0, 1, 2, 3, 4]|
|C00B971T1J|1   |[0, 1, 2, 3, 4]|
|C00B971T1J|2   |[0, 1, 2, 3, 4]|
+----------+----+---------------+
only showing top 8 rows


In [25]:
df_joined = salt_transactions.join(
    salt_customers,
    ["cust_id", "salt"],
    'inner')

In [26]:
#avg txn amt

avg_txn_amt = (
    df_joined
    .groupBy("cust_id")
    .agg(F.avg("amt").alias("avg_txn_amt"))
    .orderBy(F.desc("avg_txn_amt"))
)

In [27]:
avg_txn_amt.show(5, False)

+----------+------------------+
|cust_id   |avg_txn_amt       |
+----------+------------------+
|CRBRTDCWB5|274.74398429833167|
|CA9UYOQ5DA|257.0569479285446 |
|CQYO6YFE5T|256.60914331896555|
|CGN9VRRD9S|254.3261152684782 |
|CMWM4NK1DP|253.45855328620053|
+----------+------------------+
only showing top 5 rows


In [28]:
#START AQE
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")

In [29]:
#avg txn amt

avg_txn_amt = (
    df_joined
    .groupBy("cust_id")
    .agg(F.avg("amt").alias("avg_txn_amt"))
    .orderBy(F.desc("avg_txn_amt"))
)

avg_txn_amt.show(5, False)

+----------+------------------+
|cust_id   |avg_txn_amt       |
+----------+------------------+
|CRBRTDCWB5|274.7439842983318 |
|CA9UYOQ5DA|257.0569479285441 |
|CQYO6YFE5T|256.60914331896566|
|CGN9VRRD9S|254.3261152684782 |
|CMWM4NK1DP|253.45855328620053|
+----------+------------------+
only showing top 5 rows


In [30]:
salt_transactions.groupBy("cust_id").count().show()

+----------+-----+
|   cust_id|count|
+----------+-----+
|CEEPXNQ9NQ| 7625|
|C00WRSJF1Q| 7777|
|C9UFHRMRQE| 7214|
|C8EKBL7G9T| 7370|
|CF4T0Z3WJ2| 7152|
|COEAGSIT29| 7214|
|C4P7AEI6CC| 7713|
|CP2D91H1S8| 7782|
|C8HE0MGBDU| 7990|
|CS86SPLSV8| 7255|
|CK9WUHS2O1| 7431|
|CNERMUJRNI| 7705|
|C48FW6NX52| 7482|
|CV3G07K3WC| 7600|
|CA58S3H7V3| 7458|
|CAZMBDSMA8| 7231|
|CVV3KZW395| 7133|
|CEL4JSH0AT| 7810|
|CE9AFTP6UO| 7270|
|CW1X61XN86| 7757|
+----------+-----+
only showing top 20 rows


### INCREASE SALT NUMBER

**not always a better thing to increase salt number because then the explode() command increases unnecessary redundant data**

In [None]:
SALT_NUM = 10
salt_transactions = df_transactions.withColumn("salt", (F.rand() * SALT_NUM).cast("int"))

In [None]:
salt_transactions.printSchema()

salt_transactions = salt_transactions.select("cust_id","txn_id", "expense_type", "amt", "city","salt")
salt_transactions.show(5, False)

In [None]:
salt_customers = (
    df_customers
    .withColumn("salt_values", F.array([F.lit(i) for i in range(SALT_NUM)]))
    .withColumn("salt", F.explode(F.col("salt_values")))
    .select("cust_id","age", "city", "salt", "salt_values")
)

In [None]:
salt_customers.printSchema()
salt_customers.show(10, False)

In [None]:
salt_customers.drop("salt_values")

In [None]:
df_joined_2 = salt_transactions.join(
    salt_customers,
    ["cust_id", "salt"],
    'inner'
)

In [None]:
df_joined_2 = df_joined_2.withColumn("amt", F.col("amt").cast(IntegerType()))

In [None]:
df_joined_2.printSchema()

In [None]:
#sum txn amt group by cust_id and expense_type

sum_txn_amt = (
    df_joined
    .groupBy("cust_id", "expense_type")
    .agg(F.sum("amt").alias("sum"))
)

In [None]:
sum_txn_amt.show(truncate = False)

In [None]:
df_joined_2.select("expense_type").distinct().show()

### Trying Double Salting (Manually) without AQE

**again redundant as salted before with 5 (or 10) and then salting again with 5 is hectic**

**can be tried if the original salt number is less say 2 or 3**

In [31]:
spark.conf.set("spark.sql.adaptive.enabled", "false")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "false")

In [None]:
df_joined_2.printSchema()
#df_joined_2.show(5, truncate = False)

In [None]:
salt_customers = (
    df_customers
    .drop("city")
    .withColumn("salt_values", F.array([F.lit(i) for i in range(SALT_NUM)]))
    .withColumn("salt", F.explode(F.col("salt_values")))
    .select("cust_id","age", "salt", "salt_values")
)

In [None]:
salt_transactions.printSchema()
salt_customers.printSchema()

SALT_NUM

In [None]:
df_joined_3 = salt_transactions.join(
    salt_customers,
    ["cust_id", "salt"],
    'inner'
)

In [None]:
#SALT AGAIN for better salting
SALT_NUM_2  = 5

#salt_transactions = df_transactions.withColumn("salt", (F.rand() * SALT_NUM).cast("int"))
salt_again_txn = (
    df_joined_3
    .withColumn("salt2", (F.rand() * SALT_NUM_2).cast("int"))
)

In [None]:
salt_again_customers = (
    df_joined_3
    .withColumn("salt_values2", F.array([F.lit(i) for i in range(SALT_NUM)]))
    .withColumn("salt2", F.explode(F.col("salt_values2")))
)

In [None]:
salt_again_txn.printSchema()
salt_again_customers.printSchema()

In [None]:
df_joined_3 = salt_again_txn.join(
    salt_again_customers,
    ["cust_id", "salt","salt2"],
    'inner'
)

In [None]:
df_joined.printSchema()

In [32]:
#spark.stop()