In [1]:
from functools import partial
import jax
import jax.numpy as jnp
from jax.ops import index, index_update
import numpy as onp
import matplotlib.pyplot as plt

In [2]:
@jax.jit
def initialize(X):
    # TODO: extend to high dimensional settings
    x_min, y_min = X.min(axis=0)
    x_max, y_max = X.max(axis=0)

    xx = jnp.linspace(x_min, x_max, 10)
    yy = jnp.linspace(y_min, y_max, 10)

    prototypes = jnp.stack(jnp.meshgrid(xx, yy))
    num_dims = prototypes.shape[0]

    prototypes = prototypes[jnp.newaxis, ...]
    X = jnp.expand_dims(X, axis=tuple(range(2, num_dims + 2)))
    
    return prototypes, X

In [3]:
@partial(jax.jit, static_argnums=(2, 3,))
def mountain_function(prototypes, x, norm_ord, sigma):
    points_norm = jnp.linalg.norm(x, ord=2, axis=1)
    grid_norm = np.linalg.norm(prototypes, ord=2, axis=1)
    
    cosine = (x * prototypes).sum(axis=1) / (points_norm * grid_norm)
    
    return jnp.exp(cosine / (2 * sigma ** 2)).sum(axis=0)

In [18]:
prototypes, x, norm_ord, sigma = cluster, X, norm_ord, sigma

norm = partial(jnp.linalg.norm, ord=norm_ord)
print('proto', prototypes.shape)
print('x', x.shape)
print('dif', (prototypes - x).shape)
print('norm', norm((prototypes - x), axis=1).shape)
print('res', jnp.exp(-norm((prototypes - x), axis=1) / (2 * sigma ** 2)).sum(axis=0).shape)

proto (1, 2, 1, 1)
x (5, 2, 1, 1)
dif (5, 2, 1, 1)
norm (5, 1, 1)
res (1, 1)


In [4]:
@jax.jit
def get_cluster(prototypes, prototypes_density):
    num_dims = prototypes.shape[1]
    cluster_id = jnp.unravel_index(
        jnp.argmax(prototypes_density),
        prototypes_density.shape
    )
    cluster = prototypes[(0, tuple(range(num_dims)), *cluster_id)]
    cluster = jnp.expand_dims(cluster, axis=tuple(range(1, num_dims + 1)))[jnp.newaxis, ...]
    return cluster, prototypes_density[cluster_id]

In [5]:
@partial(jax.jit, static_argnums=(0, 1))
def stop(thresh, initial_state, state):
    return (state[-1] / initial_state[-1]) > thresh

In [19]:
@partial(jax.jit, static_argnums=(0,))
def mountain_update(params, val):
    prototypes, clusters, X, idx, prototypes_density, cluster_density = val
    norm_ord, sigma, beta = params
    num_dims = prototypes.shape[1]
    
    cluster = jnp.expand_dims(
        clusters[idx],
        tuple(range(1, num_dims + 1))
    )[jnp.newaxis, ...]

    idx += 1
    
    cluster_mass = jnp.exp(
        -jnp.linalg.norm(
            prototypes[0] - cluster[0],
            ord=norm_ord, axis=0
        ) / (2 * beta ** 2)
    )

    near_cluster_density = mountain_function(cluster, X, norm_ord, sigma) * cluster_mass
    
    new_prototypes_density = prototypes_density - near_cluster_density
    new_cluster, cluster_density = get_cluster(prototypes, new_prototypes_density)
    clusters = index_update(clusters, index[idx], jnp.squeeze(new_cluster))
    
    val = (prototypes, clusters, X, idx, new_prototypes_density, cluster_density)

    return val

In [20]:
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
def mountain_run(X, norm_ord, sigma, beta, thresh):
    prototypes, X = initialize(X)
    prototypes_density = mountain_function(prototypes, X, norm_ord, sigma)
    cluster, cluster_density = get_cluster(prototypes, prototypes_density)

    idx = 0
    params = (norm_ord, sigma, beta)
    
    clusters = jnp.squeeze(jnp.zeros_like(X)) + jnp.nan
     
    clusters = index_update(
        clusters, 
        index[idx],
        jnp.squeeze(cluster)
    )
    
    val = (prototypes, clusters, X, idx, prototypes_density, cluster_density)
    
    return mountain_update(params, val)
    
    #initial_state = (val, params, clusters, idx, cluster_density)
    #state = initial_state
    
    #return val, params
    
    state = jax.lax.while_loop(
        partial(stop, thresh, val),
        partial(mountain_update, params),
        val
    )

    return state

In [21]:
X1 = onp.random.normal([1, 10], 0.05, size=(2, 2))
y1 = onp.zeros(2)
X2 = onp.random.normal([3, 5], 0.01, size=(3, 2))
y2 = onp.ones(3)

X = onp.vstack([X1, X2])
y = onp.concatenate([y1, y2])

idxs = onp.arange(5)
onp.random.shuffle(idxs)
X = X[idxs]
y = y[idxs]

X = jnp.array(X)
y = jnp.array(y)

In [22]:
prototypes, cluster, beta = mountain_run(X, 1, 0.1, 0.1, 0.1)

In [36]:
prototype_norm = jnp.linalg.norm(prototypes[0], ord=2, axis=0)
cluster_norm = jnp.linalg.norm(cluster[0], ord=2, axis=0)

cluster_mass = jnp.exp(
    -jnp.sum(
        prototypes[0] * cluster[0], axis=0
    ) / (cluster_norm * prototype_norm) / (2 * beta ** 2)
)

cluster_mass.shape

(10, 10)

In [11]:
jax.lax.while_loop(
    partial(stop, 0.3, state),
    partial(mountain_update, params),
    state
)

Traced<ShapedArray(float32[]):JaxprTrace(level=0/1)>


(DeviceArray([[[[ 0.9700386,  1.1959361,  1.4218336,  1.6477311,
                  1.8736286,  2.0995262,  2.3254237,  2.551321 ,
                  2.7772188,  3.0031161],
                [ 0.9700386,  1.1959361,  1.4218336,  1.6477311,
                  1.8736286,  2.0995262,  2.3254237,  2.551321 ,
                  2.7772188,  3.0031161],
                [ 0.9700386,  1.1959361,  1.4218336,  1.6477311,
                  1.8736286,  2.0995262,  2.3254237,  2.551321 ,
                  2.7772188,  3.0031161],
                [ 0.9700386,  1.1959361,  1.4218336,  1.6477311,
                  1.8736286,  2.0995262,  2.3254237,  2.551321 ,
                  2.7772188,  3.0031161],
                [ 0.9700386,  1.1959361,  1.4218336,  1.6477311,
                  1.8736286,  2.0995262,  2.3254237,  2.551321 ,
                  2.7772188,  3.0031161],
                [ 0.9700386,  1.1959361,  1.4218336,  1.6477311,
                  1.8736286,  2.0995262,  2.3254237,  2.551321 ,
          