In [24]:
import sys
import os

In [25]:
os.environ.get('JAVA_HOME')

'C:\\Program Files\\Java\\jdk1.8.0_311'

In [26]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

In [27]:
spark = SparkSession.builder.appName("BroadcastApp").master("local[*]").getOrCreate()

In [28]:
prod_data = [
    ("4123", "Prod1"),
    ("6124", "Prod2"),
    ("5125", "Prod3")
]

In [29]:
prod_dict = spark.sparkContext.parallelize(prod_data).collectAsMap()

In [30]:
data_list = [
    ("4123", "2022-01-01", "1200", "01"),
    ("6124", "2022-01-02", "2345", "01"),
    ("4123", "2022-02-03", "1200", "02"),
    ("5125", "2022-02-04", "2345", "02"),
    ("4123", "2022-02-05", "9812", "02")
]

In [31]:
df = spark.createDataFrame(data_list).toDF("code", "order_date", "price", "qty")

In [32]:
df.show()

+----+----------+-----+---+
|code|order_date|price|qty|
+----+----------+-----+---+
|4123|2022-01-01| 1200| 01|
|6124|2022-01-02| 2345| 01|
|4123|2022-02-03| 1200| 02|
|5125|2022-02-04| 2345| 02|
|4123|2022-02-05| 9812| 02|
+----+----------+-----+---+



In [33]:
print(prod_dict)
print(type(prod_dict))

{'4123': 'Prod1', '6124': 'Prod2', '5125': 'Prod3'}
<class 'dict'>


In [34]:
prod_bc = spark.sparkContext.broadcast(prod_dict)

In [35]:
def get_product(code: str) -> str:
    return prod_bc.value.get(code)

In [36]:
spark.udf.register("get_product_udf", get_product, StringType())

<function __main__.get_product(code: str) -> str>

In [37]:
df.withColumn("product", expr("get_product_udf(code)")).show()

+----+----------+-----+---+-------+
|code|order_date|price|qty|product|
+----+----------+-----+---+-------+
|4123|2022-01-01| 1200| 01|  Prod1|
|6124|2022-01-02| 2345| 01|  Prod2|
|4123|2022-02-03| 1200| 02|  Prod1|
|5125|2022-02-04| 2345| 02|  Prod3|
|4123|2022-02-05| 9812| 02|  Prod1|
+----+----------+-----+---+-------+

