In [None]:
import pyspark

from pyspark.sql import SparkSession

In [None]:
spark = SparkSession.builder.config("spark.driver.memory", "8g").appName("DecisionTree").getOrCreate()

# Preparing the Data

In [None]:
data_without_header = spark.read.option("inferSchema", True)\
                      .option("header", False).csv("data/covtype.data")

data_without_header.printSchema()

In [None]:
from pyspark.sql.types import DoubleType
from pyspark.sql.functions import col


colnames = [
    "Elevation", "Aspect", "Slope", "Horizontal_Distance_To_Hydrology", \
    "Vertical_Distance_To_Hydrology", "Horizontal_Distance_To_Roadways", \
    "Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm", \
    "Horizontal_Distance_To_Fire_Points"
] + \
    [f"Wilderness_Area_{i}" for i in range(4)] + \
    [f"Soil_Type_{i}" for i in range(40)] + \
    ["Cover_Type"]

data = data_without_header.toDF(*colnames). \
    withColumn("Cover_Type", col("Cover_Type").cast(DoubleType()))

data.head()

# Our First Decision Tree

In [None]:
(train_data, test_data) = data.randomSplit([0.9, 0.1])
train_data.cache()
test_data.cache()

In [None]:
from pyspark.ml.feature import VectorAssembler

input_cols = colnames[:-1]
vector_assembler = VectorAssembler(inputCols=input_cols, outputCol="featureVector")

assembled_train_data = vector_assembler.transform(train_data)

assembled_train_data.select("featureVector").show(truncate = False)

In [None]:
from pyspark.ml.classification import DecisionTreeClassifier

classifier = DecisionTreeClassifier(seed=1234, labelCol="Cover_Type", featuresCol="featureVector",
                                    predictionCol="prediction")

model = classifier.fit(assembled_train_data)
print(model.toDebugString)

In [None]:
import pandas as pd

pd.DataFrame(model.featureImportances.toArray(), 
             index=input_cols, columns=["importance"]). \
    sort_values(by="importance", ascending=False)

In [None]:
predictions = model.transform(assembled_train_data)
predictions.select("Cover_Type", "prediction", "probability").show(10, truncate=False)