Skip to content

Commit

Permalink
Replace lap.lapjv() with scipy.optimize.linear_sum_assignment() (#3267)
Browse files Browse the repository at this point in the history
* Replace lap.lapjv() with scipy.optimize.linear_sum_assignment()

* Fix silently failing test
  • Loading branch information
fritzo committed Sep 20, 2023
1 parent 99633ae commit 3b6cae3
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 28 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ opt_einsum>=2.3.2
pyro-api>=0.1.1
tqdm>=4.36
funsor[torch]
setuptools<60
setuptools
10 changes: 3 additions & 7 deletions pyro/distributions/one_one_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,15 @@ def sample(self, sample_shape=torch.Size()):
def mode(self):
"""
Computes a maximum probability matching.
.. note:: This requires the `lap <https://pypi.org/project/lap/>`_
package and runs on CPU.
"""
return maximum_weight_matching(self.logits)


@torch.no_grad()
def maximum_weight_matching(logits):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ImportWarning)
import lap
from scipy.optimize import linear_sum_assignment

cost = -logits.cpu()
value = lap.lapjv(cost.numpy())[1]
value = linear_sum_assignment(cost.numpy())[1]
value = torch.tensor(value, dtype=torch.long, device=logits.device)
return value
10 changes: 3 additions & 7 deletions pyro/distributions/one_two_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,6 @@ def sample(self, sample_shape=torch.Size()):
def mode(self):
"""
Computes a maximum probability matching.
.. note:: This requires the `lap <https://pypi.org/project/lap/>`_
package and runs on CPU.
"""
return maximum_weight_matching(self.logits)

Expand Down Expand Up @@ -204,12 +201,11 @@ def enumerate_one_two_matchings(num_destins):

@torch.no_grad()
def maximum_weight_matching(logits):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ImportWarning)
import lap
from scipy.optimize import linear_sum_assignment

cost = -logits.cpu()
cost = torch.cat([cost, cost], dim=-1) # Duplicate destinations.
value = lap.lapjv(cost.numpy())[1]
value = linear_sum_assignment(cost.numpy())[1]
value = torch.tensor(value, dtype=torch.long, device=logits.device)
value %= logits.size(1)
return value
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"scikit-learn",
"seaborn>=0.11.0",
"wget",
"lap", # Requires setuptools<60
"scipy>=1.1",
# 'biopython>=1.54',
# 'scanpy>=1.4', # Requires HDF5
# 'scvi>=0.6', # Requires loopy and other fragile packages
Expand Down Expand Up @@ -115,7 +115,6 @@
"pytest-xdist",
"pytest>=5.0",
"ruff",
"scipy>=1.1",
],
"profile": ["prettytable", "pytest-benchmark", "snakeviz"],
"dev": EXTRAS_REQUIRE
Expand All @@ -131,7 +130,6 @@
"pytest-xdist",
"pytest>=5.0",
"ruff",
"scipy>=1.1",
"sphinx",
"sphinx_rtd_theme",
"yapf",
Expand Down
4 changes: 3 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,14 @@ def assert_tensors_equal(a, b, prec=0.0, msg=""):
assert a.size() == b.size(), msg
if isinstance(prec, numbers.Number) and prec == 0:
assert (a == b).all(), msg
return
if a.numel() == 0 and b.numel() == 0:
return
b = b.type_as(a)
b = b.cuda(device=a.get_device()) if a.is_cuda else b.cpu()
if not a.dtype.is_floating_point:
return (a == b).all()
assert (a == b).all(), msg
return
# check that NaNs are in the same locations
nan_mask = a != a
assert torch.equal(nan_mask, b != b), msg
Expand Down
4 changes: 1 addition & 3 deletions tests/distributions/test_one_one_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,19 @@ def test_grad_hard(num_nodes):
@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str)
@pytest.mark.parametrize("num_nodes", [1, 2, 3, 4, 5, 6, 7, 8])
def test_mode(num_nodes, dtype):
pytest.importorskip("lap")
logits = torch.randn(num_nodes, num_nodes, dtype=dtype) * 10
d = dist.OneOneMatching(logits)
values = d.enumerate_support()
i = d.log_prob(values).max(0).indices.item()
expected = values[i]
actual = d.mode()
assert_equal(actual, expected)
assert (actual == expected).all()


