In [None]:
import numpy as np

In [None]:
class KMeans:
    def __init__(self, k=2, centeroids_initializer='random', max_iter=100):
        self.k = k
        if (centeroids_initializer == 'random'):
            self.initialize_centroids = self.random_centroids
        elif (centeroids_initializer == 'k++'):
            self.initialize_centroids = self.kpp_centroids
        self.max_iter = max_iter
        self.centroids = None

    def fit(self, X):
        self.initialize_centroids(X)
        for _ in range(self.max_iter):
            # asign data to culsters
            clusters = self._create_clusters(X)
            # early stop
            prev_centroids = self.centroids.copy()
            # update centroids
            self.centroids = self._update_centroids(X, clusters)

            if not (self.centroids - prev_centroids).any():
                break
        
    def predict(self, X):
        return self._create_clusters(X)

    def kpp_centroids(self, X):
        n_samples = np.shape(X)[0]
        self.centroids = [X[np.random.choice(n_samples, 1)]]
        for _ in range(1, self.k):
            D = np.ones((n_samples, 1)) * np.inf
            for prev_cluster_index, prev_clusters in enumerate(self.centroids):
                distance = np.linalg.norm(X - prev_clusters, axis=1)
                distance = np.atleast_2d(distance).T
                D[distance < D] = distance[distance < D]

            D = np.array(D)
            D = D ** 2
            p = D / np.sum(D) 
            self.centroids.append(X[np.random.choice(n_samples, 1, p=p.T[0])])
        self.centroids = np.array(self.centroids).reshape(-1, 2)
    
    def _create_clusters(self, X):
        clusters = [[] for _ in range(self.k)]
        for index, x in enumerate(X):
            distance = np.sqrt(np.sum((self.centroids - x) ** 2, axis=1))   
            closest = np.argmin(distance) 
            clusters[closest].append(index)
        return clusters

    def _update_centroids(self, X, clusters):
        n_features = np.shape(X)[1]
        centroids = np.zeros((self.k, n_features))
        for index, cluster in enumerate(clusters):
            centroids[index] = np.mean(X[cluster], axis=0)
        return centroids

    def random_centroids(self, X):
        n_samples = np.shape(X)[0]
        self.centroids = X[np.random.choice(n_samples, self.k, replace=False)]