In [None]:
# !pip show pyspark

In [None]:
!apt-get update
!apt-get install openjdk-8-jdk -y

In [None]:
!rm -rf spark-3.2.0-bin-hadoop3.2
!rm -rf spark-3.0.3-bin-hadoop3.2
!rm -f spark-3.*.tgz

In [None]:
!java -version

In [None]:
# 1. Install Dependencies
# Install JDK 8 (still necessary for the JVM)
!apt-get install openjdk-8-jdk-headless -qq > /dev/null

# Install the correct PySpark version (3.5.1) and a utility library
# The 'pyspark' package contains all necessary Java binaries, simplifying setup.
!pip install -q pyspark==3.5.1

# 2. Set Java Home
import os
from pyspark.sql import SparkSession

# Set JAVA_HOME, which is often the final piece needed for the JVM
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"

# 3. Create the SparkSession
# PySpark will now use its internal libraries and findspark is not strictly needed.
spark = SparkSession.builder \
    .appName("ColabSparkAuto") \
    .master("local[*]") \
    .config("spark.driver.memory", "6g") \
    .getOrCreate()

print("\n---")
print("Spark initialized successfully! Spark Version:", spark.version)

In [None]:
import os

# Set environment variables for Java
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.1.2-bin-hadoop3.2"
os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]


In [None]:
from pyspark.sql.functions import broadcast

category_data = [
    (1, "Electronics"),
    (2, "Apparel"),
    (3, "Home Goods"),
    (4, "Media"),
    (5, "Books")
]
category_cols = ["category_id", "category_name"]
small_dim_df = spark.createDataFrame(category_data, category_cols)

# Large DF
large_transaction_df = spark.range(1000000) \
    .withColumnRenamed("id","transaction_id") \
    .withColumn("category_id", (col("transaction_id") % 5) + 1) # Add IDs 1 to 5

print(f"Small DF Count: {small_dim_df.count()}, Large DF Count: {large_transaction_df.count()}")

start_time = time.time()
united_small_large_df = large_transaction_df.join(
    broadcast(small_dim_df),
    on = "category_id",
    how = "inner"
)
united_small_large_df.show(10)
time_broadcast_join = time.time() - start_time

print(f"time with broadcast join: {time_broadcast_join}")

Small DF Count: 5, Large DF Count: 1000000
+-----------+--------------+-------------+
|category_id|transaction_id|category_name|
+-----------+--------------+-------------+
|          1|             0|  Electronics|
|          2|             1|      Apparel|
|          3|             2|   Home Goods|
|          4|             3|        Media|
|          5|             4|        Books|
|          1|             5|  Electronics|
|          2|             6|      Apparel|
|          3|             7|   Home Goods|
|          4|             8|        Media|
|          5|             9|        Books|
+-----------+--------------+-------------+
only showing top 10 rows

time with broadcast join: 0.5385096073150635
