# Performance Tuning

**Spark configuration:**
- Executor memory: `1G - 200M (reserved) * 0.6 (fraction) * 0.9 (1 - storageFraction) = 432MB`

**Skewed input data:**
- Transactions:
  - Larges partition has 3.8M records (grouped by `instrument_id`)
  - 3x Integer (4B), 1x Timestamp (8B)
  - Size of larges partition: 76MB

Inspect:
- Distribution of task runtimes (straggler tasks)
- Disk spill
- Shuffle Read & Write
- Shuffle Partitions (stragglers after exchange)


In [1]:
from pyspark.sql import SparkSession, DataFrame, Window
from pyspark.sql.functions import *
import datetime

In [2]:
spark = SparkSession.builder \
    .appName("performance-optimizations") \
    .config("spark.sql.adaptive.enabled", "false") \
    .config("spark.executor.memory", "4g") \
    .getOrCreate()

In [3]:
tx_df = spark.read.parquet('/data/gen/transaction')
tx_uniform_df = spark.read.parquet('/data/gen/transaction-uniform')

In [4]:
tx_df.groupBy('instrument_id').count().sort(col('count').desc()).show()
tx_uniform_df.groupBy('instrument_id').count().sort(col('count').desc()).show()

+-------------+-------+
|instrument_id|  count|
+-------------+-------+
|            2|3887000|
|            1|3875000|
|            3|1399000|
|            4| 517000|
|            5| 218000|
|            6|  59000|
|            7|  24000|
|            8|  13000|
|           10|   4000|
|            9|   4000|
+-------------+-------+

+-------------+-----+
|instrument_id|count|
+-------------+-----+
|         9760| 1124|
|         8261| 1122|
|         3496| 1121|
|         4740| 1120|
|         6322| 1118|
|          326| 1109|
|         8838| 1108|
|         3431| 1108|
|         6366| 1107|
|         1598| 1105|
|          473| 1102|
|         4399| 1099|
|         9975| 1098|
|         6633| 1095|
|         6738| 1094|
|         9379| 1093|
|         3662| 1093|
|         7037| 1093|
|          279| 1093|
|         2896| 1093|
+-------------+-----+
only showing top 20 rows



In [9]:
tx_df.select('instrument_id').distinct().count()

10

In [5]:
tx_df.groupBy('instrument_id').count().sort(col('count').desc()).show()

+-------------+-------+
|instrument_id|  count|
+-------------+-------+
|            2|3887000|
|            1|3875000|
|            3|1399000|
|            4| 517000|
|            5| 218000|
|            6|  59000|
|            7|  24000|
|            8|  13000|
|           10|   4000|
|            9|   4000|
+-------------+-------+



## Example 1: Expensive window function

### AQE enabled

Observations:

- Stage 1: Read parquet & exchange
  - 8 Tasks, even record distribution, no stragglers
  - Shuffle write 18MB / task
- Stage 2: AQEShuffleRead, Window, Write
  - Straggler tasks, uneven distribution of records
  - `AQEShuffleRead` 4 shuffle partitions (coalesce)
  - Disk spill (80MB) only in few partitions
  - No serialization, shuffle write (write to file)
 
### AQE disabled

Observations:

- Stage 1: Same execution
- Stage 2:
  - 200 Tasks, uneven distribution of records
  - Disk Spill (80MB) in few partitions

**Not really optimizable query except for the shuffle partitions.**

In [4]:
spark.conf.set("spark.sql.adaptive.enabled", "true")

In [5]:
default_ts = datetime.datetime.fromisoformat('2023-01-01T00:00:00')
window = Window \
    .partitionBy('instrument_id') \
    .orderBy(col('transaction_ts'))
expensive_to_calculate = tx_df \
    .withColumn('lag_date', lag('transaction_ts', 1, default_ts).over(window))

In [6]:
expensive_to_calculate.write.mode("overwrite").parquet('/tmp/unused-result.parquet')

In [8]:
expensive_to_calculate.explain("extended")

