diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 652d2fed3..b105c6c81 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -215,8 +215,7 @@ def __call__( key = jax.random.PRNGKey(0) key1, key2 = jax.random.split(key, 2) - # consider converting problem first if using low-rank solver - if self.is_low_rank and prob._is_low_rank_convertible: + if prob._is_low_rank_convertible: prob = prob.to_low_rank() if init is None: