In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.stat import Correlation
import pyspark.sql.functions as F

In [2]:
spark = SparkSession.builder.getOrCreate()

In [3]:
df = spark.read.csv("iris_dt.csv", inferSchema=True,
                   header=True)

In [4]:
df.show()

+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|species|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2|      0|
|         4.9|        3.0|         1.4|        0.2|      0|
|         4.7|        3.2|         1.3|        0.2|      0|
|         4.6|        3.1|         1.5|        0.2|      0|
|         5.0|        3.6|         1.4|        0.2|      0|
|         5.4|        3.9|         1.7|        0.4|      0|
|         4.6|        3.4|         1.4|        0.3|      0|
|         5.0|        3.4|         1.5|        0.2|      0|
|         4.4|        2.9|         1.4|        0.2|      0|
|         4.9|        3.1|         1.5|        0.1|      0|
|         5.4|        3.7|         1.5|        0.2|      0|
|         4.8|        3.4|         1.6|        0.2|      0|
|         4.8|        3.0|         1.4|        0.1|      0|
|         4.3|        3.0|         1.1| 

In [5]:
df.select("species").show()

+-------+
|species|
+-------+
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
|      0|
+-------+
only showing top 20 rows



In [6]:
df.count()

150

In [7]:
len(df.columns)

5

In [8]:
df.printSchema()

root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- species: integer (nullable = true)



In [9]:
df.describe().show()

+-------+------------------+-------------------+------------------+------------------+------------------+
|summary|      sepal_length|        sepal_width|      petal_length|       petal_width|           species|
+-------+------------------+-------------------+------------------+------------------+------------------+
|  count|               150|                150|               150|               150|               150|
|   mean| 5.843333333333335|  3.057333333333334|3.7580000000000027| 1.199333333333334|               1.0|
| stddev|0.8280661279778637|0.43586628493669793|1.7652982332594662|0.7622376689603467|0.8192319205190406|
|    min|               4.3|                2.0|               1.0|               0.1|                 0|
|    max|               7.9|                4.4|               6.9|               2.5|                 2|
+-------+------------------+-------------------+------------------+------------------+------------------+



In [10]:
df.head(5)

[Row(sepal_length=5.1, sepal_width=3.5, petal_length=1.4, petal_width=0.2, species=0),
 Row(sepal_length=4.9, sepal_width=3.0, petal_length=1.4, petal_width=0.2, species=0),
 Row(sepal_length=4.7, sepal_width=3.2, petal_length=1.3, petal_width=0.2, species=0),
 Row(sepal_length=4.6, sepal_width=3.1, petal_length=1.5, petal_width=0.2, species=0),
 Row(sepal_length=5.0, sepal_width=3.6, petal_length=1.4, petal_width=0.2, species=0)]

In [11]:
df.groupBy('species').count().show()

+-------+-----+
|species|count|
+-------+-----+
|      1|   50|
|      2|   50|
|      0|   50|
+-------+-----+



In [12]:
df.groupBy('sepal_length').count().show()

+------------+-----+
|sepal_length|count|
+------------+-----+
|         5.4|    6|
|         7.0|    1|
|         6.1|    6|
|         7.7|    4|
|         6.6|    2|
|         4.5|    1|
|         5.7|    8|
|         6.7|    8|
|         7.4|    1|
|         6.5|    5|
|         4.9|    6|
|         6.2|    4|
|         5.1|    9|
|         7.3|    1|
|         4.3|    1|
|         7.9|    1|
|         4.7|    2|
|         5.3|    1|
|         7.2|    3|
|         7.6|    1|
+------------+-----+
only showing top 20 rows



In [13]:
from pyspark.ml.feature import VectorAssembler

In [14]:
df.columns

['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']

In [15]:
assembler= VectorAssembler(inputCols=['sepal_length', 'sepal_width', 'petal_length', 'petal_width'],
                           outputCol='features')

In [16]:
assembler

VectorAssembler_2cdee12c0324

In [17]:
output= assembler.transform(df)

In [18]:
output

DataFrame[sepal_length: double, sepal_width: double, petal_length: double, petal_width: double, species: int, features: vector]

In [19]:
output.select('features','species').show(5)

+-----------------+-------+
|         features|species|
+-----------------+-------+
|[5.1,3.5,1.4,0.2]|      0|
|[4.9,3.0,1.4,0.2]|      0|
|[4.7,3.2,1.3,0.2]|      0|
|[4.6,3.1,1.5,0.2]|      0|
|[5.0,3.6,1.4,0.2]|      0|
+-----------------+-------+
only showing top 5 rows



In [20]:
df.columns

['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']

In [21]:
model_df=output.select(['features','species'])

In [22]:
training_df,test_df=model_df.randomSplit([0.70,0.30])

In [23]:
print(training_df.count())

101


In [24]:
print(test_df.count())

49


In [25]:
from pyspark.ml.classification import DecisionTreeClassifier

In [26]:
df_classifier=DecisionTreeClassifier(labelCol='species').fit(training_df)

In [27]:
# training results
df_predictions=df_classifier.transform(test_df)

In [28]:
df_predictions.show()

+-----------------+-------+--------------+-------------+----------+
|         features|species| rawPrediction|  probability|prediction|
+-----------------+-------+--------------+-------------+----------+
|[4.3,3.0,1.1,0.1]|      0|[36.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|[4.4,3.2,1.3,0.2]|      0|[36.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|[4.7,3.2,1.3,0.2]|      0|[36.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|[4.8,3.0,1.4,0.1]|      0|[36.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|[4.8,3.1,1.6,0.2]|      0|[36.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|[4.8,3.4,1.6,0.2]|      0|[36.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|[4.9,2.4,3.3,1.0]|      1|[0.0,31.0,0.0]|[0.0,1.0,0.0]|       1.0|
|[4.9,2.5,4.5,1.7]|      2|[0.0,31.0,0.0]|[0.0,1.0,0.0]|       1.0|
|[4.9,3.1,1.5,0.2]|      0|[36.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|[5.0,2.0,3.5,1.0]|      1|[0.0,31.0,0.0]|[0.0,1.0,0.0]|       1.0|
|[5.0,3.0,1.6,0.2]|      0|[36.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|[5.1,3.5,1.4,0.2]|      0|[36.0,0.0,0.0]|[1.0,0

In [29]:
df_predictions.groupBy('prediction').count().show()

+----------+-----+
|prediction|count|
+----------+-----+
|       0.0|   14|
|       1.0|   17|
|       2.0|   18|
+----------+-----+



In [30]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [31]:
df_accuracy=MulticlassClassificationEvaluator(labelCol='species',
                                              metricName='accuracy').evaluate(df_predictions)

In [32]:
print(df_accuracy)

0.9591836734693877


In [33]:
#precision
df_precision=MulticlassClassificationEvaluator(labelCol='species',metricName='weightedPrecision').evaluate(df_predictions)

In [34]:
print(df_precision)

0.9591836734693877


In [40]:
df_auc=MulticlassClassificationEvaluator(labelCol='species').evaluate(df_predictions)

In [41]:
print(df_auc)

0.9591836734693877


In [None]:
# feature importance
df_classifier.featureImportances