In [None]:
# -*- coding: utf-8 -*-
# Indentation: Jupyter Notebook

'''
Clustering using SparkML
'''

__version__ = 1.0
__author__ = "Sourav Raj"
__author_email__ = "souravraj.iitbbs@gmail.com"


In [1]:
import findspark
findspark.init()
from pyspark import SparkContext
from pyspark.sql.session import SparkSession
sc =SparkContext()
spark=SparkSession(sc)

In [26]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.clustering import KMeans

In [27]:
cluster_df = spark.read.csv('../data/clustering_dataset.csv', header=True, inferSchema =True)
# inferSchema =True to read numeric data

In [28]:
cluster_df.take(3)

[Row(col1=7, col2=4, col3=1),
 Row(col1=7, col2=7, col3=9),
 Row(col1=7, col2=9, col3=6)]

In [29]:
# using vector assembler we convert these columns to features
vec_assembler=VectorAssembler(inputCols=['col1', 'col2', 'col3'], outputCol='features')

In [30]:
vec_df = vec_assembler.transform(cluster_df)
vec_df.take(5)

[Row(col1=7, col2=4, col3=1, features=DenseVector([7.0, 4.0, 1.0])),
 Row(col1=7, col2=7, col3=9, features=DenseVector([7.0, 7.0, 9.0])),
 Row(col1=7, col2=9, col3=6, features=DenseVector([7.0, 9.0, 6.0])),
 Row(col1=1, col2=6, col3=5, features=DenseVector([1.0, 6.0, 5.0])),
 Row(col1=6, col2=7, col3=7, features=DenseVector([6.0, 7.0, 7.0]))]

In [31]:
kmeans=KMeans().setK(3) # set no of cluster k =3
kmeans=kmeans.setSeed(1) # setseed to set where kmean cluster will start

In [32]:
kmodel=kmeans.fit(vec_df)

In [33]:
centers = kmodel.clusterCenters()  # to get center of cluster

In [34]:
centers

[array([ 35.88461538,  31.46153846,  34.42307692]),
 array([ 5.12,  5.84,  4.84]),
 array([ 80.        ,  79.20833333,  78.29166667])]

# Hierarchical clustering

for large data we use this algo 

In [35]:
from pyspark.ml.clustering import BisectingKMeans
bkmeans=BisectingKMeans().setK(3)
bkmeans=bkmeans.setSeed(1)

In [36]:
bkmodel=bkmeans.fit(vec_df)

In [37]:
bkcenters=bkmodel.clusterCenters()
bkcenters

[array([ 5.12,  5.84,  4.84]),
 array([ 35.88461538,  31.46153846,  34.42307692]),
 array([ 80.        ,  79.20833333,  78.29166667])]