In [1]:
import os

# 1. Install OpenJDK 21 (if not already done in a previous cell)
!apt-get update -qq
!apt-get install -qq openjdk-21-jdk-headless

# 2. Verify where it landed (if needed)
!ls /usr/lib/jvm | grep 21

# 3. Point to JDK 21
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-21-openjdk-amd64"
os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]

# 4. Install PySpark via pip (make sure this happens AFTER setting JAVA_HOME)
!pip install pyspark --quiet

# 5. Import and start Spark
from pyspark.sql import SparkSession
spark = (
    SparkSession.builder
      .master("local[*]")
      .appName("PySpark-DecisionTreeClassifier_Iris")
      .getOrCreate()
)


W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)
Selecting previously unselected package openjdk-21-jre-headless:amd64.
(Reading database ... 126102 files and directories currently installed.)
Preparing to unpack .../openjdk-21-jre-headless_21.0.7+6~us1-0ubuntu1~22.04_amd64.deb ...
Unpacking openjdk-21-jre-headless:amd64 (21.0.7+6~us1-0ubuntu1~22.04) ...
Selecting previously unselected package openjdk-21-jdk-headless:amd64.
Preparing to unpack .../openjdk-21-jdk-headless_21.0.7+6~us1-0ubuntu1~22.04_amd64.deb ...
Unpacking openjdk-21-jdk-headless:amd64 (21.0.7+6~us1-0ubuntu1~22.04) ...
Setting up openjdk-21-jre-headless:amd64 (21.0.7+6~us1-0ubuntu1~22.04) ...
update-alternatives: using /usr/lib/jvm/java-21-openjdk-amd64/bin/java to provide /usr/bin/java (java) in auto mode
update-alternatives: using /usr/lib/jvm/java-21-openjdk-amd64/bin/jpackage to

In [2]:
# Mounting Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


# 1. Import libraries

In [3]:
# Import necessary libraries
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator


2. Load dataset

In [6]:
# Load the Iris dataset (assuming you have it in a CSV format)
iris_data = spark.read.csv("/content/iris-data.csv", header=True, inferSchema=True)

In [7]:
# Let's assume that the "class" column is our target variable (label)
# and the other columns are our features
feature_cols = iris_data.columns[:-1]

In [8]:
from pyspark.ml.feature import VectorAssembler, StringIndexer
# Convert string labels into numerical labels
indexer = StringIndexer(inputCol="class", outputCol="label")
iris_data = indexer.fit(iris_data).transform(iris_data)

In [9]:
# Create a feature vector by assembling the feature columns
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
data = assembler.transform(iris_data)

In [10]:
# Split the data into training and testing sets
(training_data, testing_data) = data.randomSplit([0.8, 0.2], seed=123)

3. Create Decision Tree Classifier

In [11]:
# Create a DecisionTreeClassifier
dt = DecisionTreeClassifier(labelCol="label", featuresCol="features", maxDepth=5, minInfoGain=0.001, impurity="entropy")
# Train the model
model = dt.fit(training_data)
# Make predictions on the testing data
predictions = model.transform(testing_data)

4. Evaluation

In [12]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# Evaluate the model
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

In [13]:
print(f"Test Accuracy: {accuracy:.2f}")

Test Accuracy: 0.93


5. Visualize the Decision Tree

In [14]:
print(model.toDebugString)

DecisionTreeClassificationModel: uid=DecisionTreeClassifier_95b6da7ee484, depth=4, numNodes=15, numClasses=3, numFeatures=4
  If (feature 2 <= 2.45)
   Predict: 0.0
  Else (feature 2 > 2.45)
   If (feature 3 <= 1.75)
    If (feature 2 <= 4.95)
     If (feature 3 <= 1.65)
      Predict: 1.0
     Else (feature 3 > 1.65)
      Predict: 2.0
    Else (feature 2 > 4.95)
     If (feature 0 <= 6.35)
      Predict: 2.0
     Else (feature 0 > 6.35)
      Predict: 1.0
   Else (feature 3 > 1.75)
    If (feature 2 <= 4.85)
     If (feature 0 <= 5.95)
      Predict: 1.0
     Else (feature 0 > 5.95)
      Predict: 2.0
    Else (feature 2 > 4.85)
     Predict: 2.0



In [15]:
# Stop the Spark session
spark.stop()