diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index f1b1ce75ee..809c96c002 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -19,11 +19,18 @@ In development children on an interval. Previously an error was thrown when some operation building the trees was attempted. (:user:`jeromekelleher`, :pr:`709`). +- The TableCollection object no longer implements the iterator protocol. + Previously ``list(tables)`` returned a sequence of (table_name, table_instance) + tuples. This has been replaced with the more intuitive and future-proof + TableCollection.name_map and TreeSequence.tables_dict attributes, which + perform the same function. (:user:`jeromekelleher`, :issue:`500`, + :pr:`694`) + **New features** - New methods to perform set operations on TableCollections and TreeSequences. ``TableCollection.subset`` subsets and reorders table collections by nodes - (:user:`mufernando`, :user:`petrelharp`, :pr:`663`, :pr:`690`). + (:user:`mufernando`, :user:`petrelharp`, :pr:`663`, :pr:`690`). ``TableCollection.union`` forms the node-wise union of two table collections (:user:`mufernando`, :user:`petrelharp`, :issue:`381` :pr:`623`). diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 84ce5421d6..92f45154ad 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -1328,7 +1328,7 @@ def test_at(self): def test_sequence_iteration(self): for ts in get_example_tree_sequences(): - for table_name, _ in ts.tables: + for table_name in ts.tables_dict.keys(): sequence = getattr(ts, table_name)() length = getattr(ts, "num_" + table_name) # Test __iter__ diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 9ae67e386a..992b63b3ac 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -25,7 +25,6 @@ between simulations and the tree sequence. """ import io -import itertools import json import math import pickle @@ -1590,15 +1589,11 @@ class TestTablesToTreeSequence(unittest.TestCase): Tests for the .tree_sequence() method of a TableCollection. """ - random_seed = 42 - def test_round_trip(self): - a = msprime.simulate(5, mutation_rate=1, random_seed=self.random_seed) + a = msprime.simulate(5, mutation_rate=1, random_seed=42) tables = a.dump_tables() b = tables.tree_sequence() - self.assertTrue( - all(a == b for a, b in zip(a.tables, b.tables) if a[0] != "provenances") - ) + self.assertEqual(a.tables, b.tables) class TestMutationTimeErrors(unittest.TestCase): @@ -2040,19 +2035,25 @@ def test_roundtrip_dict(self): t2 = tskit.TableCollection.fromdict(t1.asdict()) self.assertEqual(t1, t2) - def test_iter(self): - def test_iter(table_collection): - table_names = [ - attr_name - for attr_name in sorted(dir(table_collection)) - if isinstance(getattr(table_collection, attr_name), tskit.BaseTable) - ] - for n in table_names: - yield n, getattr(table_collection, n) - + def test_name_map(self): ts = msprime.simulate(10, mutation_rate=1, random_seed=1) - for t1, t2 in itertools.zip_longest(test_iter(ts.tables), ts.tables): - self.assertEquals(t1, t2) + tables = ts.tables + td1 = { + "individuals": tables.individuals, + "populations": tables.populations, + "nodes": tables.nodes, + "edges": tables.edges, + "sites": tables.sites, + "mutations": tables.mutations, + "migrations": tables.migrations, + "provenances": tables.provenances, + } + td2 = tables.name_map + self.assertIsInstance(td2, dict) + self.assertEqual(set(td1.keys()), set(td2.keys())) + for name in td2.keys(): + self.assertEqual(td1[name], td2[name]) + self.assertEqual(td1, td2) def test_equals_empty(self): self.assertEqual(tskit.TableCollection(), tskit.TableCollection()) diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 6424fc6a94..39be7b800a 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -53,8 +53,12 @@ def tables_equal(table_collection_1, table_collection_2, compare_provenances=Tru """ Check equality of tables, ignoring provenance timestamps (but not contents) """ - for (_, table_1), (_, table_2) in zip(table_collection_1, table_collection_2): - if isinstance(table_1, tskit.ProvenanceTable): + tc_dict_1 = table_collection_1.name_map + tc_dict_2 = table_collection_2.name_map + for table_name in tc_dict_1.keys(): + table_1 = tc_dict_1[table_name] + table_2 = tc_dict_2[table_name] + if table_name == "provenances": if compare_provenances: if np.any(table_1.record != table_2.record): return False diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 00b6b22d5a..54b317d2d5 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -1966,7 +1966,7 @@ def asdict(self): """ Returns a dictionary representation of this TableCollection. - Note: the semantics of this method changed at tskit 1.0.0. Previously a + Note: the semantics of this method changed at tskit 0.1.0. Previously a map of table names to the tables themselves was returned. """ return { @@ -1984,6 +1984,24 @@ def asdict(self): "provenances": self.provenances.asdict(), } + @property + def name_map(self): + """ + Returns a dictionary mapping table names to the corresponding + table instances. For example, the returned dictionary will contain the + key "edges" that maps to an :class:`.EdgeTable` instance. + """ + return { + "edges": self.edges, + "individuals": self.individuals, + "migrations": self.migrations, + "mutations": self.mutations, + "nodes": self.nodes, + "populations": self.populations, + "provenances": self.provenances, + "sites": self.sites, + } + def __banner(self, title): width = 60 line = "#" * width @@ -1992,20 +2010,6 @@ def __banner(self, title): title_line += "#" return line + "\n" + title_line + "\n" + line + "\n" - def __iter__(self): - """ - Iterate over all the tables in this TableCollection, ordered by table name - (i.e. deterministically), returning a tuple of (table_name, table_object) - """ - yield "edges", self.edges - yield "individuals", self.individuals - yield "migrations", self.migrations - yield "mutations", self.mutations - yield "nodes", self.nodes - yield "populations", self.populations - yield "provenances", self.provenances - yield "sites", self.sites - def __str__(self): s = self.__banner("Individuals") s += str(self.individuals) + "\n" diff --git a/python/tskit/trees.py b/python/tskit/trees.py index e0a0617f87..c9efeb9696 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -2940,6 +2940,15 @@ def dump(self, path, zlib_compression=False): # Convert the path to str to allow us use Pathlib inputs self._ll_tree_sequence.dump(str(path)) + @property + def tables_dict(self): + """ + Returns a dictionary mapping names to tables in the + underlying :class:`.TableCollection`. Equivalent to calling + ``ts.tables.name_map``. + """ + return self.tables.name_map + @property def tables(self): """