Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,18 @@ In development
children on an interval. Previously an error was thrown when some operation
building the trees was attempted. (:user:`jeromekelleher`, :pr:`709`).

- The TableCollection object no longer implements the iterator protocol.
Previously ``list(tables)`` returned a sequence of (table_name, table_instance)
tuples. This has been replaced with the more intuitive and future-proof
TableCollection.name_map and TreeSequence.tables_dict attributes, which
perform the same function. (:user:`jeromekelleher`, :issue:`500`,
:pr:`694`)

**New features**

- New methods to perform set operations on TableCollections and TreeSequences.
``TableCollection.subset`` subsets and reorders table collections by nodes
(:user:`mufernando`, :user:`petrelharp`, :pr:`663`, :pr:`690`).
(:user:`mufernando`, :user:`petrelharp`, :pr:`663`, :pr:`690`).
``TableCollection.union`` forms the node-wise union of two table collections
(:user:`mufernando`, :user:`petrelharp`, :issue:`381` :pr:`623`).

Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,7 @@ def test_at(self):

def test_sequence_iteration(self):
for ts in get_example_tree_sequences():
for table_name, _ in ts.tables:
for table_name in ts.tables_dict.keys():
sequence = getattr(ts, table_name)()
length = getattr(ts, "num_" + table_name)
# Test __iter__
Expand Down
39 changes: 20 additions & 19 deletions python/tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
between simulations and the tree sequence.
"""
import io
import itertools
import json
import math
import pickle
Expand Down Expand Up @@ -1590,15 +1589,11 @@ class TestTablesToTreeSequence(unittest.TestCase):
Tests for the .tree_sequence() method of a TableCollection.
"""

random_seed = 42

def test_round_trip(self):
a = msprime.simulate(5, mutation_rate=1, random_seed=self.random_seed)
a = msprime.simulate(5, mutation_rate=1, random_seed=42)
tables = a.dump_tables()
b = tables.tree_sequence()
self.assertTrue(
all(a == b for a, b in zip(a.tables, b.tables) if a[0] != "provenances")
)
self.assertEqual(a.tables, b.tables)


class TestMutationTimeErrors(unittest.TestCase):
Expand Down Expand Up @@ -2040,19 +2035,25 @@ def test_roundtrip_dict(self):
t2 = tskit.TableCollection.fromdict(t1.asdict())
self.assertEqual(t1, t2)

def test_iter(self):
def test_iter(table_collection):
table_names = [
attr_name
for attr_name in sorted(dir(table_collection))
if isinstance(getattr(table_collection, attr_name), tskit.BaseTable)
]
for n in table_names:
yield n, getattr(table_collection, n)

def test_name_map(self):
ts = msprime.simulate(10, mutation_rate=1, random_seed=1)
for t1, t2 in itertools.zip_longest(test_iter(ts.tables), ts.tables):
self.assertEquals(t1, t2)
tables = ts.tables
td1 = {
"individuals": tables.individuals,
"populations": tables.populations,
"nodes": tables.nodes,
"edges": tables.edges,
"sites": tables.sites,
"mutations": tables.mutations,
"migrations": tables.migrations,
"provenances": tables.provenances,
}
td2 = tables.name_map
self.assertIsInstance(td2, dict)
self.assertEqual(set(td1.keys()), set(td2.keys()))
for name in td2.keys():
self.assertEqual(td1[name], td2[name])
self.assertEqual(td1, td2)

def test_equals_empty(self):
self.assertEqual(tskit.TableCollection(), tskit.TableCollection())
Expand Down
8 changes: 6 additions & 2 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,12 @@ def tables_equal(table_collection_1, table_collection_2, compare_provenances=Tru
"""
Check equality of tables, ignoring provenance timestamps (but not contents)
"""
for (_, table_1), (_, table_2) in zip(table_collection_1, table_collection_2):
if isinstance(table_1, tskit.ProvenanceTable):
tc_dict_1 = table_collection_1.name_map
tc_dict_2 = table_collection_2.name_map
for table_name in tc_dict_1.keys():
table_1 = tc_dict_1[table_name]
table_2 = tc_dict_2[table_name]
if table_name == "provenances":
if compare_provenances:
if np.any(table_1.record != table_2.record):
return False
Expand Down
34 changes: 19 additions & 15 deletions python/tskit/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1966,7 +1966,7 @@ def asdict(self):
"""
Returns a dictionary representation of this TableCollection.

Note: the semantics of this method changed at tskit 1.0.0. Previously a
Note: the semantics of this method changed at tskit 0.1.0. Previously a
map of table names to the tables themselves was returned.
"""
return {
Expand All @@ -1984,6 +1984,24 @@ def asdict(self):
"provenances": self.provenances.asdict(),
}

@property
def name_map(self):
"""
Returns a dictionary mapping table names to the corresponding
table instances. For example, the returned dictionary will contain the
key "edges" that maps to an :class:`.EdgeTable` instance.
"""
return {
"edges": self.edges,
"individuals": self.individuals,
"migrations": self.migrations,
"mutations": self.mutations,
"nodes": self.nodes,
"populations": self.populations,
"provenances": self.provenances,
"sites": self.sites,
}

def __banner(self, title):
width = 60
line = "#" * width
Expand All @@ -1992,20 +2010,6 @@ def __banner(self, title):
title_line += "#"
return line + "\n" + title_line + "\n" + line + "\n"

def __iter__(self):
"""
Iterate over all the tables in this TableCollection, ordered by table name
(i.e. deterministically), returning a tuple of (table_name, table_object)
"""
yield "edges", self.edges
yield "individuals", self.individuals
yield "migrations", self.migrations
yield "mutations", self.mutations
yield "nodes", self.nodes
yield "populations", self.populations
yield "provenances", self.provenances
yield "sites", self.sites

def __str__(self):
s = self.__banner("Individuals")
s += str(self.individuals) + "\n"
Expand Down
9 changes: 9 additions & 0 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -2940,6 +2940,15 @@ def dump(self, path, zlib_compression=False):
# Convert the path to str to allow us use Pathlib inputs
self._ll_tree_sequence.dump(str(path))

@property
def tables_dict(self):
"""
Returns a dictionary mapping names to tables in the
underlying :class:`.TableCollection`. Equivalent to calling
``ts.tables.name_map``.
"""
return self.tables.name_map

@property
def tables(self):
"""
Expand Down