Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
collections that are not valid tree sequences to be manipulated.
(:user:`benjeffery`, :issue:`14`, :pr:`986`)

- Added ``nbytes`` method to tables, ``TableCollection`` and ``TreeSequence`` which
reports the size in bytes of those objects.
(:user:`jeromekelleher`, :user:`benjeffery`, :issue:`54`, :pr:`871`)

**Breaking changes**

- The argument to ``ts.dump`` and ``tskit.load`` has been renamed `file` from `path`.
Expand Down
167 changes: 88 additions & 79 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
import platform
import random
import re
import shutil
import tempfile
import unittest
import uuid as _uuid
import warnings

import attr
import kastore
import msprime
import networkx as nx
import numpy as np
Expand Down Expand Up @@ -360,18 +360,11 @@ def test_all_oriented_forests(self):
assert mrca == sv.get_mrca(j, k)


class HighLevelTestCase(unittest.TestCase):
class HighLevelTestCase:
"""
Superclass of tests on the high level interface.
"""

def setUp(self):
self.temp_dir = tempfile.mkdtemp(prefix="tsk_hl_testcase_")
self.temp_file = os.path.join(self.temp_dir, "generic")

def tearDown(self):
shutil.rmtree(self.temp_dir)

def verify_tree_mrcas(self, st):
# Check the mrcas
oriented_forest = [st.get_parent(j) for j in range(st.num_nodes)]
Expand Down Expand Up @@ -486,7 +479,7 @@ def verify_trees(self, ts):
assert not (st2 != st1)
left, right = st1.get_interval()
breakpoints.append(right)
self.assertAlmostEqual(left, length)
assert left == pytest.approx(length)
assert left >= 0
assert right > left
assert right <= ts.get_sequence_length()
Expand All @@ -499,7 +492,7 @@ def verify_trees(self, ts):
next(iter2)
assert ts.get_num_trees() == num_trees
assert breakpoints == list(ts.breakpoints())
self.assertAlmostEqual(length, ts.get_sequence_length())
assert length == ts.get_sequence_length()


class TestNumpySamples:
Expand Down Expand Up @@ -606,15 +599,15 @@ def verify_pairwise_diversity(self, ts):
haplotypes = ts.genotype_matrix(isolated_as_missing=False).T
pi1 = ts.get_pairwise_diversity()
pi2 = simple_get_pairwise_diversity(haplotypes)
self.assertAlmostEqual(pi1, pi2)
assert pi1 == pytest.approx(pi2)
assert pi1 >= 0.0
assert not math.isnan(pi1)
# Check for a subsample.
num_samples = ts.get_sample_size() // 2 + 1
samples = list(ts.samples())[:num_samples]
pi1 = ts.get_pairwise_diversity(samples)
pi2 = simple_get_pairwise_diversity([haplotypes[j] for j in range(num_samples)])
self.assertAlmostEqual(pi1, pi2)
assert pi1 == pytest.approx(pi2)
assert pi1 >= 0.0
assert not math.isnan(pi1)

Expand Down Expand Up @@ -1262,14 +1255,13 @@ def test_removed_methods(self):
with pytest.raises(NotImplementedError):
ts.newick_trees()

def test_dump_pathlib(self):
ts = msprime.simulate(2, random_seed=42)
path = pathlib.Path(self.temp_dir) / "tmp.trees"
def test_dump_pathlib(self, ts_fixture, tmp_path):
path = tmp_path / "tmp.trees"
assert path.exists
assert path.is_file
ts.dump(path)
ts_fixture.dump(path)
other_ts = tskit.load(path)
assert ts.tables == other_ts.tables
assert ts_fixture.tables == other_ts.tables

@pytest.mark.skipif(platform.system() == "Windows", reason="Windows doesn't raise")
def test_dump_load_errors(self):
Expand Down Expand Up @@ -1472,6 +1464,20 @@ def test_repr(self):
for table in ts.tables.name_map:
assert re.search(rf"║{table.capitalize()} *│", s)

def test_nbytes(self, tmp_path, ts_fixture):
ts_fixture.dump(tmp_path / "tables")
store = kastore.load(tmp_path / "tables")
for v in store.values():
# Check we really have data in every field
assert v.nbytes > 0
nbytes = sum(
array.nbytes
for name, array in store.items()
# nbytes is the size of asdict, so exclude file format items
if name not in ["format/version", "format/name", "uuid"]
)
assert nbytes == ts_fixture.nbytes


