In [None]:
import jax.numpy as jnp
import jax
from jax import vmap

from ott.geometry import segment, pointcloud, geometry
from ott.problems.quadratic import gw_barycenter as gw_prob
from ott.solvers.quadratic import gw_barycenter as gw_solver
from ott.tools import k_means
from ott.problems.quadratic import quadratic_problem
from ott.solvers.quadratic import gromov_wasserstein

import numpy as np

In [None]:
# internal distance functions for kmeans
def dist_function(xx, yy, xww, yww, solver, pointpoint):
    """
    Compute the distances between one set of point clouds and another.
    
    Set `pointpoint = True`  if `yy` is just pointcloud data like `xx`.
    Set `pointpoint = False` if `yy` is something output by `compute_center()`.
    """
    geom_xx = pointcloud.PointCloud(xx)
    if pointpoint:
        geom_yy = pointcloud.PointCloud(yy)
    else:
        geom_yy = geometry.Geometry(yy)
    return solver(quadratic_problem.QuadraticProblem(geom_xx, geom_yy, a=xww, b=yww)).primal_cost
vectorized_dist_function = vmap(vmap(dist_function,
                                     (None, 0, None, 0, None, None), 0),
                                (0, None, 0, None, None, None), 0)

# internal center computation for kmeans
def compute_center(ii, xx, xww, barsize, solver):
    """
    Compute a cluster center, given a set of indices `ii` saying which pointcloud is
    associated with that center.
    """
    prob = gw_prob.GWBarycenterProblem(xx[ii], xww[ii])
    out = solver(prob, bar_size=barsize)
    return out.cost

def compute_centers(assignments, xx, xww, barsize, solver):
    """
    Loop over clusters, recompute cluster centers for each cluster.
    
    BUGS:
    - This code is slow because the compute centers are done in order. The for loop
      should be turned into a jax map with jit, but we suck. Or alternatively, this
      could be just done in a multiprocessing pool instead of a for loop.
    """
    idxs = assignments[None, :] == np.unique(assignments)[:, None]
    return jnp.array([compute_center(ii, xx, xww, barsize, solver) for ii in idxs])

# def compute_centers(xx, xww, assignments, barsize, solver):
#     idxs = jnp.array([(assignments == t) for t in np.unique(assignments)])
#     return jnp.array([compute_center(xx[ii], xww[ii], barsize, solver) for ii in idxs])

# def compute_centers(points, assignments, barsize, solver):
#     """
#     BUGS: This code needs to use an asynchronous map, not a for loop.
#     """
#     # foo = jax.jit(partial(compute_one_center, points, assignments, barsize, solver))
#     def foo(t):
#         return compute_one_center(t, points, assignments, barsize, solver)
#     return jnp.array([foo(t) for t in np.unique(assignments)])
#     # return jnp.array([compute_one_center(t, points, assignments, barsize, solver) for t in np.unique(assignments)])

def kmeans_GW(points, k, rng, maxiter=10):
    """
    Generalization of the k-means algorithm to point clouds, using the GW distance
    for the distances, and the GW barycenter as the mean (averaging) operation.
    
    BUGS:
    - This code is very brittle for many reasons.
    - Because the distances and the barycenters depend on `epsilon` and are approximate,
      the algorithm hits very bad conditions (like some clusters have no assignments!).
    - The `compute_centers()` function is not properly asynchronous.
    """        
    n=len(points)
    
    #set up the solvers
    epsilon = 10.
    d_solver = gromov_wasserstein.GromovWasserstein(epsilon=epsilon)
    gwb_solver = jax.jit(gw_solver.GromovWassersteinBarycenter(epsilon=epsilon), static_argnames=["bar_size"])
    
    #unpacking to jax
    concat_pointcloud, p_weights = segment.segment_point_cloud(jnp.concatenate([point.x for point in points]), num_per_segment=[point.shape[0] for point in points])
    # print("concat_pointcloud:", concat_pointcloud.shape)
    
    #initalize centers as random points
    # idx = jax.random.choice(rng, jnp.arange(n), shape=(k,), replace=False)
    idx = jnp.arange(k)
    print("initial clusters chosen as centers:", idx)
    centers, c_weights = concat_pointcloud[idx, ...], p_weights[idx, ...]
    
    #compute distances points to centers with jax
    distances = vectorized_dist_function(concat_pointcloud, centers, p_weights, c_weights, d_solver, True)  
    print(distances)
    
    #assign points to closest cluster
    assignments = jnp.argmin(distances, axis=1)
    print("before start:", assignments, np.min(distances))
    
    #this is a total hack
    barsize = 10
    c_weights = jnp.zeros((k, barsize)) + 1. / barsize
    
    for iteration in range(maxiter):
        new_centers = compute_centers(assignments, concat_pointcloud, p_weights, barsize, gwb_solver)
        print(concat_pointcloud.shape, new_centers.shape, p_weights.shape, c_weights.shape)
        distances = vectorized_dist_function(concat_pointcloud, new_centers, p_weights, c_weights, d_solver, False)
        new_assignments = jnp.argmin(distances, axis=1)
        print("kmeans iteration:", iteration, new_assignments, np.min(distances))
        if np.all(assignments == new_assignments):
            return new_centers, assignments, distances
        else:
            assignments = new_assignments
            centers = new_centers
    return centers, assignments, distances

In [None]:
#generate random pointclouds
clouds = []

number_of_clouds = 13
points_per_cloud = 11
dimensions = 3

for i in range(number_of_clouds):
    scale = 2.4
    if i>5:
        scale = 3.2
    x = scale * jax.random.normal(jax.random.PRNGKey(i+1), (points_per_cloud+i, dimensions))
    clouds.append(x)

In [None]:
centers, assignments, distances = kmeans_GW([pointcloud.PointCloud(s) for s in clouds], 2, jax.random.PRNGKey(0))
print("final assignments:", assignments)

In [None]:
if False:
    """
    This block of code is very brittle; it can only work on a non-ragged `clouds` variable.
    """
    cloudpoints = jnp.array(clouds)
    N, M, D = cloudpoints.shape
    weights = jnp.zeros((N, M)) + 1. / M
    solver = gromov_wasserstein.GromovWasserstein(epsilon=10.)
    distances = vectorized_dist_function(cloudpoints, cloudpoints, weights, weights, solver, True)
    print(distances)
    ii, jj = np.where(distances < 0.)
    for i, j in zip(ii, jj):
        print(i, j, distances[i, j])

In [None]:
if False:
    """
    This piece of code was designed to perform an experiment with zeroing out weights; it sucked.
    """
    points = [pointcloud.PointCloud(s) for s in clouds]
    cloudpoints, weights = segment.segment_point_cloud(jnp.concatenate([point.x for point in points]),
                                                       num_per_segment=[point.shape[0] for point in points])
    epsilon = 1.
    gwb_solver = jax.jit(gw_solver.GromovWassersteinBarycenter(epsilon=epsilon), static_argnames=["bar_size"])
    barsize = 10
    c1 = compute_center(jnp.array([0, 1, 2]), cloudpoints, weights, barsize, gwb_solver)
    
    aux = np.zeros((len(points),))
    aux[:3] = 1
    weights2 = jnp.array(aux[:,None] * weights)
    weights2
    c2 = compute_center(jnp.arange(len(weights2)), cloudpoints, weights2, barsize, gwb_solver)
    
    print(c1)
    print(c2)