In [1]:
# Read data - sample online retail data
df = spark.read.csv("./online_retail.csv", header=True, inferSchema=True)

In [2]:
# Prepare dataset for clustering - string indexing
from pyspark.ml.feature import StringIndexer

indexer1 = StringIndexer(inputCol="StockCode", outputCol="StockCodeIndex", handleInvalid="skip")
indexed1 = indexer1.fit(df).transform(df)

indexer2 = StringIndexer(inputCol="Description", outputCol="DescriptionIndex", handleInvalid="skip")
indexed_dataset = indexer2.fit(indexed1).transform(indexed1)

dataset = indexed_dataset["StockCode", "StockCodeIndex", "Description", "DescriptionIndex"]

In [3]:
# Prepare dataset for clustering - vectorization
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

input_columns = ["DescriptionIndex"]
vec_assembler = VectorAssembler(inputCols=input_columns, outputCol="features")
final_dataset = vec_assembler.transform(dataset)

In [4]:
# Prepare dataset for clustering - standard scaling
from pyspark.ml.feature import StandardScaler

scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", withStd=True, withMean=True)
scalerModel = scaler.fit(final_dataset)

final_dataset = scalerModel.transform(final_dataset)

In [5]:
# Train a k-means model
from pyspark.ml.clustering import KMeans

number_of_clusters = 2000
kmeans = KMeans(featuresCol="scaledFeatures", k=number_of_clusters)
model = kmeans.fit(final_dataset)

In [6]:
# Evaluate clustering by computing 'Within Set Sum of Squared Errors'
wssse = model.computeCost(final_dataset)
print("Within Set Sum of Squared Errors = " + str(wssse))

Within Set Sum of Squared Errors = 0.25809840866562794


In [7]:
# Look at the frequency distribution according to clusters
clustered_dataset = model.transform(final_dataset)
reduced_clustered_dataset = clustered_dataset["StockCode", "Description", "prediction"]
statistics = reduced_clustered_dataset.groupBy("prediction").count()
statistics.orderBy("count", ascending=True).show()

+----------+-----+
|prediction|count|
+----------+-----+
|       207|    8|
|      1577|    9|
|       332|    9|
|      1598|    9|
|       928|    9|
|      1338|    9|
|      1758|   10|
|      1889|   10|
|      1352|   10|
|       660|   11|
|      1849|   11|
|      1728|   11|
|       518|   12|
|      1624|   12|
|        23|   12|
|      1560|   12|
|      1921|   12|
|      1216|   12|
|      1752|   12|
|      1460|   12|
+----------+-----+
only showing top 20 rows



In [10]:
# Inspect a cluster
clusters = [1921]
reduced_clustered_dataset.filter(reduced_clustered_dataset.prediction.isin(clusters)).show()

+---------+--------------------+----------+
|StockCode|         Description|prediction|
+---------+--------------------+----------+
|   84612B|SET/4 BLACK  BARO...|      1921|
|   84612B|SET/4 BLACK  BARO...|      1921|
|    21310|    CAPIZ CHANDELIER|      1921|
|    21310|    CAPIZ CHANDELIER|      1921|
|   84612B|SET/4 BLACK  BARO...|      1921|
|    90049|IVORY GOLD METAL ...|      1921|
|    90049|IVORY GOLD METAL ...|      1921|
|    90049|IVORY GOLD METAL ...|      1921|
|    21310|    CAPIZ CHANDELIER|      1921|
|    21310|    CAPIZ CHANDELIER|      1921|
|   84612B|SET/4 BLACK  BARO...|      1921|
|    90049|IVORY GOLD METAL ...|      1921|
+---------+--------------------+----------+