@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str)
@pytest.mark.parametrize("num_nodes", [3, 5, 8, 13, 100, 1000])
def test_mode_smoke(num_nodes, dtype):
pytest.importorskip("lap")
logits = torch.randn(num_nodes, num_nodes, dtype=dtype) * 10
d = dist.OneOneMatching(logits)
value = d.mode()
Expand All @@ -136,7 +135,6 @@ def test_mode_smoke(num_nodes, dtype):
@pytest.mark.parametrize("num_nodes", [2, 3, 4, 5, 6])
@pytest.mark.parametrize("bp_iters", [None, BP_ITERS], ids=["exact", "bp"])
def test_sample(num_nodes, dtype, bp_iters):
pytest.importorskip("lap")
logits = torch.randn(num_nodes, num_nodes, dtype=dtype) * 10
d = dist.OneOneMatching(logits, bp_iters=bp_iters)

Expand Down
6 changes: 0 additions & 6 deletions tests/distributions/test_one_two_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def test_grad_phylo(num_leaves):
@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str)
@pytest.mark.parametrize("num_destins", [1, 2, 3, 4, 5])
def test_mode_full(num_destins, dtype):
pytest.importorskip("lap")
num_sources = 2 * num_destins
logits = torch.randn(num_sources, num_destins, dtype=dtype) * 10
d = dist.OneTwoMatching(logits)
Expand All @@ -189,7 +188,6 @@ def test_mode_full(num_destins, dtype):
@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str)
@pytest.mark.parametrize("num_leaves", [2, 3, 4, 5, 6])
def test_mode_phylo(num_leaves, dtype):
pytest.importorskip("lap")
logits, times = random_phylo_logits(num_leaves, dtype)
d = dist.OneTwoMatching(logits)
values = d.enumerate_support()
Expand All @@ -202,7 +200,6 @@ def test_mode_phylo(num_leaves, dtype):
@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str)
@pytest.mark.parametrize("num_destins", [3, 5, 8, 13, 100, 1000])
def test_mode_full_smoke(num_destins, dtype):
pytest.importorskip("lap")
num_sources = 2 * num_destins
logits = torch.randn(num_sources, num_destins, dtype=dtype) * 10
d = dist.OneTwoMatching(logits)
Expand All @@ -213,7 +210,6 @@ def test_mode_full_smoke(num_destins, dtype):
@pytest.mark.parametrize("dtype", [torch.float, torch.double], ids=str)
@pytest.mark.parametrize("num_leaves", [3, 5, 8, 13, 100, 1000])
def test_mode_phylo_smoke(num_leaves, dtype):
pytest.importorskip("lap")
logits, times = random_phylo_logits(num_leaves, dtype)
d = dist.OneTwoMatching(logits, bp_iters=10)
value = d.mode()
Expand All @@ -224,7 +220,6 @@ def test_mode_phylo_smoke(num_leaves, dtype):
@pytest.mark.parametrize("num_destins", [2, 3, 4])
@pytest.mark.parametrize("bp_iters", [None, BP_ITERS], ids=["exact", "bp"])
def test_sample_full(num_destins, dtype, bp_iters):
pytest.importorskip("lap")
num_sources = 2 * num_destins
logits = torch.randn(num_sources, num_destins, dtype=dtype) * 10
d = dist.OneTwoMatching(logits, bp_iters=bp_iters)
Expand All @@ -251,7 +246,6 @@ def test_sample_full(num_destins, dtype, bp_iters):
@pytest.mark.parametrize("num_leaves", [3, 4, 5])
@pytest.mark.parametrize("bp_iters", [None, BP_ITERS], ids=["exact", "bp"])
def test_sample_phylo(num_leaves, dtype, bp_iters):
pytest.importorskip("lap")
logits, times = random_phylo_logits(num_leaves, dtype)
num_sources, num_destins = logits.shape
d = dist.OneTwoMatching(logits, bp_iters=bp_iters)
Expand Down

0 comments on commit 3b6cae3

Please sign in to comment.