From 43785e76b21f7742a83a74fd96df1b99efca7728 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 29 Sep 2020 12:05:49 +0100 Subject: [PATCH] Add nbytes to Table and TableCollection. Closes #54 --- python/CHANGELOG.rst | 4 + python/tests/test_highlevel.py | 167 +++++++++++++++++---------------- python/tests/test_tables.py | 93 +++++++++--------- python/tskit/tables.py | 43 +++++++++ python/tskit/trees.py | 13 ++- python/tskit/util.py | 4 +- 6 files changed, 193 insertions(+), 131 deletions(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index abb512cd69..857b35e1e6 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -38,6 +38,10 @@ collections that are not valid tree sequences to be manipulated. (:user:`benjeffery`, :issue:`14`, :pr:`986`) +- Added ``nbytes`` method to tables, ``TableCollection`` and ``TreeSequence`` which + reports the size in bytes of those objects. + (:user:`jeromekelleher`, :user:`benjeffery`, :issue:`54`, :pr:`871`) + **Breaking changes** - The argument to ``ts.dump`` and ``tskit.load`` has been renamed `file` from `path`. diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index b0caed0e5d..52782e848d 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -35,13 +35,13 @@ import platform import random import re -import shutil import tempfile import unittest import uuid as _uuid import warnings import attr +import kastore import msprime import networkx as nx import numpy as np @@ -360,18 +360,11 @@ def test_all_oriented_forests(self): assert mrca == sv.get_mrca(j, k) -class HighLevelTestCase(unittest.TestCase): +class HighLevelTestCase: """ Superclass of tests on the high level interface. """ - def setUp(self): - self.temp_dir = tempfile.mkdtemp(prefix="tsk_hl_testcase_") - self.temp_file = os.path.join(self.temp_dir, "generic") - - def tearDown(self): - shutil.rmtree(self.temp_dir) - def verify_tree_mrcas(self, st): # Check the mrcas oriented_forest = [st.get_parent(j) for j in range(st.num_nodes)] @@ -486,7 +479,7 @@ def verify_trees(self, ts): assert not (st2 != st1) left, right = st1.get_interval() breakpoints.append(right) - self.assertAlmostEqual(left, length) + assert left == pytest.approx(length) assert left >= 0 assert right > left assert right <= ts.get_sequence_length() @@ -499,7 +492,7 @@ def verify_trees(self, ts): next(iter2) assert ts.get_num_trees() == num_trees assert breakpoints == list(ts.breakpoints()) - self.assertAlmostEqual(length, ts.get_sequence_length()) + assert length == ts.get_sequence_length() class TestNumpySamples: @@ -606,7 +599,7 @@ def verify_pairwise_diversity(self, ts): haplotypes = ts.genotype_matrix(isolated_as_missing=False).T pi1 = ts.get_pairwise_diversity() pi2 = simple_get_pairwise_diversity(haplotypes) - self.assertAlmostEqual(pi1, pi2) + assert pi1 == pytest.approx(pi2) assert pi1 >= 0.0 assert not math.isnan(pi1) # Check for a subsample. @@ -614,7 +607,7 @@ def verify_pairwise_diversity(self, ts): samples = list(ts.samples())[:num_samples] pi1 = ts.get_pairwise_diversity(samples) pi2 = simple_get_pairwise_diversity([haplotypes[j] for j in range(num_samples)]) - self.assertAlmostEqual(pi1, pi2) + assert pi1 == pytest.approx(pi2) assert pi1 >= 0.0 assert not math.isnan(pi1) @@ -1262,14 +1255,13 @@ def test_removed_methods(self): with pytest.raises(NotImplementedError): ts.newick_trees() - def test_dump_pathlib(self): - ts = msprime.simulate(2, random_seed=42) - path = pathlib.Path(self.temp_dir) / "tmp.trees" + def test_dump_pathlib(self, ts_fixture, tmp_path): + path = tmp_path / "tmp.trees" assert path.exists assert path.is_file - ts.dump(path) + ts_fixture.dump(path) other_ts = tskit.load(path) - assert ts.tables == other_ts.tables + assert ts_fixture.tables == other_ts.tables @pytest.mark.skipif(platform.system() == "Windows", reason="Windows doesn't raise") def test_dump_load_errors(self): @@ -1472,6 +1464,20 @@ def test_repr(self): for table in ts.tables.name_map: assert re.search(rf"║{table.capitalize()} *│", s) + def test_nbytes(self, tmp_path, ts_fixture): + ts_fixture.dump(tmp_path / "tables") + store = kastore.load(tmp_path / "tables") + for v in store.values(): + # Check we really have data in every field + assert v.nbytes > 0 + nbytes = sum( + array.nbytes + for name, array in store.items() + # nbytes is the size of asdict, so exclude file format items + if name not in ["format/version", "format/name", "uuid"] + ) + assert nbytes == ts_fixture.nbytes + class TestTreeSequenceMethodSignatures: ts = msprime.simulate(10, random_seed=1234) @@ -1683,42 +1689,44 @@ class TestFileUuid(HighLevelTestCase): """ def validate(self, ts): - assert ts.file_uuid is None - ts.dump(self.temp_file) - other_ts = tskit.load(self.temp_file) - assert other_ts.file_uuid is not None - assert len(other_ts.file_uuid), 36 - uuid = other_ts.file_uuid - other_ts = tskit.load(self.temp_file) - assert other_ts.file_uuid == uuid - assert ts.tables == other_ts.tables - - # Check that the UUID is well-formed. - parsed = _uuid.UUID("{" + uuid + "}") - assert str(parsed) == uuid - - # Save the same tree sequence to the file. We should get a different UUID. - ts.dump(self.temp_file) - other_ts = tskit.load(self.temp_file) - assert other_ts.file_uuid is not None - assert other_ts.file_uuid != uuid - - # Even saving a ts that has a UUID to another file changes the UUID - old_uuid = other_ts.file_uuid - other_ts.dump(self.temp_file) - assert other_ts.file_uuid == old_uuid - other_ts = tskit.load(self.temp_file) - assert other_ts.file_uuid is not None - assert other_ts.file_uuid != old_uuid - - # Tables dumped from this ts are a deep copy, so they don't have - # the file_uuid. - tables = other_ts.dump_tables() - assert tables.file_uuid is None - - # For now, ts.tables also returns a deep copy. This will hopefully - # change in the future though. - assert ts.tables.file_uuid is None + with tempfile.TemporaryDirectory() as tempdir: + temp_file = pathlib.Path(tempdir) / "tmp.trees" + assert ts.file_uuid is None + ts.dump(temp_file) + other_ts = tskit.load(temp_file) + assert other_ts.file_uuid is not None + assert len(other_ts.file_uuid), 36 + uuid = other_ts.file_uuid + other_ts = tskit.load(temp_file) + assert other_ts.file_uuid == uuid + assert ts.tables == other_ts.tables + + # Check that the UUID is well-formed. + parsed = _uuid.UUID("{" + uuid + "}") + assert str(parsed) == uuid + + # Save the same tree sequence to the file. We should get a different UUID. + ts.dump(temp_file) + other_ts = tskit.load(temp_file) + assert other_ts.file_uuid is not None + assert other_ts.file_uuid != uuid + + # Even saving a ts that has a UUID to another file changes the UUID + old_uuid = other_ts.file_uuid + other_ts.dump(temp_file) + assert other_ts.file_uuid == old_uuid + other_ts = tskit.load(temp_file) + assert other_ts.file_uuid is not None + assert other_ts.file_uuid != old_uuid + + # Tables dumped from this ts are a deep copy, so they don't have + # the file_uuid. + tables = other_ts.dump_tables() + assert tables.file_uuid is None + + # For now, ts.tables also returns a deep copy. This will hopefully + # change in the future though. + assert ts.tables.file_uuid is None def test_simple_simulation(self): ts = msprime.simulate(2, random_seed=1) @@ -1859,7 +1867,7 @@ def verify_approximate_equality(self, ts1, ts2): equal, taking into account the error incurred in exporting to text. """ assert ts1.sample_size == ts2.sample_size - self.assertAlmostEqual(ts1.sequence_length, ts2.sequence_length) + assert ts1.sequence_length == ts2.sequence_length assert ts1.num_nodes == ts2.num_nodes assert ts1.num_edges == ts2.num_edges assert ts1.num_sites == ts2.num_sites @@ -1869,15 +1877,15 @@ def verify_approximate_equality(self, ts1, ts2): for n1, n2 in zip(ts1.nodes(), ts2.nodes()): assert n1.population == n2.population assert n1.metadata == n2.metadata - self.assertAlmostEqual(n1.time, n2.time) + assert n1.time == pytest.approx(n2.time) checked += 1 assert checked == ts1.num_nodes checked = 0 for r1, r2 in zip(ts1.edges(), ts2.edges()): checked += 1 - self.assertAlmostEqual(r1.left, r2.left) - self.assertAlmostEqual(r1.right, r2.right) + assert r1.left == pytest.approx(r2.left) + assert r1.right == pytest.approx(r2.right) assert r1.parent == r2.parent assert r1.child == r2.child assert ts1.num_edges == checked @@ -1885,8 +1893,8 @@ def verify_approximate_equality(self, ts1, ts2): checked = 0 for s1, s2 in zip(ts1.sites(), ts2.sites()): checked += 1 - self.assertAlmostEqual(s1.position, s2.position) - self.assertAlmostEqual(s1.ancestral_state, s2.ancestral_state) + assert s1.position == pytest.approx(s2.position) + assert s1.ancestral_state == s2.ancestral_state assert s1.metadata == s2.metadata assert s1.mutations == s2.mutations assert ts1.num_sites == checked @@ -1897,7 +1905,7 @@ def verify_approximate_equality(self, ts1, ts2): assert s1.site == s2.site assert s1.node == s2.node if not (math.isnan(s1.time) and math.isnan(s2.time)): - self.assertAlmostEqual(s1.time, s2.time) + assert s1.time == pytest.approx(s2.time) assert s1.derived_state == s2.derived_state assert s1.parent == s2.parent assert s1.metadata == s2.metadata @@ -2237,11 +2245,10 @@ def preorder_dist(tree, root): for u, v in itertools.combinations(nx.descendants(g, root), 2): mrca = tree.mrca(u, v) tmrca = tree.time(mrca) - self.assertAlmostEqual( - tree.time(root) - tmrca, + assert tree.time(root) - tmrca == pytest.approx( nx.shortest_path_length( g, source=root, target=mrca, weight="branch_length" - ), + ) ) def verify_nx_nearest_neighbor_search(self): @@ -2352,7 +2359,7 @@ def test_total_branch_length(self): if node != root: bl += t1.get_branch_length(node) assert bl > 0 - self.assertAlmostEqual(t1.get_total_branch_length(), bl) + assert t1.get_total_branch_length() == pytest.approx(bl) def test_branch_length_empty_tree(self): tables = tskit.TableCollection(1) @@ -2524,12 +2531,12 @@ def test_interval(self): assert breakpoints[0] == 0 assert breakpoints[-1] == ts.sequence_length for i, tree in enumerate(ts.trees()): - self.assertAlmostEqual(tree.interval[0], breakpoints[i]) - self.assertAlmostEqual(tree.interval.left, breakpoints[i]) - self.assertAlmostEqual(tree.interval[1], breakpoints[i + 1]) - self.assertAlmostEqual(tree.interval.right, breakpoints[i + 1]) - self.assertAlmostEqual( - tree.interval.span, breakpoints[i + 1] - breakpoints[i] + assert tree.interval[0] == pytest.approx(breakpoints[i]) + assert tree.interval.left == pytest.approx(breakpoints[i]) + assert tree.interval[1] == pytest.approx(breakpoints[i + 1]) + assert tree.interval.right == pytest.approx(breakpoints[i + 1]) + assert tree.interval.span == pytest.approx( + breakpoints[i + 1] - breakpoints[i] ) def verify_empty_tree(self, tree): @@ -2703,7 +2710,7 @@ class TestNodeOrdering(HighLevelTestCase): num_random_permutations = 10 - def verify_tree_sequences_equal(self, ts1, ts2, approx=False): + def verify_tree_sequences_equal(self, ts1, ts2, approximate=False): assert ts1.get_num_trees() == ts2.get_num_trees() assert ts1.get_sample_size() == ts2.get_sample_size() assert ts1.get_num_nodes() == ts2.get_num_nodes() @@ -2711,9 +2718,9 @@ def verify_tree_sequences_equal(self, ts1, ts2, approx=False): for r1, r2 in zip(ts1.edges(), ts2.edges()): assert r1.parent == r2.parent assert r1.child == r2.child - if approx: - self.assertAlmostEqual(r1.left, r2.left) - self.assertAlmostEqual(r1.right, r2.right) + if approximate: + assert r1.left == pytest.approx(r2.left) + assert r1.right == pytest.approx(r2.right) else: assert r1.left == r2.left assert r1.right == r2.right @@ -2723,8 +2730,8 @@ def verify_tree_sequences_equal(self, ts1, ts2, approx=False): for n1, n2 in zip(ts1.nodes(), ts2.nodes()): assert n1.metadata == n2.metadata assert n1.population == n2.population - if approx: - self.assertAlmostEqual(n1.time, n2.time) + if approximate: + assert n1.time == pytest.approx(n2.time) else: assert n1.time == n2.time j += 1 @@ -2780,8 +2787,10 @@ def verify_random_permutation(self, ts): j += 1 assert j == ts.get_num_trees() # Verify we can dump this new tree sequence OK. - other_ts.dump(self.temp_file) - ts3 = tskit.load(self.temp_file) + with tempfile.TemporaryDirectory() as tempdir: + temp_file = pathlib.Path(tempdir) / "tmp.trees" + other_ts.dump(temp_file) + ts3 = tskit.load(temp_file) self.verify_tree_sequences_equal(other_ts, ts3) nodes_file = io.StringIO() edges_file = io.StringIO() diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 93704bb8b9..da6dba68f2 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -36,6 +36,7 @@ import warnings import attr +import kastore import msprime import numpy as np import pytest @@ -589,6 +590,15 @@ def test_equality(self): assert t1 is not None assert t1 != [] + def test_nbytes(self): + for num_rows in [0, 10, 100]: + input_data = self.make_input_data(num_rows) + table = self.table_class() + table.set_columns(**input_data) + # We don't have any metadata_schema here, so we can sum over the + # columns directly. + assert sum(col.nbytes for col in input_data.values()) == table.nbytes + def test_bad_offsets(self): for num_rows in [10, 100]: input_data = self.make_input_data(num_rows) @@ -2259,10 +2269,23 @@ def test_str(self): s = str(tables) assert len(s) > 0 - def test_asdict(self): - ts = msprime.simulate(10, mutation_rate=1, random_seed=1) - t = ts.tables - self.add_metadata(t) + def test_nbytes(self, tmp_path, ts_fixture): + tables = ts_fixture.dump_tables() + tables.dump(tmp_path / "tables") + store = kastore.load(tmp_path / "tables") + for v in store.values(): + # Check we really have data in every field + assert v.nbytes > 0 + nbytes = sum( + array.nbytes + for name, array in store.items() + # nbytes is the size of asdict, so exclude file format items + if name not in ["format/version", "format/name", "uuid"] + ) + assert nbytes == tables.nbytes + + def test_asdict(self, ts_fixture): + t = ts_fixture.dump_tables() d1 = { "encoding_version": (1, 2), "sequence_length": t.sequence_length, @@ -2286,10 +2309,8 @@ def test_asdict(self): assert t1.has_index() assert t2.has_index() - def test_from_dict(self): - ts = msprime.simulate(10, mutation_rate=1, random_seed=1) - t1 = ts.tables - self.add_metadata(t1) + def test_from_dict(self, ts_fixture): + t1 = ts_fixture.tables d = { "encoding_version": (1, 1), "sequence_length": t1.sequence_length, @@ -2308,19 +2329,13 @@ def test_from_dict(self): t2 = tskit.TableCollection.fromdict(d) assert t1 == t2 - def test_roundtrip_dict(self): - ts = msprime.simulate(10, mutation_rate=1, random_seed=1) - t1 = ts.tables - t2 = tskit.TableCollection.fromdict(t1.asdict()) - assert t1 == t2 - - self.add_metadata(t1) + def test_roundtrip_dict(self, ts_fixture): + t1 = ts_fixture.tables t2 = tskit.TableCollection.fromdict(t1.asdict()) assert t1 == t2 - def test_name_map(self): - ts = msprime.simulate(10, mutation_rate=1, random_seed=1) - tables = ts.tables + def test_name_map(self, ts_fixture): + tables = ts_fixture.tables td1 = { "individuals": tables.individuals, "populations": tables.populations, @@ -2346,16 +2361,8 @@ def test_equals_sequence_length(self): sequence_length=2 ) - def test_copy(self): - pop_configs = [msprime.PopulationConfiguration(5) for _ in range(2)] - migration_matrix = [[0, 1], [1, 0]] - t1 = msprime.simulate( - population_configurations=pop_configs, - migration_matrix=migration_matrix, - mutation_rate=1, - record_migrations=True, - random_seed=100, - ).dump_tables() + def test_copy(self, ts_fixture): + t1 = ts_fixture.dump_tables() t2 = t1.copy() assert t1 is not t2 assert t1 == t2 @@ -2421,16 +2428,8 @@ def test_equals(self): t2.populations.clear() assert t1 == t2 - def test_equals_options(self): - pop_configs = [msprime.PopulationConfiguration(5) for _ in range(2)] - migration_matrix = [[0, 1], [1, 0]] - t1 = msprime.simulate( - population_configurations=pop_configs, - migration_matrix=migration_matrix, - mutation_rate=1, - record_migrations=True, - random_seed=1, - ).dump_tables() + def test_equals_options(self, ts_fixture): + t1 = ts_fixture.dump_tables() t2 = t1.copy() t1.provenances.add_row("random stuff") @@ -2472,9 +2471,8 @@ def test_sequence_length(self): tables = tskit.TableCollection(sequence_length=sequence_length) assert tables.sequence_length == sequence_length - def test_uuid_simulation(self): - ts = msprime.simulate(10, random_seed=1) - tables = ts.tables + def test_uuid_simulation(self, ts_fixture): + tables = ts_fixture.tables assert tables.file_uuid is None, None def test_uuid_empty(self): @@ -2511,9 +2509,8 @@ def test_index_unsorted(self): ts = tables.tree_sequence() assert ts.tables == tables - def test_index_from_ts(self): - ts = msprime.simulate(10, random_seed=1) - tables = ts.dump_tables() + def test_index_from_ts(self, ts_fixture): + tables = ts_fixture.dump_tables() assert tables.has_index() tables.drop_index() assert not tables.has_index() @@ -2535,8 +2532,8 @@ def test_set_sequence_length(self): tables.sequence_length = value assert tables.sequence_length == value - def test_bad_sequence_length(self): - tables = msprime.simulate(10, random_seed=1).dump_tables() + def test_bad_sequence_length(self, ts_fixture): + tables = ts_fixture.dump_tables() assert tables.sequence_length == 1 for value in [-1, 0, -0.99, 0.9999]: tables.sequence_length = value @@ -2552,8 +2549,8 @@ def test_bad_sequence_length(self): tables.simplify() assert tables.sequence_length == value - def test_sequence_length_longer_than_edges(self): - tables = msprime.simulate(10, random_seed=1).dump_tables() + def test_sequence_length_longer_than_edges(self, ts_fixture): + tables = ts_fixture.dump_tables() tables.sequence_length = 2 ts = tables.tree_sequence() assert ts.sequence_length == 2 diff --git a/python/tskit/tables.py b/python/tskit/tables.py index adf80b628b..f1a3a09796 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -148,6 +148,10 @@ class TableCollectionIndexes: def asdict(self): return attr.asdict(self, filter=lambda k, v: v is not None) + @property + def nbytes(self): + return self.edge_insertion_order.nbytes + self.edge_removal_order.nbytes + def keep_with_offset(keep, data, offset): """ @@ -196,6 +200,28 @@ def max_rows(self): def max_rows_increment(self): return self.ll_table.max_rows_increment + @property + def nbytes(self) -> int: + """ + Returns the total number of bytes required to store the data + in this table. Note that this may not be equal to + the actual memory footprint. + """ + # It's not ideal that we run asdict() here to do this as we're + # currently creating copies of the column arrays, so it would + # be more efficient to have dedicated low-level methods. However, + # if we do have read-only views on the underlying memory for the + # column arrays then this will be a perfectly good way of + # computing the nbytes values and the overhead minimal. + d = self.asdict() + nbytes = 0 + # Some tables don't have a metadata_schema + metadata_schema = d.pop("metadata_schema", None) + if metadata_schema is not None: + nbytes += len(metadata_schema.encode()) + nbytes += sum(col.nbytes for col in d.values()) + return nbytes + def equals(self, other, ignore_metadata=False): """ Returns True if `self` and `other` are equal. By default, two tables @@ -2174,6 +2200,23 @@ def name_map(self): "sites": self.sites, } + @property + def nbytes(self) -> int: + """ + Returns the total number of bytes required to store the data + in this table collection. Note that this may not be equal to + the actual memory footprint. + """ + return sum( + ( + 8, # sequence_length takes 8 bytes + len(self.metadata_bytes), + len(str(self.metadata_schema).encode()), + self.indexes.nbytes, + sum(table.nbytes for table in self.name_map.values()), + ) + ) + def __banner(self, title): width = 60 line = "#" * width diff --git a/python/tskit/trees.py b/python/tskit/trees.py index bfe802eac6..fa4b57683d 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -3098,6 +3098,15 @@ def tables(self): """ return self.dump_tables() + @property + def nbytes(self): + """ + Returns the total number of bytes required to store the data + in this tree sequence. Note that this may not be equal to + the actual memory footprint. + """ + return self.tables.nbytes + def dump_tables(self): """ A copy of the tables defining this tree sequence. @@ -3290,7 +3299,7 @@ def __repr__(self): ["Trees", str(self.num_trees)], ["Sequence Length", str(self.sequence_length)], ["Sample Nodes", str(self.num_samples)], - ["Total Size TODO", util.naturalsize(99999)], + ["Total Size", util.naturalsize(self.nbytes)], ] header = ["Table", "Rows", "Size", "Has Metadata"] table_rows = [] @@ -3301,7 +3310,7 @@ def __repr__(self): for s in [ name.capitalize(), table.num_rows, - "TODO", + util.naturalsize(table.nbytes), "Yes" if hasattr(table, "metadata") and len(table.metadata) > 0 else "No", diff --git a/python/tskit/util.py b/python/tskit/util.py index b1296fc97f..f16c0bb4ee 100644 --- a/python/tskit/util.py +++ b/python/tskit/util.py @@ -377,7 +377,7 @@ def tree_sequence_html(ts): {name.capitalize()} {table.num_rows} - TODO! {naturalsize(99999)} + {naturalsize(table.nbytes)} {'✅' if hasattr(table, "metadata") and len(table.metadata) > 0 else ''} @@ -412,7 +412,7 @@ def tree_sequence_html(ts): Trees{ts.num_trees} Sequence Length{ts.sequence_length} Sample Nodes{ts.num_samples} - Total SizeTODO! {naturalsize(99999)} + Total Size{naturalsize(ts.nbytes)} Metadata{obj_to_collapsed_html(ts.metadata, None, 1) if len(ts.tables.metadata_bytes) > 0 else "No Metadata"}