diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index e038b49a1..4974302d9 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -42,7 +42,7 @@ The contributors to this library are: * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, semi-relaxed FGW, quantized FGW, partial FGW) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein - Barycenters, GMMOT) + Barycenters, GMMOT, Barycenters for General Transport Costs) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) diff --git a/README.md b/README.md index 8b4cca7f7..f0e256eb0 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,8 @@ POT provides the following generic OT solvers (links to examples): * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71]. * Fused unbalanced Gromov-Wasserstein [70]. +* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [76] +* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 76] POT provides the following Machine Learning related solvers: @@ -389,3 +391,5 @@ Artificial Intelligence. [74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR. [75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145. + +[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) diff --git a/RELEASES.md b/RELEASES.md index 62240fa77..060186176 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,6 +7,9 @@ - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) +- Implement fixed-point solver for OT barycenters with generic cost functions + (generalizes `ot.lp.free_support_barycenter`), with example. (PR #715) +- Implement fixed-point solver for barycenters between GMMs (PR #715), with example. - Fix warning raise when import the library (PR #716) - Implement projected gradient descent solvers for entropic partial FGW (PR #702) - Fix documentation in the module `ot.gaussian` (PR #718) diff --git a/examples/barycenters/plot_free_support_barycenter_generic_cost.py b/examples/barycenters/plot_free_support_barycenter_generic_cost.py new file mode 100644 index 000000000..47e2c9236 --- /dev/null +++ b/examples/barycenters/plot_free_support_barycenter_generic_cost.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +""" +===================================== +OT Barycenter with Generic Costs Demo +===================================== + +This example illustrates the computation of an Optimal Transport Barycenter for +a ground cost that is not a power of a norm. We take the example of ground costs +:math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear) +projection onto a circle k. This is an example of the fixed-point barycenter +solver introduced in [76] which generalises [20] and [43]. + +The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in +\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over +:math:`x` with Pytorch. + +[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing +Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 +(2024) + +[20] Cuturi, M. and Doucet, A. (2014) Fast Computation of Wasserstein +Barycenters. InternationalConference in Machine Learning + +[43] Álvarez-Esteban, Pedro C., et al. A fixed-point approach to barycenters in +Wasserstein space. Journal of Mathematical Analysis and Applications 441.2 +(2016): 744-762. + +""" + +# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +# %% +# Generate data +import torch +from torch.optim import Adam +from ot.utils import dist +import numpy as np +from ot.lp import free_support_barycenter_generic_costs +import matplotlib.pyplot as plt + + +torch.manual_seed(42) + +n = 200 # number of points of the of the barycentre +d = 2 # dimensions of the original measure +K = 4 # number of measures to barycentre +m = 50 # number of points of the measures +b_list = [torch.ones(m) / m] * K # weights of the 4 measures +weights = torch.ones(K) / K # weights for the barycentre +stop_threshold = 1e-20 # stop threshold for B and for fixed-point algo + + +# map R^2 -> R^2 projection onto circle +def proj_circle(X, origin, radius): + diffs = X - origin[None, :] + norms = torch.norm(diffs, dim=1) + return origin[None, :] + radius * diffs / norms[:, None] + + +# circles on which to project +origin1 = torch.tensor([-1.0, -1.0]) +origin2 = torch.tensor([-1.0, 2.0]) +origin3 = torch.tensor([2.0, 2.0]) +origin4 = torch.tensor([2.0, -1.0]) +r = np.sqrt(2) +P_list = [ + lambda X: proj_circle(X, origin1, r), + lambda X: proj_circle(X, origin2, r), + lambda X: proj_circle(X, origin3, r), + lambda X: proj_circle(X, origin4, r), +] + +# measures to barycentre are projections of different random circles +# onto the K circles +Y_list = [] +for k in range(K): + t = torch.rand(m) * 2 * np.pi + X_temp = 0.5 * torch.stack([torch.cos(t), torch.sin(t)], axis=1) + X_temp = X_temp + torch.tensor([0.5, 0.5])[None, :] + Y_list.append(P_list[k](X_temp)) + + +# %% +# Define costs and ground barycenter function +# cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a +# (n, n_k) matrix of costs +def c1(x, y): + return dist(P_list[0](x), y) + + +def c2(x, y): + return dist(P_list[1](x), y) + + +def c3(x, y): + return dist(P_list[2](x), y) + + +def c4(x, y): + return dist(P_list[3](x), y) + + +cost_list = [c1, c2, c3, c4] + + +# batched total ground cost function for candidate points x (n, d) +# for computation of the ground barycenter B with gradient descent +def C(x, y): + """ + Computes the barycenter cost for candidate points x (n, d) and + measure supports y: List(n, d_k). + """ + n = x.shape[0] + K = len(y) + out = torch.zeros(n) + for k in range(K): + out += (1 / K) * torch.sum((P_list[k](x) - y[k]) ** 2, axis=1) + return out + + +# ground barycenter function +def B(y, its=150, lr=1, stop_threshold=stop_threshold): + """ + Computes the ground barycenter for measure supports y: List(n, d_k). + Output: (n, d) array + """ + x = torch.randn(n, d) + x.requires_grad_(True) + opt = Adam([x], lr=lr) + for _ in range(its): + x_prev = x.data.clone() + opt.zero_grad() + loss = torch.sum(C(x, y)) + loss.backward() + opt.step() + diff = torch.sum((x.data - x_prev) ** 2) + if diff < stop_threshold: + break + return x + + +# %% +# Compute the barycenter measure +fixed_point_its = 3 +X_init = torch.rand(n, d) +X_bar = free_support_barycenter_generic_costs( + Y_list, + b_list, + X_init, + cost_list, + B, + numItermax=fixed_point_its, + stopThr=stop_threshold, +) + +# %% +# Plot Barycenter (Iteration 3) +alpha = 0.4 +s = 80 +labels = ["circle 1", "circle 2", "circle 3", "circle 4"] +for Y, label in zip(Y_list, labels): + plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s) +plt.scatter( + *(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha, s=s +) +plt.axis("equal") +plt.xlim(-0.3, 1.3) +plt.ylim(-0.3, 1.3) +plt.axis("off") +plt.legend() +plt.tight_layout() + +# %% diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py index 5b3572bd4..b21c66f13 100644 --- a/examples/barycenters/plot_generalized_free_support_barycenter.py +++ b/examples/barycenters/plot_generalized_free_support_barycenter.py @@ -14,7 +14,7 @@ """ -# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu> +# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr> # # License: MIT License diff --git a/examples/barycenters/plot_gmm_barycenter.py b/examples/barycenters/plot_gmm_barycenter.py new file mode 100644 index 000000000..f379a9914 --- /dev/null +++ b/examples/barycenters/plot_gmm_barycenter.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +""" +===================================== +Gaussian Mixture Model OT Barycenters +===================================== + +This example illustrates the computation of a barycenter between Gaussian +Mixtures in the sense of GMM-OT [69]. This computation is done using the +fixed-point method for OT barycenters with generic costs [76], for which POT +provides a general solver, and a specific GMM solver. Note that this is a +'free-support' method, implying that the number of components of the barycenter +GMM and their weights are fixed. + +The idea behind GMM-OT barycenters is to see the GMMs as discrete measures over +the space of Gaussian distributions :math:`\mathcal{N}` (or equivalently the +Bures-Wasserstein manifold), and to compute barycenters with respect to the +2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a +gaussian mixture is a finite combination of Diracs on specific gaussians, and +two mixtures are compared with the 2-Wasserstein distance on this space, where +ground cost the squared Bures distance between gaussians. + +[69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space +of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. + +[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing +Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 +(2024) + +""" + +# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr> +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +# %% +# Generate data +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.patches import Ellipse +import ot +from ot.gmm import gmm_barycenter_fixed_point + + +K = 3 # number of GMMs +d = 2 # dimension +n = 6 # number of components of the desired barycenter + + +def get_random_gmm(K, d, seed=0, min_cov_eig=1, cov_scale=1e-2): + rng = np.random.RandomState(seed=seed) + means = rng.randn(K, d) + P = rng.randn(K, d, d) * cov_scale + # C[k] = P[k] @ P[k]^T + min_cov_eig * I + covariances = np.einsum("kab,kcb->kac", P, P) + covariances += min_cov_eig * np.array([np.eye(d) for _ in range(K)]) + weights = rng.random(K) + weights /= np.sum(weights) + return means, covariances, weights + + +m_list = [5, 6, 7] # number of components in each GMM +offsets = [np.array([-3, 0]), np.array([2, 0]), np.array([0, 4])] +means_list = [] # list of means for each GMM +covs_list = [] # list of covariances for each GMM +w_list = [] # list of weights for each GMM + +# generate GMMs +for k in range(K): + means, covs, b = get_random_gmm( + m_list[k], d, seed=k, min_cov_eig=0.25, cov_scale=0.5 + ) + means = means / 2 + offsets[k][None, :] + means_list.append(means) + covs_list.append(covs) + w_list.append(b) + +# %% +# Compute the barycenter using the fixed-point method +init_means, init_covs, _ = get_random_gmm(n, d, seed=0) +weights = ot.unif(K) # barycenter coefficients +means_bar, covs_bar, log = gmm_barycenter_fixed_point( + means_list, + covs_list, + w_list, + init_means, + init_covs, + weights, + iterations=3, + log=True, +) + + +# %% +# Define plotting functions + + +# draw a covariance ellipse +def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5, ax=None): + def eigsorted(cov): + vals, vecs = np.linalg.eigh(cov) + order = vals.argsort()[::-1].copy() + return vals[order], vecs[:, order] + + vals, vecs = eigsorted(C) + theta = np.degrees(np.arctan2(*vecs[:, 0][::-1])) + w, h = 2 * nstd * np.sqrt(vals) + ell = Ellipse( + xy=(mu[0], mu[1]), + width=w, + height=h, + alpha=alpha, + angle=theta, + facecolor=color, + edgecolor=color, + label=label, + fill=True, + ) + if ax is None: + ax = plt.gca() + ax.add_artist(ell) + + +# draw a gmm as a set of ellipses with weights shown in alpha value +def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None): + for k in range(ms.shape[0]): + draw_cov( + ms[k], Cs[k], color, label if k == 0 else None, nstd, alpha * ws[k], ax=ax + ) + + +# %% +# Plot the results +fig, ax = plt.subplots(figsize=(6, 6)) +axis = [-4, 4, -2, 6] +ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16) +for k in range(K): + draw_gmm(means_list[k], covs_list[k], w_list[k], color="C0", ax=ax) +draw_gmm(means_bar, covs_bar, ot.unif(n), color="C1", ax=ax) +ax.axis(axis) +ax.axis("off") diff --git a/examples/others/plot_GMMOT_plan.py b/examples/others/plot_GMMOT_plan.py index 7742d496e..4964ddd66 100644 --- a/examples/others/plot_GMMOT_plan.py +++ b/examples/others/plot_GMMOT_plan.py @@ -16,7 +16,7 @@ """ -# Author: Eloi Tanguy <eloi.tanguy@u-paris> +# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr> # Remi Flamary <remi.flamary@polytehnique.edu> # Julie Delon <julie.delon@math.cnrs.fr> # diff --git a/examples/others/plot_GMM_flow.py b/examples/others/plot_GMM_flow.py index beb675755..dc26ff3ce 100644 --- a/examples/others/plot_GMM_flow.py +++ b/examples/others/plot_GMM_flow.py @@ -10,7 +10,7 @@ """ -# Author: Eloi Tanguy <eloi.tanguy@u-paris> +# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr> # Remi Flamary <remi.flamary@polytehnique.edu> # Julie Delon <julie.delon@math.cnrs.fr> # diff --git a/examples/others/plot_SSNB.py b/examples/others/plot_SSNB.py index fbc343a8a..e167b1ee4 100644 --- a/examples/others/plot_SSNB.py +++ b/examples/others/plot_SSNB.py @@ -38,7 +38,7 @@ 2017. """ -# Author: Eloi Tanguy <eloi.tanguy@u-paris.fr> +# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr> # License: MIT License # sphinx_gallery_thumbnail_number = 3 diff --git a/ot/gmm.py b/ot/gmm.py index 5c7a4c287..a065c73b0 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -3,8 +3,8 @@ Optimal transport for Gaussian Mixtures """ -# Author: Eloi Tanguy <eloi.tanguy@u-paris> -# Remi Flamary <remi.flamary@polytehnique.edu> +# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr> +# Remi Flamary <remi.flamary@polytechnique.edu> # Julie Delon <julie.delon@math.cnrs.fr> # # License: MIT License @@ -13,7 +13,7 @@ from .lp import emd2, emd import numpy as np from .utils import dist -from .gaussian import bures_wasserstein_mapping +from .gaussian import bures_wasserstein_mapping, bures_wasserstein_barycenter def gaussian_logpdf(x, m, C): @@ -440,3 +440,148 @@ def Tk0k1(k0, k1): ] ) return nx.sum(mat, axis=(0, 1)) + + +def gmm_barycenter_fixed_point( + means_list, + covs_list, + w_list, + means_init, + covs_init, + weights, + w_bar=None, + iterations=100, + log=False, + barycentric_proj_method="euclidean", +): + r""" + Solves the Gaussian Mixture Model OT barycenter problem (defined in [69]) + using the fixed point algorithm (proposed in [76]). The + weights of the barycenter are not optimized, and stay the same as the input + `w_list` or are initialized to uniform. + + The algorithm uses barycentric projections of GMM-OT plans, and these can be + computed either through Bures Barycenters (slow but accurate, + barycentric_proj_method='bures') or by convex combination (fast, + barycentric_proj_method='euclidean', default). + + This is a special case of the generic free-support barycenter solver + `ot.lp.free_support_barycenter_generic_costs`. + + Parameters + ---------- + means_list : list of array-like + List of K (m_k, d) GMM means. + covs_list : list of array-like + List of K (m_k, d, d) GMM covariances. + w_list : list of array-like + List of K (m_k) arrays of weights. + means_init : array-like + Initial (n, d) GMM means. + covs_init : array-like + Initial (n, d, d) GMM covariances. + weights : array-like + Array (K,) of the barycentre coefficients. + w_bar : array-like, optional + Initial weights (n) of the barycentre GMM. If None, initialized to uniform. + iterations : int, optional + Number of iterations (default is 100). + log : bool, optional + Whether to return the list of iterations (default is False). + barycentric_proj_method : str, optional + Method to project the barycentre weights: 'euclidean' (default) or 'bures'. + + Returns + ------- + means : array-like + (n, d) barycentre GMM means. + covs : array-like + (n, d, d) barycentre GMM covariances. + log_dict : dict, optional + Dictionary containing the list of iterations if log is True. + + References + ---------- + .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. + + .. [76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) + + See Also + -------- + ot.lp.free_support_barycenter_generic_costs : Compute barycenter of measures for generic transport costs. + """ + nx = get_backend( + means_init, covs_init, means_list[0], covs_list[0], w_list[0], weights + ) + K = len(means_list) + n = means_init.shape[0] + d = means_init.shape[1] + means_its = [nx.copy(means_init)] + covs_its = [nx.copy(covs_init)] + means, covs = means_init, covs_init + + if w_bar is None: + w_bar = nx.ones(n, type_as=means) / n + + for _ in range(iterations): + pi_list = [ + gmm_ot_plan(means, means_list[k], covs, covs_list[k], w_bar, w_list[k]) + for k in range(K) + ] + + # filled in the euclidean case + means_selection, covs_selection = None, None + + # in the euclidean case, the selection of Gaussians from each K sources + # comes from a barycentric projection: it is a convex combination of the + # selected means and covariances, which can be computed without a + # for loop on i = 0, ..., n -1 + if barycentric_proj_method == "euclidean": + means_selection = nx.zeros((n, K, d), type_as=means) + covs_selection = nx.zeros((n, K, d, d), type_as=means) + for k in range(K): + means_selection[:, k, :] = n * pi_list[k] @ means_list[k] + covs_selection[:, k, :, :] = ( + nx.einsum("ij,jab->iab", pi_list[k], covs_list[k]) * n + ) + + # each component i of the barycentre will be a Bures barycentre of the + # selected components of the K GMMs. In the 'bures' barycentric + # projection option, the selected components are also Bures barycentres. + for i in range(n): + # means_selection_i (K, d) is the selected means, each comes from a + # Gaussian barycentre along the disintegration of pi_k at i + # covs_selection_i (K, d, d) are the selected covariances + means_selection_i = None + covs_selection_i = None + + # use previous computation (convex combination) + if barycentric_proj_method == "euclidean": + means_selection_i = means_selection[i] + covs_selection_i = covs_selection[i] + + # compute Bures barycentre of certain components to get the + # selection at i + elif barycentric_proj_method == "bures": + means_selection_i = nx.zeros((K, d), type_as=means) + covs_selection_i = nx.zeros((K, d, d), type_as=means) + for k in range(K): + w = (1 / w_bar[i]) * pi_list[k][i, :] + m, C = bures_wasserstein_barycenter(means_list[k], covs_list[k], w) + means_selection_i[k] = m + covs_selection_i[k] = C + + else: + raise ValueError("Unknown barycentric_proj_method") + + means[i], covs[i] = bures_wasserstein_barycenter( + means_selection_i, covs_selection_i, weights + ) + + if log: + means_its.append(nx.copy(means)) + covs_its.append(nx.copy(covs)) + + if log: + return means, covs, {"means_its": means_its, "covs_its": covs_its} + return means, covs diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 932b261df..974679440 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -14,6 +14,7 @@ barycenter, free_support_barycenter, generalized_free_support_barycenter, + free_support_barycenter_generic_costs, ) from ..utils import check_number_threads @@ -45,4 +46,5 @@ "dmmot_monge_1dgrid_loss", "dmmot_monge_1dgrid_optimize", "check_number_threads", + "free_support_barycenter_generic_costs", ] diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 4779662e9..725af26c4 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -199,14 +199,12 @@ def free_support_barycenter( measures_weights : list of N (k_i,) array-like Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one representing the weights of each discrete input measure - X_init : (k,d) array-like Initialization of the support locations (on `k` atoms) of the barycenter b : (k,) array-like Initialization of the weights of the barycenter (non-negatives, sum to 1) weights : (N,) array-like Initialization of the coefficients of the barycenter (non-negatives, sum to 1) - numItermax : int, optional Max number of iterations stopThr : float, optional @@ -219,13 +217,11 @@ def free_support_barycenter( If compiled with OpenMP, chooses the number of threads to parallelize. "max" selects the highest number possible. - Returns ------- X : (k,d) array-like Support locations (on k atoms) of the barycenter - .. _references-free-support-barycenter: References ---------- @@ -426,3 +422,219 @@ def generalized_free_support_barycenter( return Y, log_dict else: return Y + + +def free_support_barycenter_generic_costs( + measure_locations, + measure_weights, + X_init, + cost_list, + ground_bary=None, + a=None, + numItermax=100, + stopThr=1e-5, + log=False, + ground_bary_lr=1e-2, + ground_bary_numItermax=100, + ground_bary_stopThr=1e-5, + ground_bary_solver="SGD", +): + r""" + Solves the OT barycenter problem for generic costs using the fixed point + algorithm, iterating the ground barycenter function B on transport plans + between the current barycenter and the measures. + + The problem finds an optimal barycenter support `X` of given size (n, d) + (enforced by the initialisation), minimising a sum of pairwise transport + costs for the costs :math:`c_k`: + + .. math:: + \min_{X} \sum_{k=1}^K \mathcal{T}_{c_k}(X, a, Y_k, b_k), + + where: + + - :math:`X` (n, d) is the barycenter support, + - :math:`a` (n) is the (fixed) barycenter weights, + - :math:`Y_k` (m_k, d_k) is the k-th measure support + (`measure_locations[k]`), + - :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`), + - :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} + \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function + (which computes the pairwise cost matrix) + - :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycenter measure and the k-th measure with respect to the cost :math:`c_k`: + + .. math:: + \mathcal{T}_{c_k}(X, a, Y_k, b_k) = \min_\pi \quad \langle \pi, c_k(X, Y_k) \rangle_F + + s.t. \ \pi \mathbf{1} = \mathbf{a} + + \pi^T \mathbf{1} = \mathbf{b_k} + + \pi \geq 0 + + in other words, :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is `ot.emd2(a, b_k, + c_k(X, Y_k))`. + + The algorithm requires a given ground barycenter function `B` which computes + (broadcasted of `n`) solutions of the following minimisation problem given + :math:`(Y_1, \cdots, Y_K) \in \mathbb{R}^{n\times + d_1}\times\cdots\times\mathbb{R}^{n\times d_K}`: + + .. math:: + B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k), + + where :math:`c_k(x, y_k) \in \mathbb{R}_+` is the cost between the points + :math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{n\times + d_1}\times \cdots\times\mathbb{R}^{n\times d_K} \longrightarrow + \mathbb{R}^{n\times d}` is an input to this function, and for certain costs + it can be computed explicitly of through a numerical solver. The input + function B takes a list of K arrays of shape (n, d_k) and returns an array + of shape (n, d). + + This function implements [76] Algorithm 2, which generalises [20] and [43] + to general costs and includes convergence guarantees, including for discrete + measures. + + Parameters + ---------- + measure_locations : list of array-like + List of K arrays of measure positions, each of shape (m_k, d_k). + measure_weights : list of array-like + List of K arrays of measure weights, each of shape (m_k). + X_init : array-like + Array of shape (n, d) representing initial barycenter points. + cost_list : list of callable or callable + List of K cost functions :math:`c_k: \mathbb{R}^{n\times + d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times + m_k}`. If cost_list is a single callable, the same cost is used K times. + ground_bary : callable or None, optional + Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays + of shape (n\times d_K), computing the ground barycenters (broadcasted + over n). If not provided, done with Adam on PyTorch (requires PyTorch + backend) + a : array-like, optional + Array of shape (n,) representing weights of the barycenter + measure.Defaults to uniform. + numItermax : int, optional + Maximum number of iterations (default is 100). + stopThr : float, optional + If the iterations move less than this, terminate (default is 1e-5). + log : bool, optional + Whether to return the log dictionary (default is False). + ground_bary_lr : float, optional + Learning rate for the ground barycenter solver (if auto is used). + ground_bary_numItermax : int, optional + Maximum number of iterations for the ground barycenter solver (if auto + is used). + ground_bary_stopThr : float, optional + Stop threshold for the ground barycenter solver (if auto is used). + ground_bary_solver : str, optional + Solver for auto ground bary solver (torch SGD or Adam). Default is + "SGD". + + Returns + ------- + X : array-like + Array of shape (n, d) representing barycenter points. + log_dict : list of array-like, optional + log containing the exit status, list of iterations and list of + displacements if log is True. + + References + ---------- + .. [76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing + barycenters of Measures for Generic Transport Costs. arXiv preprint + 2501.04016 (2024) + + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein + barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to + barycenters in Wasserstein space." Journal of Mathematical Analysis and + Applications 441.2 (2016): 744-762. + + See Also + -------- + ot.lp.free_support_barycenter : Free support solver for the case where + :math:`c_k(x,y) = \lambda_k\|x-y\|_2^2`. + ot.lp.generalized_free_support_barycenter : Free support solver for the case + where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear. + """ + nx = get_backend(X_init, measure_locations[0]) + K = len(measure_locations) + n = X_init.shape[0] + if a is None: + a = nx.ones(n, type_as=X_init) / n + if callable(cost_list): # use the given cost for all K pairs + cost_list = [cost_list] * K + auto_ground_bary = False + + if ground_bary is None: + auto_ground_bary = True + assert str(nx) == "torch", ( + f"Backend {str(nx)} is not compatible with ground_bary=None, it" + "must be provided if not using PyTorch backend" + ) + try: + import torch + from torch.optim import Adam, SGD + + def ground_bary(y, x_init): + x = x_init.clone().detach().requires_grad_(True) + solver = Adam if ground_bary_solver == "Adam" else SGD + opt = solver([x], lr=ground_bary_lr) + for _ in range(ground_bary_numItermax): + x_prev = x.data.clone() + opt.zero_grad() + # inefficient cost computation but compatible + # with the choice of cost_list[k] giving the cost matrix + loss = torch.sum( + torch.stack( + [torch.diag(cost_list[k](x, y[k])) for k in range(K)] + ) + ) + loss.backward() + opt.step() + diff = torch.sum((x.data - x_prev) ** 2) + if diff < ground_bary_stopThr: + break + return x.detach() + + except ImportError: + raise ImportError("PyTorch is required to use ground_bary=None") + + X_list = [X_init] if log else [] # store the iterations + X = X_init + dX_list = [] # store the displacement squared norms + exit_status = "Max iterations reached" + + for _ in range(numItermax): + pi_list = [ # compute the pairwise transport plans + emd(a, measure_weights[k], cost_list[k](X, measure_locations[k])) + for k in range(K) + ] + Y_perm = [] + for k in range(K): # compute barycentric projections + Y_perm.append(n * pi_list[k] @ measure_locations[k]) + if auto_ground_bary: # use previous position as initialization + X_next = ground_bary(Y_perm, X) + else: + X_next = ground_bary(Y_perm) + + if log: + X_list.append(X_next) + + # stationary criterion: move less than the threshold + dX = nx.sum((X - X_next) ** 2) + X = X_next + + if log: + dX_list.append(dX) + + if dX < stopThr: + exit_status = "Stationary Point" + break + + if log: + return X, {"X_list": X_list, "exit_status": exit_status, "dX_list": dX_list} + return X diff --git a/ot/mapping.py b/ot/mapping.py index ea1917772..cc3e6cd57 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -7,7 +7,7 @@ use it you need to explicitly import :mod:`ot.mapping` """ -# Author: Eloi Tanguy <eloi.tanguy@u-paris.fr> +# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr> # Remi Flamary <remi.flamary@unice.fr> # # License: MIT License diff --git a/test/test_gmm.py b/test/test_gmm.py index 5f1a92965..629a68d57 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -1,6 +1,6 @@ """Tests for module gaussian""" -# Author: Eloi Tanguy <eloi.tanguy@u-paris> +# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr> # Remi Flamary <remi.flamary@polytehnique.edu> # Julie Delon <julie.delon@math.cnrs.fr> # @@ -17,6 +17,7 @@ gmm_ot_plan, gmm_ot_apply_map, gmm_ot_plan_density, + gmm_barycenter_fixed_point, ) try: @@ -193,3 +194,54 @@ def test_gmm_ot_plan_density(nx): with pytest.raises(AssertionError): gmm_ot_plan_density(x[:, 1:], y, m_s, m_t, C_s, C_t, w_s, w_t) + + +@pytest.skip_backend("tf") # skips because of array assignment +@pytest.skip_backend("jax") +def test_gmm_barycenter_fixed_point(nx): + m_s, m_t, C_s, C_t, w_s, w_t = get_gmms(nx) + means_list = [m_s, m_t] + covs_list = [C_s, C_t] + w_list = [w_s, w_t] + n_iter = 3 + n = m_s.shape[0] # number of components of barycenter + means_init = m_s + covs_init = C_s + weights = nx.ones(2, type_as=m_s) / 2 # barycenter coefficients + + # with euclidean barycentric projections + means, covs = gmm_barycenter_fixed_point( + means_list, covs_list, w_list, means_init, covs_init, weights, iterations=n_iter + ) + + # with bures barycentric projections and assigned weights to uniform + means_bures_proj, covs_bures_proj, log = gmm_barycenter_fixed_point( + means_list, + covs_list, + w_list, + means_init, + covs_init, + weights, + iterations=n_iter, + w_bar=nx.ones(n, type_as=m_s) / n, + barycentric_proj_method="bures", + log=True, + ) + + assert "means_its" in log + assert "covs_its" in log + + assert np.allclose(means, means_bures_proj, atol=1e-6) + assert np.allclose(covs, covs_bures_proj, atol=1e-6) + + with pytest.raises(ValueError): + gmm_barycenter_fixed_point( + means_list, + covs_list, + w_list, + means_init, + covs_init, + weights, + iterations=n_iter, + barycentric_proj_method="unknown", + ) diff --git a/test/test_ot.py b/test/test_ot.py index f84f8773a..22612fa4a 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -395,6 +395,182 @@ def test_generalised_free_support_barycenter_backends(nx): np.testing.assert_allclose(Y, nx.to_numpy(Y2)) +def test_free_support_barycenter_generic_costs(): + measures_locations = [ + np.array([-1.0]).reshape((1, 1)), + np.array([1.0]).reshape((1, 1)), + ] + measures_weights = [np.array([1.0]), np.array([1.0])] + + X_init = np.array([-12.0]).reshape((1, 1)) + + # obvious barycenter location between two Diracs + bar_locations = np.array([0.0]).reshape((1, 1)) + + def cost(x, y): + return ot.dist(x, y) + + cost_list = [cost, cost] + + def ground_bary(y): + out = 0 + for yk in y: + out += yk / len(y) + return out + + X = ot.lp.free_support_barycenter_generic_costs( + measures_locations, measures_weights, X_init, cost_list, ground_bary + ) + + np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) + + # test with log and specific weights + X2, log = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + a=ot.unif(1), + log=True, + ) + + assert "X_list" in log + assert "exit_status" in log + assert "dX_list" in log + + np.testing.assert_allclose(X, X2, rtol=1e-5, atol=1e-7) + + # test with one iteration for Max Iterations Reached + X3, log2 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + numItermax=1, + log=True, + ) + assert log2["exit_status"] == "Max iterations reached" + + # test with a single callable cost + X3, log3 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost, + ground_bary, + numItermax=1, + log=True, + ) + + # test with no ground_bary but in numpy: requires pytorch backend + with pytest.raises(AssertionError): + ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary=None, + numItermax=1, + ) + + +@pytest.mark.skipif(not torch, reason="No torch available") +def test_free_support_barycenter_generic_costs_auto_ground_bary(): + measures_locations = [ + torch.tensor([1.0]).reshape((1, 1)), + torch.tensor([2.0]).reshape((1, 1)), + ] + measures_weights = [torch.tensor([1.0]), torch.tensor([1.0])] + + X_init = torch.tensor([1.2]).reshape((1, 1)) + + def cost(x, y): + return ot.dist(x, y) + + cost_list = [cost, cost] + + def ground_bary(y): + out = 0 + for yk in y: + out += yk / len(y) + return out + + X = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + numItermax=1, + ) + + X2, log2 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary=None, + ground_bary_lr=1e-2, + ground_bary_stopThr=1e-20, + ground_bary_numItermax=50, + numItermax=10, + log=True, + ) + + np.testing.assert_allclose(X2.numpy(), X.numpy(), rtol=1e-4, atol=1e-4) + + X3 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary=None, + ground_bary_lr=1e-2, + ground_bary_stopThr=1e-20, + ground_bary_numItermax=50, + numItermax=10, + ground_bary_solver="Adam", + ) + + np.testing.assert_allclose(X2.numpy(), X3.numpy(), rtol=1e-3, atol=1e-3) + + +def test_free_support_barycenter_generic_costs_backends(nx): + measures_locations = [ + np.array([-1.0]).reshape((1, 1)), + np.array([1.0]).reshape((1, 1)), + ] + measures_weights = [np.array([1.0]), np.array([1.0])] + X_init = np.array([-12.0]).reshape((1, 1)) + + def cost(x, y): + return ot.dist(x, y) + + cost_list = [cost, cost] + + def ground_bary(y): + out = 0 + for yk in y: + out += yk / len(y) + return out + + X = ot.lp.free_support_barycenter_generic_costs( + measures_locations, measures_weights, X_init, cost_list, ground_bary + ) + + measures_locations2 = nx.from_numpy(*measures_locations) + measures_weights2 = nx.from_numpy(*measures_weights) + X_init2 = nx.from_numpy(X_init) + + X2 = ot.lp.free_support_barycenter_generic_costs( + measures_locations2, measures_weights2, X_init2, cost_list, ground_bary + ) + + np.testing.assert_allclose(X, nx.to_numpy(X2)) + + @pytest.mark.skipif(not ot.lp._barycenter_solvers.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None]