class TestTreeSequenceMethodSignatures:
ts = msprime.simulate(10, random_seed=1234)
Expand Down Expand Up @@ -1683,42 +1689,44 @@ class TestFileUuid(HighLevelTestCase):
"""

def validate(self, ts):
assert ts.file_uuid is None
ts.dump(self.temp_file)
other_ts = tskit.load(self.temp_file)
assert other_ts.file_uuid is not None
assert len(other_ts.file_uuid), 36
uuid = other_ts.file_uuid
other_ts = tskit.load(self.temp_file)
assert other_ts.file_uuid == uuid
assert ts.tables == other_ts.tables

# Check that the UUID is well-formed.
parsed = _uuid.UUID("{" + uuid + "}")
assert str(parsed) == uuid

# Save the same tree sequence to the file. We should get a different UUID.
ts.dump(self.temp_file)
other_ts = tskit.load(self.temp_file)
assert other_ts.file_uuid is not None
assert other_ts.file_uuid != uuid

# Even saving a ts that has a UUID to another file changes the UUID
old_uuid = other_ts.file_uuid
other_ts.dump(self.temp_file)
assert other_ts.file_uuid == old_uuid
other_ts = tskit.load(self.temp_file)
assert other_ts.file_uuid is not None
assert other_ts.file_uuid != old_uuid

# Tables dumped from this ts are a deep copy, so they don't have
# the file_uuid.
tables = other_ts.dump_tables()
assert tables.file_uuid is None

# For now, ts.tables also returns a deep copy. This will hopefully
# change in the future though.
assert ts.tables.file_uuid is None
with tempfile.TemporaryDirectory() as tempdir:
temp_file = pathlib.Path(tempdir) / "tmp.trees"
assert ts.file_uuid is None
ts.dump(temp_file)
other_ts = tskit.load(temp_file)
assert other_ts.file_uuid is not None
assert len(other_ts.file_uuid), 36
uuid = other_ts.file_uuid
other_ts = tskit.load(temp_file)
assert other_ts.file_uuid == uuid
assert ts.tables == other_ts.tables

# Check that the UUID is well-formed.
parsed = _uuid.UUID("{" + uuid + "}")
assert str(parsed) == uuid

# Save the same tree sequence to the file. We should get a different UUID.
ts.dump(temp_file)
other_ts = tskit.load(temp_file)
assert other_ts.file_uuid is not None
assert other_ts.file_uuid != uuid

# Even saving a ts that has a UUID to another file changes the UUID
old_uuid = other_ts.file_uuid
other_ts.dump(temp_file)
assert other_ts.file_uuid == old_uuid
other_ts = tskit.load(temp_file)
assert other_ts.file_uuid is not None
assert other_ts.file_uuid != old_uuid

# Tables dumped from this ts are a deep copy, so they don't have
# the file_uuid.
tables = other_ts.dump_tables()
assert tables.file_uuid is None

# For now, ts.tables also returns a deep copy. This will hopefully
# change in the future though.
assert ts.tables.file_uuid is None

def test_simple_simulation(self):
ts = msprime.simulate(2, random_seed=1)
Expand Down Expand Up @@ -1859,7 +1867,7 @@ def verify_approximate_equality(self, ts1, ts2):
equal, taking into account the error incurred in exporting to text.
"""
assert ts1.sample_size == ts2.sample_size
self.assertAlmostEqual(ts1.sequence_length, ts2.sequence_length)
assert ts1.sequence_length == ts2.sequence_length
assert ts1.num_nodes == ts2.num_nodes
assert ts1.num_edges == ts2.num_edges
assert ts1.num_sites == ts2.num_sites
Expand All @@ -1869,24 +1877,24 @@ def verify_approximate_equality(self, ts1, ts2):
for n1, n2 in zip(ts1.nodes(), ts2.nodes()):
assert n1.population == n2.population
assert n1.metadata == n2.metadata
self.assertAlmostEqual(n1.time, n2.time)
assert n1.time == pytest.approx(n2.time)
checked += 1
assert checked == ts1.num_nodes

checked = 0
for r1, r2 in zip(ts1.edges(), ts2.edges()):
checked += 1
self.assertAlmostEqual(r1.left, r2.left)
self.assertAlmostEqual(r1.right, r2.right)
assert r1.left == pytest.approx(r2.left)
assert r1.right == pytest.approx(r2.right)
assert r1.parent == r2.parent
assert r1.child == r2.child
assert ts1.num_edges == checked

