## MeanShift with scikit-learn

* Please click [here](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html) for the reference
* Use the iris dataset to demonstrate how to cluster with meanshift
* Remember to test the hyper-parameters
* For more references between R and Python, click [here](https://towardsdatascience.com/are-you-bilingual-be-fluent-in-r-and-python-7cb1533ff99f)

In [8]:
from sklearn import datasets
from sklearn.cluster import MeanShift
import pandas as pd
import numpy as np

In [45]:
# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
X_pd = pd.DataFrame(X)
X_pd.columns = ["Sepal.Length","Sepal.Width","Petal.Length","Petal.Width"]
X_pd.head()

Unnamed: 0,Sepal.Length,Sepal.Width,Petal.Length,Petal.Width
0,5.1,3.5,1.4,0.2
1,4.9,3.0,1.4,0.2
2,4.7,3.2,1.3,0.2
3,4.6,3.1,1.5,0.2
4,5.0,3.6,1.4,0.2


In [138]:
ms = MeanShift(bandwidth=1,
               seeds=None, 
               bin_seeding=False, 
               min_bin_freq=1, 
               cluster_all=True, 
               n_jobs=None)

In [139]:
clustering = ms.fit(X_pd)

In [140]:
clustering

MeanShift(bandwidth=1, bin_seeding=False, cluster_all=True, min_bin_freq=1,
          n_jobs=None, seeds=None)

In [141]:
clustering.labels_

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 2, 0, 0, 0, 2, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [142]:
for_pred = X_pd
for_pred.head()

Unnamed: 0,Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,label
0,5.1,3.5,1.4,0.2,1
1,4.9,3.0,1.4,0.2,1
2,4.7,3.2,1.3,0.2,1
3,4.6,3.1,1.5,0.2,1
4,5.0,3.6,1.4,0.2,1


In [143]:
y = clustering.predict(for_pred)
y

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 2, 0, 0, 0, 2, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

## Show the summary statistics by cluster

In [144]:
for_pred['label'] = y

In [145]:
stat = for_pred.groupby('label').mean()

In [146]:
stat

Unnamed: 0_level_0,Sepal.Length,Sepal.Width,Petal.Length,Petal.Width
label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,6.359341,2.912088,5.043956,1.732967
1,5.006,3.428,1.462,0.246
2,5.277778,2.466667,3.511111,1.1
