In [568]:
import torch
import numpy as np


In [569]:
class PytorchKmeans():
    '''
    Kmeans clustering for Pytorch 
    
    parameters
    k - number of clusters
    tol - error tolerance
    iteration - num iteration 
    
    attributes 
    cluster_centers_ - cluster centers of the fitted dataset
    labels_ - labels with respect to the cluster_centers
    
    '''
    def __init__(self, k, tol=1e-5, iteration=300):
        self.k = k
        self.tol = tol
        self.iteration = iteration
        self.cluster_centers_ = torch.Tensor([])
        self.labels_ = torch.Tensor([])
    def fit(self, X):
        m, n = X.size()
        X_mean = torch.mean(X.float(), dim=0)
        X_std = torch.std(X.float(), dim=0)
        centroids = X_mean + X_std*torch.randn(size=(self.k, n))
        old_centroids = centroids.detach().clone()
        for iteration in range(self.iteration):
            clusters = torch.Tensor([])
            for i in centroids:
                
                clusters = torch.cat([clusters, pairwise_distance(X,i).view(-1,1)], dim=1)
            
            cluster = torch.argmin(clusters, dim=1)
            for j in range(self.k):
                choice_selected = X[cluster==j]

                centroids[j,:] = torch.mean(choice_selected,dim=0)
            
            error = torch.norm(old_centroids- centroids)
            if error < self.tol:
                break            
            old_centroids = centroids.detach().clone()
        self.cluster_centers_ = centroids
        self.labels_ = cluster
            
        return centroids

    def pairwise_distance(self, data1, data2):
        return torch.norm(data1-data2, 2,dim = 1)
    
    def predict(self, x):
        clusters = torch.Tensor([])
        for i in self.cluster_centers_:
            clusters = torch.cat([clusters, pairwise_distance(x,i).view(-1,1)], dim=1)
        return torch.argmin(clusters, dim=1)


        
            
        