In [2]:
from functools import partial
import jax
import jax.numpy as jnp
from jax.ops import index, index_update
import numpy as onp
import matplotlib.pyplot as plt

In [3]:
@partial(jax.jit, static_argnums=(2, 3,))
def mountain_function(x, d, norm_ord, r_a):
    norm = partial(jnp.linalg.norm, ord=norm_ord)
    return jnp.exp(-norm((d - x), axis=1) / ((r_a / 2) ** 2)).sum(axis=-1)

In [4]:
@jax.jit
def get_cluster(x, density):
    cluster_idx = jnp.argmax(density)
    cluster = x[cluster_idx]
    return cluster, density[cluster_idx]

In [5]:
@partial(jax.jit, static_argnums=(0,))
def subtractive_update(params, val):
    idx, x, density, clusters, cluster_density = val
    norm_ord, r_a = params
    r_b = 1.5 * r_a
    idx += 1
    
    cluster = clusters[idx - 1]
    
    cluster_mass = jnp.exp(
        -jnp.linalg.norm(
            jnp.squeeze(x) - cluster,
            ord=norm_ord, axis=1
        ) / ((r_b / 2) ** 2)
    )
    
    near_cluster_density = mountain_function(
        x, cluster[jnp.newaxis, ...], norm_ord, r_a
    ) * cluster_mass
 
    new_density = density - near_cluster_density
    new_cluster, cluster_density = get_cluster(x, new_density)
    
    clusters = index_update(clusters, index[idx], jnp.squeeze(new_cluster))
    
    val = (idx, x, density, clusters, cluster_density)

    return val

In [6]:
@partial(jax.jit, static_argnums=(0, 1))
def stop(thresh, initial_state, state):
    return (state[-1] / initial_state[-1]) > thresh

In [7]:
@partial(jax.jit, static_argnums=(1, 2, 3))
def subtractive_run(x, norm_ord, r_a, thresh):
    
    d = x.T[jnp.newaxis, ...]
    x = x[..., jnp.newaxis]
    
    density = mountain_function(x, d, norm_ord, r_a)
    cluster, cluster_density = get_cluster(x, density)

    idx = 0
    params = (norm_ord, r_a)
    
    clusters = jnp.squeeze(jnp.zeros_like(x)) + jnp.nan
     
    clusters = index_update(
        clusters,
        index[idx],
        jnp.squeeze(cluster)
    )
    
    val = (idx, x, density, clusters, cluster_density)

    state = jax.lax.while_loop(
        partial(stop, thresh, val),
        partial(subtractive_update, params),
        val
    )

    return state

In [8]:
X1 = onp.random.normal([1, 10], 0.05, size=(2, 2))
y1 = onp.zeros(2)
X2 = onp.random.normal([3, 5], 0.01, size=(3, 2))
y2 = onp.ones(3)

X = onp.vstack([X1, X2])
y = onp.concatenate([y1, y2])

idxs = onp.arange(5)
onp.random.shuffle(idxs)
X = X[idxs]
y = y[idxs]

X = jnp.array(X)
y = jnp.array(y)



In [None]:
subtractive_run(X, 2, 0.3, 0.01)

In [None]:
x