In [18]:
## Use this section to suppress warnings generated by the code:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn
warnings.filterwarnings('ignore')

In [19]:
## Step 1. Import the necessary libraries
import findspark
findspark.init()

from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import VectorAssembler

from pyspark.sql import SparkSession

In [20]:
## Step 2. Create a SparkSession
spark = SparkSession.builder \
        .appName("Clustering with Spark ML") \
        .getOrCreate()

In [21]:
## Step 3. Read the data from a CSV file
sdf = spark.read.csv("sources/seeds.csv", header=True, inferSchema=True)

# print the schema and the number of records
sdf.printSchema()
print(sdf.count())

root
 |-- area: double (nullable = true)
 |-- perimeter: double (nullable = true)
 |-- compactness: double (nullable = true)
 |-- length of kernel: double (nullable = true)
 |-- width of kernel: double (nullable = true)
 |-- asymmetry coefficient: double (nullable = true)
 |-- length of kernel groove: double (nullable = true)

213


In [22]:
## Step 4. Data Preview - Show the top 5 rows
sdf.show(n=5, truncate=False, vertical=True)

-RECORD 0-------------------------
 area                    | 15.26  
 perimeter               | 14.84  
 compactness             | 0.871  
 length of kernel        | 5.763  
 width of kernel         | 3.312  
 asymmetry coefficient   | 2.221  
 length of kernel groove | 5.22   
-RECORD 1-------------------------
 area                    | 14.88  
 perimeter               | 14.57  
 compactness             | 0.8811 
 length of kernel        | 5.554  
 width of kernel         | 3.333  
 asymmetry coefficient   | 1.018  
 length of kernel groove | 4.956  
-RECORD 2-------------------------
 area                    | 14.29  
 perimeter               | 14.09  
 compactness             | 0.905  
 length of kernel        | 5.291  
 width of kernel         | 3.337  
 asymmetry coefficient   | 2.699  
 length of kernel groove | 4.825  
-RECORD 3-------------------------
 area                    | 13.84  
 perimeter               | 13.94  
 compactness             | 0.8955 
 length of kernel   

In [23]:
## Step 5. Use Spark SQL to check if there are any null values
# Get the list of all columns
columns = sdf.columns

# Create a temporary view
sdf.createOrReplaceTempView("seeds")

# Find and show rows with any null values
rows_with_nulls = spark.sql("SELECT * FROM seeds WHERE " + " OR ".join([f"`{col}` IS NULL" for col in columns]))
rows_with_nulls.show()

# Count the number of rows with any null values
count_rows_with_nulls = spark.sql("SELECT COUNT(*) FROM seeds WHERE " + " OR ".join([f"`{col}` IS NULL" for col in columns]))
count_rows_with_nulls.show()

+----+---------+-----------+----------------+---------------+---------------------+-----------------------+
|area|perimeter|compactness|length of kernel|width of kernel|asymmetry coefficient|length of kernel groove|
+----+---------+-----------+----------------+---------------+---------------------+-----------------------+
|null|     null|     0.8099|            null|          2.641|                 null|                  5.185|
|12.2|     null|     0.8874|            null|           null|                 null|                    5.0|
|12.3|    13.34|       null|            null|           null|                5.637|                   null|
+----+---------+-----------+----------------+---------------+---------------------+-----------------------+

+--------+
|count(1)|
+--------+
|       3|
+--------+



In [24]:
## Step 6. Data Cleaning
# Previous step showed that there are several rows with many null values, so we need to remove them
sdf = sdf.dropna()

In [25]:
## Step 7. Verify Data Cleaning
# Check the number of records after removing the rows with null values (should be 210)
print(sdf.count())

210


In [26]:
## Step 8. Assemble all columns into a single vector
feature_cols = columns

assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
sdf_transformed = assembler.transform(sdf)

In [27]:
## Step 9. Create 4 Clusters
number_of_clusters = 4

In [28]:
## Step 10. Create a K-Means clustering model
kmeans = KMeans(k = number_of_clusters)

In [29]:
## Step 11. Train the model on the transformed data
model = kmeans.fit(sdf_transformed)

In [30]:
## Step 12. Make predictions on the dataset
predictions = model.transform(sdf_transformed)

In [31]:
## Step 13. Use Spark SQL to display the prediction results
# Create a temporary view
predictions.createOrReplaceTempView("seeds_predictions")

# Check two rows from each cluster
spark.sql("SELECT * FROM seeds_predictions WHERE prediction = 0 LIMIT 2").show()
spark.sql("SELECT * FROM seeds_predictions WHERE prediction = 1 LIMIT 2").show()
spark.sql("SELECT * FROM seeds_predictions WHERE prediction = 2 LIMIT 2").show()
spark.sql("SELECT * FROM seeds_predictions WHERE prediction = 3 LIMIT 2").show()

+-----+---------+-----------+----------------+---------------+---------------------+-----------------------+--------------------+----------+
| area|perimeter|compactness|length of kernel|width of kernel|asymmetry coefficient|length of kernel groove|            features|prediction|
+-----+---------+-----------+----------------+---------------+---------------------+-----------------------+--------------------+----------+
|15.26|    14.84|      0.871|           5.763|          3.312|                2.221|                   5.22|[15.26,14.84,0.87...|         0|
|14.88|    14.57|     0.8811|           5.554|          3.333|                1.018|                  4.956|[14.88,14.57,0.88...|         0|
+-----+---------+-----------+----------------+---------------+---------------------+-----------------------+--------------------+----------+

+-----+---------+-----------+----------------+---------------+---------------------+-----------------------+--------------------+----------+
| area|perim

In [32]:
## Step 14. Display how many records are in each cluster
predictions.groupBy('prediction').count().show()

+----------+-----+
|prediction|count|
+----------+-----+
|         1|   56|
|         3|   39|
|         2|   55|
|         0|   60|
+----------+-----+



In [33]:
## Step 15. Stop the SparkSession
spark.stop()