#Spark training

## Get spark's configuration

In [0]:
scaling_type = spark.sparkContext.getConf().get('spark.databricks.clusterUsageTags.clusterScalingType')
min_executors = spark.sparkContext.getConf().get('spark.databricks.clusterUsageTags.clusterMinWorkers')
max_executors = spark.sparkContext.getConf().get('spark.databricks.clusterUsageTags.clusterMaxWorkers')
target_executors = spark.sparkContext.getConf().get('spark.databricks.clusterUsageTags.clusterWorkers')

executor_instances = spark.sparkContext.getConf().get('spark.executor.instances')
executor_cores = spark.sparkContext.getConf().get('spark.executor.cores')
executor_memory = spark.sparkContext.getConf().get('spark.executor.memory')

driver_cores = spark.sparkContext.getConf().get('spark.driver.cores')
driver_memory = spark.sparkContext.getConf().get('spark.driver.memory')

print(f'''
Scaling type: {scaling_type}
Min executors: {min_executors}
Max workers: {max_executors}
Current workers: {target_executors}
----------------------------------------
Executor instances: {executor_instances}
Executor cores: {executor_cores}
Executor memory: {executor_memory}
----------------------------------------
Driver cores: {driver_cores}
Driver memory: {driver_memory}
''')

### Import libraries

In [0]:
import pyspark.sql.functions as F
import pyspark.sql.types as T
import datetime

In [0]:
c_df = spark.read.format('parquet').load('/mnt/dls/data/small/customers')
c_df.display()

In [0]:
t_df = spark.read.format('parquet').load('/mnt/dls/data/small/transactions')
t_df.display()

### Skewed data

In [0]:
t_df.groupBy('cust_id').agg(F.count(F.col('txn_id')).alias('count')).sort(F.desc(F.col('count'))).display()

Databricks visualization. Run in Databricks to view.

### AQE Disabled, sort merge join

In [0]:
# TEMPORARY disable auto broadcast join
spark.conf.set('spark.sql.autoBroadcastJoinThreshold', -1)
spark.conf.set('spark.sql.adaptive.enabled', 'false')
spark.conf.set('spark.sql.adaptive.join.enabled', 'false')

In [0]:
joined_df = t_df.join(c_df, on='cust_id', how='inner')

In [0]:
ts = datetime.datetime.now()
joined_df.write.format('noop').mode('overwrite').save('/mnt/dls/results/x')
pt_01 = (datetime.datetime.now() - ts).seconds
print(f'The processing time was {pt_01} seconds')

### AQE Enabled, broadcast join

In [0]:
# RE-ENABLE auto broadcast join
spark.conf.set('spark.sql.autoBroadcastJoinThreshold', 10485760)
spark.conf.set('spark.sql.adaptive.enabled', 'true')
spark.conf.set('spark.sql.adaptive.join.enabled', 'true')

In [0]:
joined_df = t_df.join(c_df, on='cust_id', how='inner')

In [0]:
ts = datetime.datetime.now()
joined_df.write.format('noop').mode('overwrite').save('/mnt/dls/results/x')
pt_02 = (datetime.datetime.now() - ts).seconds
print(f'The processing time was {pt_02} seconds')

### Salting

In [0]:
t_df.rdd.getNumPartitions()

In [0]:
c_df.rdd.getNumPartitions()

In [0]:
num_salts = 10 # t_df.rdd.getNumPartitions()
default_salt = 1

In [0]:
# Keeys thatneed salting
target_keys = ['C0YDPQWPBJ']

In [0]:
t_salted_df = \
    t_df \
        .filter(
            F.col('cust_id').isin(target_keys)) \
        .withColumn(
            'salt',
            (F.rand() * num_salts).cast('int')) \
    .union(
        t_df \
            .filter(
                ~F.col('cust_id').isin(target_keys)) \
            .withColumn(
                'salt',
                F.lit(default_salt)
                )
    )

In [0]:
t_salted_df.display()

In [0]:
c_salted_df = c_df \
    .filter(
        F.col('cust_id').isin(target_keys)) \
    .withColumn(
        'salted_array',
        F.array([F.lit(i) for i in range(num_salts)])) \
    .withColumn(
        'salt',
        F.explode(F.col('salted_array'))) \
    .union(
        c_df \
            .filter(
                ~F.col('cust_id').isin(target_keys)) \
            .withColumn(
                'salted_array',
                F.array([F.lit(default_salt)])) \
            .withColumn(
                'salt',
                F.lit(default_salt))
    )

In [0]:
c_salted_df.display()

In [0]:
t_salted_df = t_salted_df.withColumn('new_key', F.concat_ws('_', F.col('cust_id'), F.col('salt')))
t_salted_df.display()

In [0]:
x_df = t_salted_df.groupBy('new_key').agg(F.count(F.col('txn_id')).alias('count')).sort(F.desc(F.col('count')))
#x_df.cache()
x_df.display()

In [0]:
x_df.explain()

In [0]:
x_df = x_df.withColumn('original_key', F.split(F.col('new_key'), '_')[0])
x_df.display()

In [0]:
x_df.groupBy('original_key').count().display()

#### Testing

In [0]:
# TEMPORARY disable auto broadcast join
spark.conf.set('spark.sql.autoBroadcastJoinThreshold', -1)
spark.conf.set('spark.sql.adaptive.enabled', 'false')
spark.conf.set('spark.sql.adaptive.join.enabled', 'false')

In [0]:
salted_joined_df = t_salted_df.join(c_salted_df, on=['cust_id', 'salt'], how='inner')

In [0]:
ts = datetime.datetime.now()
salted_joined_df.write.format('noop').mode('overwrite').save('/mnt/dls/results/x')
pt_04 = (datetime.datetime.now() - ts).seconds
print(f'The processing time was {pt_04} seconds')

In [0]:
# RE-ENABLE auto broadcast join
spark.conf.set('spark.sql.autoBroadcastJoinThreshold', 10485760)
spark.conf.set('spark.sql.adaptive.enabled', 'true')
spark.conf.set('spark.sql.adaptive.join.enabled', 'true')