# Fuzzy c-means
* Initialize U :: (c, n) s.t. U.sum(axis=0) == 1
* Calculate C :: (c, d)
* Compute the cost
* New U

In [1]:
from functools import partial

import jax.numpy as jnp
import jax
import numpy as onp
from sklearn import datasets

In [9]:
@partial(jax.jit, static_argnums=(2,))
def get_cluster(points, membership, m):
    return jnp.sum(
        membership[..., jnp.newaxis] ** m * points, axis=1
    ) / jnp.sum(membership, axis=1)[0]

In [10]:
@partial(jax.jit, static_argnums=(0,))
def improve_membership(params, val):
    membership, points, _, old_cost = val    
    norm_ord, m = params
    
    cluster = get_cluster(points, membership, m)

    c, num_dims = cluster.shape

    dist = jnp.linalg.norm(
        (cluster.reshape(c, 1, num_dims) - x),
        ord=norm_ord, axis=2
    )

    new_cost = jnp.sum(membership ** m * dist ** 2)

    tmp_dist = dist ** (2 / (m - 1))
    tmp_dist /= tmp_dist.sum(axis=0)
    new_membership = 1 / tmp_dist
    new_membership /= new_membership.sum(axis=0)

    return new_membership, points, old_cost, new_cost

In [14]:
@partial(jax.jit, static_argnums=(2, 4, 3))
def cmeans(key, points, c, norm_ord=2, tol=1e-5):
    num_samples, num_dims = points.shape
    
    membership = jax.random.uniform(key, (c, num_samples))
    membership /= membership.sum(axis=0)
    
    initial_val = (membership, points, jnp.inf, 0.)
    params = (c, norm_ord)
    
    state = jax.lax.while_loop(
        lambda val: abs(val[-1] - val[-2]) > tol,
        partial(improve_membership, params),
        initial_val
    )
    
    membership, points, other, cost = state
    cluster = get_cluster(points, membership, m)
    
    return cluster, membership, points, cost

In [None]:
X1 = onp.random.normal([1, 10], 0.05, size=(2, 2))
y1 = onp.zeros(2)
X2 = onp.random.normal([3, 5], 0.01, size=(3, 2))
y2 = onp.ones(3)

X = onp.vstack([X1, X2])
y = onp.concatenate([y1, y2])

idxs = onp.arange(5)
onp.random.shuffle(idxs)
X = X[idxs]
y = y[idxs]

x = jnp.array(X)
y = jnp.array(y)

In [13]:
c = 2
m = 4

key = jax.random.PRNGKey(onp.random.randint(0, 1000))

cmeans(key, x, c, 2, 1e-5)

(DeviceArray([[1.7760717e+00, 4.9881673e+00],
              [1.8960817e-04, 1.5658026e-03]], dtype=float32),
 DeviceArray([[0.8619992 , 0.8659331 , 0.9348949 , 0.9366297 , 0.9362057 ],
              [0.13800086, 0.13406691, 0.06510515, 0.06337035, 0.06379431]],            dtype=float32),
 DeviceArray([[ 1.0000064, 10.024741 ],
              [ 1.0691454,  9.945861 ],
              [ 3.0004535,  4.983425 ],
              [ 2.9855025,  4.99777  ],
              [ 3.0103474,  5.004288 ]], dtype=float32),
 DeviceArray(33.22462, dtype=float32),
 DeviceArray(33.224625, dtype=float32))