From 590e4d714f73d6596cd14614c93b1c15e7426c51 Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 20 Jan 2025 11:43:10 +0100
Subject: [PATCH 01/23] ot.lp reorganise to avoid def in __init__

---
 CONTRIBUTORS.md          |   2 +-
 RELEASES.md              |   2 +
 ot/lp/__init__.py        | 876 +--------------------------------------
 ot/lp/barycenter.py      | 266 ++++++++++++
 ot/lp/network_simplex.py | 612 +++++++++++++++++++++++++++
 5 files changed, 887 insertions(+), 871 deletions(-)
 create mode 100644 ot/lp/barycenter.py
 create mode 100644 ot/lp/network_simplex.py

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 39f0b23d4..6f6a72737 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -48,7 +48,7 @@ The contributors to this library are:
 * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
 * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
 * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)
-* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein)
+* [Clément Bonet](https://clbonet.github.io) (Wasserstein on circle, Spherical Sliced-Wasserstein)
 * [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
 * [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
 * [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers)
diff --git a/RELEASES.md b/RELEASES.md
index 0ddac599b..e29be544e 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -6,6 +6,8 @@
 - Implement CG solvers for partial FGW (PR #687)
 - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693)
 - Automatic PR labeling and release file update check (PR #704)
+- Implement fixed-point solver for OT barycenters with generic cost functions
+  (generalizes `ot.lp.free_support_barycenter`). (PR #???)
 
 #### Closed issues
 - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 2b93e84f3..d11a5ee41 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -8,15 +8,17 @@
 #
 # License: MIT License
 
-import numpy as np
-import warnings
-
 from . import cvx
 from .cvx import barycenter
 from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize
+from .network_simplex import emd, emd2
+from .barycenter import (
+    free_support_barycenter, 
+    generalized_free_support_barycenter
+)
 
 # import compiled emd
-from .emd_wrap import emd_c, check_result, emd_1d_sorted
+from .emd_wrap import emd_1d_sorted
 from .solver_1d import (
     emd_1d,
     emd2_1d,
@@ -26,9 +28,6 @@
     semidiscrete_wasserstein2_unif_circle,
 )
 
-from ..utils import dist, list_to_array
-from ..backend import get_backend
-
 __all__ = [
     "emd",
     "emd2",
@@ -46,866 +45,3 @@
     "dmmot_monge_1dgrid_loss",
     "dmmot_monge_1dgrid_optimize",
 ]
-
-
-def check_number_threads(numThreads):
-    """Checks whether or not the requested number of threads has a valid value.
-
-    Parameters
-    ----------
-    numThreads : int or str
-        The requested number of threads, should either be a strictly positive integer or "max" or None
-
-    Returns
-    -------
-    numThreads : int
-        Corrected number of threads
-    """
-    if (numThreads is None) or (
-        isinstance(numThreads, str) and numThreads.lower() == "max"
-    ):
-        return -1
-    if (not isinstance(numThreads, int)) or numThreads < 1:
-        raise ValueError(
-            'numThreads should either be "max" or a strictly positive integer'
-        )
-    return numThreads
-
-
-def center_ot_dual(alpha0, beta0, a=None, b=None):
-    r"""Center dual OT potentials w.r.t. their weights
-
-    The main idea of this function is to find unique dual potentials
-    that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having
-    stability when multiple calling of the OT solver with small changes.
-
-    Basically we add another constraint to the potential that will not
-    change the objective value but will ensure unicity. The constraint
-    is the following:
-
-    .. math::
-        \alpha^T \mathbf{a} = \beta^T \mathbf{b}
-
-    in addition to the OT problem constraints.
-
-    since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing
-    a constant from both  :math:`\alpha_0` and :math:`\beta_0`.
-
-    .. math::
-        c &= \frac{\beta_0^T \mathbf{b} - \alpha_0^T \mathbf{a}}{\mathbf{1}^T \mathbf{b} + \mathbf{1}^T \mathbf{a}}
-
-        \alpha &= \alpha_0 + c
-
-        \beta &= \beta_0 + c
-
-    Parameters
-    ----------
-    alpha0 : (ns,) numpy.ndarray, float64
-        Source dual potential
-    beta0 : (nt,) numpy.ndarray, float64
-        Target dual potential
-    a : (ns,) numpy.ndarray, float64
-        Source histogram (uniform weight if empty list)
-    b : (nt,) numpy.ndarray, float64
-        Target histogram (uniform weight if empty list)
-
-    Returns
-    -------
-    alpha : (ns,) numpy.ndarray, float64
-        Source centered dual potential
-    beta : (nt,) numpy.ndarray, float64
-        Target centered dual potential
-
-    """
-    # if no weights are provided, use uniform
-    if a is None:
-        a = np.ones(alpha0.shape[0]) / alpha0.shape[0]
-    if b is None:
-        b = np.ones(beta0.shape[0]) / beta0.shape[0]
-
-    # compute constant that balances the weighted sums of the duals
-    c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum())
-
-    # update duals
-    alpha = alpha0 + c
-    beta = beta0 - c
-
-    return alpha, beta
-
-
-def estimate_dual_null_weights(alpha0, beta0, a, b, M):
-    r"""Estimate feasible values for 0-weighted dual potentials
-
-    The feasible values are computed efficiently but rather coarsely.
-
-    .. warning::
-        This function is necessary because the C++ solver in `emd_c`
-        discards all samples in the distributions with
-        zeros weights. This means that while the primal variable (transport
-        matrix) is exact, the solver only returns feasible dual potentials
-        on the samples with weights different from zero.
-
-    First we compute the constraints violations:
-
-    .. math::
-        \mathbf{V} = \alpha + \beta^T - \mathbf{M}
-
-    Next we compute the max amount of violation per row (:math:`\alpha`) and
-    columns (:math:`beta`)
-
-    .. math::
-        \mathbf{v^a}_i = \max_j \mathbf{V}_{i,j}
-
-        \mathbf{v^b}_j = \max_i \mathbf{V}_{i,j}
-
-    Finally we update the dual potential with 0 weights if a
-    constraint is violated
-
-    .. math::
-        \alpha_i = \alpha_i - \mathbf{v^a}_i \quad \text{ if } \mathbf{a}_i=0 \text{ and } \mathbf{v^a}_i>0
-
-        \beta_j = \beta_j - \mathbf{v^b}_j \quad \text{ if } \mathbf{b}_j=0 \text{ and } \mathbf{v^b}_j > 0
-
-    In the end the dual potentials are centered using function
-    :py:func:`ot.lp.center_ot_dual`.
-
-    Note that all those updates do not change the objective value of the
-    solution but provide dual potentials that do not violate the constraints.
-
-    Parameters
-    ----------
-    alpha0 : (ns,) numpy.ndarray, float64
-        Source dual potential
-    beta0 : (nt,) numpy.ndarray, float64
-        Target dual potential
-    alpha0 : (ns,) numpy.ndarray, float64
-        Source dual potential
-    beta0 : (nt,) numpy.ndarray, float64
-        Target dual potential
-    a : (ns,) numpy.ndarray, float64
-        Source distribution (uniform weights if empty list)
-    b : (nt,) numpy.ndarray, float64
-        Target distribution (uniform weights if empty list)
-    M : (ns,nt) numpy.ndarray, float64
-        Loss matrix (c-order array with type float64)
-
-    Returns
-    -------
-    alpha : (ns,) numpy.ndarray, float64
-        Source corrected dual potential
-    beta : (nt,) numpy.ndarray, float64
-        Target corrected dual potential
-
-    """
-
-    # binary indexing of non-zeros weights
-    asel = a != 0
-    bsel = b != 0
-
-    # compute dual constraints violation
-    constraint_violation = alpha0[:, None] + beta0[None, :] - M
-
-    # Compute largest violation per line and columns
-    aviol = np.max(constraint_violation, 1)
-    bviol = np.max(constraint_violation, 0)
-
-    # update corrects violation of
-    alpha_up = -1 * ~asel * np.maximum(aviol, 0)
-    beta_up = -1 * ~bsel * np.maximum(bviol, 0)
-
-    alpha = alpha0 + alpha_up
-    beta = beta0 + beta_up
-
-    return center_ot_dual(alpha, beta, a, b)
-
-
-def emd(
-    a,
-    b,
-    M,
-    numItermax=100000,
-    log=False,
-    center_dual=True,
-    numThreads=1,
-    check_marginals=True,
-):
-    r"""Solves the Earth Movers distance problem and returns the OT matrix
-
-
-    .. math::
-        \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
-
-        s.t. \ \gamma \mathbf{1} = \mathbf{a}
-
-             \gamma^T \mathbf{1} = \mathbf{b}
-
-             \gamma \geq 0
-
-    where :
-
-    - :math:`\mathbf{M}` is the metric cost matrix
-    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
-
-    .. warning:: Note that the :math:`\mathbf{M}` matrix in numpy needs to be a C-order
-        numpy.array in float64 format. It will be converted if not in this
-        format
-
-    .. note:: This function is backend-compatible and will work on arrays
-        from all compatible backends. But the algorithm uses the C++ CPU backend
-        which can lead to copy overhead on GPU arrays.
-
-    .. note:: This function will cast the computed transport plan to the data type
-        of the provided input with the following priority: :math:`\mathbf{a}`,
-        then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided.
-        Casting to an integer tensor might result in a loss of precision.
-        If this behaviour is unwanted, please make sure to provide a
-        floating point input.
-
-    .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
-
-    Uses the algorithm proposed in :ref:`[1] <references-emd>`.
-
-    Parameters
-    ----------
-    a : (ns,) array-like, float
-        Source histogram (uniform weight if empty list)
-    b : (nt,) array-like, float
-        Target histogram (uniform weight if empty list)
-    M : (ns,nt) array-like, float
-        Loss matrix (c-order array in numpy with type float64)
-    numItermax : int, optional (default=100000)
-        The maximum number of iterations before stopping the optimization
-        algorithm if it has not converged.
-    log: bool, optional (default=False)
-        If True, returns a dictionary containing the cost and dual variables.
-        Otherwise returns only the optimal transportation matrix.
-    center_dual: boolean, optional (default=True)
-        If True, centers the dual potential using function
-        :py:func:`ot.lp.center_ot_dual`.
-    numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
-        If compiled with OpenMP, chooses the number of threads to parallelize.
-        "max" selects the highest number possible.
-    check_marginals: bool, optional (default=True)
-        If True, checks that the marginals mass are equal. If False, skips the
-        check.
-
-
-    Returns
-    -------
-    gamma: array-like, shape (ns, nt)
-        Optimal transportation matrix for the given
-        parameters
-    log: dict, optional
-        If input log is true, a dictionary containing the
-        cost and dual variables and exit status
-
-
-    Examples
-    --------
-
-    Simple example with obvious solution. The function emd accepts lists and
-    perform automatic conversion to numpy arrays
-
-    >>> import ot
-    >>> a=[.5,.5]
-    >>> b=[.5,.5]
-    >>> M=[[0.,1.],[1.,0.]]
-    >>> ot.emd(a, b, M)
-    array([[0.5, 0. ],
-           [0. , 0.5]])
-
-
-    .. _references-emd:
-    References
-    ----------
-    .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
-        December).  Displacement interpolation using Lagrangian mass transport.
-        In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
-
-    See Also
-    --------
-    ot.bregman.sinkhorn : Entropic regularized OT
-    ot.optim.cg : General regularized OT
-    """
-
-    a, b, M = list_to_array(a, b, M)
-    nx = get_backend(M, a, b)
-
-    if len(a) != 0:
-        type_as = a
-    elif len(b) != 0:
-        type_as = b
-    else:
-        type_as = M
-
-    # if empty array given then use uniform distributions
-    if len(a) == 0:
-        a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0]
-    if len(b) == 0:
-        b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]
-
-    # convert to numpy
-    M, a, b = nx.to_numpy(M, a, b)
-
-    # ensure float64
-    a = np.asarray(a, dtype=np.float64)
-    b = np.asarray(b, dtype=np.float64)
-    M = np.asarray(M, dtype=np.float64, order="C")
-
-    # if empty array given then use uniform distributions
-    if len(a) == 0:
-        a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
-    if len(b) == 0:
-        b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
-
-    assert (
-        a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]
-    ), "Dimension mismatch, check dimensions of M with a and b"
-
-    # ensure that same mass
-    if check_marginals:
-        np.testing.assert_almost_equal(
-            a.sum(0),
-            b.sum(0),
-            err_msg="a and b vector must have the same sum",
-            decimal=6,
-        )
-    b = b * a.sum() / b.sum()
-
-    asel = a != 0
-    bsel = b != 0
-
-    numThreads = check_number_threads(numThreads)
-
-    G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
-
-    if center_dual:
-        u, v = center_ot_dual(u, v, a, b)
-
-    if np.any(~asel) or np.any(~bsel):
-        u, v = estimate_dual_null_weights(u, v, a, b, M)
-
-    result_code_string = check_result(result_code)
-    if not nx.is_floating_point(type_as):
-        warnings.warn(
-            "Input histogram consists of integer. The transport plan will be "
-            "casted accordingly, possibly resulting in a loss of precision. "
-            "If this behaviour is unwanted, please make sure your input "
-            "histogram consists of floating point elements.",
-            stacklevel=2,
-        )
-    if log:
-        log = {}
-        log["cost"] = cost
-        log["u"] = nx.from_numpy(u, type_as=type_as)
-        log["v"] = nx.from_numpy(v, type_as=type_as)
-        log["warning"] = result_code_string
-        log["result_code"] = result_code
-        return nx.from_numpy(G, type_as=type_as), log
-    return nx.from_numpy(G, type_as=type_as)
-
-
-def emd2(
-    a,
-    b,
-    M,
-    processes=1,
-    numItermax=100000,
-    log=False,
-    return_matrix=False,
-    center_dual=True,
-    numThreads=1,
-    check_marginals=True,
-):
-    r"""Solves the Earth Movers distance problem and returns the loss
-
-    .. math::
-        \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
-
-        s.t. \ \gamma \mathbf{1} = \mathbf{a}
-
-             \gamma^T \mathbf{1} = \mathbf{b}
-
-             \gamma \geq 0
-
-    where :
-
-    - :math:`\mathbf{M}` is the metric cost matrix
-    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
-
-    .. note:: This function is backend-compatible and will work on arrays
-        from all compatible backends. But the algorithm uses the C++ CPU backend
-        which can lead to copy overhead on GPU arrays.
-
-    .. note:: This function will cast the computed transport plan and
-        transportation loss to the data type of the provided input with the
-        following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`,
-        then :math:`\mathbf{M}` if marginals are not provided.
-        Casting to an integer tensor might result in a loss of precision.
-        If this behaviour is unwanted, please make sure to provide a
-        floating point input.
-
-    .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
-
-    Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
-
-    Parameters
-    ----------
-    a : (ns,) array-like, float64
-        Source histogram (uniform weight if empty list)
-    b : (nt,) array-like, float64
-        Target histogram (uniform weight if empty list)
-    M : (ns,nt) array-like, float64
-        Loss matrix (for numpy c-order array with type float64)
-    processes : int, optional (default=1)
-        Nb of processes used for multiple emd computation (deprecated)
-    numItermax : int, optional (default=100000)
-        The maximum number of iterations before stopping the optimization
-        algorithm if it has not converged.
-    log: boolean, optional (default=False)
-        If True, returns a dictionary containing dual
-        variables. Otherwise returns only the optimal transportation cost.
-    return_matrix: boolean, optional (default=False)
-        If True, returns the optimal transportation matrix in the log.
-    center_dual: boolean, optional (default=True)
-        If True, centers the dual potential using function
-        :py:func:`ot.lp.center_ot_dual`.
-    numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
-        If compiled with OpenMP, chooses the number of threads to parallelize.
-        "max" selects the highest number possible.
-    check_marginals: bool, optional (default=True)
-        If True, checks that the marginals mass are equal. If False, skips the
-        check.
-
-
-    Returns
-    -------
-    W: float, array-like
-        Optimal transportation loss for the given parameters
-    log: dict
-        If input log is true, a dictionary containing dual
-        variables and exit status
-
-
-    Examples
-    --------
-
-    Simple example with obvious solution. The function emd accepts lists and
-    perform automatic conversion to numpy arrays
-
-
-    >>> import ot
-    >>> a=[.5,.5]
-    >>> b=[.5,.5]
-    >>> M=[[0.,1.],[1.,0.]]
-    >>> ot.emd2(a,b,M)
-    0.0
-
-
-    .. _references-emd2:
-    References
-    ----------
-    .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
-        (2011, December).  Displacement interpolation using Lagrangian mass
-        transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
-        158). ACM.
-
-    See Also
-    --------
-    ot.bregman.sinkhorn : Entropic regularized OT
-    ot.optim.cg : General regularized OT
-    """
-
-    a, b, M = list_to_array(a, b, M)
-    nx = get_backend(M, a, b)
-
-    if len(a) != 0:
-        type_as = a
-    elif len(b) != 0:
-        type_as = b
-    else:
-        type_as = M
-
-    # if empty array given then use uniform distributions
-    if len(a) == 0:
-        a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0]
-    if len(b) == 0:
-        b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]
-
-    # store original tensors
-    a0, b0, M0 = a, b, M
-
-    # convert to numpy
-    M, a, b = nx.to_numpy(M, a, b)
-
-    a = np.asarray(a, dtype=np.float64)
-    b = np.asarray(b, dtype=np.float64)
-    M = np.asarray(M, dtype=np.float64, order="C")
-
-    assert (
-        a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]
-    ), "Dimension mismatch, check dimensions of M with a and b"
-
-    # ensure that same mass
-    if check_marginals:
-        np.testing.assert_almost_equal(
-            a.sum(0),
-            b.sum(0, keepdims=True),
-            err_msg="a and b vector must have the same sum",
-            decimal=6,
-        )
-    b = b * a.sum(0) / b.sum(0, keepdims=True)
-
-    asel = a != 0
-
-    numThreads = check_number_threads(numThreads)
-
-    if log or return_matrix:
-
-        def f(b):
-            bsel = b != 0
-
-            G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
-
-            if center_dual:
-                u, v = center_ot_dual(u, v, a, b)
-
-            if np.any(~asel) or np.any(~bsel):
-                u, v = estimate_dual_null_weights(u, v, a, b, M)
-
-            result_code_string = check_result(result_code)
-            log = {}
-            if not nx.is_floating_point(type_as):
-                warnings.warn(
-                    "Input histogram consists of integer. The transport plan will be "
-                    "casted accordingly, possibly resulting in a loss of precision. "
-                    "If this behaviour is unwanted, please make sure your input "
-                    "histogram consists of floating point elements.",
-                    stacklevel=2,
-                )
-            G = nx.from_numpy(G, type_as=type_as)
-            if return_matrix:
-                log["G"] = G
-            log["u"] = nx.from_numpy(u, type_as=type_as)
-            log["v"] = nx.from_numpy(v, type_as=type_as)
-            log["warning"] = result_code_string
-            log["result_code"] = result_code
-            cost = nx.set_gradients(
-                nx.from_numpy(cost, type_as=type_as),
-                (a0, b0, M0),
-                (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), G),
-            )
-            return [cost, log]
-    else:
-
-        def f(b):
-            bsel = b != 0
-            G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
-
-            if center_dual:
-                u, v = center_ot_dual(u, v, a, b)
-
-            if np.any(~asel) or np.any(~bsel):
-                u, v = estimate_dual_null_weights(u, v, a, b, M)
-
-            if not nx.is_floating_point(type_as):
-                warnings.warn(
-                    "Input histogram consists of integer. The transport plan will be "
-                    "casted accordingly, possibly resulting in a loss of precision. "
-                    "If this behaviour is unwanted, please make sure your input "
-                    "histogram consists of floating point elements.",
-                    stacklevel=2,
-                )
-            G = nx.from_numpy(G, type_as=type_as)
-            cost = nx.set_gradients(
-                nx.from_numpy(cost, type_as=type_as),
-                (a0, b0, M0),
-                (
-                    nx.from_numpy(u - np.mean(u), type_as=type_as),
-                    nx.from_numpy(v - np.mean(v), type_as=type_as),
-                    G,
-                ),
-            )
-
-            check_result(result_code)
-            return cost
-
-    if len(b.shape) == 1:
-        return f(b)
-    nb = b.shape[1]
-
-    if processes > 1:
-        warnings.warn(
-            "The 'processes' parameter has been deprecated. "
-            "Multiprocessing should be done outside of POT."
-        )
-    res = list(map(f, [b[:, i].copy() for i in range(nb)]))
-
-    return res
-
-
-def free_support_barycenter(
-    measures_locations,
-    measures_weights,
-    X_init,
-    b=None,
-    weights=None,
-    numItermax=100,
-    stopThr=1e-7,
-    verbose=False,
-    log=None,
-    numThreads=1,
-):
-    r"""
-    Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally:
-
-    .. math::
-        \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_2^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i)
-
-    where :
-
-    - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
-    - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex)
-    - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations
-    - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
-
-    This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
-    There are two differences with the following codes:
-
-    - we do not optimize over the weights
-    - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
-      :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
-      implementation of the fixed-point algorithm of
-      :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting.
-
-    Parameters
-    ----------
-    measures_locations : list of N (k_i,d) array-like
-        The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
-        (:math:`k_i` can be different for each element of the list)
-    measures_weights : list of N (k_i,) array-like
-        Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
-        representing the weights of each discrete input measure
-
-    X_init : (k,d) array-like
-        Initialization of the support locations (on `k` atoms) of the barycenter
-    b : (k,) array-like
-        Initialization of the weights of the barycenter (non-negatives, sum to 1)
-    weights : (N,) array-like
-        Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
-
-    numItermax : int, optional
-        Max number of iterations
-    stopThr : float, optional
-        Stop threshold on error (>0)
-    verbose : bool, optional
-        Print information along iterations
-    log : bool, optional
-        record log if True
-    numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
-        If compiled with OpenMP, chooses the number of threads to parallelize.
-        "max" selects the highest number possible.
-
-
-    Returns
-    -------
-    X : (k,d) array-like
-        Support locations (on k atoms) of the barycenter
-
-
-    .. _references-free-support-barycenter:
-
-    References
-    ----------
-    .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
-
-    .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
-
-    """
-
-    nx = get_backend(*measures_locations, *measures_weights, X_init)
-
-    iter_count = 0
-
-    N = len(measures_locations)
-    k = X_init.shape[0]
-    d = X_init.shape[1]
-    if b is None:
-        b = nx.ones((k,), type_as=X_init) / k
-    if weights is None:
-        weights = nx.ones((N,), type_as=X_init) / N
-
-    X = X_init
-
-    log_dict = {}
-    displacement_square_norms = []
-
-    displacement_square_norm = stopThr + 1.0
-
-    while displacement_square_norm > stopThr and iter_count < numItermax:
-        T_sum = nx.zeros((k, d), type_as=X_init)
-
-        for measure_locations_i, measure_weights_i, weight_i in zip(
-            measures_locations, measures_weights, weights
-        ):
-            M_i = dist(X, measure_locations_i)
-            T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads)
-            T_sum = T_sum + weight_i * 1.0 / b[:, None] * nx.dot(
-                T_i, measure_locations_i
-            )
-
-        displacement_square_norm = nx.sum((T_sum - X) ** 2)
-        if log:
-            displacement_square_norms.append(displacement_square_norm)
-
-        X = T_sum
-
-        if verbose:
-            print(
-                "iteration %d, displacement_square_norm=%f\n",
-                iter_count,
-                displacement_square_norm,
-            )
-
-        iter_count += 1
-
-    if log:
-        log_dict["displacement_square_norms"] = displacement_square_norms
-        return X, log_dict
-    else:
-        return X
-
-
-def generalized_free_support_barycenter(
-    X_list,
-    a_list,
-    P_list,
-    n_samples_bary,
-    Y_init=None,
-    b=None,
-    weights=None,
-    numItermax=100,
-    stopThr=1e-7,
-    verbose=False,
-    log=None,
-    numThreads=1,
-    eps=0,
-):
-    r"""
-    Solves the free support generalized Wasserstein barycenter problem: finding a barycenter (a discrete measure with
-    a fixed amount of points of uniform weights) whose respective projections fit the input measures.
-    More formally:
-
-    .. math::
-        \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma)
-
-    where :
-
-    - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d`
-    - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter
-    - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}`
-    - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex)
-    - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations
-    - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex)
-    - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}`
-
-    As show by :ref:`[42] <references-generalized-free-support-barycenter>`,
-    this problem can be re-written as a Wasserstein Barycenter problem,
-    which we solve using the free support method :ref:`[20] <references-generalized-free-support-barycenter>`
-    (Algorithm 2).
-
-    Parameters
-    ----------
-    X_list : list of p (k_i,d_i) array-like
-        Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space
-        (:math:`k_i` can be different for each element of the list)
-    a_list : list of p (k_i,) array-like
-        Measure weights: each element is a vector (k_i) on the simplex
-    P_list : list of p (d_i,d) array-like
-        Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}`
-    n_samples_bary : int
-        Number of barycenter points
-    Y_init : (n_samples_bary,d) array-like
-        Initialization of the support locations (on `k` atoms) of the barycenter
-    b : (n_samples_bary,) array-like
-        Initialization of the weights of the barycenter measure (on the simplex)
-    weights : (p,) array-like
-        Initialization of the coefficients of the barycenter (on the simplex)
-    numItermax : int, optional
-        Max number of iterations
-    stopThr : float, optional
-        Stop threshold on error (>0)
-    verbose : bool, optional
-        Print information along iterations
-    log : bool, optional
-        record log if True
-    numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
-        If compiled with OpenMP, chooses the number of threads to parallelize.
-        "max" selects the highest number possible.
-    eps: Stability coefficient for the change of variable matrix inversion
-        If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix
-        inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense)
-
-
-    Returns
-    -------
-    Y : (n_samples_bary,d) array-like
-        Support locations (on n_samples_bary atoms) of the barycenter
-
-
-    .. _references-generalized-free-support-barycenter:
-    References
-    ----------
-    .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
-
-    .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021.
-
-    """
-    nx = get_backend(*X_list, *a_list, *P_list)
-    d = P_list[0].shape[1]
-    p = len(X_list)
-
-    if weights is None:
-        weights = nx.ones(p, type_as=X_list[0]) / p
-
-    # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB)
-    A = eps * nx.eye(
-        d, type_as=X_list[0]
-    )  # if eps nonzero: will force the invertibility of A
-    for P_i, lambda_i in zip(P_list, weights):
-        A = A + lambda_i * P_i.T @ P_i
-    B = nx.inv(nx.sqrtm(A))
-
-    Z_list = [
-        x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list)
-    ]  # change of variables -> (WB) problem on Z
-
-    if Y_init is None:
-        Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0])
-
-    if b is None:
-        b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary  # not optimized
-
-    out = free_support_barycenter(
-        Z_list,
-        a_list,
-        Y_init,
-        b,
-        numItermax=numItermax,
-        stopThr=stopThr,
-        verbose=verbose,
-        log=log,
-        numThreads=numThreads,
-    )
-
-    if log:  # unpack
-        Y, log_dict = out
-    else:
-        Y = out
-        log_dict = None
-    Y = Y @ B.T  # return to the Generalized WB formulation
-
-    if log:
-        return Y, log_dict
-    else:
-        return Y
diff --git a/ot/lp/barycenter.py b/ot/lp/barycenter.py
new file mode 100644
index 000000000..5468fb4eb
--- /dev/null
+++ b/ot/lp/barycenter.py
@@ -0,0 +1,266 @@
+
+def free_support_barycenter(
+    measures_locations,
+    measures_weights,
+    X_init,
+    b=None,
+    weights=None,
+    numItermax=100,
+    stopThr=1e-7,
+    verbose=False,
+    log=None,
+    numThreads=1,
+):
+    r"""
+    Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally:
+
+    .. math::
+        \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_2^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i)
+
+    where :
+
+    - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
+    - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex)
+    - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations
+    - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
+
+    This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
+    There are two differences with the following codes:
+
+    - we do not optimize over the weights
+    - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
+      :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
+      implementation of the fixed-point algorithm of
+      :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting.
+
+    Parameters
+    ----------
+    measures_locations : list of N (k_i,d) array-like
+        The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
+        (:math:`k_i` can be different for each element of the list)
+    measures_weights : list of N (k_i,) array-like
+        Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
+        representing the weights of each discrete input measure
+
+    X_init : (k,d) array-like
+        Initialization of the support locations (on `k` atoms) of the barycenter
+    b : (k,) array-like
+        Initialization of the weights of the barycenter (non-negatives, sum to 1)
+    weights : (N,) array-like
+        Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
+
+    numItermax : int, optional
+        Max number of iterations
+    stopThr : float, optional
+        Stop threshold on error (>0)
+    verbose : bool, optional
+        Print information along iterations
+    log : bool, optional
+        record log if True
+    numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+        If compiled with OpenMP, chooses the number of threads to parallelize.
+        "max" selects the highest number possible.
+
+
+    Returns
+    -------
+    X : (k,d) array-like
+        Support locations (on k atoms) of the barycenter
+
+
+    .. _references-free-support-barycenter:
+
+    References
+    ----------
+    .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+    .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
+    """
+
+    nx = get_backend(*measures_locations, *measures_weights, X_init)
+
+    iter_count = 0
+
+    N = len(measures_locations)
+    k = X_init.shape[0]
+    d = X_init.shape[1]
+    if b is None:
+        b = nx.ones((k,), type_as=X_init) / k
+    if weights is None:
+        weights = nx.ones((N,), type_as=X_init) / N
+
+    X = X_init
+
+    log_dict = {}
+    displacement_square_norms = []
+
+    displacement_square_norm = stopThr + 1.0
+
+    while displacement_square_norm > stopThr and iter_count < numItermax:
+        T_sum = nx.zeros((k, d), type_as=X_init)
+
+        for measure_locations_i, measure_weights_i, weight_i in zip(
+            measures_locations, measures_weights, weights
+        ):
+            M_i = dist(X, measure_locations_i)
+            T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads)
+            T_sum = T_sum + weight_i * 1.0 / b[:, None] * nx.dot(
+                T_i, measure_locations_i
+            )
+
+        displacement_square_norm = nx.sum((T_sum - X) ** 2)
+        if log:
+            displacement_square_norms.append(displacement_square_norm)
+
+        X = T_sum
+
+        if verbose:
+            print(
+                "iteration %d, displacement_square_norm=%f\n",
+                iter_count,
+                displacement_square_norm,
+            )
+
+        iter_count += 1
+
+    if log:
+        log_dict["displacement_square_norms"] = displacement_square_norms
+        return X, log_dict
+    else:
+        return X
+
+
+def generalized_free_support_barycenter(
+    X_list,
+    a_list,
+    P_list,
+    n_samples_bary,
+    Y_init=None,
+    b=None,
+    weights=None,
+    numItermax=100,
+    stopThr=1e-7,
+    verbose=False,
+    log=None,
+    numThreads=1,
+    eps=0,
+):
+    r"""
+    Solves the free support generalized Wasserstein barycenter problem: finding a barycenter (a discrete measure with
+    a fixed amount of points of uniform weights) whose respective projections fit the input measures.
+    More formally:
+
+    .. math::
+        \min_\gamma \quad \sum_{i=1}^p w_i W_2^2(\nu_i, \mathbf{P}_i\#\gamma)
+
+    where :
+
+    - :math:`\gamma = \sum_{l=1}^n b_l\delta_{y_l}` is the desired barycenter with each :math:`y_l \in \mathbb{R}^d`
+    - :math:`\mathbf{b} \in \mathbb{R}^{n}` is the desired weights vector of the barycenter
+    - The input measures are :math:`\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{x_{i,j}}`
+    - The :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}` are the respective empirical measures weights (on the simplex)
+    - The :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d_i}` are the respective empirical measures atoms locations
+    - :math:`w = (w_1, \cdots w_p)` are the barycenter coefficients (on the simplex)
+    - Each :math:`\mathbf{P}_i \in \mathbb{R}^{d, d_i}`, and :math:`P_i\#\nu_i = \sum_{j=1}^{k_i}a_{i,j}\delta_{P_ix_{i,j}}`
+
+    As show by :ref:`[42] <references-generalized-free-support-barycenter>`,
+    this problem can be re-written as a Wasserstein Barycenter problem,
+    which we solve using the free support method :ref:`[20] <references-generalized-free-support-barycenter>`
+    (Algorithm 2).
+
+    Parameters
+    ----------
+    X_list : list of p (k_i,d_i) array-like
+        Discrete supports of the input measures: each consists of :math:`k_i` locations of a `d_i`-dimensional space
+        (:math:`k_i` can be different for each element of the list)
+    a_list : list of p (k_i,) array-like
+        Measure weights: each element is a vector (k_i) on the simplex
+    P_list : list of p (d_i,d) array-like
+        Each :math:`P_i` is a linear map :math:`\mathbb{R}^{d} \rightarrow \mathbb{R}^{d_i}`
+    n_samples_bary : int
+        Number of barycenter points
+    Y_init : (n_samples_bary,d) array-like
+        Initialization of the support locations (on `k` atoms) of the barycenter
+    b : (n_samples_bary,) array-like
+        Initialization of the weights of the barycenter measure (on the simplex)
+    weights : (p,) array-like
+        Initialization of the coefficients of the barycenter (on the simplex)
+    numItermax : int, optional
+        Max number of iterations
+    stopThr : float, optional
+        Stop threshold on error (>0)
+    verbose : bool, optional
+        Print information along iterations
+    log : bool, optional
+        record log if True
+    numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+        If compiled with OpenMP, chooses the number of threads to parallelize.
+        "max" selects the highest number possible.
+    eps: Stability coefficient for the change of variable matrix inversion
+        If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix
+        inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense)
+
+
+    Returns
+    -------
+    Y : (n_samples_bary,d) array-like
+        Support locations (on n_samples_bary atoms) of the barycenter
+
+
+    .. _references-generalized-free-support-barycenter:
+    References
+    ----------
+    .. [20] Cuturi, M. and Doucet, A.. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+    .. [42] Delon, J., Gozlan, N., and Saint-Dizier, A.. Generalized Wasserstein barycenters between probability measures living on different subspaces. arXiv preprint arXiv:2105.09755, 2021.
+
+    """
+    nx = get_backend(*X_list, *a_list, *P_list)
+    d = P_list[0].shape[1]
+    p = len(X_list)
+
+    if weights is None:
+        weights = nx.ones(p, type_as=X_list[0]) / p
+
+    # variable change matrix to reduce the problem to a Wasserstein Barycenter (WB)
+    A = eps * nx.eye(
+        d, type_as=X_list[0]
+    )  # if eps nonzero: will force the invertibility of A
+    for P_i, lambda_i in zip(P_list, weights):
+        A = A + lambda_i * P_i.T @ P_i
+    B = nx.inv(nx.sqrtm(A))
+
+    Z_list = [
+        x @ Pi @ B.T for (x, Pi) in zip(X_list, P_list)
+    ]  # change of variables -> (WB) problem on Z
+
+    if Y_init is None:
+        Y_init = nx.randn(n_samples_bary, d, type_as=X_list[0])
+
+    if b is None:
+        b = nx.ones(n_samples_bary, type_as=X_list[0]) / n_samples_bary  # not optimized
+
+    out = free_support_barycenter(
+        Z_list,
+        a_list,
+        Y_init,
+        b,
+        numItermax=numItermax,
+        stopThr=stopThr,
+        verbose=verbose,
+        log=log,
+        numThreads=numThreads,
+    )
+
+    if log:  # unpack
+        Y, log_dict = out
+    else:
+        Y = out
+        log_dict = None
+    Y = Y @ B.T  # return to the Generalized WB formulation
+
+    if log:
+        return Y, log_dict
+    else:
+        return Y
diff --git a/ot/lp/network_simplex.py b/ot/lp/network_simplex.py
new file mode 100644
index 000000000..0e820fec6
--- /dev/null
+++ b/ot/lp/network_simplex.py
@@ -0,0 +1,612 @@
+# -*- coding: utf-8 -*-
+"""
+Solvers for the original linear program OT problem.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import warnings
+
+from ..utils import list_to_array
+from ..backend import get_backend
+from .emd_wrap import emd_c, check_result
+
+
+def check_number_threads(numThreads):
+    """Checks whether or not the requested number of threads has a valid value.
+
+    Parameters
+    ----------
+    numThreads : int or str
+        The requested number of threads, should either be a strictly positive integer or "max" or None
+
+    Returns
+    -------
+    numThreads : int
+        Corrected number of threads
+    """
+    if (numThreads is None) or (
+        isinstance(numThreads, str) and numThreads.lower() == "max"
+    ):
+        return -1
+    if (not isinstance(numThreads, int)) or numThreads < 1:
+        raise ValueError(
+            'numThreads should either be "max" or a strictly positive integer'
+        )
+    return numThreads
+
+
+def center_ot_dual(alpha0, beta0, a=None, b=None):
+    r"""Center dual OT potentials w.r.t. their weights
+
+    The main idea of this function is to find unique dual potentials
+    that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having
+    stability when multiple calling of the OT solver with small changes.
+
+    Basically we add another constraint to the potential that will not
+    change the objective value but will ensure unicity. The constraint
+    is the following:
+
+    .. math::
+        \alpha^T \mathbf{a} = \beta^T \mathbf{b}
+
+    in addition to the OT problem constraints.
+
+    since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing
+    a constant from both  :math:`\alpha_0` and :math:`\beta_0`.
+
+    .. math::
+        c &= \frac{\beta_0^T \mathbf{b} - \alpha_0^T \mathbf{a}}{\mathbf{1}^T \mathbf{b} + \mathbf{1}^T \mathbf{a}}
+
+        \alpha &= \alpha_0 + c
+
+        \beta &= \beta_0 + c
+
+    Parameters
+    ----------
+    alpha0 : (ns,) numpy.ndarray, float64
+        Source dual potential
+    beta0 : (nt,) numpy.ndarray, float64
+        Target dual potential
+    a : (ns,) numpy.ndarray, float64
+        Source histogram (uniform weight if empty list)
+    b : (nt,) numpy.ndarray, float64
+        Target histogram (uniform weight if empty list)
+
+    Returns
+    -------
+    alpha : (ns,) numpy.ndarray, float64
+        Source centered dual potential
+    beta : (nt,) numpy.ndarray, float64
+        Target centered dual potential
+
+    """
+    # if no weights are provided, use uniform
+    if a is None:
+        a = np.ones(alpha0.shape[0]) / alpha0.shape[0]
+    if b is None:
+        b = np.ones(beta0.shape[0]) / beta0.shape[0]
+
+    # compute constant that balances the weighted sums of the duals
+    c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum())
+
+    # update duals
+    alpha = alpha0 + c
+    beta = beta0 - c
+
+    return alpha, beta
+
+
+def estimate_dual_null_weights(alpha0, beta0, a, b, M):
+    r"""Estimate feasible values for 0-weighted dual potentials
+
+    The feasible values are computed efficiently but rather coarsely.
+
+    .. warning::
+        This function is necessary because the C++ solver in `emd_c`
+        discards all samples in the distributions with
+        zeros weights. This means that while the primal variable (transport
+        matrix) is exact, the solver only returns feasible dual potentials
+        on the samples with weights different from zero.
+
+    First we compute the constraints violations:
+
+    .. math::
+        \mathbf{V} = \alpha + \beta^T - \mathbf{M}
+
+    Next we compute the max amount of violation per row (:math:`\alpha`) and
+    columns (:math:`beta`)
+
+    .. math::
+        \mathbf{v^a}_i = \max_j \mathbf{V}_{i,j}
+
+        \mathbf{v^b}_j = \max_i \mathbf{V}_{i,j}
+
+    Finally we update the dual potential with 0 weights if a
+    constraint is violated
+
+    .. math::
+        \alpha_i = \alpha_i - \mathbf{v^a}_i \quad \text{ if } \mathbf{a}_i=0 \text{ and } \mathbf{v^a}_i>0
+
+        \beta_j = \beta_j - \mathbf{v^b}_j \quad \text{ if } \mathbf{b}_j=0 \text{ and } \mathbf{v^b}_j > 0
+
+    In the end the dual potentials are centered using function
+    :py:func:`ot.lp.center_ot_dual`.
+
+    Note that all those updates do not change the objective value of the
+    solution but provide dual potentials that do not violate the constraints.
+
+    Parameters
+    ----------
+    alpha0 : (ns,) numpy.ndarray, float64
+        Source dual potential
+    beta0 : (nt,) numpy.ndarray, float64
+        Target dual potential
+    alpha0 : (ns,) numpy.ndarray, float64
+        Source dual potential
+    beta0 : (nt,) numpy.ndarray, float64
+        Target dual potential
+    a : (ns,) numpy.ndarray, float64
+        Source distribution (uniform weights if empty list)
+    b : (nt,) numpy.ndarray, float64
+        Target distribution (uniform weights if empty list)
+    M : (ns,nt) numpy.ndarray, float64
+        Loss matrix (c-order array with type float64)
+
+    Returns
+    -------
+    alpha : (ns,) numpy.ndarray, float64
+        Source corrected dual potential
+    beta : (nt,) numpy.ndarray, float64
+        Target corrected dual potential
+
+    """
+
+    # binary indexing of non-zeros weights
+    asel = a != 0
+    bsel = b != 0
+
+    # compute dual constraints violation
+    constraint_violation = alpha0[:, None] + beta0[None, :] - M
+
+    # Compute largest violation per line and columns
+    aviol = np.max(constraint_violation, 1)
+    bviol = np.max(constraint_violation, 0)
+
+    # update corrects violation of
+    alpha_up = -1 * ~asel * np.maximum(aviol, 0)
+    beta_up = -1 * ~bsel * np.maximum(bviol, 0)
+
+    alpha = alpha0 + alpha_up
+    beta = beta0 + beta_up
+
+    return center_ot_dual(alpha, beta, a, b)
+
+
+def emd(
+    a,
+    b,
+    M,
+    numItermax=100000,
+    log=False,
+    center_dual=True,
+    numThreads=1,
+    check_marginals=True,
+):
+    r"""Solves the Earth Movers distance problem and returns the OT matrix
+
+
+    .. math::
+        \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
+
+        s.t. \ \gamma \mathbf{1} = \mathbf{a}
+
+             \gamma^T \mathbf{1} = \mathbf{b}
+
+             \gamma \geq 0
+
+    where :
+
+    - :math:`\mathbf{M}` is the metric cost matrix
+    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
+
+    .. warning:: Note that the :math:`\mathbf{M}` matrix in numpy needs to be a C-order
+        numpy.array in float64 format. It will be converted if not in this
+        format
+
+    .. note:: This function is backend-compatible and will work on arrays
+        from all compatible backends. But the algorithm uses the C++ CPU backend
+        which can lead to copy overhead on GPU arrays.
+
+    .. note:: This function will cast the computed transport plan to the data type
+        of the provided input with the following priority: :math:`\mathbf{a}`,
+        then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided.
+        Casting to an integer tensor might result in a loss of precision.
+        If this behaviour is unwanted, please make sure to provide a
+        floating point input.
+
+    .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
+
+    Uses the algorithm proposed in :ref:`[1] <references-emd>`.
+
+    Parameters
+    ----------
+    a : (ns,) array-like, float
+        Source histogram (uniform weight if empty list)
+    b : (nt,) array-like, float
+        Target histogram (uniform weight if empty list)
+    M : (ns,nt) array-like, float
+        Loss matrix (c-order array in numpy with type float64)
+    numItermax : int, optional (default=100000)
+        The maximum number of iterations before stopping the optimization
+        algorithm if it has not converged.
+    log: bool, optional (default=False)
+        If True, returns a dictionary containing the cost and dual variables.
+        Otherwise returns only the optimal transportation matrix.
+    center_dual: boolean, optional (default=True)
+        If True, centers the dual potential using function
+        :py:func:`ot.lp.center_ot_dual`.
+    numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+        If compiled with OpenMP, chooses the number of threads to parallelize.
+        "max" selects the highest number possible.
+    check_marginals: bool, optional (default=True)
+        If True, checks that the marginals mass are equal. If False, skips the
+        check.
+
+
+    Returns
+    -------
+    gamma: array-like, shape (ns, nt)
+        Optimal transportation matrix for the given
+        parameters
+    log: dict, optional
+        If input log is true, a dictionary containing the
+        cost and dual variables and exit status
+
+
+    Examples
+    --------
+
+    Simple example with obvious solution. The function emd accepts lists and
+    perform automatic conversion to numpy arrays
+
+    >>> import ot
+    >>> a=[.5,.5]
+    >>> b=[.5,.5]
+    >>> M=[[0.,1.],[1.,0.]]
+    >>> ot.emd(a, b, M)
+    array([[0.5, 0. ],
+           [0. , 0.5]])
+
+
+    .. _references-emd:
+    References
+    ----------
+    .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
+        December).  Displacement interpolation using Lagrangian mass transport.
+        In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
+
+    See Also
+    --------
+    ot.bregman.sinkhorn : Entropic regularized OT
+    ot.optim.cg : General regularized OT
+    """
+
+    a, b, M = list_to_array(a, b, M)
+    nx = get_backend(M, a, b)
+
+    if len(a) != 0:
+        type_as = a
+    elif len(b) != 0:
+        type_as = b
+    else:
+        type_as = M
+
+    # if empty array given then use uniform distributions
+    if len(a) == 0:
+        a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0]
+    if len(b) == 0:
+        b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]
+
+    # convert to numpy
+    M, a, b = nx.to_numpy(M, a, b)
+
+    # ensure float64
+    a = np.asarray(a, dtype=np.float64)
+    b = np.asarray(b, dtype=np.float64)
+    M = np.asarray(M, dtype=np.float64, order="C")
+
+    # if empty array given then use uniform distributions
+    if len(a) == 0:
+        a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
+    if len(b) == 0:
+        b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
+
+    assert (
+        a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]
+    ), "Dimension mismatch, check dimensions of M with a and b"
+
+    # ensure that same mass
+    if check_marginals:
+        np.testing.assert_almost_equal(
+            a.sum(0),
+            b.sum(0),
+            err_msg="a and b vector must have the same sum",
+            decimal=6,
+        )
+    b = b * a.sum() / b.sum()
+
+    asel = a != 0
+    bsel = b != 0
+
+    numThreads = check_number_threads(numThreads)
+
+    G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
+
+    if center_dual:
+        u, v = center_ot_dual(u, v, a, b)
+
+    if np.any(~asel) or np.any(~bsel):
+        u, v = estimate_dual_null_weights(u, v, a, b, M)
+
+    result_code_string = check_result(result_code)
+    if not nx.is_floating_point(type_as):
+        warnings.warn(
+            "Input histogram consists of integer. The transport plan will be "
+            "casted accordingly, possibly resulting in a loss of precision. "
+            "If this behaviour is unwanted, please make sure your input "
+            "histogram consists of floating point elements.",
+            stacklevel=2,
+        )
+    if log:
+        log = {}
+        log["cost"] = cost
+        log["u"] = nx.from_numpy(u, type_as=type_as)
+        log["v"] = nx.from_numpy(v, type_as=type_as)
+        log["warning"] = result_code_string
+        log["result_code"] = result_code
+        return nx.from_numpy(G, type_as=type_as), log
+    return nx.from_numpy(G, type_as=type_as)
+
+
+def emd2(
+    a,
+    b,
+    M,
+    processes=1,
+    numItermax=100000,
+    log=False,
+    return_matrix=False,
+    center_dual=True,
+    numThreads=1,
+    check_marginals=True,
+):
+    r"""Solves the Earth Movers distance problem and returns the loss
+
+    .. math::
+        \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F
+
+        s.t. \ \gamma \mathbf{1} = \mathbf{a}
+
+             \gamma^T \mathbf{1} = \mathbf{b}
+
+             \gamma \geq 0
+
+    where :
+
+    - :math:`\mathbf{M}` is the metric cost matrix
+    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
+
+    .. note:: This function is backend-compatible and will work on arrays
+        from all compatible backends. But the algorithm uses the C++ CPU backend
+        which can lead to copy overhead on GPU arrays.
+
+    .. note:: This function will cast the computed transport plan and
+        transportation loss to the data type of the provided input with the
+        following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`,
+        then :math:`\mathbf{M}` if marginals are not provided.
+        Casting to an integer tensor might result in a loss of precision.
+        If this behaviour is unwanted, please make sure to provide a
+        floating point input.
+
+    .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.
+
+    Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
+
+    Parameters
+    ----------
+    a : (ns,) array-like, float64
+        Source histogram (uniform weight if empty list)
+    b : (nt,) array-like, float64
+        Target histogram (uniform weight if empty list)
+    M : (ns,nt) array-like, float64
+        Loss matrix (for numpy c-order array with type float64)
+    processes : int, optional (default=1)
+        Nb of processes used for multiple emd computation (deprecated)
+    numItermax : int, optional (default=100000)
+        The maximum number of iterations before stopping the optimization
+        algorithm if it has not converged.
+    log: boolean, optional (default=False)
+        If True, returns a dictionary containing dual
+        variables. Otherwise returns only the optimal transportation cost.
+    return_matrix: boolean, optional (default=False)
+        If True, returns the optimal transportation matrix in the log.
+    center_dual: boolean, optional (default=True)
+        If True, centers the dual potential using function
+        :py:func:`ot.lp.center_ot_dual`.
+    numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
+        If compiled with OpenMP, chooses the number of threads to parallelize.
+        "max" selects the highest number possible.
+    check_marginals: bool, optional (default=True)
+        If True, checks that the marginals mass are equal. If False, skips the
+        check.
+
+
+    Returns
+    -------
+    W: float, array-like
+        Optimal transportation loss for the given parameters
+    log: dict
+        If input log is true, a dictionary containing dual
+        variables and exit status
+
+
+    Examples
+    --------
+
+    Simple example with obvious solution. The function emd accepts lists and
+    perform automatic conversion to numpy arrays
+
+
+    >>> import ot
+    >>> a=[.5,.5]
+    >>> b=[.5,.5]
+    >>> M=[[0.,1.],[1.,0.]]
+    >>> ot.emd2(a,b,M)
+    0.0
+
+
+    .. _references-emd2:
+    References
+    ----------
+    .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
+        (2011, December).  Displacement interpolation using Lagrangian mass
+        transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
+        158). ACM.
+
+    See Also
+    --------
+    ot.bregman.sinkhorn : Entropic regularized OT
+    ot.optim.cg : General regularized OT
+    """
+
+    a, b, M = list_to_array(a, b, M)
+    nx = get_backend(M, a, b)
+
+    if len(a) != 0:
+        type_as = a
+    elif len(b) != 0:
+        type_as = b
+    else:
+        type_as = M
+
+    # if empty array given then use uniform distributions
+    if len(a) == 0:
+        a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0]
+    if len(b) == 0:
+        b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]
+
+    # store original tensors
+    a0, b0, M0 = a, b, M
+
+    # convert to numpy
+    M, a, b = nx.to_numpy(M, a, b)
+
+    a = np.asarray(a, dtype=np.float64)
+    b = np.asarray(b, dtype=np.float64)
+    M = np.asarray(M, dtype=np.float64, order="C")
+
+    assert (
+        a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]
+    ), "Dimension mismatch, check dimensions of M with a and b"
+
+    # ensure that same mass
+    if check_marginals:
+        np.testing.assert_almost_equal(
+            a.sum(0),
+            b.sum(0, keepdims=True),
+            err_msg="a and b vector must have the same sum",
+            decimal=6,
+        )
+    b = b * a.sum(0) / b.sum(0, keepdims=True)
+
+    asel = a != 0
+
+    numThreads = check_number_threads(numThreads)
+
+    if log or return_matrix:
+
+        def f(b):
+            bsel = b != 0
+
+            G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
+
+            if center_dual:
+                u, v = center_ot_dual(u, v, a, b)
+
+            if np.any(~asel) or np.any(~bsel):
+                u, v = estimate_dual_null_weights(u, v, a, b, M)
+
+            result_code_string = check_result(result_code)
+            log = {}
+            if not nx.is_floating_point(type_as):
+                warnings.warn(
+                    "Input histogram consists of integer. The transport plan will be "
+                    "casted accordingly, possibly resulting in a loss of precision. "
+                    "If this behaviour is unwanted, please make sure your input "
+                    "histogram consists of floating point elements.",
+                    stacklevel=2,
+                )
+            G = nx.from_numpy(G, type_as=type_as)
+            if return_matrix:
+                log["G"] = G
+            log["u"] = nx.from_numpy(u, type_as=type_as)
+            log["v"] = nx.from_numpy(v, type_as=type_as)
+            log["warning"] = result_code_string
+            log["result_code"] = result_code
+            cost = nx.set_gradients(
+                nx.from_numpy(cost, type_as=type_as),
+                (a0, b0, M0),
+                (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), G),
+            )
+            return [cost, log]
+    else:
+
+        def f(b):
+            bsel = b != 0
+            G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
+
+            if center_dual:
+                u, v = center_ot_dual(u, v, a, b)
+
+            if np.any(~asel) or np.any(~bsel):
+                u, v = estimate_dual_null_weights(u, v, a, b, M)
+
+            if not nx.is_floating_point(type_as):
+                warnings.warn(
+                    "Input histogram consists of integer. The transport plan will be "
+                    "casted accordingly, possibly resulting in a loss of precision. "
+                    "If this behaviour is unwanted, please make sure your input "
+                    "histogram consists of floating point elements.",
+                    stacklevel=2,
+                )
+            G = nx.from_numpy(G, type_as=type_as)
+            cost = nx.set_gradients(
+                nx.from_numpy(cost, type_as=type_as),
+                (a0, b0, M0),
+                (
+                    nx.from_numpy(u - np.mean(u), type_as=type_as),
+                    nx.from_numpy(v - np.mean(v), type_as=type_as),
+                    G,
+                ),
+            )
+
+            check_result(result_code)
+            return cost
+
+    if len(b.shape) == 1:
+        return f(b)
+    nb = b.shape[1]
+
+    if processes > 1:
+        warnings.warn(
+            "The 'processes' parameter has been deprecated. "
+            "Multiprocessing should be done outside of POT."
+        )
+    res = list(map(f, [b[:, i].copy() for i in range(nb)]))
+
+    return res

From 109edb7534653c767d490703cfd631aad55a6592 Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 20 Jan 2025 11:53:38 +0100
Subject: [PATCH 02/23] pr number + enabled pre-commit

---
 RELEASES.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/RELEASES.md b/RELEASES.md
index e29be544e..2eae33215 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -7,7 +7,7 @@
 - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693)
 - Automatic PR labeling and release file update check (PR #704)
 - Implement fixed-point solver for OT barycenters with generic cost functions
-  (generalizes `ot.lp.free_support_barycenter`). (PR #???)
+  (generalizes `ot.lp.free_support_barycenter`). (PR #714)
 
 #### Closed issues
 - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)

From 0957904c9d4fb2bdba58a357899077192c1ee52d Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 20 Jan 2025 11:57:45 +0100
Subject: [PATCH 03/23] added barycenter.py imports

---
 ot/lp/barycenter.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/ot/lp/barycenter.py b/ot/lp/barycenter.py
index 5468fb4eb..b1411abe1 100644
--- a/ot/lp/barycenter.py
+++ b/ot/lp/barycenter.py
@@ -1,3 +1,7 @@
+from ..backend import get_backend
+from ..utils import dist
+from .network_simplex import emd
+
 
 def free_support_barycenter(
     measures_locations,

From 818b3e7a278af75ad5a95c50f3a599775193a768 Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 20 Jan 2025 12:10:21 +0100
Subject: [PATCH 04/23] fixed wrong import in ot.gmm

---
 ot/gmm.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/ot/gmm.py b/ot/gmm.py
index cde2f8bbd..5c7a4c287 100644
--- a/ot/gmm.py
+++ b/ot/gmm.py
@@ -12,7 +12,7 @@
 from .backend import get_backend
 from .lp import emd2, emd
 import numpy as np
-from .lp import dist
+from .utils import dist
 from .gaussian import bures_wasserstein_mapping
 
 

From 08c2285cafe4a1ee6517e799a043af3251031a6e Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 20 Jan 2025 12:24:20 +0100
Subject: [PATCH 05/23] ruff fix attempt

---
 README.md                                      |  7 ++++++-
 ot/gromov/_partial.py                          |  6 +++---
 ot/gromov/_quantized.py                        |  6 +++---
 ot/lp/__init__.py                              |  6 +++---
 ot/lp/{barycenter.py => barycenter_solvers.py} |  0
 ot/partial.py                                  | 14 +++++++-------
 ot/utils.py                                    |  4 ++--
 7 files changed, 24 insertions(+), 19 deletions(-)
 rename ot/lp/{barycenter.py => barycenter_solvers.py} (100%)

diff --git a/README.md b/README.md
index 7bbae9e8a..dd9622d9d 100644
--- a/README.md
+++ b/README.md
@@ -51,10 +51,11 @@ POT provides the following generic OT solvers (links to examples):
 * [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50].
 * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with  [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
 * [Smooth Strongly Convex Nearest Brenier Potentials](https://pythonot.github.io/auto_examples/others/plot_SSNB.html#sphx-glr-auto-examples-others-plot-ssnb-py) [58], with an extension to bounding potentials using [59].
-* Gaussian Mixture Model OT [69]
+* [Gaussian Mixture Model OT](https://pythonot.github.io/auto_examples/others/plot_GMMOT_plan.html#sphx-glr-auto-examples-others-plot-gmmot-plan-py) [69].
 * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and
 [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71].
 * Fused unbalanced Gromov-Wasserstein [70].
+* OT Barycenters for generic transport costs [].
 
 POT provides the following Machine Learning related solvers:
 
@@ -391,3 +392,7 @@ Artificial Intelligence.
 [72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS).
 
 [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
+
+[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing
+Barycentres of Measures for Generic Transport
+Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
\ No newline at end of file
diff --git a/ot/gromov/_partial.py b/ot/gromov/_partial.py
index c6837f1d3..6672240d0 100644
--- a/ot/gromov/_partial.py
+++ b/ot/gromov/_partial.py
@@ -185,7 +185,7 @@ def partial_gromov_wasserstein(
     if m is None:
         m = min(np.sum(p), np.sum(q))
     elif m < 0:
-        raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.")
+        raise ValueError("Problem infeasible. Parameter m should be greater than 0.")
     elif m > min(np.sum(p), np.sum(q)):
         raise ValueError(
             "Problem infeasible. Parameter m should lower or"
@@ -654,7 +654,7 @@ def partial_fused_gromov_wasserstein(
     if m is None:
         m = min(np.sum(p), np.sum(q))
     elif m < 0:
-        raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.")
+        raise ValueError("Problem infeasible. Parameter m should be greater than 0.")
     elif m > min(np.sum(p), np.sum(q)):
         raise ValueError(
             "Problem infeasible. Parameter m should lower or"
@@ -1213,7 +1213,7 @@ def entropic_partial_gromov_wasserstein(
     if m is None:
         m = min(nx.sum(p), nx.sum(q))
     elif m < 0:
-        raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.")
+        raise ValueError("Problem infeasible. Parameter m should be greater than 0.")
     elif m > min(nx.sum(p), nx.sum(q)):
         raise ValueError(
             "Problem infeasible. Parameter m should lower or"
diff --git a/ot/gromov/_quantized.py b/ot/gromov/_quantized.py
index ac2db5d2d..f4a8fafa7 100644
--- a/ot/gromov/_quantized.py
+++ b/ot/gromov/_quantized.py
@@ -375,7 +375,7 @@ def get_graph_partition(
         raise ValueError(
             f"""
             Unknown `part_method='{part_method}'`. Use one of:
-            {'random', 'louvain', 'fluid', 'spectral', 'GW', 'FGW'}.
+            {"random", "louvain", "fluid", "spectral", "GW", "FGW"}.
             """
         )
     return nx.from_numpy(part, type_as=C0)
@@ -447,7 +447,7 @@ def get_graph_representants(C, part, rep_method="pagerank", random_state=0, nx=N
         raise ValueError(
             f"""
             Unknown `rep_method='{rep_method}'`. Use one of:
-            {'random', 'pagerank'}.
+            {"random", "pagerank"}.
             """
         )
 
@@ -953,7 +953,7 @@ def get_partition_and_representants_samples(
     else:
         raise ValueError(
             f"""
-            Unknown `method='{method}'`. Use one of: {'random', 'kmeans'}
+            Unknown `method='{method}'`. Use one of: {"random", "kmeans"}
             """
         )
 
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index d11a5ee41..b29029243 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -12,9 +12,9 @@
 from .cvx import barycenter
 from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize
 from .network_simplex import emd, emd2
-from .barycenter import (
-    free_support_barycenter, 
-    generalized_free_support_barycenter
+from .barycenter_solvers import (
+    free_support_barycenter,
+    generalized_free_support_barycenter,
 )
 
 # import compiled emd
diff --git a/ot/lp/barycenter.py b/ot/lp/barycenter_solvers.py
similarity index 100%
rename from ot/lp/barycenter.py
rename to ot/lp/barycenter_solvers.py
diff --git a/ot/partial.py b/ot/partial.py
index c11ab228a..6b2304e08 100755
--- a/ot/partial.py
+++ b/ot/partial.py
@@ -126,7 +126,7 @@ def partial_wasserstein_lagrange(
     nx = get_backend(a, b, M)
 
     if nx.sum(a) > 1 + 1e-15 or nx.sum(b) > 1 + 1e-15:  # 1e-15 for numerical errors
-        raise ValueError("Problem infeasible. Check that a and b are in the " "simplex")
+        raise ValueError("Problem infeasible. Check that a and b are in the simplex")
 
     if reg_m is None:
         reg_m = float(nx.max(M)) + 1
@@ -171,7 +171,7 @@ def partial_wasserstein_lagrange(
 
     if log_emd["warning"] is not None:
         raise ValueError(
-            "Error in the EMD resolution: try to increase the" " number of dummy points"
+            "Error in the EMD resolution: try to increase the number of dummy points"
         )
     log_emd["cost"] = nx.sum(gamma * M0)
     log_emd["u"] = nx.from_numpy(log_emd["u"], type_as=a0)
@@ -287,7 +287,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
     if m is None:
         return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs)
     elif m < 0:
-        raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.")
+        raise ValueError("Problem infeasible. Parameter m should be greater than 0.")
     elif m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
         raise ValueError(
             "Problem infeasible. Parameter m should lower or"
@@ -315,7 +315,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
 
     if log_emd["warning"] is not None:
         raise ValueError(
-            "Error in the EMD resolution: try to increase the" " number of dummy points"
+            "Error in the EMD resolution: try to increase the number of dummy points"
         )
     log_emd["partial_w_dist"] = nx.sum(M * gamma)
     log_emd["u"] = log_emd["u"][: len(a)]
@@ -522,7 +522,7 @@ def entropic_partial_wasserstein(
     if m is None:
         m = nx.min(nx.stack((nx.sum(a), nx.sum(b)))) * 1.0
     if m < 0:
-        raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.")
+        raise ValueError("Problem infeasible. Parameter m should be greater than 0.")
     if m > nx.min(nx.stack((nx.sum(a), nx.sum(b)))):
         raise ValueError(
             "Problem infeasible. Parameter m should lower or"
@@ -780,7 +780,7 @@ def partial_gromov_wasserstein(
     if m is None:
         m = np.min((np.sum(p), np.sum(q)))
     elif m < 0:
-        raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.")
+        raise ValueError("Problem infeasible. Parameter m should be greater than 0.")
     elif m > np.min((np.sum(p), np.sum(q))):
         raise ValueError(
             "Problem infeasible. Parameter m should lower or"
@@ -1132,7 +1132,7 @@ def entropic_partial_gromov_wasserstein(
     if m is None:
         m = np.min((np.sum(p), np.sum(q)))
     elif m < 0:
-        raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.")
+        raise ValueError("Problem infeasible. Parameter m should be greater than 0.")
     elif m > np.min((np.sum(p), np.sum(q))):
         raise ValueError(
             "Problem infeasible. Parameter m should lower or"
diff --git a/ot/utils.py b/ot/utils.py
index a2d328484..42673ecd6 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -517,7 +517,7 @@ def check_random_state(seed):
     if isinstance(seed, np.random.RandomState):
         return seed
     raise ValueError(
-        "{} cannot be used to seed a numpy.random.RandomState" " instance".format(seed)
+        "{} cannot be used to seed a numpy.random.RandomState instance".format(seed)
     )
 
 
@@ -787,7 +787,7 @@ def _update_doc(self, olddoc):
 def _is_deprecated(func):
     r"""Helper to check if func is wrapped by our deprecated decorator"""
     if sys.version_info < (3, 5):
-        raise NotImplementedError("This is only available for python3.5 " "or above")
+        raise NotImplementedError("This is only available for python3.5 or above")
     closures = getattr(func, "__closure__", [])
     if closures is None:
         closures = []

From f26851586a7c03d4707a8ed710b8047f9acfc78c Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 20 Jan 2025 13:33:23 +0100
Subject: [PATCH 06/23] removed ot bar contribs -> only o.lp reorganisation in
 this PR

---
 README.md   | 5 -----
 RELEASES.md | 3 +--
 2 files changed, 1 insertion(+), 7 deletions(-)

diff --git a/README.md b/README.md
index dd9622d9d..f64db8f56 100644
--- a/README.md
+++ b/README.md
@@ -55,7 +55,6 @@ POT provides the following generic OT solvers (links to examples):
 * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and
 [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71].
 * Fused unbalanced Gromov-Wasserstein [70].
-* OT Barycenters for generic transport costs [].
 
 POT provides the following Machine Learning related solvers:
 
@@ -392,7 +391,3 @@ Artificial Intelligence.
 [72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS).
 
 [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
-
-[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing
-Barycentres of Measures for Generic Transport
-Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
\ No newline at end of file
diff --git a/RELEASES.md b/RELEASES.md
index 2eae33215..1550b479f 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -6,8 +6,7 @@
 - Implement CG solvers for partial FGW (PR #687)
 - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693)
 - Automatic PR labeling and release file update check (PR #704)
-- Implement fixed-point solver for OT barycenters with generic cost functions
-  (generalizes `ot.lp.free_support_barycenter`). (PR #714)
+- Moved functions from `ot/lp/__init__.py` to separate modules. (PR #714)
 
 #### Closed issues
 - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)

From 8f24cb95f28e8c1e3f80cb6e72e768f1b45cc2dc Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 20 Jan 2025 13:39:19 +0100
Subject: [PATCH 07/23] add check_number_threads to ot/lp/__init__.py __all__

---
 ot/lp/__init__.py        |  2 ++
 ot/lp/network_simplex.py | 26 +-------------------------
 ot/utils.py              | 24 ++++++++++++++++++++++++
 3 files changed, 27 insertions(+), 25 deletions(-)

diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index b29029243..548200123 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -16,6 +16,7 @@
     free_support_barycenter,
     generalized_free_support_barycenter,
 )
+from ..utils import check_number_threads
 
 # import compiled emd
 from .emd_wrap import emd_1d_sorted
@@ -44,4 +45,5 @@
     "semidiscrete_wasserstein2_unif_circle",
     "dmmot_monge_1dgrid_loss",
     "dmmot_monge_1dgrid_optimize",
+    "check_number_threads",
 ]
diff --git a/ot/lp/network_simplex.py b/ot/lp/network_simplex.py
index 0e820fec6..492e4c7ac 100644
--- a/ot/lp/network_simplex.py
+++ b/ot/lp/network_simplex.py
@@ -11,35 +11,11 @@
 import numpy as np
 import warnings
 
-from ..utils import list_to_array
+from ..utils import list_to_array, check_number_threads
 from ..backend import get_backend
 from .emd_wrap import emd_c, check_result
 
 
-def check_number_threads(numThreads):
-    """Checks whether or not the requested number of threads has a valid value.
-
-    Parameters
-    ----------
-    numThreads : int or str
-        The requested number of threads, should either be a strictly positive integer or "max" or None
-
-    Returns
-    -------
-    numThreads : int
-        Corrected number of threads
-    """
-    if (numThreads is None) or (
-        isinstance(numThreads, str) and numThreads.lower() == "max"
-    ):
-        return -1
-    if (not isinstance(numThreads, int)) or numThreads < 1:
-        raise ValueError(
-            'numThreads should either be "max" or a strictly positive integer'
-        )
-    return numThreads
-
-
 def center_ot_dual(alpha0, beta0, a=None, b=None):
     r"""Center dual OT potentials w.r.t. their weights
 
diff --git a/ot/utils.py b/ot/utils.py
index 42673ecd6..66ff7e354 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -1341,3 +1341,27 @@ def proj_SDP(S, nx=None, vmin=0.0):
         Q = nx.einsum("ijk,ik->ijk", P, w)  # Q[i] = P[i] @ diag(w[i])
         # R[i] = Q[i] @ P[i].T
         return nx.einsum("ijk,ikl->ijl", Q, nx.transpose(P, (0, 2, 1)))
+
+
+def check_number_threads(numThreads):
+    """Checks whether or not the requested number of threads has a valid value.
+
+    Parameters
+    ----------
+    numThreads : int or str
+        The requested number of threads, should either be a strictly positive integer or "max" or None
+
+    Returns
+    -------
+    numThreads : int
+        Corrected number of threads
+    """
+    if (numThreads is None) or (
+        isinstance(numThreads, str) and numThreads.lower() == "max"
+    ):
+        return -1
+    if (not isinstance(numThreads, int)) or numThreads < 1:
+        raise ValueError(
+            'numThreads should either be "max" or a strictly positive integer'
+        )
+    return numThreads

From 3e3b4445f4c1edf588c8d58bb218ccadd5ad0111 Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 20 Jan 2025 13:41:29 +0100
Subject: [PATCH 08/23] update releases

---
 RELEASES.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/RELEASES.md b/RELEASES.md
index 1550b479f..7d138c9c6 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -6,7 +6,7 @@
 - Implement CG solvers for partial FGW (PR #687)
 - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693)
 - Automatic PR labeling and release file update check (PR #704)
-- Moved functions from `ot/lp/__init__.py` to separate modules. (PR #714)
+- Reorganize sub-module  `ot/lp/__init__.py` into separate files. (PR #714) (PR #714)
 
 #### Closed issues
 - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)

From 566a0fc1e3171cd16cd22b58b926a58cc3c9a2cb Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 20 Jan 2025 15:07:46 +0100
Subject: [PATCH 09/23] made barycenter_solvers and network_simplex hidden +
 deprecated ot.lp.cvx

---
 RELEASES.md                                   |   2 +-
 ot/lp/__init__.py                             |   6 +-
 ...nter_solvers.py => _barycenter_solvers.py} | 156 +++++++++++++++++-
 ...network_simplex.py => _network_simplex.py} |   0
 ot/lp/cvx.py                                  | 148 +----------------
 5 files changed, 163 insertions(+), 149 deletions(-)
 rename ot/lp/{barycenter_solvers.py => _barycenter_solvers.py} (69%)
 rename ot/lp/{network_simplex.py => _network_simplex.py} (100%)

diff --git a/RELEASES.md b/RELEASES.md
index 7d138c9c6..a0474eda0 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -6,7 +6,7 @@
 - Implement CG solvers for partial FGW (PR #687)
 - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693)
 - Automatic PR labeling and release file update check (PR #704)
-- Reorganize sub-module  `ot/lp/__init__.py` into separate files. (PR #714) (PR #714)
+- Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714)
 
 #### Closed issues
 - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index 548200123..e3cfce0fd 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -9,10 +9,10 @@
 # License: MIT License
 
 from . import cvx
-from .cvx import barycenter
 from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize
-from .network_simplex import emd, emd2
-from .barycenter_solvers import (
+from ._network_simplex import emd, emd2
+from ._barycenter_solvers import (
+    barycenter,
     free_support_barycenter,
     generalized_free_support_barycenter,
 )
diff --git a/ot/lp/barycenter_solvers.py b/ot/lp/_barycenter_solvers.py
similarity index 69%
rename from ot/lp/barycenter_solvers.py
rename to ot/lp/_barycenter_solvers.py
index b1411abe1..8b64214d9 100644
--- a/ot/lp/barycenter_solvers.py
+++ b/ot/lp/_barycenter_solvers.py
@@ -1,6 +1,160 @@
+# -*- coding: utf-8 -*-
+"""
+OT Barycenter Solvers
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#         Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
+#
+# License: MIT License
+
 from ..backend import get_backend
 from ..utils import dist
-from .network_simplex import emd
+from ._network_simplex import emd
+
+import numpy as np
+import scipy as sp
+import scipy.sparse as sps
+
+try:
+    import cvxopt  # for cvxopt barycenter solver
+    from cvxopt import solvers, matrix, spmatrix
+except ImportError:
+    cvxopt = False
+
+
+def scipy_sparse_to_spmatrix(A):
+    """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix"""
+    coo = A.tocoo()
+    SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape)
+    return SP
+
+
+def barycenter(A, M, weights=None, verbose=False, log=False, solver="highs-ipm"):
+    r"""Compute the Wasserstein barycenter of distributions A
+
+     The function solves the following optimization problem [16]:
+
+    .. math::
+       \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i)
+
+    where :
+
+    - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn)
+    - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+
+    The linear program is solved using the interior point solver from scipy.optimize.
+    If cvxopt solver if installed it can use cvxopt
+
+    Note that this problem do not scale well (both in memory and computational time).
+
+    Parameters
+    ----------
+    A : np.ndarray (d,n)
+        n training distributions a_i of size d
+    M : np.ndarray (d,d)
+        loss matrix for OT
+    reg : float
+        Regularization term >0
+    weights : np.ndarray (n,)
+        Weights of each histogram a_i on the simplex (barycentric coordinates)
+    verbose : bool, optional
+        Print information along iterations
+    log : bool, optional
+        record log if True
+    solver : string, optional
+        the solver used, default 'interior-point' use the lp solver from
+        scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt.
+
+    Returns
+    -------
+    a : (d,) ndarray
+        Wasserstein barycenter
+    log : dict
+        log dictionary return only if log==True in parameters
+
+
+    References
+    ----------
+
+    .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924.
+
+
+    """
+
+    if weights is None:
+        weights = np.ones(A.shape[1]) / A.shape[1]
+    else:
+        assert len(weights) == A.shape[1]
+
+    n_distributions = A.shape[1]
+    n = A.shape[0]
+
+    n2 = n * n
+    c = np.zeros((0))
+    b_eq1 = np.zeros((0))
+    for i in range(n_distributions):
+        c = np.concatenate((c, M.ravel() * weights[i]))
+        b_eq1 = np.concatenate((b_eq1, A[:, i]))
+    c = np.concatenate((c, np.zeros(n)))
+
+    lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)]
+    #  row constraints
+    A_eq1 = sps.hstack(
+        (sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n)))
+    )
+
+    # columns constraints
+    lst_idiag2 = []
+    lst_eye = []
+    for i in range(n_distributions):
+        if i == 0:
+            lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n)))
+            lst_eye.append(-sps.eye(n))
+        else:
+            lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n)))
+            lst_eye.append(-sps.eye(n - 1, n))
+
+    A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye)))
+    b_eq2 = np.zeros((A_eq2.shape[0]))
+
+    # full problem
+    A_eq = sps.vstack((A_eq1, A_eq2))
+    b_eq = np.concatenate((b_eq1, b_eq2))
+
+    if not cvxopt or solver in ["interior-point", "highs", "highs-ipm", "highs-ds"]:
+        # cvxopt not installed or interior point
+
+        if solver is None:
+            solver = "interior-point"
+
+        options = {"disp": verbose}
+        sol = sp.optimize.linprog(
+            c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options
+        )
+        x = sol.x
+        b = x[-n:]
+
+    else:
+        h = np.zeros((n_distributions * n2 + n))
+        G = -sps.eye(n_distributions * n2 + n)
+
+        sol = solvers.lp(
+            matrix(c),
+            scipy_sparse_to_spmatrix(G),
+            matrix(h),
+            A=scipy_sparse_to_spmatrix(A_eq),
+            b=matrix(b_eq),
+            solver=solver,
+        )
+
+        x = np.array(sol["x"])
+        b = x[-n:].ravel()
+
+    if log:
+        return b, sol
+    else:
+        return b
 
 
 def free_support_barycenter(
diff --git a/ot/lp/network_simplex.py b/ot/lp/_network_simplex.py
similarity index 100%
rename from ot/lp/network_simplex.py
rename to ot/lp/_network_simplex.py
diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index 01f5e5d87..b2269b8b4 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -1,152 +1,12 @@
 # -*- coding: utf-8 -*-
 """
-LP solvers for optimal transport using cvxopt
+(DEPRECATED) LP solvers for optimal transport using cvxopt
 """
 
 # Author: Remi Flamary <remi.flamary@unice.fr>
 #
 # License: MIT License
 
-import numpy as np
-import scipy as sp
-import scipy.sparse as sps
-
-try:
-    import cvxopt
-    from cvxopt import solvers, matrix, spmatrix
-except ImportError:
-    cvxopt = False
-
-
-def scipy_sparse_to_spmatrix(A):
-    """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix"""
-    coo = A.tocoo()
-    SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape)
-    return SP
-
-
-def barycenter(A, M, weights=None, verbose=False, log=False, solver="highs-ipm"):
-    r"""Compute the Wasserstein barycenter of distributions A
-
-     The function solves the following optimization problem [16]:
-
-    .. math::
-       \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i)
-
-    where :
-
-    - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn)
-    - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
-
-    The linear program is solved using the interior point solver from scipy.optimize.
-    If cvxopt solver if installed it can use cvxopt
-
-    Note that this problem do not scale well (both in memory and computational time).
-
-    Parameters
-    ----------
-    A : np.ndarray (d,n)
-        n training distributions a_i of size d
-    M : np.ndarray (d,d)
-        loss matrix   for OT
-    reg : float
-        Regularization term >0
-    weights : np.ndarray (n,)
-        Weights of each histogram a_i on the simplex (barycentric coordinates)
-    verbose : bool, optional
-        Print information along iterations
-    log : bool, optional
-        record log if True
-    solver : string, optional
-        the solver used, default 'interior-point' use the lp solver from
-        scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt.
-
-    Returns
-    -------
-    a : (d,) ndarray
-        Wasserstein barycenter
-    log : dict
-        log dictionary return only if log==True in parameters
-
-
-    References
-    ----------
-
-    .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924.
-
-
-    """
-
-    if weights is None:
-        weights = np.ones(A.shape[1]) / A.shape[1]
-    else:
-        assert len(weights) == A.shape[1]
-
-    n_distributions = A.shape[1]
-    n = A.shape[0]
-
-    n2 = n * n
-    c = np.zeros((0))
-    b_eq1 = np.zeros((0))
-    for i in range(n_distributions):
-        c = np.concatenate((c, M.ravel() * weights[i]))
-        b_eq1 = np.concatenate((b_eq1, A[:, i]))
-    c = np.concatenate((c, np.zeros(n)))
-
-    lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)]
-    #  row constraints
-    A_eq1 = sps.hstack(
-        (sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n)))
-    )
-
-    # columns constraints
-    lst_idiag2 = []
-    lst_eye = []
-    for i in range(n_distributions):
-        if i == 0:
-            lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n)))
-            lst_eye.append(-sps.eye(n))
-        else:
-            lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n)))
-            lst_eye.append(-sps.eye(n - 1, n))
-
-    A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye)))
-    b_eq2 = np.zeros((A_eq2.shape[0]))
-
-    # full problem
-    A_eq = sps.vstack((A_eq1, A_eq2))
-    b_eq = np.concatenate((b_eq1, b_eq2))
-
-    if not cvxopt or solver in ["interior-point", "highs", "highs-ipm", "highs-ds"]:
-        # cvxopt not installed or interior point
-
-        if solver is None:
-            solver = "interior-point"
-
-        options = {"disp": verbose}
-        sol = sp.optimize.linprog(
-            c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options
-        )
-        x = sol.x
-        b = x[-n:]
-
-    else:
-        h = np.zeros((n_distributions * n2 + n))
-        G = -sps.eye(n_distributions * n2 + n)
-
-        sol = solvers.lp(
-            matrix(c),
-            scipy_sparse_to_spmatrix(G),
-            matrix(h),
-            A=scipy_sparse_to_spmatrix(A_eq),
-            b=matrix(b_eq),
-            solver=solver,
-        )
-
-        x = np.array(sol["x"])
-        b = x[-n:].ravel()
-
-    if log:
-        return b, sol
-    else:
-        return b
+print(
+    "The module ot.lp.cvx is deprecated and will be removed in future versions. The function `barycenter` was moved to ot.lp._barycenter_solvers and can be importer via ot.lp."
+)

From 5c35d586ef1b6adf3b5b7d77edb8d90a504904bd Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 20 Jan 2025 15:10:34 +0100
Subject: [PATCH 10/23] fix ref to lp.cvx in test

---
 test/test_ot.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/test_ot.py b/test/test_ot.py
index da0ec746e..f84f8773a 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -395,7 +395,7 @@ def test_generalised_free_support_barycenter_backends(nx):
     np.testing.assert_allclose(Y, nx.to_numpy(Y2))
 
 
-@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
+@pytest.mark.skipif(not ot.lp._barycenter_solvers.cvxopt, reason="No cvxopt available")
 def test_lp_barycenter_cvxopt():
     a1 = np.array([1.0, 0, 0])[:, None]
     a2 = np.array([0, 0, 1.0])[:, None]

From 8ffb06190ce085af685676ac3072335ef5364680 Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 20 Jan 2025 15:23:50 +0100
Subject: [PATCH 11/23] lp.cvx now imports barycenter and gives a
 warnings.warning

---
 ot/lp/cvx.py | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index b2269b8b4..4f7846341 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -7,6 +7,11 @@
 #
 # License: MIT License
 
-print(
-    "The module ot.lp.cvx is deprecated and will be removed in future versions. The function `barycenter` was moved to ot.lp._barycenter_solvers and can be importer via ot.lp."
+import warnings
+
+
+warnings.warn(
+    "The module ot.lp.cvx is deprecated and will be removed in future versions."
+    "The function `barycenter` was moved to ot.lp._barycenter_solvers and can"
+    "be importer via ot.lp."
 )

From 26748eb0602305ed5d115ad1d7a3b43f352ff06c Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 20 Jan 2025 15:28:04 +0100
Subject: [PATCH 12/23] cvx import barycenter

---
 ot/lp/cvx.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py
index 4f7846341..e88d15375 100644
--- a/ot/lp/cvx.py
+++ b/ot/lp/cvx.py
@@ -8,6 +8,10 @@
 # License: MIT License
 
 import warnings
+from ._barycenter_solvers import barycenter
+
+
+__all__ = ["barycenter"]
 
 
 warnings.warn(

From 081e4eb14285a50f23891cb398472d42da70e724 Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 20 Jan 2025 16:52:19 +0100
Subject: [PATCH 13/23] added fixed-point barycenter function to
 ot.lp._barycenter_solvers_

---
 CONTRIBUTORS.md              |  2 +-
 README.md                    |  4 ++
 RELEASES.md                  |  2 +
 ot/lp/_barycenter_solvers.py | 87 ++++++++++++++++++++++++++++++++++++
 4 files changed, 94 insertions(+), 1 deletion(-)

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 6f6a72737..fc1ecc313 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -44,7 +44,7 @@ The contributors to this library are:
 * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW,
   semi-relaxed FGW, quantized FGW, partial FGW)
 * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein
-  Barycenters, GMMOT)
+  Barycenters, GMMOT, Barycenters for General Transport Costs)
 * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
 * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
 * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)
diff --git a/README.md b/README.md
index f64db8f56..9a8e5b371 100644
--- a/README.md
+++ b/README.md
@@ -391,3 +391,7 @@ Artificial Intelligence.
 [72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS).
 
 [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
+
+[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing
+Barycentres of Measures for Generic Transport
+Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
diff --git a/RELEASES.md b/RELEASES.md
index a0474eda0..2a6867484 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -7,6 +7,8 @@
 - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693)
 - Automatic PR labeling and release file update check (PR #704)
 - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714)
+- Implement fixed-point solver for OT barycenters with generic cost functions
+  (generalizes `ot.lp.free_support_barycenter`). (PR #715)
 
 #### Closed issues
 - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py
index 8b64214d9..7e801caa6 100644
--- a/ot/lp/_barycenter_solvers.py
+++ b/ot/lp/_barycenter_solvers.py
@@ -422,3 +422,90 @@ def generalized_free_support_barycenter(
         return Y, log_dict
     else:
         return Y
+
+
+class StoppingCriterionReached(Exception):
+    pass
+
+
+def solve_OT_barycenter_fixed_point(
+    X_init,
+    Y_list,
+    b_list,
+    cost_list,
+    B,
+    max_its=5,
+    stop_threshold=1e-5,
+    log=False,
+):
+    """
+    Solves the OT barycenter problem using the fixed point algorithm, iterating
+    the function B on plans between the current barycentre and the measures.
+
+    Parameters
+    ----------
+    X_init : array-like
+        Array of shape (n, d) representing initial barycentre points.
+    Y_list : list of array-like
+        List of K arrays of measure positions, each of shape (m_k, d_k).
+    b_list : list of array-like
+        List of K arrays of measure weights, each of shape (m_k).
+    cost_list : list of callable
+        List of K cost functions R^(n, d) x R^(m_k, d_k) -> R_+^(n, m_k).
+    B : callable
+        Function from R^d_1 x ... x R^d_K to R^d accepting a list of K arrays of shape (n, d_K), computing the ground barycentre.
+    max_its : int, optional
+        Maximum number of iterations (default is 5).
+    stop_threshold : float, optional
+        If the iterations move less than this, terminate (default is 1e-5).
+    log : bool, optional
+        Whether to return the log dictionary (default is False).
+
+    Returns
+    -------
+    X : array-like
+        Array of shape (n, d) representing barycentre points.
+    log_dict : list of array-like, optional
+        log containing the exit status, list of iterations and list of
+        displacements if log is True.
+    """
+    nx = get_backend(X_init, Y_list[0])
+    K = len(Y_list)
+    n = X_init.shape[0]
+    a = nx.ones(n) / n
+    X_list = [X_init] if log else []  # store the iterations
+    X = X_init
+    dX_list = []  # store the displacement squared norms
+    exit_status = "Unknown"
+
+    try:
+        for _ in range(max_its):
+            pi_list = [  # compute the pairwise transport plans
+                emd(a, b_list[k], cost_list[k](X, Y_list[k])) for k in range(K)
+            ]
+            Y_perm = []
+            for k in range(K):  # compute barycentric projections
+                Y_perm.append(n * pi_list[k] @ Y_list[k])
+            X_next = B(Y_perm)
+
+            if log:
+                X_list.append(X_next)
+
+            # stationary criterion: move less than the threshold
+            dX = nx.sum((X - X_next) ** 2)
+            X = X_next
+
+            if log:
+                dX_list.append(dX)
+
+            if dX < stop_threshold:
+                exit_status = "Stationary Point"
+                raise StoppingCriterionReached
+
+        exit_status = "Max iterations reached"
+        raise StoppingCriterionReached
+
+    except StoppingCriterionReached:
+        if log:
+            return X, {"X_list": X_list, "exit_status": exit_status, "dX_list": dX_list}
+        return X

From 59520198b25a6dd3e2c9f8a403e1846bd77e0995 Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 20 Jan 2025 18:06:32 +0100
Subject: [PATCH 14/23] ot bar demo

---
 .../plot_barycenter_generic_cost.py           | 167 ++++++++++++++++++
 ...lot_generalized_free_support_barycenter.py |   2 +-
 examples/others/plot_GMMOT_plan.py            |   2 +-
 examples/others/plot_GMM_flow.py              |   2 +-
 examples/others/plot_SSNB.py                  |   2 +-
 ot/gmm.py                                     |   4 +-
 ot/lp/__init__.py                             |   3 +-
 ot/lp/_barycenter_solvers.py                  |   2 +-
 ot/mapping.py                                 |   2 +-
 9 files changed, 177 insertions(+), 9 deletions(-)
 create mode 100644 examples/barycenters/plot_barycenter_generic_cost.py

diff --git a/examples/barycenters/plot_barycenter_generic_cost.py b/examples/barycenters/plot_barycenter_generic_cost.py
new file mode 100644
index 000000000..14779fdff
--- /dev/null
+++ b/examples/barycenters/plot_barycenter_generic_cost.py
@@ -0,0 +1,167 @@
+# -*- coding: utf-8 -*-
+"""
+=====================================
+OT Barycenter with Generic Costs Demo
+=====================================
+
+This example illustrates the computation of an Optimal Transport for a ground
+cost that is not a power of a norm. We take the example of ground costs
+:math:`c_k(x, y) = |P_k(x)-y|^2`, where :math:`P_k` is the (non-linear)
+projection onto a circle k. This is an example of the fixed-point barycenter
+solver introduced in [74] which generalises [20].
+
+The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in
+\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over
+:math:`x` with Pytorch.
+
+[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing
+Barycentres of Measures for Generic Transport
+Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
+
+[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein
+Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International
+Conference in Machine Learning
+
+"""
+
+# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 1
+
+# %% Generate data
+import torch
+from torch.optim import Adam
+from ot.utils import dist
+import numpy as np
+from ot.lp import free_support_barycenter_generic_costs
+import matplotlib.pyplot as plt
+
+
+torch.manual_seed(42)
+
+n = 100  # number of points of the of the barycentre
+d = 2  # dimensions of the original measure
+K = 4  # number of measures to barycentre
+m = 50  # number of points of the measures
+b_list = [torch.ones(m) / m] * K  # weights of the 4 measures
+weights = torch.ones(K) / K  # weights for the barycentre
+stop_threshold = 1e-20  # stop threshold for B and for fixed-point algo
+
+
+# map R^2 -> R^2 projection onto circle
+def proj_circle(X, origin, radius):
+    diffs = X - origin[None, :]
+    norms = torch.norm(diffs, dim=1)
+    return origin[None, :] + radius * diffs / norms[:, None]
+
+
+# circles on which to project
+origin1 = torch.tensor([-1.0, -1.0])
+origin2 = torch.tensor([-1.0, 2.0])
+origin3 = torch.tensor([2.0, 2.0])
+origin4 = torch.tensor([2.0, -1.0])
+r = np.sqrt(2)
+P_list = [
+    lambda X: proj_circle(X, origin1, r),
+    lambda X: proj_circle(X, origin2, r),
+    lambda X: proj_circle(X, origin3, r),
+    lambda X: proj_circle(X, origin4, r),
+]
+
+# measures to barycentre are projections of different random circles
+# onto the K circles
+Y_list = []
+for k in range(K):
+    t = torch.rand(m) * 2 * np.pi
+    X_temp = 0.5 * torch.stack([torch.cos(t), torch.sin(t)], axis=1)
+    X_temp = X_temp + torch.tensor([0.5, 0.5])[None, :]
+    Y_list.append(P_list[k](X_temp))
+
+
+# %% Define costs and ground barycenter function
+# cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a
+# (n, n_k) matrix of costs
+def c1(x, y):
+    return dist(P_list[0](x), y)
+
+
+def c2(x, y):
+    return dist(P_list[1](x), y)
+
+
+def c3(x, y):
+    return dist(P_list[2](x), y)
+
+
+def c4(x, y):
+    return dist(P_list[3](x), y)
+
+
+cost_list = [c1, c2, c3, c4]
+
+
+# batched total ground cost function for candidate points x (n, d)
+# for computation of the ground barycenter B with gradient descent
+def C(x, y):
+    """
+    Computes the barycenter cost for candidate points x (n, d) and
+    measure supports y: List(n, d_k).
+    """
+    n = x.shape[0]
+    K = len(y)
+    out = torch.zeros(n)
+    for k in range(K):
+        out += (1 / K) * torch.sum((P_list[k](x) - y[k]) ** 2, axis=1)
+    return out
+
+
+# ground barycenter function
+def B(y, its=150, lr=1, stop_threshold=stop_threshold):
+    """
+    Computes the ground barycenter for measure supports y: List(n, d_k).
+    Output: (n, d) array
+    """
+    x = torch.randn(n, d)
+    x.requires_grad_(True)
+    opt = Adam([x], lr=lr)
+    for _ in range(its):
+        x_prev = x.data.clone()
+        opt.zero_grad()
+        loss = torch.sum(C(x, y))
+        loss.backward()
+        opt.step()
+        diff = torch.sum((x.data - x_prev) ** 2)
+        if diff < stop_threshold:
+            break
+    return x
+
+
+# %% Compute the barycenter measure
+fixed_point_its = 10
+X_init = torch.rand(n, d)
+X_bar = free_support_barycenter_generic_costs(
+    X_init,
+    Y_list,
+    b_list,
+    cost_list,
+    B,
+    max_its=fixed_point_its,
+    stop_threshold=stop_threshold,
+)
+
+# %% Plot Barycenter (Iteration 10)
+alpha = 0.5
+labels = ["circle 1", "circle 2", "circle 3", "circle 4"]
+for Y, label in zip(Y_list, labels):
+    plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label)
+plt.scatter(*(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha)
+plt.axis("equal")
+plt.xlim(-0.3, 1.3)
+plt.ylim(-0.3, 1.3)
+plt.axis("off")
+plt.legend()
+plt.tight_layout()
+
+# %%
diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py
index 5b3572bd4..b21c66f13 100644
--- a/examples/barycenters/plot_generalized_free_support_barycenter.py
+++ b/examples/barycenters/plot_generalized_free_support_barycenter.py
@@ -14,7 +14,7 @@
 
 """
 
-# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu>
+# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
 #
 # License: MIT License
 
diff --git a/examples/others/plot_GMMOT_plan.py b/examples/others/plot_GMMOT_plan.py
index 7742d496e..4964ddd66 100644
--- a/examples/others/plot_GMMOT_plan.py
+++ b/examples/others/plot_GMMOT_plan.py
@@ -16,7 +16,7 @@
 
 """
 
-# Author: Eloi Tanguy <eloi.tanguy@u-paris>
+# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
 #         Remi Flamary <remi.flamary@polytehnique.edu>
 #         Julie Delon <julie.delon@math.cnrs.fr>
 #
diff --git a/examples/others/plot_GMM_flow.py b/examples/others/plot_GMM_flow.py
index beb675755..dc26ff3ce 100644
--- a/examples/others/plot_GMM_flow.py
+++ b/examples/others/plot_GMM_flow.py
@@ -10,7 +10,7 @@
 
 """
 
-# Author: Eloi Tanguy <eloi.tanguy@u-paris>
+# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
 #         Remi Flamary <remi.flamary@polytehnique.edu>
 #         Julie Delon <julie.delon@math.cnrs.fr>
 #
diff --git a/examples/others/plot_SSNB.py b/examples/others/plot_SSNB.py
index fbc343a8a..e167b1ee4 100644
--- a/examples/others/plot_SSNB.py
+++ b/examples/others/plot_SSNB.py
@@ -38,7 +38,7 @@
         2017.
 """
 
-# Author: Eloi Tanguy <eloi.tanguy@u-paris.fr>
+# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
 # License: MIT License
 
 # sphinx_gallery_thumbnail_number = 3
diff --git a/ot/gmm.py b/ot/gmm.py
index 5c7a4c287..d99d4e5db 100644
--- a/ot/gmm.py
+++ b/ot/gmm.py
@@ -3,8 +3,8 @@
 Optimal transport for Gaussian Mixtures
 """
 
-# Author: Eloi Tanguy <eloi.tanguy@u-paris>
-#         Remi Flamary <remi.flamary@polytehnique.edu>
+# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
+#         Remi Flamary <remi.flamary@polytechnique.edu>
 #         Julie Delon <julie.delon@math.cnrs.fr>
 #
 # License: MIT License
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index e3cfce0fd..974679440 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -8,13 +8,13 @@
 #
 # License: MIT License
 
-from . import cvx
 from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize
 from ._network_simplex import emd, emd2
 from ._barycenter_solvers import (
     barycenter,
     free_support_barycenter,
     generalized_free_support_barycenter,
+    free_support_barycenter_generic_costs,
 )
 from ..utils import check_number_threads
 
@@ -46,4 +46,5 @@
     "dmmot_monge_1dgrid_loss",
     "dmmot_monge_1dgrid_optimize",
     "check_number_threads",
+    "free_support_barycenter_generic_costs",
 ]
diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py
index 7e801caa6..e45092caa 100644
--- a/ot/lp/_barycenter_solvers.py
+++ b/ot/lp/_barycenter_solvers.py
@@ -428,7 +428,7 @@ class StoppingCriterionReached(Exception):
     pass
 
 
-def solve_OT_barycenter_fixed_point(
+def free_support_barycenter_generic_costs(
     X_init,
     Y_list,
     b_list,
diff --git a/ot/mapping.py b/ot/mapping.py
index 1ec55cb95..d2a05809c 100644
--- a/ot/mapping.py
+++ b/ot/mapping.py
@@ -7,7 +7,7 @@
     use it you need to explicitly import :mod:`ot.mapping`
 """
 
-# Author: Eloi Tanguy <eloi.tanguy@u-paris.fr>
+# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
 #         Remi Flamary <remi.flamary@unice.fr>
 #
 # License: MIT License

From 3e8421eb6dca94900bbca636a3594ff413cf5925 Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Tue, 21 Jan 2025 11:35:53 +0100
Subject: [PATCH 15/23] ot bar doc

---
 .../plot_barycenter_generic_cost.py           |  10 +-
 ot/lp/_barycenter_solvers.py                  | 100 ++++++++++++++----
 2 files changed, 87 insertions(+), 23 deletions(-)

diff --git a/examples/barycenters/plot_barycenter_generic_cost.py b/examples/barycenters/plot_barycenter_generic_cost.py
index 14779fdff..3e5ba38fe 100644
--- a/examples/barycenters/plot_barycenter_generic_cost.py
+++ b/examples/barycenters/plot_barycenter_generic_cost.py
@@ -6,9 +6,9 @@
 
 This example illustrates the computation of an Optimal Transport for a ground
 cost that is not a power of a norm. We take the example of ground costs
-:math:`c_k(x, y) = |P_k(x)-y|^2`, where :math:`P_k` is the (non-linear)
+:math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear)
 projection onto a circle k. This is an example of the fixed-point barycenter
-solver introduced in [74] which generalises [20].
+solver introduced in [74] which generalises [20] and [43].
 
 The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in
 \mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over
@@ -22,6 +22,8 @@
 Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International
 Conference in Machine Learning
 
+[43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
 """
 
 # Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
@@ -147,8 +149,8 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold):
     b_list,
     cost_list,
     B,
-    max_its=fixed_point_its,
-    stop_threshold=stop_threshold,
+    numItermax=fixed_point_its,
+    stopThr=stop_threshold,
 )
 
 # %% Plot Barycenter (Iteration 10)
diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py
index e45092caa..a04d4de05 100644
--- a/ot/lp/_barycenter_solvers.py
+++ b/ot/lp/_barycenter_solvers.py
@@ -430,33 +430,78 @@ class StoppingCriterionReached(Exception):
 
 def free_support_barycenter_generic_costs(
     X_init,
-    Y_list,
-    b_list,
+    measure_locations,
+    measure_weights,
     cost_list,
     B,
-    max_its=5,
-    stop_threshold=1e-5,
+    numItermax=5,
+    stopThr=1e-5,
     log=False,
 ):
-    """
-    Solves the OT barycenter problem using the fixed point algorithm, iterating
-    the function B on plans between the current barycentre and the measures.
+    r"""
+    Solves the OT barycenter problem for generic costs using the fixed point
+    algorithm, iterating the ground barycenter function B on transport plans
+    between the current barycentre and the measures.
+
+    The problem finds an optimal barycenter support `X` of given size (n, d)
+    (enforced by the initialisation), minimising a sum of pairwise transport
+    costs for the costs :math:`c_k`:
+
+    .. math::
+        \min_{X} \sum_{k=1}^K \mathcal{T}_{c_k}(X, a, Y_k, b_k),
+
+    where:
+
+    - :math:`X` (n, d) is the barycentre support,
+    - :math:`a` (n) is the (fixed) barycentre weights,
+    - :math:`Y_k` (m_k, d_k) is the k-th measure support (`measure_locations[k]`),
+    - :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`),
+    - :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function (which computes the pairwise cost matrix)
+    - :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycentre measure and the k-th measure with respect to the cost :math:`c_k`:
+
+    .. math::
+        \mathcal{T}_{c_k}(X, a, Y_k, b_k) = \min_\pi \quad \langle \pi, c_k(X, Y_k) \rangle_F
+
+        s.t. \ \pi \mathbf{1} = \mathbf{a}
+
+             \pi^T \mathbf{1} = \mathbf{b_k}
+
+             \pi \geq 0
+
+    in other words, :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is `ot.emd2(a, b_k,
+    c_k(X, Y_k))`.
+
+    The algorithm requires a given ground barycentre function `B` which computes
+    a solution of the following minimisation problem given :math:`(y_1, \cdots,
+    y_K) \in \mathbb{R}^{d_1}\times\cdots\times\mathbb{R}^{d_K}`:
+
+    .. math::
+        B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k),
+
+    where :math:`c_k(x, y_k) \in \mathbb{R}_+` is the cost between the points
+    :math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{d_1}\times
+    \cdots\times\mathbb{R}^{d_K} \longrightarrow \mathbb{R}^d` is an input to
+    this function, and for certain costs it can be computed explicitly of
+    through a numerical solver.
+
+    This function implements [74] Algorithm 2, which generalises [20] and [43]
+    to general costs and includes convergence guarantees, including for discrete measures.
 
     Parameters
     ----------
     X_init : array-like
         Array of shape (n, d) representing initial barycentre points.
-    Y_list : list of array-like
+    measure_locations : list of array-like
         List of K arrays of measure positions, each of shape (m_k, d_k).
-    b_list : list of array-like
+    measure_weights : list of array-like
         List of K arrays of measure weights, each of shape (m_k).
     cost_list : list of callable
-        List of K cost functions R^(n, d) x R^(m_k, d_k) -> R_+^(n, m_k).
+        List of K cost functions :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}`.
     B : callable