checked = 0
for s1, s2 in zip(ts1.sites(), ts2.sites()):
checked += 1
self.assertAlmostEqual(s1.position, s2.position)
self.assertAlmostEqual(s1.ancestral_state, s2.ancestral_state)
assert s1.position == pytest.approx(s2.position)
assert s1.ancestral_state == s2.ancestral_state
assert s1.metadata == s2.metadata
assert s1.mutations == s2.mutations
assert ts1.num_sites == checked
Expand All @@ -1897,7 +1905,7 @@ def verify_approximate_equality(self, ts1, ts2):
assert s1.site == s2.site
assert s1.node == s2.node
if not (math.isnan(s1.time) and math.isnan(s2.time)):
self.assertAlmostEqual(s1.time, s2.time)
assert s1.time == pytest.approx(s2.time)
assert s1.derived_state == s2.derived_state
assert s1.parent == s2.parent
assert s1.metadata == s2.metadata
Expand Down Expand Up @@ -2237,11 +2245,10 @@ def preorder_dist(tree, root):
for u, v in itertools.combinations(nx.descendants(g, root), 2):
mrca = tree.mrca(u, v)
tmrca = tree.time(mrca)
self.assertAlmostEqual(
tree.time(root) - tmrca,
assert tree.time(root) - tmrca == pytest.approx(
nx.shortest_path_length(
g, source=root, target=mrca, weight="branch_length"
),
)
)

def verify_nx_nearest_neighbor_search(self):
Expand Down Expand Up @@ -2352,7 +2359,7 @@ def test_total_branch_length(self):
if node != root:
bl += t1.get_branch_length(node)
assert bl > 0
self.assertAlmostEqual(t1.get_total_branch_length(), bl)
assert t1.get_total_branch_length() == pytest.approx(bl)

def test_branch_length_empty_tree(self):
tables = tskit.TableCollection(1)
Expand Down Expand Up @@ -2524,12 +2531,12 @@ def test_interval(self):
assert breakpoints[0] == 0
assert breakpoints[-1] == ts.sequence_length
for i, tree in enumerate(ts.trees()):
self.assertAlmostEqual(tree.interval[0], breakpoints[i])
self.assertAlmostEqual(tree.interval.left, breakpoints[i])
self.assertAlmostEqual(tree.interval[1], breakpoints[i + 1])
self.assertAlmostEqual(tree.interval.right, breakpoints[i + 1])
self.assertAlmostEqual(
tree.interval.span, breakpoints[i + 1] - breakpoints[i]
assert tree.interval[0] == pytest.approx(breakpoints[i])
assert tree.interval.left == pytest.approx(breakpoints[i])
assert tree.interval[1] == pytest.approx(breakpoints[i + 1])
assert tree.interval.right == pytest.approx(breakpoints[i + 1])
assert tree.interval.span == pytest.approx(
breakpoints[i + 1] - breakpoints[i]
)

def verify_empty_tree(self, tree):
Expand Down Expand Up @@ -2703,17 +2710,17 @@ class TestNodeOrdering(HighLevelTestCase):

num_random_permutations = 10

def verify_tree_sequences_equal(self, ts1, ts2, approx=False):
def verify_tree_sequences_equal(self, ts1, ts2, approximate=False):
assert ts1.get_num_trees() == ts2.get_num_trees()
assert ts1.get_sample_size() == ts2.get_sample_size()
assert ts1.get_num_nodes() == ts2.get_num_nodes()
j = 0
for r1, r2 in zip(ts1.edges(), ts2.edges()):
assert r1.parent == r2.parent
assert r1.child == r2.child
if approx:
self.assertAlmostEqual(r1.left, r2.left)
self.assertAlmostEqual(r1.right, r2.right)
if approximate:
assert r1.left == pytest.approx(r2.left)
assert r1.right == pytest.approx(r2.right)
else:
assert r1.left == r2.left
assert r1.right == r2.right
Expand All @@ -2723,8 +2730,8 @@ def verify_tree_sequences_equal(self, ts1, ts2, approx=False):
for n1, n2 in zip(ts1.nodes(), ts2.nodes()):
assert n1.metadata == n2.metadata
assert n1.population == n2.population
if approx:
self.assertAlmostEqual(n1.time, n2.time)
if approximate:
assert n1.time == pytest.approx(n2.time)
else:
assert n1.time == n2.time
j += 1
Expand Down Expand Up @@ -2780,8 +2787,10 @@ def verify_random_permutation(self, ts):
j += 1
assert j == ts.get_num_trees()
# Verify we can dump this new tree sequence OK.
other_ts.dump(self.temp_file)
ts3 = tskit.load(self.temp_file)
with tempfile.TemporaryDirectory() as tempdir:
temp_file = pathlib.Path(tempdir) / "tmp.trees"
other_ts.dump(temp_file)
ts3 = tskit.load(temp_file)
self.verify_tree_sequences_equal(other_ts, ts3)
nodes_file = io.StringIO()
edges_file = io.StringIO()
Expand Down
Loading