In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
data = spark.read.option("inferSchema", True).option("header", False).csv('covtype.data')
data.summary()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/09/14 15:43:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
                                                                                

DataFrame[summary: string, _c0: string, _c1: string, _c2: string, _c3: string, _c4: string, _c5: string, _c6: string, _c7: string, _c8: string, _c9: string, _c10: string, _c11: string, _c12: string, _c13: string, _c14: string, _c15: string, _c16: string, _c17: string, _c18: string, _c19: string, _c20: string, _c21: string, _c22: string, _c23: string, _c24: string, _c25: string, _c26: string, _c27: string, _c28: string, _c29: string, _c30: string, _c31: string, _c32: string, _c33: string, _c34: string, _c35: string, _c36: string, _c37: string, _c38: string, _c39: string, _c40: string, _c41: string, _c42: string, _c43: string, _c44: string, _c45: string, _c46: string, _c47: string, _c48: string, _c49: string, _c50: string, _c51: string, _c52: string, _c53: string, _c54: string]

In [2]:
from pyspark.sql.functions import col
from pyspark.sql.types import DoubleType, IntegerType

colnames = ["Elevation", "Aspect", "Slope", "Horizontal_Distance_To_Hydrology", "Vertical_Distance_To_Hydrology",
            "Horizontal_Distance_To_Roadways", "Hillshade_9am", "Hillshade_noon", "Hillshade_3pm",
            "Horizontal_Distance_To_Fire_Points"] + \
           [f"Wilderness_Area_{i}" for i in range(4)] + [f"Soil_Type_{i}" for i in range(40)] + ["Cover_Type"]

data = data.toDF(*colnames)
data = data.withColumn("Cover_Type", col("Cover_Type").cast(DoubleType()))
for name in colnames[:-1]:
    data = data.withColumn(name, col(name).cast(IntegerType()))
data = data.na.drop()
data.printSchema()

root
 |-- Elevation: integer (nullable = true)
 |-- Aspect: integer (nullable = true)
 |-- Slope: integer (nullable = true)
 |-- Horizontal_Distance_To_Hydrology: integer (nullable = true)
 |-- Vertical_Distance_To_Hydrology: integer (nullable = true)
 |-- Horizontal_Distance_To_Roadways: integer (nullable = true)
 |-- Hillshade_9am: integer (nullable = true)
 |-- Hillshade_noon: integer (nullable = true)
 |-- Hillshade_3pm: integer (nullable = true)
 |-- Horizontal_Distance_To_Fire_Points: integer (nullable = true)
 |-- Wilderness_Area_0: integer (nullable = true)
 |-- Wilderness_Area_1: integer (nullable = true)
 |-- Wilderness_Area_2: integer (nullable = true)
 |-- Wilderness_Area_3: integer (nullable = true)
 |-- Soil_Type_0: integer (nullable = true)
 |-- Soil_Type_1: integer (nullable = true)
 |-- Soil_Type_2: integer (nullable = true)
 |-- Soil_Type_3: integer (nullable = true)
 |-- Soil_Type_4: integer (nullable = true)
 |-- Soil_Type_5: integer (nullable = true)
 |-- Soil_Type

In [3]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
train_data, test_data = data.randomSplit([0.9, 0.1])
train_data.cache()
test_data.cache()

input_cols = colnames[:-1]
vector_assembler = VectorAssembler(inputCols=input_cols, outputCol="featureVector")
assembled_train_data = vector_assembler.transform(train_data)

classifier = DecisionTreeClassifier(seed=1234, labelCol="Cover_Type", featuresCol="featureVector",
                                    predictionCol="prediction")
model = classifier.fit(assembled_train_data)
print(model.toDebugString)

24/09/14 15:43:50 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

DecisionTreeClassificationModel: uid=DecisionTreeClassifier_b791574fb87d, depth=5, numNodes=43, numClasses=8, numFeatures=54
  If (feature 0 <= 3047.5)
   If (feature 0 <= 2500.5)
    If (feature 3 <= 15.0)
     If (feature 12 <= 0.5)
      If (feature 0 <= 2344.5)
       Predict: 4.0
      Else (feature 0 > 2344.5)
       Predict: 2.0
     Else (feature 12 > 0.5)
      Predict: 6.0
    Else (feature 3 > 15.0)
     If (feature 16 <= 0.5)
      Predict: 3.0
     Else (feature 16 > 0.5)
      If (feature 9 <= 1318.5)
       Predict: 3.0
      Else (feature 9 > 1318.5)
       Predict: 4.0
   Else (feature 0 > 2500.5)
    If (feature 17 <= 0.5)
     If (feature 0 <= 2952.5)
      If (feature 15 <= 0.5)
       Predict: 2.0
      Else (feature 15 > 0.5)
       Predict: 3.0
     Else (feature 0 > 2952.5)
      Predict: 2.0
    Else (feature 17 > 0.5)
     If (feature 0 <= 2711.5)
      Predict: 3.0
     Else (feature 0 > 2711.5)
      If (feature 5 <= 1228.0)
       Predict: 5.0
      Else (f

In [6]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

predictions = model.transform(assembled_train_data)
predictions.select("Cover_Type", "prediction", "probability").show(10, truncate=False)

evaluator = MulticlassClassificationEvaluator(labelCol="Cover_Type", predictionCol="prediction")

accuracy = evaluator.setMetricName("accuracy").evaluate(predictions)
print("Accuracy:", accuracy)

precision = evaluator.setMetricName("weightedPrecision").evaluate(predictions)
print("Precision:", precision)

recall = evaluator.setMetricName("weightedRecall").evaluate(predictions)
print("Recall:", recall)

+----------+----------+-----------------------------------------------------------------------------------------------------------------------------------+
|Cover_Type|prediction|probability                                                                                                                        |
+----------+----------+-----------------------------------------------------------------------------------------------------------------------------------+
|6.0       |3.0       |[0.0,3.152187618207036E-5,0.0683709494389106,0.6082461228092296,0.020016391375614676,0.0017021813138317992,0.30163283318623124,0.0]|
|6.0       |4.0       |[0.0,0.0,0.003931847968545216,0.30930537352555704,0.5943643512450852,0.0,0.09239842726081259,0.0]                                  |
|6.0       |3.0       |[0.0,3.152187618207036E-5,0.0683709494389106,0.6082461228092296,0.020016391375614676,0.0017021813138317992,0.30163283318623124,0.0]|
|6.0       |3.0       |[0.0,3.152187618207036E-5,0.0683709494389

[Stage 22:>                                                       (0 + 18) / 18]                                                                                

Accuracy: 0.7026128156944992
Precision: 0.7010685345593746
Recall: 0.7026128156944991
