In [65]:
import cupy as cp
from cuml import KMeans
from cuml.preprocessing import SimpleImputer
from scipy.sparse import issparse


def cp_kmeans(X, k, round_values=True, dense=False):
    group_names = [str(i) for i in range(X.shape[1])]
    if 'cudf' in str(type(X)):
        group_names = X.columns
        X = X.values

    # in case there are any missing values in data impute them
    imp = SimpleImputer(missing_values=np.nan, strategy='mean')
    X = imp.fit_transform(X)

    kmeans = KMeans(n_clusters=k, random_state=0).fit(X)

    if round_values:
        for i in range(k):
            for j in range(X.shape[1]):
                xj = X[:, j].toarray().flatten() if issparse(
                    X) else X[:, j]  # sparse support courtesy of @PrimozGodec
                ind = np.argmin(np.abs(xj - kmeans.cluster_centers_[i, j]))
                kmeans.cluster_centers_[i, j] = X[ind, j]
    data = kmeans.cluster_centers_
    groups = [np.array([i]) for i in range(len(group_names))]

    length = sum(len(g) for g in groups)
    num_samples = data.shape[0]
    t = False
    if length != data.shape[1]:
        t = True
        num_samples = data.shape[1]

    valid = ( not t and length == data.shape[1]) or (t and length == data.shape[0])
    assert valid, "# of names must match data matrix!"

    weights = 1.0 *cp.bincount(kmeans.labels_)
    weights /= cp.sum(weights)
    wl = len(weights)
    valid = (not t and wl == data.shape[0]) or (t and wl == data.shape[1])
    assert valid, "# weights must match data matrix!"
    return data, group_names, weights
#     if dense:
#         # I can either keep densedata here or move this wrapper to interpret-community
#         return DenseData(
#             kmeans.cluster_centers_,
#             group_names,
#             None,
#             1.0 *
#             cp.bincount(
#                 kmeans.labels_))
#     else:
#         return kmeans.cluster_centers_, group_names, cp.bincount(kmeans.labels_)


class Data:
    def __init__(self):
        pass


class DenseData(Data):
    def __init__(self, data, group_names, *args):
#         import pdb; pdb.set_trace()
        self.groups = args[0] if len(args) > 0 and args[0] is not None else [
            np.array([i]) for i in range(len(group_names))]

        length = sum(len(g) for g in self.groups)
        num_samples = data.shape[0]
        t = False
        if length != data.shape[1]:
            t = True
            num_samples = data.shape[1]

        valid = (
            not t and length == data.shape[1]) or (
            t and length == data.shape[0])
        assert valid, "# of names must match data matrix!"

        self.weights = args[1] if len(args) > 1 else cp.ones(num_samples)
        self.weights /= cp.sum(self.weights)
        wl = len(self.weights)
        valid = (not t and wl == data.shape[0]) or (t and wl == data.shape[1])
        assert valid, "# weights must match data matrix!"

        self.transposed = t
        self.group_names = group_names
        self.data = data
        self.groups_size = len(self.groups)


In [66]:
import cudf 
X = cudf.DataFrame({'a':cp.arange(100),
                       'b':cp.arange(100)
                    })
X.head()

Unnamed: 0,a,b
0,0,0
1,1,1
2,2,2
3,3,3
4,4,4


In [67]:
tmp = cp_kmeans(X, 3)

In [68]:
tmp

(array([[51., 51.],
        [17., 17.],
        [83., 83.]]),
 Index(['a', 'b'], dtype='object'),
 array([0.33, 0.35, 0.32]))

In [48]:
lgt = cp_kmeans(X, 3, dense=False)
lgt

(array([[51., 51.],
        [17., 17.],
        [83., 83.]]),
 Index(['a', 'b'], dtype='object'),
 array([33, 35, 32]))

In [50]:
dns = cp_kmeans(X, 3, dense=True)
dns.data, dns.group_names, dns.weights

(array([[51., 51.],
        [17., 17.],
        [83., 83.]]),
 Index(['a', 'b'], dtype='object'),
 array([0.33, 0.35, 0.32]))