In [None]:
import pandas as pd
import seaborn as sns
from pyspark.sql import SparkSession

iris = sns.load_dataset('iris')
df = iris.drop(['sepal_length', 'sepal_width', 'species'], axis=1)

# SparkSessionインスタンスを生成
spark = SparkSession.builder.appName("KMeans-Iris").getOrCreate()
# Spark DataFrame の作成
sdf = spark.createDataFrame(pd.DataFrame(df))

In [None]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.clustering import KMeans

# feature vector への変換
vectorAssembler = VectorAssembler(inputCols=sdf.columns, outputCol='features')
# KMeans
clustering = KMeans().setK(3).setSeed(1)
# Pipeline
pipeline = Pipeline(stages=[vectorAssembler, clustering])

In [None]:
model = pipeline.fit(sdf)
predictions = model.transform(sdf)
predictions.show(10)

In [None]:
df = predictions.select('*').toPandas()
df['species'] = iris.species
df

In [None]:
sns.scatterplot(x=df.petal_length, y=df.petal_width, hue=df.prediction)

In [None]:
sns.scatterplot(x=df.petal_length, y=df.petal_width, hue=df.species)