<a href="https://colab.research.google.com/github/parag2489/Algorithms/blob/master/kmeans.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np

In [10]:
class KMeans:
  """KMeans class."""

  def __init__(self, k: int, tolerance: float = 1e-3, max_iter: int = 1000):
    self.k = k
    self.tolerance = tolerance
    self.max_iter = max_iter

  def init_centroids(self, x):
    indices = list(range(x.shape[0]))
    np.random.shuffle(indices)
    indices = indices[:self.k]
    return x[indices, :]

  def compute_memberships(self, x, centroids):
    x = x.reshape((-1, 1, x.shape[-1]))
    centroids = centroids.reshape((1, -1, centroids.shape[-1]))
    distance = np.sum((x - centroids) ** 2, axis=-1)
    memberships = np.argmin(distance, axis=-1)
    return memberships

  def compute_centroids(self, x, memberships):
    centroids = np.zeros((self.k, x.shape[-1]))
    for i in range(self.k):
      if any(memberships == i):
        centroids[i, :] = np.mean(x[memberships == i], axis=0)
    return centroids

  def compute_rss(self, x, memberships, centroids):
    rss = 0
    for i in range(self.k):
      if any(memberships == i):
        rss += np.sum(abs(x[memberships == i] - centroids[i]))

    return rss

  def __call__(self, x):
    error = float("inf")
    rss = 1e9
    iteration = 0
    centroids = self.init_centroids(x)
    while error > self.tolerance or iteration >= self.max_iter:
      prev_rss = rss
      memberships = self.compute_memberships(x, centroids)
      new_centroids = self.compute_centroids(x, memberships)
      rss = self.compute_rss(x, memberships, new_centroids)
      centroids = new_centroids
      error = abs(rss - prev_rss)
      iteration += 1
    return memberships, centroids

In [13]:
kmeans = KMeans(4)
data = np.vstack((np.random.randint(0, 10, (5, 4)), np.random.randint(16, 18, (5, 4)), np.random.randint(20, 25, (5, 4)), np.random.randint(32, 36, (5, 4)))).astype(np.float32)
memberships, centroids = kmeans(data)
print(data)
print(memberships)
print(centroids)

[[ 6.  0.  7.  4.]
 [ 0.  7.  9.  7.]
 [ 6.  1.  4.  5.]
 [ 4.  4.  0.  2.]
 [ 1.  9.  2.  8.]
 [17. 16. 17. 16.]
 [16. 17. 17. 17.]
 [16. 17. 17. 16.]
 [17. 16. 17. 17.]
 [17. 17. 17. 16.]
 [21. 24. 23. 24.]
 [21. 20. 20. 24.]
 [23. 23. 21. 21.]
 [22. 23. 23. 21.]
 [23. 24. 24. 20.]
 [32. 33. 34. 35.]
 [33. 32. 35. 34.]
 [35. 33. 32. 35.]
 [35. 34. 35. 33.]
 [33. 35. 34. 35.]]
[0 0 0 0 0 3 3 3 3 3 2 2 2 2 2 1 1 1 1 1]
[[ 3.4000001   4.19999981  4.4000001   5.19999981]
 [33.59999847 33.40000153 34.         34.40000153]
 [22.         22.79999924 22.20000076 22.        ]
 [16.60000038 16.60000038 17.         16.39999962]]
