In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql import Row
from pyspark.sql.types import IntegerType

In [2]:
spark = SparkSession.builder.appName("practice").getOrCreate()

2021-09-26 08:52:51,918 WARN util.Utils: Your hostname, tb-LinuxBox resolves to a loopback address: 127.0.1.1; using 10.0.2.15 instead (on interface enp0s3)
2021-09-26 08:52:51,920 WARN util.Utils: Set SPARK_LOCAL_IP if you need to bind to another address
2021-09-26 08:52:54,661 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [3]:
orders = spark.read.format("parquet")\
                    .options(header=True)\
                    .load("file:///home/tamaghna/big_data_spark/sales_parquet")

products = spark.read.format("parquet")\
                        .options(header=True)\
                        .load("file:///home/tamaghna/big_data_spark/products_parquet")

                                                                                

In [4]:
orders.printSchema()

root
 |-- order_id: string (nullable = true)
 |-- product_id: string (nullable = true)
 |-- seller_id: string (nullable = true)
 |-- date: string (nullable = true)
 |-- num_pieces_sold: string (nullable = true)
 |-- bill_raw_text: string (nullable = true)



In [5]:
orders = orders.drop("bill_raw_text")

In [6]:
orders.count()

                                                                                

20000040

In [7]:
products.count()

75000000

In [8]:
# products.count() >> orders.count()
# Joining these tables could skew the result and slow down processing.

In [9]:
# Apply key-salting

In [10]:
result_list = orders.groupBy("product_id").count().sort(col("count").desc()).limit(100).collect()

                                                                                

In [11]:
result_list[:10]

[Row(product_id='0', count=19000000),
 Row(product_id='32602520', count=3),
 Row(product_id='57735075', count=3),
 Row(product_id='20774718', count=3),
 Row(product_id='28183035', count=3),
 Row(product_id='36269838', count=3),
 Row(product_id='18182299', count=3),
 Row(product_id='52606213', count=3),
 Row(product_id='14542470', count=3),
 Row(product_id='40579633', count=3)]

In [12]:
repeated_products = []
l = []
replication_factor = 101

In [13]:
for row_items in result_list:
    repeated_products.append(row_items["product_id"])
    for i in range(replication_factor):
        l.append((row_items["product_id"],i))
l[:10]

[('0', 0),
 ('0', 1),
 ('0', 2),
 ('0', 3),
 ('0', 4),
 ('0', 5),
 ('0', 6),
 ('0', 7),
 ('0', 8),
 ('0', 9)]

In [14]:
rdd = spark.sparkContext.parallelize(l)

In [15]:
repeated_products[:10]

['0',
 '32602520',
 '57735075',
 '20774718',
 '28183035',
 '36269838',
 '18182299',
 '52606213',
 '14542470',
 '40579633']

In [16]:
rdd.take(10)

                                                                                

[('0', 0),
 ('0', 1),
 ('0', 2),
 ('0', 3),
 ('0', 4),
 ('0', 5),
 ('0', 6),
 ('0', 7),
 ('0', 8),
 ('0', 9)]

In [17]:
rdd = rdd.map(lambda tup:Row(product_id=tup[0],replication=int(tup[1])))

In [18]:
rdd.take(5)

[Row(product_id='0', replication=0),
 Row(product_id='0', replication=1),
 Row(product_id='0', replication=2),
 Row(product_id='0', replication=3),
 Row(product_id='0', replication=4)]

In [19]:
replicatedDF = spark.createDataFrame(rdd)

In [20]:
replicatedDF.show()

+----------+-----------+
|product_id|replication|
+----------+-----------+
|         0|          0|
|         0|          1|
|         0|          2|
|         0|          3|
|         0|          4|
|         0|          5|
|         0|          6|
|         0|          7|
|         0|          8|
|         0|          9|
|         0|         10|
|         0|         11|
|         0|         12|
|         0|         13|
|         0|         14|
|         0|         15|
|         0|         16|
|         0|         17|
|         0|         18|
|         0|         19|
+----------+-----------+
only showing top 20 rows



In [21]:
salted_products = products.join(broadcast(replicatedDF),products.product_id==replicatedDF.product_id,"left")\
                                .withColumn("salted_join_key",when(replicatedDF.replication.isNull()\
                                                                  ,products.product_id).otherwise(\
                                                                     concat(replicatedDF.product_id,lit("-"),replicatedDF.replication)
                                                                                                )
                                                               )

In [22]:
salted_products.show()

[Stage 13:>                                                         (0 + 1) / 1]

+----------+------------+-----+----------+-----------+---------------+
|product_id|product_name|price|product_id|replication|salted_join_key|
+----------+------------+-----+----------+-----------+---------------+
|         0|   product_0|   22|         0|        100|          0-100|
|         0|   product_0|   22|         0|         99|           0-99|
|         0|   product_0|   22|         0|         98|           0-98|
|         0|   product_0|   22|         0|         97|           0-97|
|         0|   product_0|   22|         0|         96|           0-96|
|         0|   product_0|   22|         0|         95|           0-95|
|         0|   product_0|   22|         0|         94|           0-94|
|         0|   product_0|   22|         0|         93|           0-93|
|         0|   product_0|   22|         0|         92|           0-92|
|         0|   product_0|   22|         0|         91|           0-91|
|         0|   product_0|   22|         0|         90|           0-90|
|     

                                                                                

In [32]:
orders.withColumn("salted_join_key",when(orders.product_id.isin([1,22.])\
                                        ,concat(orders.product_id,lit("-")\
                                        ,round(rand()*(replication_factor-1),0).cast(IntegerType())))\
                                     .oterwise(orders.product_id)).show()

AttributeError: 'DataFrame' object has no attribute 'isin'