-        Function from R^d_1 x ... x R^d_K to R^d accepting a list of K arrays of shape (n, d_K), computing the ground barycentre.
-    max_its : int, optional
+        Function from :math:`\mathbb{R}^{d_1} \times\cdots \times \mathbb{R}^{d_K}` to :math:`\mathbb{R}^d` accepting a list of K arrays of shape (n\times d_K), computing the ground barycentre.
+    numItermax : int, optional
         Maximum number of iterations (default is 5).
-    stop_threshold : float, optional
+    stopThr : float, optional
         If the iterations move less than this, terminate (default is 1e-5).
     log : bool, optional
         Whether to return the log dictionary (default is False).
@@ -468,9 +513,25 @@ def free_support_barycenter_generic_costs(
     log_dict : list of array-like, optional
         log containing the exit status, list of iterations and list of
         displacements if log is True.
+
+    .. _references-free-support-barycenter-generic-costs:
+
+    References
+    ----------
+    .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
+
+    .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+
+    .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+
+    See Also
+    --------
+    ot.lp.free_support_barycenter : Free support solver for the case where
+    :math:`c_k(x,y) = \|x-y\|_2^2`.
+    ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear.
     """
-    nx = get_backend(X_init, Y_list[0])
-    K = len(Y_list)
+    nx = get_backend(X_init, measure_locations[0])
+    K = len(measure_locations)
     n = X_init.shape[0]
     a = nx.ones(n) / n
     X_list = [X_init] if log else []  # store the iterations
@@ -479,13 +540,14 @@ def free_support_barycenter_generic_costs(
     exit_status = "Unknown"
 
     try:
-        for _ in range(max_its):
+        for _ in range(numItermax):
             pi_list = [  # compute the pairwise transport plans
-                emd(a, b_list[k], cost_list[k](X, Y_list[k])) for k in range(K)
+                emd(a, measure_weights[k], cost_list[k](X, measure_locations[k]))
+                for k in range(K)
             ]
             Y_perm = []
             for k in range(K):  # compute barycentric projections
-                Y_perm.append(n * pi_list[k] @ Y_list[k])
+                Y_perm.append(n * pi_list[k] @ measure_locations[k])
             X_next = B(Y_perm)
 
             if log:
@@ -498,7 +560,7 @@ def free_support_barycenter_generic_costs(
             if log:
                 dX_list.append(dX)
 
-            if dX < stop_threshold:
+            if dX < stopThr:
                 exit_status = "Stationary Point"
                 raise StoppingCriterionReached
 

From ccf608a19e515b8f3b664792532f6c1b5136ca5f Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Tue, 21 Jan 2025 15:08:00 +0100
Subject: [PATCH 16/23] doc fixes + ot bar coverage

---
 .../plot_barycenter_generic_cost.py           | 46 +++++----
 ot/lp/_barycenter_solvers.py                  | 61 +++++++-----
 test/test_ot.py                               | 95 +++++++++++++++++++
 3 files changed, 161 insertions(+), 41 deletions(-)

diff --git a/examples/barycenters/plot_barycenter_generic_cost.py b/examples/barycenters/plot_barycenter_generic_cost.py
index 3e5ba38fe..e5e5af73a 100644
--- a/examples/barycenters/plot_barycenter_generic_cost.py
+++ b/examples/barycenters/plot_barycenter_generic_cost.py
@@ -10,19 +10,20 @@
 projection onto a circle k. This is an example of the fixed-point barycenter
 solver introduced in [74] which generalises [20] and [43].
 
-The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in
-\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over
+The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in
+\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over
 :math:`x` with Pytorch.
 
-[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing
-Barycentres of Measures for Generic Transport
-Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
+[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
+Barycentres of Measures for Generic Transport Costs.
+arXiv preprint 2501.04016 (2024)
 
-[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein
-Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International
-Conference in Machine Learning
+[20] Cuturi, M. and Doucet, A. (2014) Fast Computation of Wasserstein
+Barycenters. InternationalConference in Machine Learning
 
-[43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+[43] Álvarez-Esteban, Pedro C., et al. A fixed-point approach to barycenters in
+Wasserstein space. Journal of Mathematical Analysis and Applications 441.2
+(2016): 744-762.
 
 """
 
@@ -32,7 +33,8 @@
 
 # sphinx_gallery_thumbnail_number = 1
 
-# %% Generate data
+# %%
+# Generate data
 import torch
 from torch.optim import Adam
 from ot.utils import dist
@@ -43,7 +45,7 @@
 
 torch.manual_seed(42)
 
-n = 100  # number of points of the of the barycentre
+n = 200  # number of points of the of the barycentre
 d = 2  # dimensions of the original measure
 K = 4  # number of measures to barycentre
 m = 50  # number of points of the measures
@@ -82,7 +84,8 @@ def proj_circle(X, origin, radius):
     Y_list.append(P_list[k](X_temp))
 
 
-# %% Define costs and ground barycenter function
+# %%
+# Define costs and ground barycenter function
 # cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a
 # (n, n_k) matrix of costs
 def c1(x, y):
@@ -140,25 +143,30 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold):
     return x
 
 
