# **1. Download and Import necessary libraries** #

In [1]:
# Download pyspark
!pip install pyspark



In [2]:
# Import SparkSession
from pyspark.sql import SparkSession

# Import libraries for Clustering
from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import VectorAssembler

In [3]:
# Create Spark Session
spark = SparkSession.builder.appName("KMeanswithSpark").getOrCreate()
spark.version

'3.5.0'

# **2. Read and Explore dataset** #

In [4]:
# Connect Google Drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
# Read dataset
df_customer = spark.read.csv("/content/drive/MyDrive/Machine_Learning_wtih_Spark/dataset/customers.csv",
                             header = True, inferSchema = True)

In [6]:
# Print schema of dataframe
df_customer.printSchema()

root
 |-- Fresh_Food: integer (nullable = true)
 |-- Milk: integer (nullable = true)
 |-- Grocery: integer (nullable = true)
 |-- Frozen_Food: integer (nullable = true)



In [7]:
# Show first 5 rows
df_customer.head(5)

[Row(Fresh_Food=12669, Milk=9656, Grocery=7561, Frozen_Food=214),
 Row(Fresh_Food=7057, Milk=9810, Grocery=9568, Frozen_Food=1762),
 Row(Fresh_Food=6353, Milk=8808, Grocery=7684, Frozen_Food=2405),
 Row(Fresh_Food=13265, Milk=1196, Grocery=4221, Frozen_Food=6404),
 Row(Fresh_Food=22615, Milk=5410, Grocery=7198, Frozen_Food=3915)]

# **3. Create features vector** #

In [8]:
# Create features column
feature_cols = ["Fresh_Food", "Milk", "Grocery", "Frozen_Food"]
assembler = VectorAssembler(inputCols = feature_cols, outputCol = "features")
df_customer_transformed = assembler.transform(df_customer)

In [9]:
df_customer_transformed.show()

+----------+-----+-------+-----------+--------------------+
|Fresh_Food| Milk|Grocery|Frozen_Food|            features|
+----------+-----+-------+-----------+--------------------+
|     12669| 9656|   7561|        214|[12669.0,9656.0,7...|
|      7057| 9810|   9568|       1762|[7057.0,9810.0,95...|
|      6353| 8808|   7684|       2405|[6353.0,8808.0,76...|
|     13265| 1196|   4221|       6404|[13265.0,1196.0,4...|
|     22615| 5410|   7198|       3915|[22615.0,5410.0,7...|
|      9413| 8259|   5126|        666|[9413.0,8259.0,51...|
|     12126| 3199|   6975|        480|[12126.0,3199.0,6...|
|      7579| 4956|   9426|       1669|[7579.0,4956.0,94...|
|      5963| 3648|   6192|        425|[5963.0,3648.0,61...|
|      6006|11093|  18881|       1159|[6006.0,11093.0,1...|
|      3366| 5403|  12974|       4400|[3366.0,5403.0,12...|
|     13146| 1124|   4523|       1420|[13146.0,1124.0,4...|
|     31714|12319|  11757|        287|[31714.0,12319.0,...|
|     21217| 6208|  14982|       3095|[2

In [10]:
# Define number of Cluster
number_of_cluster = 3

# **4. Create and Train KMeans model** #

In [11]:
# Crete KMeans model
kmeans = KMeans(k = number_of_cluster)

# Train model
model = kmeans.fit(df_customer_transformed)

# **5. Print Cluster detail** #

In [12]:
# Make predictions on the dataset
predictions = model.transform(df_customer_transformed)

In [13]:
# Display the results
predictions.show(5)

+----------+----+-------+-----------+--------------------+----------+
|Fresh_Food|Milk|Grocery|Frozen_Food|            features|prediction|
+----------+----+-------+-----------+--------------------+----------+
|     12669|9656|   7561|        214|[12669.0,9656.0,7...|         1|
|      7057|9810|   9568|       1762|[7057.0,9810.0,95...|         1|
|      6353|8808|   7684|       2405|[6353.0,8808.0,76...|         1|
|     13265|1196|   4221|       6404|[13265.0,1196.0,4...|         1|
|     22615|5410|   7198|       3915|[22615.0,5410.0,7...|         0|
+----------+----+-------+-----------+--------------------+----------+
only showing top 5 rows



In [14]:
predictions.groupBy('prediction').count().show()

+----------+-----+
|prediction|count|
+----------+-----+
|         1|  331|
|         2|   49|
|         0|   60|
+----------+-----+



# **6. Stop Spark Session** #

In [15]:
spark.stop()