## Initializing Environment and Loading Data

In [1]:
import findspark
findspark.init('/home/shahayush954/spark-3.4.1-bin-hadoop3')

In [2]:
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder.appName('k_means').getOrCreate()

23/08/20 14:45:59 WARN Utils: Your hostname, ubuntu-22 resolves to a loopback address: 127.0.1.1; using 10.0.2.15 instead (on interface enp0s3)
23/08/20 14:45:59 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/08/20 14:46:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
data = spark.read.csv('seeds_dataset.csv', header=True, inferSchema=True)

                                                                                

In [5]:
data.printSchema()

root
 |-- area: double (nullable = true)
 |-- perimeter: double (nullable = true)
 |-- compactness: double (nullable = true)
 |-- length_of_kernel: double (nullable = true)
 |-- width_of_kernel: double (nullable = true)
 |-- asymmetry_coefficient: double (nullable = true)
 |-- length_of_groove: double (nullable = true)



## Data Preparation

In [6]:
from pyspark.ml.feature import VectorAssembler

In [7]:
data.columns

['area',
 'perimeter',
 'compactness',
 'length_of_kernel',
 'width_of_kernel',
 'asymmetry_coefficient',
 'length_of_groove']

In [8]:
assembler = VectorAssembler(inputCols=data.columns, outputCol='features')

In [9]:
final_data = assembler.transform(data)

In [10]:
final_data.printSchema()

root
 |-- area: double (nullable = true)
 |-- perimeter: double (nullable = true)
 |-- compactness: double (nullable = true)
 |-- length_of_kernel: double (nullable = true)
 |-- width_of_kernel: double (nullable = true)
 |-- asymmetry_coefficient: double (nullable = true)
 |-- length_of_groove: double (nullable = true)
 |-- features: vector (nullable = true)



## Data Scaling

In [11]:
from pyspark.ml.feature import StandardScaler

In [13]:
scaler = StandardScaler(inputCol='features', outputCol='scaledFeatures')

In [14]:
scaler_model = scaler.fit(final_data)

                                                                                

In [15]:
final_data = scaler_model.transform(final_data)

In [16]:
final_data.head(1)

                                                                                

[Row(area=15.26, perimeter=14.84, compactness=0.871, length_of_kernel=5.763, width_of_kernel=3.312, asymmetry_coefficient=2.221, length_of_groove=5.22, features=DenseVector([15.26, 14.84, 0.871, 5.763, 3.312, 2.221, 5.22]), scaledFeatures=DenseVector([5.2445, 11.3633, 36.8608, 13.0072, 8.7685, 1.4772, 10.621]))]

## KMeans Clustering

In [17]:
from pyspark.ml.clustering import KMeans

In [18]:
kmeans = KMeans(featuresCol='scaledFeatures', k=3)

In [19]:
model = kmeans.fit(final_data)

                                                                                

In [20]:
model.clusterCenters()

[array([ 6.32636687, 12.38115343, 37.39222755, 13.9206997 ,  9.75485787,
         2.41428142, 12.28078861]),
 array([ 4.90993613, 10.92295738, 37.28032496, 12.38401355,  8.5873381 ,
         1.7739463 , 10.35147469]),
 array([ 4.06818854, 10.13938448, 35.87110297, 11.81191124,  7.52564313,
         3.24585755, 10.40780927])]

In [23]:
model.transform(final_data).select(['scaledFeatures', 'prediction']).show()

+--------------------+----------+
|      scaledFeatures|prediction|
+--------------------+----------+
|[5.24452795332028...|         1|
|[5.11393027165175...|         1|
|[4.91116018695588...|         1|
|[4.75650503761158...|         1|
|[5.54696468981581...|         1|
|[4.94209121682475...|         1|
|[5.04863143081749...|         1|
|[4.84929812721816...|         1|
|[5.71536696354628...|         0|
|[5.65006812271202...|         1|
|[5.24452795332028...|         1|
|[4.82180387844584...|         1|
|[4.77368894309428...|         1|
|[4.73588435103234...|         1|
|[4.72213722664617...|         1|
|[5.01426361985209...|         1|
|[4.80805675405968...|         1|
|[5.39230954047151...|         1|
|[5.05206821191403...|         1|
|[4.37158555479908...|         2|
+--------------------+----------+
only showing top 20 rows

