<a href="https://colab.research.google.com/github/ronglu-stanford/RL_reference_public/blob/main/Copy_of_1_k_means.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import plotly.graph_objects as go
import numpy as np

In [None]:
from sklearn.datasets import make_blobs
X, y_true = make_blobs(n_samples=300, centers=4,
                       cluster_std=0.60, random_state=0)

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=X[:, 0],y=X[:, 1],mode="markers"))
fig.update_layout(autosize=False, width=700, height=500)
fig.show()

In [None]:
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=4)
kmeans.fit(X)
y_kmeans = kmeans.predict(X)

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=X[:, 0], y=X[:, 1], mode="markers", marker=dict(size=8,color=y_kmeans)))
centers = kmeans.cluster_centers_
fig.add_trace(go.Scatter(x=centers[:, 0], y=centers[:, 1], mode="markers", marker=dict(size=16,color="red")))
fig.update_layout(autosize=False, width=700, height=500)
fig.show()

In [None]:
from sklearn.metrics import pairwise_distances_argmin
from plotly.subplots import make_subplots

# 1. Randomly choose clusters
rng = np.random.RandomState(2)
n_clusters = 4
i = rng.permutation(X.shape[0])[:n_clusters]
centers = X[i]

round = 0

while True:
    # 2a. Assign labels based on closest center
    labels = pairwise_distances_argmin(X, centers)

    fig = make_subplots(rows=1, cols=2,\
                      subplot_titles=(f"Expectation step; iteration {round}",\
                                      f"Maximization step; iteration {round}"))
    fig.update_layout(autosize=False, width=1200, height=500)    
    
    fig.add_trace(go.Scatter(x=X[:, 0], y=X[:, 1], mode="markers",\
                             marker=dict(size=8,color=labels)), row=1, col=1)
    fig.add_trace(go.Scatter(x=centers[:, 0], y=centers[:, 1], mode="markers",\
                             marker=dict(size=16,color="red")), row=1, col=1)
    
    # 2b. Find new centers from means of points
    old_centers = centers
    centers = np.array([X[labels == i].mean(0)
                            for i in range(n_clusters)])
    
    # 2c. Check for convergence
    if np.all(centers == old_centers):
      break    

    fig.add_trace(go.Scatter(x=X[:, 0], y=X[:, 1], mode="markers",\
                             marker=dict(size=8,color=labels)), row=1, col=2)
    fig.add_trace(go.Scatter(x=centers[:, 0], y=centers[:, 1], mode="markers",\
                             marker=dict(size=16,color="red")), row=1, col=2)

    fig.show()

    round += 1