In [3]:
from pyspark.sql import SparkSession
from pyspark.sql.types import DoubleType
from pyspark.sql.functions import col
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Initialize Spark session
spark = SparkSession.builder.config("spark.driver.memory", "16g").appName('DecisionTreeExample').getOrCreate()

# Load dataset
data_without_header = spark.read.option("inferSchema", True).option("header", False).csv("covtype.data")

# Define column names
colnames = ["Elevation", "Aspect", "Slope",
            "Horizontal_Distance_To_Hydrology",
            "Vertical_Distance_To_Hydrology", "Horizontal_Distance_To_Roadways",
            "Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm",
            "Horizontal_Distance_To_Fire_Points"] + \
           [f"Wilderness_Area_{i}" for i in range(4)] + \
           [f"Soil_Type_{i}" for i in range(40)] + \
           ["Cover_Type"]

# Apply column names
data = data_without_header.toDF(*colnames)

# Cast Cover_Type column to DoubleType
data = data.withColumn("Cover_Type", col("Cover_Type").cast(DoubleType()))

# Display basic statistics and schema
print("Schema:")
data.printSchema()
print("Basic Statistics:")
data.describe().show()

# Handle missing values - Fill missing values with 0
data_filled = data.fillna(0)

# Assemble features into a single vector
input_cols = colnames[:-1]  # All columns except the label column 'Cover_Type'
vector_assembler = VectorAssembler(inputCols=input_cols, outputCol="featureVector")
data_assembled = vector_assembler.transform(data_filled)

# Split data into training and test sets
(train_data, test_data) = data_assembled.randomSplit([0.9, 0.1])
train_data.cache()
test_data.cache()

# Initialize the Decision Tree Classifier
classifier = DecisionTreeClassifier(seed=1234, labelCol="Cover_Type", featuresCol="featureVector", predictionCol="prediction")

# Train the model
model = classifier.fit(train_data)

# Display the model's debug string
print("Model Debug String:")
print(model.toDebugString)

# Make predictions
predictions = model.transform(test_data)

# Initialize evaluator
evaluator = MulticlassClassificationEvaluator(labelCol="Cover_Type", predictionCol="prediction")

# Evaluate accuracy
accuracy = evaluator.setMetricName("accuracy").evaluate(predictions)
print(f"Accuracy: {accuracy}")

# Evaluate F1 score
f1 = evaluator.setMetricName("f1").evaluate(predictions)
print(f"F1 Score: {f1}")

# Display confusion matrix
confusion_matrix = predictions.groupBy("Cover_Type").pivot("prediction", range(1, 8)).count().na.fill(0.0).orderBy("Cover_Type")
print("Confusion Matrix:")
confusion_matrix.show()

# Display feature importances
import pandas as pd
feature_importances = pd.DataFrame(model.featureImportances.toArray(), index=input_cols, columns=['importance'])
feature_importances = feature_importances.sort_values(by="importance", ascending=False)
print("Feature Importances:")
print(feature_importances)


                                                                                

Schema:
root
 |-- Elevation: integer (nullable = true)
 |-- Aspect: integer (nullable = true)
 |-- Slope: integer (nullable = true)
 |-- Horizontal_Distance_To_Hydrology: integer (nullable = true)
 |-- Vertical_Distance_To_Hydrology: integer (nullable = true)
 |-- Horizontal_Distance_To_Roadways: integer (nullable = true)
 |-- Hillshade_9am: integer (nullable = true)
 |-- Hillshade_Noon: integer (nullable = true)
 |-- Hillshade_3pm: integer (nullable = true)
 |-- Horizontal_Distance_To_Fire_Points: integer (nullable = true)
 |-- Wilderness_Area_0: integer (nullable = true)
 |-- Wilderness_Area_1: integer (nullable = true)
 |-- Wilderness_Area_2: integer (nullable = true)
 |-- Wilderness_Area_3: integer (nullable = true)
 |-- Soil_Type_0: integer (nullable = true)
 |-- Soil_Type_1: integer (nullable = true)
 |-- Soil_Type_2: integer (nullable = true)
 |-- Soil_Type_3: integer (nullable = true)
 |-- Soil_Type_4: integer (nullable = true)
 |-- Soil_Type_5: integer (nullable = true)
 |-- S

                                                                                

