# Unsupervised Machine Learning with PySpark: Clustering

In this example we turn to **unsupervised machine learning**: Given a dataset without predefined labels, our task is to find structure in the data. A typical approach to this problem is **clustering** - find **clusters** of similar datapoints that group together

## Preamble

In [None]:
import findspark
findspark.init()
import pyspark

In [None]:
spark = pyspark.sql.SparkSession \
    .builder \
    .appName("Clustering Example") \
    .getOrCreate()


## The Iris data set

![Iris versicolor](https://upload.wikimedia.org/wikipedia/commons/d/db/Iris_versicolor_4.jpg)

> The Iris flower data set or "Fisher's Iris data set" is a multivariate data set introduced by the British statistician and biologist Ronald Fisher in his 1936 paper _The use of multiple measurements in taxonomic problems_ as an example of linear discriminant analysis.
> 
> The data set consists of 50 samples from each of three species of Iris (_Iris setosa_, _Iris virginica_ and _Iris versicolor_). Four features were measured from each sample: the length and the width of the sepals and petals, in centimetres. Based on the combination of these four features, Fisher developed a linear discriminant model to distinguish the species from each other.
> &mdash; ["Iris flower data set," Wikipedia](https://en.wikipedia.org/wiki/Iris_flower_data_set)

## Loading the Data

In [None]:
import pandas
import sklearn

In [None]:
data = spark.read \
    .format("csv") \
    .option("header", "true") \
    .schema("sepal_length DOUBLE, sepal_width DOUBLE, petal_length DOUBLE, petal_width DOUBLE, species STRING") \
    .load("../.assets/data/iris/iris.csv")


In [None]:
data.show()

In [None]:
data.schema

## Applying the Clustering Algorithm

In the following we apply the **[k-means clustering algorithm](https://en.m.wikipedia.org/wiki/K-means_clustering)**, a popular choice for getting started with clustering on a new dataset.

In [None]:
from pyspark.ml.clustering import KMeans

As the name says, _k-means clustering_ has one **parameter**, $k$, which is the number of clusters to be detected in the data. The algorithm will place $k$ cluster centers and assign each datapoint to the closest one. In our example, we know that we are looking for three different species of Iris, and in the hope that they can be found as clusters in the measurements, we set $k = 3$:

In [None]:
# Trains a k-means model.
kmeans = KMeans().setK(3)

The rest of the workflow is standard procedure for ML with Spark: Collect the feature columns into a single feature vector column. Fit the estimator - here, our k-means clustering algorithm. Transform the data with the resulting trained model - the result are cluster labels for each data point.

In [None]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline

In [None]:
feature_cols = ["sepal_length", "sepal_width", "petal_length", "petal_width"]

In [None]:
assemble_features = VectorAssembler(inputCols=feature_cols, 
                                    outputCol="features")

In [None]:
data = assemble_features.transform(data)

In [None]:
kmeans_model = kmeans.fit(data)

In [None]:
labelled = kmeans_model.transform(data)
labelled.show()

## Interpretation

Computing a clustering is the easy part. The bigger data analysis task is still ahead of us: How to evaluate and interpret the result? What does a cluster represent? How well does it match the partition of the data we were trying to find? There are many strategies we can apply, and we only briefly look at some of them.

### Cluster Centers

Like any **centroid clustering** algorithm, the k-Means algorithm defines a cluster as the set of data points close to a central data point, and iteratively tries to find good centers so that the overall distance is minimized. This enables one way of interpreting the result: After the algorithm terminates, the cluster centers can be treated as "typical" and "representative" for their respective cluster. 

In [None]:

centers = kmeans_model.clusterCenters()
print("Cluster Centers: ")
for center in centers:
    print(list(zip(feature_cols, center)))

In this case, we can say that a specimen of the Iris plant from cluster 1 typically has a sepal length of around 5.8 cm, and so on. Let us compare that with the **ground truth** given by the species label:

In [None]:
from pyspark.sql.functions import avg, col
for i in range(3):
    print("species ", i)
    data[data["species"] == i].agg(dict((feature, "avg") for feature in feature_cols)).show()

### Visualization

In order to inspect the clusters found, we can use data visualization techniques such as **scatter plots**. The Python ecosystem provides powerful plotting tools such as [`matplotlib`](https://matplotlib.org/) and [`seaborn`](https://seaborn.pydata.org/index.html). For most visualization purposes it makes sense to pull a small data sample computed with Spark, then leave the Spark territory and convert to a `pandas.DataFrame` to work with those tools.

In [None]:
import seaborn

In [None]:
labelled_pd = labelled.toPandas()

In [None]:
seaborn.pairplot(labelled_pd, vars=["sepal_length", "sepal_width", "petal_length", "petal_width"], hue="species")

---
_This notebook is licensed under a [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/). Copyright © 2018-2025 [Point 8 GmbH](https://point-8.de)_