diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index f6540591bc..dffaa018aa 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -12,6 +12,8 @@ - The JSON metadata codec now interprets the empty string as an empty object. This means that applying a schema to an existing table will no longer necessitate modifying the existing rows. (:user:`benjeffery`, :issue:`2064`, :pr:`2104`) +- ``tree.mrca`` now takes 2 or more arguments. + (:user:`savitakartik`, :issue:`1340`, :pr:`2121`) ---------------------- diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 4832e46031..9f29f19b30 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -851,6 +851,40 @@ def test_minlex_postorder_multiple_roots(self): assert tree_orders == expected_result +class TestMRCA: + t = tskit.Tree.generate_balanced(3) + # 4 + # ┏━┻┓ + # ┃ 3 + # ┃ ┏┻┓ + # 0 1 2 + + def test_two_or_more_args(self): + assert self.t.mrca(2, 1) == 3 + assert self.t.mrca(0, 1, 2) == 4 + + def test_less_than_two_args(self): + with pytest.raises(ValueError): + self.t.mrca(1) + + def test_no_args(self): + with pytest.raises(ValueError): + self.t.mrca() + + def test_same_args(self): + assert self.t.mrca(0, 0, 0, 0) == 0 + + def test_different_tree_levels(self): + assert self.t.mrca(0, 3) == 4 + + def test_out_of_bounds_args(self): + with pytest.raises(ValueError): + self.t.mrca(0, 6) + + def test_virtual_root_arg(self): + assert self.t.mrca(0, 5) == 5 + + class TestMRCACalculator: """ Class to test the Schieber-Vishkin algorithm. diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 51c5f38d97..336efce07b 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1017,16 +1017,17 @@ def get_mrca(self, u, v): # Deprecated alias for mrca return self.mrca(u, v) - def mrca(self, u, v): + def mrca(self, *args): """ Returns the most recent common ancestor of the specified nodes. - :param int u: The first node. - :param int v: The second node. - :return: The most recent common ancestor of u and v. + :param int `*args`: input node IDs, must be at least 2. + :return: The most recent common ancestor of input nodes. :rtype: int """ - return self._ll_tree.get_mrca(u, v) + if len(args) < 2: + raise ValueError("Must supply at least two arguments") + return functools.reduce(self._ll_tree.get_mrca, args) def get_tmrca(self, u, v): # Deprecated alias for tmrca