In [1]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import *

spark = SparkSession\
    .builder\
    .appName("chapter-14-broadcast-vars")\
    .getOrCreate()

import os
SPARK_BOOK_DATA_PATH = os.environ['SPARK_BOOK_DATA_PATH']

sc = spark.sparkContext

In [2]:
sc

In [3]:
my_collection = "Spark The Definitive Guide : Big Data Processing Made Simple, Spark in the Park, very powerful"\
  .split(" ")
words = sc.parallelize(my_collection, 2)  # numSlices = 2

In [4]:
type(words)

pyspark.rdd.RDD

In [5]:
words.collect()

['Spark',
 'The',
 'Definitive',
 'Guide',
 ':',
 'Big',
 'Data',
 'Processing',
 'Made',
 'Simple,',
 'Spark',
 'in',
 'the',
 'Park,',
 'very',
 'powerful']

In [6]:
words.glom().collect()

[['Spark', 'The', 'Definitive', 'Guide', ':', 'Big', 'Data', 'Processing'],
 ['Made', 'Simple,', 'Spark', 'in', 'the', 'Park,', 'very', 'powerful']]

### Broadcast

push a small shared dataset to worker nodes to avoid shuffle

In [7]:
# map selected word to number
supplementalData = {"Spark":1000, "Definitive":200,
                    "Big":-300, "Simple":100, "Data": 99}

In [8]:
suppBroadcast = sc.broadcast(supplementalData)

In [9]:
type(suppBroadcast)

pyspark.broadcast.Broadcast

In [10]:
# access broadcast var
suppBroadcast.value

{'Spark': 1000, 'Definitive': 200, 'Big': -300, 'Simple': 100, 'Data': 99}

In [11]:
words.map(lambda word: (word, suppBroadcast.value.get(word, -999999999)))\
  .sortBy(lambda wordPair: wordPair[1])\
  .collect()

[('The', -999999999),
 ('Guide', -999999999),
 (':', -999999999),
 ('Processing', -999999999),
 ('Made', -999999999),
 ('Simple,', -999999999),
 ('in', -999999999),
 ('the', -999999999),
 ('Park,', -999999999),
 ('very', -999999999),
 ('powerful', -999999999),
 ('Big', -300),
 ('Data', 99),
 ('Definitive', 200),
 ('Spark', 1000),
 ('Spark', 1000)]

### Accumulator


creating a named accumulator is not possible in pyspark. this issue has already been raised. you can track this thread https://issues.apache.org/jira/browse/SPARK-2868.


global count

In [3]:
file_path = SPARK_BOOK_DATA_PATH + "/data/flight-data/parquet/2010-summary.parquet"

flights = spark.read.parquet(file_path)

In [5]:
accChina = sc.accumulator(0)

this works only in scala
```
spark.sparkContext.register(accChina, "accChina")
```

In [6]:
type(accChina)

pyspark.accumulators.Accumulator

In [7]:
def accChinaFunc(flight_row):
    if flight_row["DEST_COUNTRY_NAME"] == "China" or flight_row["ORIGIN_COUNTRY_NAME"] == "China":
        accChina.add(flight_row["count"])

`foreach()` to process each row

In [8]:
flights.foreach(lambda flight_row: accChinaFunc(flight_row))

In [9]:
accChina.value # 953

953

#### verify accumulator via DataFrame

In [18]:
flights.filter("DEST_COUNTRY_NAME == 'China' OR ORIGIN_COUNTRY_NAME == 'China'").show()

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|    United States|              China|  505|
|            China|      United States|  448|
+-----------------+-------------------+-----+



In [19]:
flights.where("DEST_COUNTRY_NAME='China' or ORIGIN_COUNTRY_NAME='China'").selectExpr("sum(count)").show()

+----------+
|sum(count)|
+----------+
|       953|
+----------+



#### verify accumulator via SQL

In [20]:
flights.createOrReplaceTempView("flights")

In [21]:
sql_stmt = """
select sum(count) as accChina
from flights 
where DEST_COUNTRY_NAME='China' or ORIGIN_COUNTRY_NAME='China' 
"""
spark.sql(sql_stmt).show()

+--------+
|accChina|
+--------+
|     953|
+--------+



### RDD.glom()

Return an RDD created by coalescing all elements within each partition
into a list.

Examine how data is partitioned

In [22]:
rdd = sc.parallelize(range(15), 4)

In [23]:
type(rdd)

pyspark.rdd.PipelinedRDD

In [24]:
rdd.getNumPartitions()

4

In [25]:
rdd.collect()

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

In [26]:
rdd.glom().collect()

[[0, 1, 2], [3, 4, 5, 6], [7, 8, 9, 10], [11, 12, 13, 14]]

In [27]:
sc.parallelize([0, 2, 3, 4, 6, 7], 5).glom().collect()
# [[0], [2], [3], [4], [6]]

[[0], [2], [3], [4], [6, 7]]

In [28]:
spark.stop()