+-------+-----------------+------------------+------------------+--------------------------------+------------------------------+-------------------------------+------------------+------------------+------------------+----------------------------------+------------------+--------------------+-------------------+-------------------+--------------------+-------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-------------------+--------------------+--------------------+--------------------+-------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-------------------+-------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-------------------+-------------------+-------------------+------

24/09/09 09:23:20 WARN MemoryStore: Not enough space to cache rdd_65_6 in memory! (computed 2.3 MiB so far)
24/09/09 09:23:20 WARN MemoryStore: Not enough space to cache rdd_65_11 in memory! (computed 2.3 MiB so far)
24/09/09 09:23:20 WARN MemoryStore: Not enough space to cache rdd_65_9 in memory! (computed 2.2 MiB so far)
24/09/09 09:23:20 WARN MemoryStore: Not enough space to cache rdd_65_8 in memory! (computed 2.2 MiB so far)
24/09/09 09:23:20 WARN MemoryStore: Not enough space to cache rdd_65_10 in memory! (computed 2.3 MiB so far)
24/09/09 09:23:20 WARN MemoryStore: Not enough space to cache rdd_65_4 in memory! (computed 2.2 MiB so far)
24/09/09 09:23:20 WARN MemoryStore: Not enough space to cache rdd_65_7 in memory! (computed 2.3 MiB so far)
24/09/09 09:23:20 WARN MemoryStore: Not enough space to cache rdd_65_12 in memory! (computed 2.3 MiB so far)
24/09/09 09:23:20 WARN MemoryStore: Not enough space to cache rdd_65_15 in memory! (computed 2.2 MiB so far)
24/09/09 09:23:20 WARN M

Model Debug String:
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_51b48f314af6, depth=5, numNodes=41, numClasses=8, numFeatures=54
  If (feature 0 <= 3047.5)
   If (feature 0 <= 2497.5)
    If (feature 3 <= 15.0)
     If (feature 12 <= 0.5)
      If (feature 23 <= 0.5)
       Predict: 4.0
      Else (feature 23 > 0.5)
       Predict: 3.0
     Else (feature 12 > 0.5)
      Predict: 6.0
    Else (feature 3 > 15.0)
     If (feature 16 <= 0.5)
      Predict: 3.0
     Else (feature 16 > 0.5)
      If (feature 9 <= 1327.0)
       Predict: 3.0
      Else (feature 9 > 1327.0)
       Predict: 4.0
   Else (feature 0 > 2497.5)
    If (feature 17 <= 0.5)
     If (feature 15 <= 0.5)
      Predict: 2.0
     Else (feature 15 > 0.5)
      Predict: 3.0
    Else (feature 17 > 0.5)
     If (feature 0 <= 2711.5)
      Predict: 3.0
     Else (feature 0 > 2711.5)
      If (feature 5 <= 1238.0)
       Predict: 5.0
      Else (feature 5 > 1238.0)
       Predict: 2.0
  Else (feature 0 > 3047.5)
 

                                                                                

Accuracy: 0.7021008910259725
F1 Score: 0.6850645240220538
Confusion Matrix:
+----------+-----+-----+----+---+---+---+----+
|Cover_Type|    1|    2|   3|  4|  5|  6|   7|
+----------+-----+-----+----+---+---+---+----+
|       1.0|13160| 7281|  14|  0|  4|  0| 612|
|       2.0| 4748|23189| 388|  6| 40|  4| 100|
|       3.0|    0|  426|3036| 36|  4| 10|   0|
|       4.0|    0|    1| 157|122|  0|  0|   0|
|       5.0|    0|  902|  29|  2| 46|  0|   0|
|       6.0|    0|  459|1137| 13|  0| 53|   0|
|       7.0|  886|   26|   0|  0|  0|  0|1132|
+----------+-----+-----+----+---+---+---+----+

Feature Importances:
                                    importance
Elevation                             0.832390
Soil_Type_3                           0.037451
Soil_Type_1                           0.031900
Hillshade_Noon                        0.026854
Horizontal_Distance_To_Hydrology      0.023093
Soil_Type_31                          0.018140
Wilderness_Area_2                     0.015855
Horizonta