== Parsed Logical Plan ==
'Project [transaction_id#0, instrument_id#1, trader_id#2, transaction_ts#3, lag('transaction_ts, -1, 2023-01-01 00:00:00) windowspecdefinition('instrument_id, 'transaction_ts ASC NULLS FIRST, unspecifiedframe$()) AS lag_date#56]
+- Relation [transaction_id#0,instrument_id#1,trader_id#2,transaction_ts#3] parquet

== Analyzed Logical Plan ==
transaction_id: int, instrument_id: int, trader_id: int, transaction_ts: timestamp, lag_date: timestamp
Project [transaction_id#0, instrument_id#1, trader_id#2, transaction_ts#3, lag_date#56]
+- Project [transaction_id#0, instrument_id#1, trader_id#2, transaction_ts#3, lag_date#56, lag_date#56]
   +- Window [lag(transaction_ts#3, -1, 2023-01-01 00:00:00) windowspecdefinition(instrument_id#1, transaction_ts#3 ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, -1)) AS lag_date#56], [instrument_id#1], [transaction_ts#3 ASC NULLS FIRST]
      +- Project [transaction_id#0, instrument_id#1, trader_id#2, transaction_ts#3]
        

## Example 2: Skewed join

- Join of two datasets
  - Transaction: highly-skewed on `instrument_id`, 10M records
  - Instrument: uniform distribution w.r.t `instrument_id`

### AQE & BroadcastJoin disabled

Observations:
- execution time 11s
- Stage 1: Read transactions
  - 8 tasks, even distribution, 17MB shuffle write
- Stage 2: Read instruments
  - 8 tasks, even distribution, 35KB shuffle write
  - high serialization time (why?)
- Stage 3:
  - Straggler tasks (few), critical path, highly uneven distribution of records
  - Disk spill: 45MB (max) in few partitions

### AQE & BroadcastJoin enabled

Observations:
- execution time 5s
- optimized shuffle partitions (8)
- only two stages
  - Stage 1: Load instruments
  - Stage 2: Load transactions, join & write to parquet
    - even distribution of records, no stragglers
    - 8 tasks
    - no disk spill

**Very well optimizable due to nature of input data: Instrument DataFrame is smaller than `broadcastJoinThreshold`.**

Use `instrument_df.hint("broadcast")` for an explicit hint.

In [27]:
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 10485760)  # default: 10485760 (10mb)
spark.conf.set("spark.sql.shuffle.partitions", 10)

In [28]:
instrument_df = spark.read.parquet('/data/gen/instrument')

In [18]:
tx_df.select('instrument_id').distinct().count()

10

In [33]:
tx_df.join(instrument_df, 'instrument_id') \
    .explain("extended")

== Parsed Logical Plan ==
'Join UsingJoin(Inner, [instrument_id])
:- Relation [transaction_id#0,instrument_id#1,trader_id#2,transaction_ts#3] parquet
+- Relation [instrument_id#139,instrument_type#140,registered_ts#141] parquet

== Analyzed Logical Plan ==
instrument_id: int, transaction_id: int, trader_id: int, transaction_ts: timestamp, instrument_type: string, registered_ts: timestamp
Project [instrument_id#1, transaction_id#0, trader_id#2, transaction_ts#3, instrument_type#140, registered_ts#141]
+- Join Inner, (instrument_id#1 = instrument_id#139)
   :- Relation [transaction_id#0,instrument_id#1,trader_id#2,transaction_ts#3] parquet
   +- Relation [instrument_id#139,instrument_type#140,registered_ts#141] parquet

== Optimized Logical Plan ==
Project [instrument_id#1, transaction_id#0, trader_id#2, transaction_ts#3, instrument_type#140, registered_ts#141]
+- Join Inner, (instrument_id#1 = instrument_id#139)
   :- Filter isnotnull(instrument_id#1)
   :  +- Relation [transaction_id#0

In [29]:
tx_df.join(instrument_df, 'instrument_id') \
    .write.mode('overwrite').parquet('/tmp/unused-result.parquet')

## Salting implementation

In [32]:
tx_df.withColumn('instrument_id_salted', concat(col('instrument_id'), lit('_'), col('transaction_id') % 10)).groupBy('instrument_id_salted').count().orderBy(col('count').desc()).show()

+--------------------+------+
|instrument_id_salted| count|
+--------------------+------+
|                 2_2|415000|
|                 1_1|406000|
|                 1_6|403000|
|                 1_4|401000|
|                 1_9|400000|
|                 1_7|399000|
|                 2_6|396000|
|                 1_3|391000|
|                 2_4|391000|
|                 2_8|391000|
|                 2_3|387000|
|                 2_7|387000|
|                 2_0|386000|
|                 2_9|382000|
|                 2_5|380000|
|                 1_5|376000|
|                 2_1|372000|
|                 1_2|369000|
|                 1_8|368000|
|                 1_0|362000|
+--------------------+------+
only showing top 20 rows



In [41]:
salting_keys = spark.range(0, 10)
instrument_df.join(salting_keys, how='cross').withColumn('instrument_id_salted', concat(col('instrument_id'), lit('_'), col('id'))).filter(col('instrument_id') == 1250).show()

+-------------+---------------+-------------------+---+--------------------+
|instrument_id|instrument_type|      registered_ts| id|instrument_id_salted|
+-------------+---------------+-------------------+---+--------------------+
|         1250|            ETF|2017-10-23 23:00:00|  0|              1250_0|
|         1250|            ETF|2017-10-23 23:00:00|  1|              1250_1|
|         1250|            ETF|2017-10-23 23:00:00|  2|              1250_2|
|         1250|            ETF|2017-10-23 23:00:00|  3|              1250_3|
|         1250|            ETF|2017-10-23 23:00:00|  4|              1250_4|
|         1250|            ETF|2017-10-23 23:00:00|  5|              1250_5|
|         1250|            ETF|2017-10-23 23:00:00|  6|              1250_6|
|         1250|            ETF|2017-10-23 23:00:00|  7|              1250_7|
|         1250|            ETF|2017-10-23 23:00:00|  8|              1250_8|
|         1250|            ETF|2017-10-23 23:00:00|  9|              1250_9|