In [1]:
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from sklearn.datasets import load_iris

spark = (SparkSession.builder
         .appName("Decision_tree_classification").getOrCreate())

# load iris data
def load_iris_data():
    iris = load_iris()
    X = iris.data
    y = iris.target
    
    feature_names = iris.feature_names
    target_names = iris.target_names
    
    pandas_df = pd.DataFrame(X, columns=feature_names)
    pandas_df["label"] = y
    
    return pandas_df

# convert to spark
def convert_to_spark(df):
    return spark.createDataFrame(df)

# feature vector
def feature_vector(df):
    assembler = VectorAssembler(
        inputCols=df.columns[:-1], outputCol="features")
    return assembler.transform(df)

# split data
def split_data(data, train_ratio=0.8,
               seed=87):
    train, test = data.randomSplit(
        [train_ratio, 1 - train_ratio], seed)
    return train, test

# model training
def build_tree(train_data):
    dt = RandomForestClassifier(
        featuresCol="features", labelCol="label")
    return dt.fit(train_data)

# show feature importance
def feature_importance(model, df):
    feat_imp = model.featureImportances.toArray()
    feature_names = df.columns[:-1]
    tuple_pairs = [(feat_name, imp) for feat_name, imp
                   in zip(feature_names, feat_imp)]
    sort_tuple_pairs = sorted(tuple_pairs,
                              key=lambda x: x[1],
                              reverse=True)
    for name, imp in sort_tuple_pairs:
        print(f"{name}: {imp}")
    

# model evaluation
def evaluate_model(model, test_data):
    predictions = model.transform(test_data)
    evaluator = MulticlassClassificationEvaluator(
        predictionCol="prediction", labelCol="label", metricName="accuracy")
    
    accuracy = evaluator.evaluate(predictions)
    print(f"Accuracy: {accuracy}")


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/11/14 11:31:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
df = load_iris_data()
spark_df = convert_to_spark(df)

transformed_data = feature_vector(spark_df)

train_data, test_data = split_data(transformed_data)

model = build_tree(train_data)

feature_importance(model, spark_df)

evaluate_model(model, test_data)

                                                                                

petal length (cm): 0.5187038160417065
petal width (cm): 0.3961069369929094
sepal length (cm): 0.06520078756189998
sepal width (cm): 0.01998845940348416
Accuracy: 1.0