-# %% Compute the barycenter measure
-fixed_point_its = 10
+# %%
+# Compute the barycenter measure
+fixed_point_its = 3
 X_init = torch.rand(n, d)
 X_bar = free_support_barycenter_generic_costs(
-    X_init,
     Y_list,
     b_list,
+    X_init,
     cost_list,
     B,
     numItermax=fixed_point_its,
     stopThr=stop_threshold,
 )
 
-# %% Plot Barycenter (Iteration 10)
-alpha = 0.5
+# %%
+# Plot Barycenter (Iteration 3)
+alpha = 0.4
+s = 80
 labels = ["circle 1", "circle 2", "circle 3", "circle 4"]
 for Y, label in zip(Y_list, labels):
-    plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label)
-plt.scatter(*(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha)
+    plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s)
+plt.scatter(
+    *(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha, s=s
+)
 plt.axis("equal")
 plt.xlim(-0.3, 1.3)
 plt.ylim(-0.3, 1.3)
diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py
index a04d4de05..445a996df 100644
--- a/ot/lp/_barycenter_solvers.py
+++ b/ot/lp/_barycenter_solvers.py
@@ -429,11 +429,12 @@ class StoppingCriterionReached(Exception):
 
 
 def free_support_barycenter_generic_costs(
-    X_init,
     measure_locations,
     measure_weights,
+    X_init,
     cost_list,
     B,
+    a=None,
     numItermax=5,
     stopThr=1e-5,
     log=False,
@@ -441,7 +442,7 @@ def free_support_barycenter_generic_costs(
     r"""
     Solves the OT barycenter problem for generic costs using the fixed point
     algorithm, iterating the ground barycenter function B on transport plans
-    between the current barycentre and the measures.
+    between the current barycenter and the measures.
 
     The problem finds an optimal barycenter support `X` of given size (n, d)
     (enforced by the initialisation), minimising a sum of pairwise transport
@@ -452,12 +453,13 @@ def free_support_barycenter_generic_costs(
 
     where:
 
-    - :math:`X` (n, d) is the barycentre support,
-    - :math:`a` (n) is the (fixed) barycentre weights,
-    - :math:`Y_k` (m_k, d_k) is the k-th measure support (`measure_locations[k]`),
+    - :math:`X` (n, d) is the barycenter support,
+    - :math:`a` (n) is the (fixed) barycenter weights,
+    - :math:`Y_k` (m_k, d_k) is the k-th measure support
+      (`measure_locations[k]`),
     - :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`),
     - :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function (which computes the pairwise cost matrix)
-    - :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycentre measure and the k-th measure with respect to the cost :math:`c_k`:
+    - :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycenter measure and the k-th measure with respect to the cost :math:`c_k`:
 
     .. math::
         \mathcal{T}_{c_k}(X, a, Y_k, b_k) = \min_\pi \quad \langle \pi, c_k(X, Y_k) \rangle_F
@@ -471,9 +473,10 @@ def free_support_barycenter_generic_costs(
     in other words, :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is `ot.emd2(a, b_k,
     c_k(X, Y_k))`.
 
-    The algorithm requires a given ground barycentre function `B` which computes
-    a solution of the following minimisation problem given :math:`(y_1, \cdots,
-    y_K) \in \mathbb{R}^{d_1}\times\cdots\times\mathbb{R}^{d_K}`:
+    The algorithm requires a given ground barycenter function `B` which computes
+    (broadcasted of `n`) solutions of the following minimisation problem given
+    :math:`(Y_1, \cdots, Y_K) \in
+    \mathbb{R}^{n\times d_1}\times\cdots\times\mathbb{R}^{n\times d_K}`:
 
     .. math::
         B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k),
@@ -482,23 +485,32 @@ def free_support_barycenter_generic_costs(
     :math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{d_1}\times
     \cdots\times\mathbb{R}^{d_K} \longrightarrow \mathbb{R}^d` is an input to
     this function, and for certain costs it can be computed explicitly of
-    through a numerical solver.
+    through a numerical solver. The input function B takes a list of K arrays of
+    shape (n, d_k) and returns an array of shape (n, d).
 
     This function implements [74] Algorithm 2, which generalises [20] and [43]
-    to general costs and includes convergence guarantees, including for discrete measures.
+    to general costs and includes convergence guarantees, including for discrete
+    measures.
 
     Parameters
     ----------
-    X_init : array-like
-        Array of shape (n, d) representing initial barycentre points.
     measure_locations : list of array-like
         List of K arrays of measure positions, each of shape (m_k, d_k).
     measure_weights : list of array-like
         List of K arrays of measure weights, each of shape (m_k).
+    X_init : array-like
+        Array of shape (n, d) representing initial barycenter points.
     cost_list : list of callable
-        List of K cost functions :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}`.
+        List of K cost functions :math:`c_k: \mathbb{R}^{n\times
+        d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times
+        m_k}`.
     B : callable
-        Function from :math:`\mathbb{R}^{d_1} \times\cdots \times \mathbb{R}^{d_K}` to :math:`\mathbb{R}^d` accepting a list of K arrays of shape (n\times d_K), computing the ground barycentre.
+        Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays
+        of shape (n\times d_K), computing the ground barycenters (broadcasted
+        over n).
+    a : array-like, optional
+        Array of shape (n,) representing weights of the barycenter
+        measure.Defaults to uniform.
     numItermax : int, optional
         Maximum number of iterations (default is 5).
     stopThr : float, optional
@@ -509,7 +521,7 @@ def free_support_barycenter_generic_costs(
     Returns
     -------
     X : array-like
-        Array of shape (n, d) representing barycentre points.
+        Array of shape (n, d) representing barycenter points.
     log_dict : list of array-like, optional
         log containing the exit status, list of iterations and list of
         displacements if log is True.
@@ -518,22 +530,27 @@ def free_support_barycenter_generic_costs(
 
     References
     ----------
-    .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
+    .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
+        barycenters of Measures for Generic Transport Costs. arXiv preprint
+        2501.04016 (2024)
 
-    .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
+    .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein
+        barycenters." International Conference on Machine Learning. 2014.
 
-    .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
+    .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to
+        barycenters in Wasserstein space." Journal of Mathematical Analysis and
+        Applications 441.2 (2016): 744-762.
 
     See Also
     --------
-    ot.lp.free_support_barycenter : Free support solver for the case where
-    :math:`c_k(x,y) = \|x-y\|_2^2`.
+    ot.lp.free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|x-y\|_2^2`.
     ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear.
     """
     nx = get_backend(X_init, measure_locations[0])
     K = len(measure_locations)
     n = X_init.shape[0]
-    a = nx.ones(n) / n
+    if a is None:
+        a = nx.ones(n, type_as=X_init) / n
     X_list = [X_init] if log else []  # store the iterations
     X = X_init
     dX_list = []  # store the displacement squared norms
diff --git a/test/test_ot.py b/test/test_ot.py
index f84f8773a..4916d71aa 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -13,6 +13,8 @@
 from ot.datasets import make_1D_gauss as gauss
 from ot.backend import torch, tf
 
+# import ot.lp._barycenter_solvers  # TODO: remove this import
+
 
 def test_emd_dimension_and_mass_mismatch():
     # test emd and emd2 for dimension mismatch
@@ -395,6 +397,99 @@ def test_generalised_free_support_barycenter_backends(nx):
     np.testing.assert_allclose(Y, nx.to_numpy(Y2))
 
 
+def test_free_support_barycenter_generic_costs():
+    measures_locations = [
+        np.array([-1.0]).reshape((1, 1)),
+        np.array([1.0]).reshape((1, 1)),
+    ]
+    measures_weights = [np.array([1.0]), np.array([1.0])]
+
+    X_init = np.array([-12.0]).reshape((1, 1))
+
+    # obvious barycenter location between two Diracs
+    bar_locations = np.array([0.0]).reshape((1, 1))
+
+    def cost(x, y):
+        return ot.dist(x, y)
+
+    cost_list = [cost, cost]
+
+    def B(y):
+        out = 0
+        for yk in y:
+            out += yk / len(y)
+        return out
+
+    X = ot.lp.free_support_barycenter_generic_costs(
+        measures_locations, measures_weights, X_init, cost_list, B
+    )
+
+    np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
+
+    # test with log and specific weights
+    X2, log = ot.lp.free_support_barycenter_generic_costs(
+        measures_locations,
+        measures_weights,
+        X_init,
+        cost_list,
+        B,
+        a=ot.unif(1),
+        log=True,
+    )
+
+    assert "X_list" in log
+    assert "exit_status" in log
+    assert "dX_list" in log
+
+    np.testing.assert_allclose(X, X2, rtol=1e-5, atol=1e-7)
+
+    # test with one iteration for Max Iterations Reached
+    X3, log2 = ot.lp.free_support_barycenter_generic_costs(
+        measures_locations,
+        measures_weights,
+        X_init,
+        cost_list,
+        B,
+        numItermax=1,
+        log=True,
+    )
+    assert log2["exit_status"] == "Max iterations reached"
+
+
+def test_free_support_barycenter_generic_costs_backends(nx):
+    measures_locations = [
+        np.array([-1.0]).reshape((1, 1)),
+        np.array([1.0]).reshape((1, 1)),
+    ]
+    measures_weights = [np.array([1.0]), np.array([1.0])]
+    X_init = np.array([-12.0]).reshape((1, 1))
+
+    def cost(x, y):
+        return ot.dist(x, y)
+
+    cost_list = [cost, cost]
+
+    def B(y):
+        out = 0
+        for yk in y:
+            out += yk / len(y)
+        return out
+
+    X = ot.lp.free_support_barycenter_generic_costs(
+        measures_locations, measures_weights, X_init, cost_list, B
+    )
+
+    measures_locations2 = nx.from_numpy(*measures_locations)
+    measures_weights2 = nx.from_numpy(*measures_weights)
+    X_init2 = nx.from_numpy(X_init)
+
+    X2 = ot.lp.free_support_barycenter_generic_costs(
+        measures_locations2, measures_weights2, X_init2, cost_list, B
+    )
+
+    np.testing.assert_allclose(X, nx.to_numpy(X2))
+
+
 @pytest.mark.skipif(not ot.lp._barycenter_solvers.cvxopt, reason="No cvxopt available")
 def test_lp_barycenter_cvxopt():
     a1 = np.array([1.0, 0, 0])[:, None]

From 37b9c80cad43f3b71768a265a4c57ef57734e06c Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Tue, 21 Jan 2025 15:46:20 +0100
Subject: [PATCH 17/23] python 3.13 in test workflow + added ggmot barycenter
 (WIP)

---
 .github/workflows/build_tests.yml |   2 +-
 ot/gmm.py                         | 114 +++++++++++++++++++++++++++++-
 2 files changed, 114 insertions(+), 2 deletions(-)

diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml
index 4356daa2b..52b4e1d99 100644
--- a/.github/workflows/build_tests.yml
+++ b/.github/workflows/build_tests.yml
@@ -47,7 +47,7 @@ jobs:
     strategy:
       max-parallel: 4
       matrix:
-        python-version: ["3.9", "3.10", "3.11", "3.12"]
+        python-version: ["3.9", "3.10", "3.11", "3.12, "3.13"]
 
     steps:
     - uses: actions/checkout@v4
diff --git a/ot/gmm.py b/ot/gmm.py
index d99d4e5db..bf4e700d3 100644
--- a/ot/gmm.py
+++ b/ot/gmm.py
@@ -13,7 +13,7 @@
 from .lp import emd2, emd
 import numpy as np
 from .utils import dist
-from .gaussian import bures_wasserstein_mapping
+from .gaussian import bures_wasserstein_mapping, bures_wasserstein_barycenter
 
 
 def gaussian_logpdf(x, m, C):
@@ -440,3 +440,115 @@ def Tk0k1(k0, k1):
         ]
     )
     return nx.sum(mat, axis=(0, 1))
+
+
+def solve_gmm_barycenter_fixed_point(
+    means,
+    covs,
+    means_list,
+    covs_list,
+    b_list,
+    weights,
+    max_its=300,
+    log=False,
+    barycentric_proj_method="euclidean",
+):
+    r"""
+    Solves the GMM OT barycenter problem using the fixed point algorithm.
+
+    Parameters
+    ----------
+    means : array-like
+        Initial (n, d) GMM means.
+    covs : array-like
+        Initial (n, d, d) GMM covariances.
+    means_list : list of array-like
+        List of K (m_k, d) GMM means.
+    covs_list : list of array-like
+        List of K (m_k, d, d) GMM covariances.
+    b_list : list of array-like
+        List of K (m_k) arrays of weights.
+    weights : array-like
+        Array (K,) of the barycentre coefficients.
+    max_its : int, optional
+        Maximum number of iterations (default is 300).
+    log : bool, optional
+        Whether to return the list of iterations (default is False).
+    barycentric_proj_method : str, optional
+        Method to project the barycentre weights: 'euclidean' (default) or 'bures'.
+
+    Returns
+    -------
+    means : array-like
+        (n, d) barycentre GMM means.
+    covs : array-like
+        (n, d, d) barycentre GMM covariances.
+    log_dict : dict, optional
+        Dictionary containing the list of iterations if log is True.
+    """
+    nx = get_backend(means, covs[0], means_list[0], covs_list[0])
+    K = len(means_list)
+    n = means.shape[0]
+    d = means.shape[1]
+    means_its = [means.copy()]
+    covs_its = [covs.copy()]
+    a = nx.ones(n, type_as=means) / n
+
+    for _ in range(max_its):
+        pi_list = [
+            gmm_ot_plan(means, means_list[k], covs, covs_list[k], a, b_list[k])
+            for k in range(K)
+        ]
+
+        means_selection, covs_selection = None, None
+        # in the euclidean case, the selection of Gaussians from each K sources
+        # comes from a  barycentric projection is a convex combination of the
+        # selected means and  covariances, which can be computed without a
+        # for loop on i
+        if barycentric_proj_method == "euclidean":
+            means_selection = nx.zeros((n, K, d), type_as=means)
+            covs_selection = nx.zeros((n, K, d, d), type_as=means)
+
+            for k in range(K):
+                means_selection[:, k, :] = n * pi_list[k] @ means_list[k]
+                covs_selection[:, k, :, :] = (
+                    nx.einsum("ij,jab->iab", pi_list[k], covs_list[k]) * n
+                )
+
+        # each component i of the barycentre will be a Bures barycentre of the
+        # selected components of the K GMMs. In the 'bures' barycentric
+        # projection option, the selected components are also Bures barycentres.
+        for i in range(n):
+            # means_slice_i (K, d) is the selected means, each comes from a
+            # Gaussian barycentre along the disintegration of pi_k at i
+            # covs_slice_i (K, d, d) are the selected covariances
+            means_selection_i = []
+            covs_selection_i = []
+
+            # use previous computation (convex combination)
+            if barycentric_proj_method == "euclidean":
+                means_selection_i = means_selection[i]
+                covs_selection_i = covs_selection[i]
+
+            # compute Bures barycentre of the selected components
+            elif barycentric_proj_method == "bures":
+                w = (1 / a[i]) * pi_list[k][i, :]
+                for k in range(K):
+                    m, C = bures_wasserstein_barycenter(means_list[k], covs_list[k], w)
+                    means_selection_i.append(m)
+                    covs_selection_i.append(C)
+
+            else:
+                raise ValueError("Unknown barycentric_proj_method")
+
+            means[i], covs[i] = bures_wasserstein_barycenter(
+                means_selection_i, covs_selection_i, weights
+            )
+
+        if log:
+            means_its.append(means.copy())
+            covs_its.append(covs.copy())
+
+    if log:
+        return means, covs, {"means_its": means_its, "covs_its": covs_its}
+    return means, covs

From a20d3f0656e0e64c0dc4b7a74e94cc9a407c9bd9 Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Tue, 21 Jan 2025 16:06:43 +0100
Subject: [PATCH 18/23] fixed github action file

---
 .github/workflows/build_tests.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml
index 52b4e1d99..a8e27b323 100644
--- a/.github/workflows/build_tests.yml
+++ b/.github/workflows/build_tests.yml
@@ -47,7 +47,7 @@ jobs:
     strategy:
       max-parallel: 4
       matrix:
-        python-version: ["3.9", "3.10", "3.11", "3.12, "3.13"]
+        python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
 
     steps:
     - uses: actions/checkout@v4

From 0b6217b00188f4f01bc80f5de7ba838e039cb39e Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Tue, 21 Jan 2025 17:19:56 +0100
Subject: [PATCH 19/23] ot bar doc + test coverage

---
 .github/workflows/build_tests.yml |   2 +-
 ot/gmm.py                         | 103 ++++++++++++++++++++----------
 ot/lp/_barycenter_solvers.py      |   4 +-
 test/test_gmm.py                  |  54 +++++++++++++++-
 4 files changed, 124 insertions(+), 39 deletions(-)

diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml
index a8e27b323..4356daa2b 100644
--- a/.github/workflows/build_tests.yml
+++ b/.github/workflows/build_tests.yml
@@ -47,7 +47,7 @@ jobs:
     strategy:
       max-parallel: 4
       matrix:
-        python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
+        python-version: ["3.9", "3.10", "3.11", "3.12"]
 
     steps:
     - uses: actions/checkout@v4
diff --git a/ot/gmm.py b/ot/gmm.py
index bf4e700d3..214720d1e 100644
--- a/ot/gmm.py
+++ b/ot/gmm.py
@@ -442,36 +442,50 @@ def Tk0k1(k0, k1):
     return nx.sum(mat, axis=(0, 1))
 
 
-def solve_gmm_barycenter_fixed_point(
-    means,
-    covs,
+def gmm_barycenter_fixed_point(
     means_list,
     covs_list,
-    b_list,
+    w_list,
+    means_init,
+    covs_init,
     weights,
-    max_its=300,
+    w_bar=None,
+    iterations=100,
     log=False,
     barycentric_proj_method="euclidean",
 ):
     r"""
-    Solves the GMM OT barycenter problem using the fixed point algorithm.
+    Solves the Gaussian Mixture Model OT barycenter problem (defined in [69])
+    using the fixed point algorithm (proposed in [74]). The
+    weights of the barycenter are not optimized, and stay the same as the input
+    `w_list` or are initialized to uniform.
+
+    The algorithm uses barycentric projections of GMM-OT plans, and these can be
+    computed either through Bures Barycenters (slow but accurate,
+    barycentric_proj_method='bures') or by convex combination (fast,
+    barycentric_proj_method='euclidean', default).
+
+    This is a special case of the generic free-support barycenter solver
+    `ot.lp.free_support_barycenter_generic_costs`.
 
     Parameters
     ----------
-    means : array-like
-        Initial (n, d) GMM means.
-    covs : array-like
-        Initial (n, d, d) GMM covariances.
     means_list : list of array-like
         List of K (m_k, d) GMM means.
     covs_list : list of array-like
         List of K (m_k, d, d) GMM covariances.
-    b_list : list of array-like
+    w_list : list of array-like
         List of K (m_k) arrays of weights.
+    means_init : array-like
+        Initial (n, d) GMM means.
+    covs_init : array-like
+        Initial (n, d, d) GMM covariances.
     weights : array-like
         Array (K,) of the barycentre coefficients.
-    max_its : int, optional
-        Maximum number of iterations (default is 300).
+    w_bar : array-like, optional
+        Initial weights (n) of the barycentre GMM. If None, initialized to uniform.
+    iterations : int, optional
+        Number of iterations (default is 100).
     log : bool, optional
         Whether to return the list of iterations (default is False).
     barycentric_proj_method : str, optional
@@ -485,30 +499,46 @@ def solve_gmm_barycenter_fixed_point(
         (n, d, d) barycentre GMM covariances.
     log_dict : dict, optional
         Dictionary containing the list of iterations if log is True.
+
+    References
+    ----------
+    .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
+
+    .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024)
+
+    See Also
+    --------
+    ot.lp.free_support_barycenter_generic_costs : Compute barycenter of measures for generic transport costs.
     """
-    nx = get_backend(means, covs[0], means_list[0], covs_list[0])
+    nx = get_backend(
+        means_init, covs_init, means_list[0], covs_list[0], w_list[0], weights
+    )
     K = len(means_list)
-    n = means.shape[0]
-    d = means.shape[1]
-    means_its = [means.copy()]
-    covs_its = [covs.copy()]
-    a = nx.ones(n, type_as=means) / n
+    n = means_init.shape[0]
+    d = means_init.shape[1]
+    means_its = [nx.copy(means_init)]
+    covs_its = [nx.copy(covs_init)]
+    means, covs = means_init, covs_init
+
+    if w_bar is None:
+        w_bar = nx.ones(n, type_as=means) / n
 
-    for _ in range(max_its):
+    for _ in range(iterations):
         pi_list = [
-            gmm_ot_plan(means, means_list[k], covs, covs_list[k], a, b_list[k])
+            gmm_ot_plan(means, means_list[k], covs, covs_list[k], w_bar, w_list[k])
             for k in range(K)
         ]
 
+        # filled in the euclidean case
         means_selection, covs_selection = None, None
+
         # in the euclidean case, the selection of Gaussians from each K sources
-        # comes from a  barycentric projection is a convex combination of the
-        # selected means and  covariances, which can be computed without a
-        # for loop on i
+        # comes from a barycentric projection: it is a convex combination of the
+        # selected means and covariances, which can be computed without a
+        # for loop on i = 0, ..., n -1
         if barycentric_proj_method == "euclidean":
             means_selection = nx.zeros((n, K, d), type_as=means)
             covs_selection = nx.zeros((n, K, d, d), type_as=means)
-
             for k in range(K):
                 means_selection[:, k, :] = n * pi_list[k] @ means_list[k]
                 covs_selection[:, k, :, :] = (
@@ -519,24 +549,27 @@ def solve_gmm_barycenter_fixed_point(
         # selected components of the K GMMs. In the 'bures' barycentric
         # projection option, the selected components are also Bures barycentres.
         for i in range(n):
-            # means_slice_i (K, d) is the selected means, each comes from a
+            # means_selection_i (K, d) is the selected means, each comes from a
             # Gaussian barycentre along the disintegration of pi_k at i
-            # covs_slice_i (K, d, d) are the selected covariances
-            means_selection_i = []
-            covs_selection_i = []
+            # covs_selection_i (K, d, d) are the selected covariances
+            means_selection_i = None
+            covs_selection_i = None
 
             # use previous computation (convex combination)
             if barycentric_proj_method == "euclidean":
                 means_selection_i = means_selection[i]
                 covs_selection_i = covs_selection[i]
 
-            # compute Bures barycentre of the selected components
+            # compute Bures barycentre of certain components to get the
+            # selection at i
             elif barycentric_proj_method == "bures":
-                w = (1 / a[i]) * pi_list[k][i, :]
+                means_selection_i = nx.zeros((K, d), type_as=means)
+                covs_selection_i = nx.zeros((K, d, d), type_as=means)
                 for k in range(K):
+                    w = (1 / w_bar[i]) * pi_list[k][i, :]
                     m, C = bures_wasserstein_barycenter(means_list[k], covs_list[k], w)
-                    means_selection_i.append(m)
-                    covs_selection_i.append(C)
+                    means_selection_i[k] = m
+                    covs_selection_i[k] = C
 
             else:
                 raise ValueError("Unknown barycentric_proj_method")
@@ -546,8 +579,8 @@ def solve_gmm_barycenter_fixed_point(
             )
 
         if log:
-            means_its.append(means.copy())
-            covs_its.append(covs.copy())
+            means_its.append(nx.copy(means))
+            covs_its.append(nx.copy(covs))
 
     if log:
         return means, covs, {"means_its": means_its, "covs_its": covs_its}
diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py
index 445a996df..9589121bd 100644
--- a/ot/lp/_barycenter_solvers.py
+++ b/ot/lp/_barycenter_solvers.py
@@ -435,7 +435,7 @@ def free_support_barycenter_generic_costs(
     cost_list,
     B,
     a=None,
-    numItermax=5,
+    numItermax=100,
     stopThr=1e-5,
     log=False,
 ):
@@ -512,7 +512,7 @@ def free_support_barycenter_generic_costs(
         Array of shape (n,) representing weights of the barycenter
         measure.Defaults to uniform.
     numItermax : int, optional
-        Maximum number of iterations (default is 5).
+        Maximum number of iterations (default is 100).
     stopThr : float, optional
         If the iterations move less than this, terminate (default is 1e-5).
     log : bool, optional
diff --git a/test/test_gmm.py b/test/test_gmm.py
index 5f1a92965..629a68d57 100644
--- a/test/test_gmm.py
+++ b/test/test_gmm.py
@@ -1,6 +1,6 @@
 """Tests for module gaussian"""
 
-# Author: Eloi Tanguy <eloi.tanguy@u-paris>
+# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
 #         Remi Flamary <remi.flamary@polytehnique.edu>
 #         Julie Delon <julie.delon@math.cnrs.fr>
 #
@@ -17,6 +17,7 @@
     gmm_ot_plan,
     gmm_ot_apply_map,
     gmm_ot_plan_density,
+    gmm_barycenter_fixed_point,
 )
 
 try:
@@ -193,3 +194,54 @@ def test_gmm_ot_plan_density(nx):
 
     with pytest.raises(AssertionError):
         gmm_ot_plan_density(x[:, 1:], y, m_s, m_t, C_s, C_t, w_s, w_t)
+
+
+@pytest.skip_backend("tf")  # skips because of array assignment
+@pytest.skip_backend("jax")
+def test_gmm_barycenter_fixed_point(nx):
+    m_s, m_t, C_s, C_t, w_s, w_t = get_gmms(nx)
+    means_list = [m_s, m_t]
+    covs_list = [C_s, C_t]
+    w_list = [w_s, w_t]
+    n_iter = 3
+    n = m_s.shape[0]  # number of components of barycenter
+    means_init = m_s
+    covs_init = C_s
+    weights = nx.ones(2, type_as=m_s) / 2  # barycenter coefficients
+
+    # with euclidean barycentric projections
+    means, covs = gmm_barycenter_fixed_point(
+        means_list, covs_list, w_list, means_init, covs_init, weights, iterations=n_iter
+    )
+
+    # with bures barycentric projections and assigned weights to uniform
+    means_bures_proj, covs_bures_proj, log = gmm_barycenter_fixed_point(
+        means_list,
+        covs_list,
+        w_list,
+        means_init,
+        covs_init,
+        weights,
+        iterations=n_iter,
+        w_bar=nx.ones(n, type_as=m_s) / n,
+        barycentric_proj_method="bures",
+        log=True,
+    )
+
+    assert "means_its" in log
+    assert "covs_its" in log
+
+    assert np.allclose(means, means_bures_proj, atol=1e-6)
+    assert np.allclose(covs, covs_bures_proj, atol=1e-6)
+
+    with pytest.raises(ValueError):
+        gmm_barycenter_fixed_point(
+            means_list,
+            covs_list,
+            w_list,
+            means_init,
+            covs_init,
+            weights,
+            iterations=n_iter,
+            barycentric_proj_method="unknown",
+        )

From 21bf86b944f2ce6cb71f381718c50095ca485850 Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Tue, 21 Jan 2025 17:52:11 +0100
Subject: [PATCH 20/23] examples: ot bar with projections onto circles + gmm ot
 bar

---
 README.md                                     |   4 +-
 ...t_free_support_barycenter_generic_cost.py} |   8 +-
 examples/barycenters/plot_gmm_barycenter.py   | 144 ++++++++++++++++++
 3 files changed, 149 insertions(+), 7 deletions(-)
 rename examples/barycenters/{plot_barycenter_generic_cost.py => plot_free_support_barycenter_generic_cost.py} (96%)
 create mode 100644 examples/barycenters/plot_gmm_barycenter.py

diff --git a/README.md b/README.md
index 9a8e5b371..9266c99c6 100644
--- a/README.md
+++ b/README.md
@@ -392,6 +392,4 @@ Artificial Intelligence.
 
 [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
 
-[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing
-Barycentres of Measures for Generic Transport
-Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
+[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
diff --git a/examples/barycenters/plot_barycenter_generic_cost.py b/examples/barycenters/plot_free_support_barycenter_generic_cost.py
similarity index 96%
rename from examples/barycenters/plot_barycenter_generic_cost.py
rename to examples/barycenters/plot_free_support_barycenter_generic_cost.py
index e5e5af73a..55a75b157 100644
--- a/examples/barycenters/plot_barycenter_generic_cost.py
+++ b/examples/barycenters/plot_free_support_barycenter_generic_cost.py
@@ -4,8 +4,8 @@
 OT Barycenter with Generic Costs Demo
 =====================================
 
-This example illustrates the computation of an Optimal Transport for a ground
-cost that is not a power of a norm. We take the example of ground costs
+This example illustrates the computation of an Optimal Transport Barycenter for
+a ground cost that is not a power of a norm. We take the example of ground costs
 :math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear)
 projection onto a circle k. This is an example of the fixed-point barycenter
 solver introduced in [74] which generalises [20] and [43].
@@ -15,8 +15,8 @@
 :math:`x` with Pytorch.
 
 [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
-Barycentres of Measures for Generic Transport Costs.
-arXiv preprint 2501.04016 (2024)
+Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016
+(2024)
 
 [20] Cuturi, M. and Doucet, A. (2014) Fast Computation of Wasserstein
 Barycenters. InternationalConference in Machine Learning
diff --git a/examples/barycenters/plot_gmm_barycenter.py b/examples/barycenters/plot_gmm_barycenter.py
new file mode 100644
index 000000000..07792c0dd
--- /dev/null
+++ b/examples/barycenters/plot_gmm_barycenter.py
@@ -0,0 +1,144 @@
+# -*- coding: utf-8 -*-
+"""
+=====================================
+Gaussian Mixture Model OT Barycenters
+=====================================
+
+This example illustrates the computation of a barycenter between Gaussian
+Mixtures in the sense of GMM-OT [69]. This computation is done using the
+fixed-point method for OT barycenters with generic costs [74], for which POT
+provides a general solver, and a specific GMM solver. Note that this is a
+'free-support' method, implying that the number of components of the barycenter
+GMM and their weights are fixed.
+
+The idea behind GMM-OT barycenters is to see the GMMs as discrete measures over
+the space of Gaussian distributions :math:`\mathcal{N}` (or equivalently the
+Bures-Wasserstein manifold), and to compute barycenters with respect to the
+2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a
+gaussian mixture is a finite combination of Diracs on specific gaussians, and
+two mixtures are compared with the 2-Wasserstein distance on this space with
+ground cost the squared Bures distance between gaussians.
+
+[69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space
+of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
+
+[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
+Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016
+(2024)
+
+"""
+
+# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
+#
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 1
+
+# %%
+# Generate data
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.patches import Ellipse
+import ot
+from ot.gmm import gmm_barycenter_fixed_point
+
+
+K = 3  # number of GMMs
+d = 2  # dimension
+n = 6  # number of components of the desired barycenter
+
+
+def get_random_gmm(K, d, seed=0, min_cov_eig=1, cov_scale=1e-2):
+    rng = np.random.RandomState(seed=seed)
+    means = rng.randn(K, d)
+    P = rng.randn(K, d, d) * cov_scale
+    # C[k] = P[k] @ P[k]^T + min_cov_eig * I
+    covariances = np.einsum("kab,kcb->kac", P, P)
+    covariances += min_cov_eig * np.array([np.eye(d) for _ in range(K)])
+    weights = rng.random(K)
+    weights /= np.sum(weights)
+    return means, covariances, weights
+
+
+m_list = [5, 6, 7]  # number of components in each GMM
+offsets = [np.array([-3, 0]), np.array([2, 0]), np.array([0, 4])]
+means_list = []  # list of means for each GMM
+covs_list = []  # list of covariances for each GMM
+w_list = []  # list of weights for each GMM
+
+# generate GMMs
+for k in range(K):
+    means, covs, b = get_random_gmm(
+        m_list[k], d, seed=k, min_cov_eig=0.25, cov_scale=0.5
+    )
+    means = means / 2 + offsets[k][None, :]
+    means_list.append(means)
+    covs_list.append(covs)
+    w_list.append(b)
+
+# %%
+# Compute the barycenter using the fixed-point method
+init_means, init_covs, _ = get_random_gmm(n, d, seed=0)
+weights = ot.unif(K)  # barycenter coefficients
+means_bar, covs_bar, log = gmm_barycenter_fixed_point(
+    means_list,
+    covs_list,
+    w_list,
+    init_means,
+    init_covs,
+    weights,
+    iterations=3,
+    log=True,
+)
+
+
+# %%
+# Define plotting functions
+
+
+# draw a covariance ellipse
+def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5, ax=None):
+    def eigsorted(cov):
+        vals, vecs = np.linalg.eigh(cov)
+        order = vals.argsort()[::-1].copy()
+        return vals[order], vecs[:, order]
+
+    vals, vecs = eigsorted(C)
+    theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
+    w, h = 2 * nstd * np.sqrt(vals)
+    ell = Ellipse(
+        xy=(mu[0], mu[1]),
+        width=w,
+        height=h,
+        alpha=alpha,
+        angle=theta,
+        facecolor=color,
+        edgecolor=color,
+        label=label,
+        fill=True,
+    )
+    if ax is None:
+        ax = plt.gca()
+    ax.add_artist(ell)
+
+
+# draw a gmm as a set of ellipses with weights shown in alpha value
+def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None):
+    for k in range(ms.shape[0]):
+        draw_cov(
+            ms[k], Cs[k], color, label if k == 0 else None, nstd, alpha * ws[k], ax=ax
+        )
+
+
+# %%
+# Plot the results
+fig, ax = plt.subplots(figsize=(6, 6))
+axis = [-4, 4, -2, 6]
+ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16)
+for k in range(K):
+    draw_gmm(means_list[k], covs_list[k], w_list[k], color="C0", ax=ax)
+draw_gmm(means_bar, covs_bar, ot.unif(n), color="C1", ax=ax)
+ax.axis(axis)
+ax.axis("off")
+
+# %%

From 0820e513e3415a1aa03abb6cd6a9acb27a7096d9 Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Tue, 21 Jan 2025 18:03:59 +0100
Subject: [PATCH 21/23] releases + readme + docs update

---
 README.md                                   |  2 ++
 RELEASES.md                                 |  3 ++-
 examples/barycenters/plot_gmm_barycenter.py |  2 +-
 ot/lp/_barycenter_solvers.py                | 27 ++++++++++++---------
 4 files changed, 20 insertions(+), 14 deletions(-)

diff --git a/README.md b/README.md
index 9266c99c6..48a4a87fe 100644
--- a/README.md
+++ b/README.md
@@ -55,6 +55,8 @@ POT provides the following generic OT solvers (links to examples):
 * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and
 [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71].
 * Fused unbalanced Gromov-Wasserstein [70].
+* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [74]
+* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 74]
 
 POT provides the following Machine Learning related solvers:
 
diff --git a/RELEASES.md b/RELEASES.md
index ff8496bef..add09378c 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -8,7 +8,8 @@
 - Automatic PR labeling and release file update check (PR #704)
 - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714)
 - Implement fixed-point solver for OT barycenters with generic cost functions
-  (generalizes `ot.lp.free_support_barycenter`). (PR #715)
+  (generalizes `ot.lp.free_support_barycenter`), with example. (PR #715)
+- Implement fixed-point solver for barycenters between GMMs (PR #715), with example.
 
 #### Closed issues
 - Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
diff --git a/examples/barycenters/plot_gmm_barycenter.py b/examples/barycenters/plot_gmm_barycenter.py
index 07792c0dd..84d0ee638 100644
--- a/examples/barycenters/plot_gmm_barycenter.py
+++ b/examples/barycenters/plot_gmm_barycenter.py
@@ -16,7 +16,7 @@
 Bures-Wasserstein manifold), and to compute barycenters with respect to the
 2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a
 gaussian mixture is a finite combination of Diracs on specific gaussians, and
-two mixtures are compared with the 2-Wasserstein distance on this space with
+two mixtures are compared with the 2-Wasserstein distance on this space, where
 ground cost the squared Bures distance between gaussians.
 
 [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space
diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py
index 9589121bd..5e53c66d2 100644
--- a/ot/lp/_barycenter_solvers.py
+++ b/ot/lp/_barycenter_solvers.py
@@ -458,7 +458,9 @@ def free_support_barycenter_generic_costs(
     - :math:`Y_k` (m_k, d_k) is the k-th measure support
       (`measure_locations[k]`),
     - :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`),
-    - :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function (which computes the pairwise cost matrix)
+    - :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k}
+         \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function
+         (which computes the pairwise cost matrix)
     - :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycenter measure and the k-th measure with respect to the cost :math:`c_k`:
 
     .. math::
@@ -475,18 +477,19 @@ def free_support_barycenter_generic_costs(
 
     The algorithm requires a given ground barycenter function `B` which computes
     (broadcasted of `n`) solutions of the following minimisation problem given
-    :math:`(Y_1, \cdots, Y_K) \in
-    \mathbb{R}^{n\times d_1}\times\cdots\times\mathbb{R}^{n\times d_K}`:
+    :math:`(Y_1, \cdots, Y_K) \in \mathbb{R}^{n\times
+    d_1}\times\cdots\times\mathbb{R}^{n\times d_K}`:
 
     .. math::
         B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k),
 
     where :math:`c_k(x, y_k) \in \mathbb{R}_+` is the cost between the points
-    :math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{d_1}\times
-    \cdots\times\mathbb{R}^{d_K} \longrightarrow \mathbb{R}^d` is an input to
-    this function, and for certain costs it can be computed explicitly of
-    through a numerical solver. The input function B takes a list of K arrays of
-    shape (n, d_k) and returns an array of shape (n, d).
+    :math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{n\times
+    d_1}\times \cdots\times\mathbb{R}^{n\times d_K} \longrightarrow
+    \mathbb{R}^{n\times d}` is an input to this function, and for certain costs
+    it can be computed explicitly of through a numerical solver. The input
+    function B takes a list of K arrays of shape (n, d_k) and returns an array
+    of shape (n, d).
 
     This function implements [74] Algorithm 2, which generalises [20] and [43]
     to general costs and includes convergence guarantees, including for discrete
@@ -526,8 +529,6 @@ def free_support_barycenter_generic_costs(
         log containing the exit status, list of iterations and list of
         displacements if log is True.
 
-    .. _references-free-support-barycenter-generic-costs:
-
     References
     ----------
     .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
@@ -543,8 +544,10 @@ def free_support_barycenter_generic_costs(
 
     See Also
     --------
-    ot.lp.free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|x-y\|_2^2`.
-    ot.lp.generalized_free_support_barycenter : Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear.
+    ot.lp.free_support_barycenter : Free support solver for the case where
+    :math:`c_k(x,y) = \|x-y\|_2^2`. ot.lp.generalized_free_support_barycenter :
+    Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2`
+    with :math:`P_k` linear.
     """
     nx = get_backend(X_init, measure_locations[0])
     K = len(measure_locations)

From 6bd4af8b9c280798c2d5d8b617d611340589fdc7 Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Wed, 12 Mar 2025 15:15:45 +0100
Subject: [PATCH 22/23] ref fix

---
 README.md                                                   | 4 ++--
 .../plot_free_support_barycenter_generic_cost.py            | 4 ++--
 examples/barycenters/plot_gmm_barycenter.py                 | 6 ++----
 ot/gmm.py                                                   | 4 ++--
 ot/lp/_barycenter_solvers.py                                | 4 ++--
 5 files changed, 10 insertions(+), 12 deletions(-)

diff --git a/README.md b/README.md
index 124c5d809..a7f1ff830 100644
--- a/README.md
+++ b/README.md
@@ -54,8 +54,8 @@ POT provides the following generic OT solvers (links to examples):
 * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and
 [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71].
 * Fused unbalanced Gromov-Wasserstein [70].
-* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [74]
-* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 74]
+* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [76]
+* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 76]
 
 POT provides the following Machine Learning related solvers:
 
diff --git a/examples/barycenters/plot_free_support_barycenter_generic_cost.py b/examples/barycenters/plot_free_support_barycenter_generic_cost.py
index 55a75b157..47e2c9236 100644
--- a/examples/barycenters/plot_free_support_barycenter_generic_cost.py
+++ b/examples/barycenters/plot_free_support_barycenter_generic_cost.py
@@ -8,13 +8,13 @@
 a ground cost that is not a power of a norm. We take the example of ground costs
 :math:`c_k(x, y) = \|P_k(x)-y\|_2^2`, where :math:`P_k` is the (non-linear)
 projection onto a circle k. This is an example of the fixed-point barycenter
-solver introduced in [74] which generalises [20] and [43].
+solver introduced in [76] which generalises [20] and [43].
 
 The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in
 \mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over
 :math:`x` with Pytorch.
 
-[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
+[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
 Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016
 (2024)
 
diff --git a/examples/barycenters/plot_gmm_barycenter.py b/examples/barycenters/plot_gmm_barycenter.py
index 84d0ee638..f379a9914 100644
--- a/examples/barycenters/plot_gmm_barycenter.py
+++ b/examples/barycenters/plot_gmm_barycenter.py
@@ -6,7 +6,7 @@
 
 This example illustrates the computation of a barycenter between Gaussian
 Mixtures in the sense of GMM-OT [69]. This computation is done using the
-fixed-point method for OT barycenters with generic costs [74], for which POT
+fixed-point method for OT barycenters with generic costs [76], for which POT
 provides a general solver, and a specific GMM solver. Note that this is a
 'free-support' method, implying that the number of components of the barycenter
 GMM and their weights are fixed.
@@ -22,7 +22,7 @@
 [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space
 of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
 
-[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
+[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
 Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016
 (2024)
 
@@ -140,5 +140,3 @@ def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None):
 draw_gmm(means_bar, covs_bar, ot.unif(n), color="C1", ax=ax)
 ax.axis(axis)
 ax.axis("off")
-
-# %%
diff --git a/ot/gmm.py b/ot/gmm.py
index 214720d1e..a065c73b0 100644
--- a/ot/gmm.py
+++ b/ot/gmm.py
@@ -456,7 +456,7 @@ def gmm_barycenter_fixed_point(
 ):
     r"""
     Solves the Gaussian Mixture Model OT barycenter problem (defined in [69])
-    using the fixed point algorithm (proposed in [74]). The
+    using the fixed point algorithm (proposed in [76]). The
     weights of the barycenter are not optimized, and stay the same as the input
     `w_list` or are initialized to uniform.
 
@@ -504,7 +504,7 @@ def gmm_barycenter_fixed_point(
     ----------
     .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
 
-    .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024)
+    .. [76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024)
 
     See Also
     --------
diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py
index 61b4fce49..f803d23db 100644
--- a/ot/lp/_barycenter_solvers.py
+++ b/ot/lp/_barycenter_solvers.py
@@ -495,7 +495,7 @@ def free_support_barycenter_generic_costs(
     function B takes a list of K arrays of shape (n, d_k) and returns an array
     of shape (n, d).
 
-    This function implements [74] Algorithm 2, which generalises [20] and [43]
+    This function implements [76] Algorithm 2, which generalises [20] and [43]
     to general costs and includes convergence guarantees, including for discrete
     measures.
 
@@ -535,7 +535,7 @@ def free_support_barycenter_generic_costs(
 
     References
     ----------
-    .. [74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
+    .. [76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
         barycenters of Measures for Generic Transport Costs. arXiv preprint
         2501.04016 (2024)
 

From 51722bf65f1be26a453f5602f07d1ecb4752c896 Mon Sep 17 00:00:00 2001
From: eloitanguy <tanguy.eloi@gmail.com>
Date: Mon, 17 Mar 2025 19:54:14 +0100
Subject: [PATCH 23/23] implementation comments

---
 ot/lp/_barycenter_solvers.py | 133 +++++++++++++++++++++++------------
 test/test_ot.py              |  99 +++++++++++++++++++++++---
 2 files changed, 178 insertions(+), 54 deletions(-)

diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py
index f803d23db..725af26c4 100644
--- a/ot/lp/_barycenter_solvers.py
+++ b/ot/lp/_barycenter_solvers.py
@@ -199,14 +199,12 @@ def free_support_barycenter(
     measures_weights : list of N (k_i,) array-like
         Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
         representing the weights of each discrete input measure
-
     X_init : (k,d) array-like
         Initialization of the support locations (on `k` atoms) of the barycenter
     b : (k,) array-like
         Initialization of the weights of the barycenter (non-negatives, sum to 1)
     weights : (N,) array-like
         Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
-
     numItermax : int, optional
         Max number of iterations
     stopThr : float, optional
@@ -219,13 +217,11 @@ def free_support_barycenter(
         If compiled with OpenMP, chooses the number of threads to parallelize.
         "max" selects the highest number possible.
 
-
     Returns
     -------
     X : (k,d) array-like
         Support locations (on k atoms) of the barycenter
 
-
     .. _references-free-support-barycenter:
     References
     ----------
@@ -428,20 +424,20 @@ def generalized_free_support_barycenter(
         return Y
 
 
-class StoppingCriterionReached(Exception):
-    pass
-
-
 def free_support_barycenter_generic_costs(
     measure_locations,
     measure_weights,
     X_init,
     cost_list,
-    B,
+    ground_bary=None,
     a=None,
     numItermax=100,
     stopThr=1e-5,
     log=False,
+    ground_bary_lr=1e-2,
+    ground_bary_numItermax=100,
+    ground_bary_stopThr=1e-5,
+    ground_bary_solver="SGD",
 ):
     r"""
     Solves the OT barycenter problem for generic costs using the fixed point
@@ -507,14 +503,15 @@ def free_support_barycenter_generic_costs(
         List of K arrays of measure weights, each of shape (m_k).
     X_init : array-like
         Array of shape (n, d) representing initial barycenter points.
-    cost_list : list of callable
+    cost_list : list of callable or callable
         List of K cost functions :math:`c_k: \mathbb{R}^{n\times
         d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times
-        m_k}`.
-    B : callable
+        m_k}`. If cost_list is a single callable, the same cost is used K times.
+    ground_bary : callable or None, optional
         Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays
         of shape (n\times d_K), computing the ground barycenters (broadcasted
-        over n).
+        over n). If not provided, done with Adam on PyTorch (requires PyTorch
+        backend)
     a : array-like, optional
         Array of shape (n,) representing weights of the barycenter
         measure.Defaults to uniform.
@@ -524,6 +521,16 @@ def free_support_barycenter_generic_costs(
         If the iterations move less than this, terminate (default is 1e-5).
     log : bool, optional
         Whether to return the log dictionary (default is False).
+    ground_bary_lr : float, optional
+        Learning rate for the ground barycenter solver (if auto is used).
+    ground_bary_numItermax : int, optional
+        Maximum number of iterations for the ground barycenter solver (if auto
+        is used).
+    ground_bary_stopThr : float, optional
+        Stop threshold for the ground barycenter solver (if auto is used).
+    ground_bary_solver : str, optional
+        Solver for auto ground bary solver (torch SGD or Adam). Default is
+        "SGD".
 
     Returns
     -------
@@ -549,49 +556,85 @@ def free_support_barycenter_generic_costs(
     See Also
     --------
     ot.lp.free_support_barycenter : Free support solver for the case where
-    :math:`c_k(x,y) = \|x-y\|_2^2`. ot.lp.generalized_free_support_barycenter :
-    Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2`
-    with :math:`P_k` linear.
+    :math:`c_k(x,y) = \lambda_k\|x-y\|_2^2`.
+    ot.lp.generalized_free_support_barycenter : Free support solver for the case
+    where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear.
     """
     nx = get_backend(X_init, measure_locations[0])
     K = len(measure_locations)
     n = X_init.shape[0]
     if a is None:
         a = nx.ones(n, type_as=X_init) / n
+    if callable(cost_list):  # use the given cost for all K pairs
+        cost_list = [cost_list] * K
+    auto_ground_bary = False
+
+    if ground_bary is None:
+        auto_ground_bary = True
+        assert str(nx) == "torch", (
+            f"Backend {str(nx)} is not compatible with ground_bary=None, it"
+            "must be provided if not using PyTorch backend"
+        )
+        try:
+            import torch
+            from torch.optim import Adam, SGD
+
+            def ground_bary(y, x_init):
+                x = x_init.clone().detach().requires_grad_(True)
+                solver = Adam if ground_bary_solver == "Adam" else SGD
+                opt = solver([x], lr=ground_bary_lr)
+                for _ in range(ground_bary_numItermax):
+                    x_prev = x.data.clone()
+                    opt.zero_grad()
+                    # inefficient cost computation but compatible
+                    # with the choice of cost_list[k] giving the cost matrix
+                    loss = torch.sum(
+                        torch.stack(
+                            [torch.diag(cost_list[k](x, y[k])) for k in range(K)]
+                        )
+                    )
+                    loss.backward()
+                    opt.step()
+                    diff = torch.sum((x.data - x_prev) ** 2)
+                    if diff < ground_bary_stopThr:
+                        break
+                return x.detach()
+
+        except ImportError:
+            raise ImportError("PyTorch is required to use ground_bary=None")
+
     X_list = [X_init] if log else []  # store the iterations
     X = X_init
     dX_list = []  # store the displacement squared norms
-    exit_status = "Unknown"
-
-    try:
-        for _ in range(numItermax):
-            pi_list = [  # compute the pairwise transport plans
-                emd(a, measure_weights[k], cost_list[k](X, measure_locations[k]))
-                for k in range(K)
-            ]
-            Y_perm = []
-            for k in range(K):  # compute barycentric projections
-                Y_perm.append(n * pi_list[k] @ measure_locations[k])
-            X_next = B(Y_perm)
-
-            if log:
-                X_list.append(X_next)
+    exit_status = "Max iterations reached"
+
+    for _ in range(numItermax):
+        pi_list = [  # compute the pairwise transport plans
+            emd(a, measure_weights[k], cost_list[k](X, measure_locations[k]))
+            for k in range(K)
+        ]
+        Y_perm = []
+        for k in range(K):  # compute barycentric projections
+            Y_perm.append(n * pi_list[k] @ measure_locations[k])
+        if auto_ground_bary:  # use previous position as initialization
+            X_next = ground_bary(Y_perm, X)
+        else:
+            X_next = ground_bary(Y_perm)
 
-            # stationary criterion: move less than the threshold
-            dX = nx.sum((X - X_next) ** 2)
-            X = X_next
+        if log:
+            X_list.append(X_next)
 
-            if log:
-                dX_list.append(dX)
+        # stationary criterion: move less than the threshold
+        dX = nx.sum((X - X_next) ** 2)
+        X = X_next
 
-            if dX < stopThr:
-                exit_status = "Stationary Point"
-                raise StoppingCriterionReached
+        if log:
+            dX_list.append(dX)
 
-        exit_status = "Max iterations reached"
-        raise StoppingCriterionReached
+        if dX < stopThr:
+            exit_status = "Stationary Point"
+            break
 
-    except StoppingCriterionReached:
-        if log:
-            return X, {"X_list": X_list, "exit_status": exit_status, "dX_list": dX_list}
-        return X
+    if log:
+        return X, {"X_list": X_list, "exit_status": exit_status, "dX_list": dX_list}
+    return X
diff --git a/test/test_ot.py b/test/test_ot.py
index 4916d71aa..22612fa4a 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -13,8 +13,6 @@
 from ot.datasets import make_1D_gauss as gauss
 from ot.backend import torch, tf
 
-# import ot.lp._barycenter_solvers  # TODO: remove this import
-
 
 def test_emd_dimension_and_mass_mismatch():
     # test emd and emd2 for dimension mismatch
@@ -414,14 +412,14 @@ def cost(x, y):
 
     cost_list = [cost, cost]
 
-    def B(y):
+    def ground_bary(y):
         out = 0
         for yk in y:
             out += yk / len(y)
         return out
 
     X = ot.lp.free_support_barycenter_generic_costs(
-        measures_locations, measures_weights, X_init, cost_list, B
+        measures_locations, measures_weights, X_init, cost_list, ground_bary
     )
 
     np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
@@ -432,7 +430,7 @@ def B(y):
         measures_weights,
         X_init,
         cost_list,
-        B,
+        ground_bary,
         a=ot.unif(1),
         log=True,
     )
@@ -449,12 +447,95 @@ def B(y):
         measures_weights,
         X_init,
         cost_list,
-        B,
+        ground_bary,
         numItermax=1,
         log=True,
     )
     assert log2["exit_status"] == "Max iterations reached"
 
+    # test with a single callable cost
+    X3, log3 = ot.lp.free_support_barycenter_generic_costs(
+        measures_locations,
+        measures_weights,
+        X_init,
+        cost,
+        ground_bary,
+        numItermax=1,
+        log=True,
+    )
+
+    # test with no ground_bary but in numpy: requires pytorch backend
+    with pytest.raises(AssertionError):
+        ot.lp.free_support_barycenter_generic_costs(
+            measures_locations,
+            measures_weights,
+            X_init,
+            cost_list,
+            ground_bary=None,
+            numItermax=1,
+        )
+
+
+@pytest.mark.skipif(not torch, reason="No torch available")
+def test_free_support_barycenter_generic_costs_auto_ground_bary():
+    measures_locations = [
+        torch.tensor([1.0]).reshape((1, 1)),
+        torch.tensor([2.0]).reshape((1, 1)),
+    ]
+    measures_weights = [torch.tensor([1.0]), torch.tensor([1.0])]
+
+    X_init = torch.tensor([1.2]).reshape((1, 1))
+
+    def cost(x, y):
+        return ot.dist(x, y)
+
+    cost_list = [cost, cost]
+
+    def ground_bary(y):
+        out = 0
+        for yk in y:
+            out += yk / len(y)
+        return out
+
+    X = ot.lp.free_support_barycenter_generic_costs(
+        measures_locations,
+        measures_weights,
+        X_init,
+        cost_list,
+        ground_bary,
+        numItermax=1,
+    )
+
+    X2, log2 = ot.lp.free_support_barycenter_generic_costs(
+        measures_locations,
+        measures_weights,
+        X_init,
+        cost_list,
+        ground_bary=None,
+        ground_bary_lr=1e-2,
+        ground_bary_stopThr=1e-20,
+        ground_bary_numItermax=50,
+        numItermax=10,
+        log=True,
+    )
+
+    np.testing.assert_allclose(X2.numpy(), X.numpy(), rtol=1e-4, atol=1e-4)
+
+    X3 = ot.lp.free_support_barycenter_generic_costs(
+        measures_locations,
+        measures_weights,
+        X_init,
+        cost_list,
+        ground_bary=None,
+        ground_bary_lr=1e-2,
+        ground_bary_stopThr=1e-20,
+        ground_bary_numItermax=50,
+        numItermax=10,
+        ground_bary_solver="Adam",
+    )
+
+    np.testing.assert_allclose(X2.numpy(), X3.numpy(), rtol=1e-3, atol=1e-3)
+
 
 def test_free_support_barycenter_generic_costs_backends(nx):
     measures_locations = [
@@ -469,14 +550,14 @@ def cost(x, y):
 
     cost_list = [cost, cost]
 
-    def B(y):
+    def ground_bary(y):
         out = 0
         for yk in y:
             out += yk / len(y)
         return out
 
     X = ot.lp.free_support_barycenter_generic_costs(
-        measures_locations, measures_weights, X_init, cost_list, B
+        measures_locations, measures_weights, X_init, cost_list, ground_bary
     )
 
     measures_locations2 = nx.from_numpy(*measures_locations)
@@ -484,7 +565,7 @@ def B(y):
     X_init2 = nx.from_numpy(X_init)
 
     X2 = ot.lp.free_support_barycenter_generic_costs(
-        measures_locations2, measures_weights2, X_init2, cost_list, B
+        measures_locations2, measures_weights2, X_init2, cost_list, ground_bary
     )
 
     np.testing.assert_allclose(X, nx.to_numpy(X2))