From 9400ee74e13f85740d5ce3cf8d83196c3a20438a Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Wed, 11 Nov 2020 15:28:49 +0000 Subject: [PATCH] Add TableCollection dump/load --- python/CHANGELOG.rst | 4 ++ python/_tskitmodule.c | 108 ++++++++++++++++++++++++++++++++- python/tests/test_highlevel.py | 3 + python/tests/test_lowlevel.py | 80 +++++++++++++----------- python/tests/test_tables.py | 44 ++++++++++++-- python/tskit/tables.py | 22 +++++++ python/tskit/trees.py | 87 ++++++-------------------- python/tskit/util.py | 31 ++++++++++ 8 files changed, 270 insertions(+), 109 deletions(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 22135f3042..abb512cd69 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -34,6 +34,10 @@ - Added ``TreeSequence.__repr__`` to display a summary for terminal usage. (:user:`benjeffery`, :issue:`938`, :pr:`985`) +- Added ``TableCollection.dump`` and ``TableCollection.load``. This allows table + collections that are not valid tree sequences to be manipulated. + (:user:`benjeffery`, :issue:`14`, :pr:`986`) + **Breaking changes** - The argument to ``ts.dump`` and ``tskit.load`` has been renamed `file` from `path`. diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 922d2595b9..92b48b73e9 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -4495,6 +4495,26 @@ TableCollection_check_state(TableCollection *self) return ret; } +static int +TableCollection_alloc(TableCollection *self) +{ + int ret = -1; + + if (self->tables != NULL) { + tsk_table_collection_free(self->tables); + PyMem_Free(self->tables); + } + self->tables = PyMem_Malloc(sizeof(tsk_table_collection_t)); + if (self->tables == NULL) { + PyErr_NoMemory(); + goto out; + } + memset(self->tables, 0, sizeof(*self->tables)); + ret = 0; +out: + return ret; +} + static void TableCollection_dealloc(TableCollection *self) { @@ -5449,6 +5469,82 @@ TableCollection_equals(TableCollection *self, PyObject *args, PyObject *kwds) return ret; } +static PyObject * +TableCollection_dump(TableCollection *self, PyObject *args, PyObject *kwds) +{ + int err; + FILE *file = NULL; + PyObject *py_file = NULL; + PyObject *ret = NULL; + static char *kwlist[] = { "file", NULL }; + + if (TableCollection_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &py_file)) { + goto out; + } + + file = make_file(py_file, "wb"); + if (file == NULL) { + goto out; + } + + err = tsk_table_collection_dumpf(self->tables, file, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + if (file != NULL) { + (void) fclose(file); + } + return ret; +} + +static PyObject * +TableCollection_load(TableCollection *self, PyObject *args, PyObject *kwds) +{ + int err; + PyObject *ret = NULL; + PyObject *py_file; + FILE *file = NULL; + static char *kwlist[] = { "file", NULL }; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &py_file)) { + goto out; + } + file = make_file(py_file, "rb"); + if (file == NULL) { + goto out; + } + /* Set unbuffered mode to ensure no more bytes are read than requested. + * Buffered reads could read beyond the end of the current store in a + * multi-store file or stream. This data would be discarded when we + * fclose() the file below, such that attempts to load the next store + * will fail. */ + if (setvbuf(file, NULL, _IONBF, 0) != 0) { + PyErr_SetFromErrno(PyExc_OSError); + goto out; + } + err = TableCollection_alloc(self); + if (err != 0) { + goto out; + } + err = tsk_table_collection_loadf(self->tables, file, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + if (file != NULL) { + (void) fclose(file); + } + return ret; +} + static PyGetSetDef TableCollection_getsetters[] = { { .name = "individuals", .get = (getter) TableCollection_get_individuals, @@ -5553,6 +5649,14 @@ static PyMethodDef TableCollection_methods[] = { .ml_meth = (PyCFunction) TableCollection_has_index, .ml_flags = METH_NOARGS, .ml_doc = "Returns True if the TableCollection is indexed." }, + { .ml_name = "dump", + .ml_meth = (PyCFunction) TableCollection_dump, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Writes the table collection out to the specified file." }, + { .ml_name = "load", + .ml_meth = (PyCFunction) TableCollection_load, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Loads the table collection out to the specified file." }, { NULL } /* Sentinel */ }; @@ -7340,11 +7444,11 @@ static PyMethodDef TreeSequence_methods[] = { { .ml_name = "dump", .ml_meth = (PyCFunction) TreeSequence_dump, .ml_flags = METH_VARARGS | METH_KEYWORDS, - .ml_doc = "Writes the tree sequence out to the specified path." }, + .ml_doc = "Writes the tree sequence out to the specified file." }, { .ml_name = "load", .ml_meth = (PyCFunction) TreeSequence_load, .ml_flags = METH_VARARGS | METH_KEYWORDS, - .ml_doc = "Loads a tree sequence from the specified path." }, + .ml_doc = "Loads a tree sequence from the specified file." }, { .ml_name = "load_tables", .ml_meth = (PyCFunction) TreeSequence_load_tables, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index b05848c2d5..b0caed0e5d 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -1292,6 +1292,9 @@ def test_dump_load_errors(self): except OSError as e: message = str(e) assert "File name too long" in message + for bad_filename in [[], None, {}]: + with pytest.raises(TypeError): + func(bad_filename) def test_tables_sequence_length_round_trip(self): for sequence_length in [0.1, 1, 10, 100]: diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index b63a6b5bf9..e51f38da62 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -27,10 +27,8 @@ import inspect import itertools import os -import platform import random import tempfile -import unittest import msprime import numpy as np @@ -38,8 +36,6 @@ import _tskit -IS_WINDOWS = platform.system() == "Windows" - def get_tracked_sample_counts(st, tracked_samples): """ @@ -70,7 +66,7 @@ def get_sample_counts(tree_sequence, st): return nu -class LowLevelTestCase(unittest.TestCase): +class LowLevelTestCase: """ Superclass of tests for the low-level interface. """ @@ -166,6 +162,32 @@ class TestTableCollection(LowLevelTestCase): Tests for the low-level TableCollection class """ + def test_file_errors(self): + tc1 = _tskit.TableCollection(1) + self.get_example_tree_sequence().dump_tables(tc1) + + def loader(*args): + tc = _tskit.TableCollection(1) + tc.load(*args) + + for func in [tc1.dump, loader]: + with pytest.raises(TypeError): + func() + for bad_type in [None, [], {}]: + with pytest.raises(TypeError): + func(bad_type) + + def test_dump_equality(self, tmp_path): + for ts in self.get_example_tree_sequences(): + tc = _tskit.TableCollection(sequence_length=ts.get_sequence_length()) + ts.dump_tables(tc) + with open(tmp_path / "tmp.trees", "wb") as f: + tc.dump(f) + with open(tmp_path / "tmp.trees", "rb") as f: + tc2 = _tskit.TableCollection() + tc2.load(f) + assert tc.equals(tc2) + def test_reference_deletion(self): ts = msprime.simulate(10, mutation_rate=1, random_seed=1) tc = ts.tables._ll_tables @@ -518,7 +540,6 @@ def setUp(self): def tearDown(self): os.unlink(self.temp_file) - @pytest.mark.skipif(IS_WINDOWS, reason="File permissions on Windows") def test_file_errors(self): ts1 = self.get_example_tree_sequence() @@ -565,38 +586,23 @@ def test_num_nodes(self): max_node = node assert max_node + 1 == ts.get_num_nodes() - def verify_dump_equality(self, ts): - """ - Verifies that we can dump a copy of the specified tree sequence - to the specified file, and load an identical copy. - """ - with open(self.temp_file, "wb") as f: - ts.dump(f) - with open(self.temp_file, "rb") as f: - ts2 = _tskit.TreeSequence() - ts2.load(f) - assert ts.get_num_samples() == ts2.get_num_samples() - assert ts.get_sequence_length() == ts2.get_sequence_length() - assert ts.get_num_mutations() == ts2.get_num_mutations() - assert ts.get_num_nodes() == ts2.get_num_nodes() - records1 = [ts.get_edge(j) for j in range(ts.get_num_edges())] - records2 = [ts2.get_edge(j) for j in range(ts2.get_num_edges())] - assert records1 == records2 - mutations1 = [ts.get_mutation(j) for j in range(ts.get_num_mutations())] - mutations2 = [ts2.get_mutation(j) for j in range(ts2.get_num_mutations())] - assert mutations1 == mutations2 - provenances1 = [ts.get_provenance(j) for j in range(ts.get_num_provenances())] - provenances2 = [ts2.get_provenance(j) for j in range(ts2.get_num_provenances())] - assert provenances1 == provenances2 - - def test_dump_equality(self): + def test_dump_equality(self, tmp_path): for ts in self.get_example_tree_sequences(): tables = _tskit.TableCollection(sequence_length=ts.get_sequence_length()) ts.dump_tables(tables) tables.compute_mutation_times() ts = _tskit.TreeSequence() ts.load_tables(tables) - self.verify_dump_equality(ts) + with open(tmp_path / "temp.trees", "wb") as f: + ts.dump(f) + with open(tmp_path / "temp.trees", "rb") as f: + ts2 = _tskit.TreeSequence() + ts2.load(f) + tc = _tskit.TableCollection(ts.get_sequence_length()) + ts.dump_tables(tc) + tc2 = _tskit.TableCollection(ts2.get_sequence_length()) + ts2.dump_tables(tc2) + assert tc.equals(tc2) def verify_mutations(self, ts): mutations = [ts.get_mutation(j) for j in range(ts.get_num_mutations())] @@ -829,7 +835,7 @@ def test_kc_distance(self): for lambda_ in [-1, 0, 1, 1000, -1e300]: x1 = ts1.get_kc_distance(ts2, lambda_) x2 = ts2.get_kc_distance(ts1, lambda_) - self.assertAlmostEqual(x1, x2) + assert x1 == x2 def test_load_tables_build_indexes(self): for ts in self.get_example_tree_sequences(): @@ -2423,7 +2429,7 @@ def test_kc_distance(self): for lambda_ in [-1, 0, 1, 1000, -1e300]: x1 = t1.get_kc_distance(t2, lambda_) x2 = t2.get_kc_distance(t1, lambda_) - self.assertAlmostEqual(x1, x2) + assert x1 == x2 def test_copy(self): for ts in self.get_example_tree_sequences(): @@ -2561,7 +2567,11 @@ def test_tskit_version(self): def test_uninitialised(): # These methods work from an instance that has a NULL ref so don't check - skip_list = ["TreeSequence_load", "TreeSequence_load_tables"] + skip_list = [ + "TableCollection_load", + "TreeSequence_load", + "TreeSequence_load_tables", + ] for cls_name, cls in inspect.getmembers(_tskit): if ( type(cls) == type diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 099d718513..93704bb8b9 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -27,7 +27,9 @@ import io import json import math +import pathlib import pickle +import platform import random import time import unittest @@ -2604,8 +2606,8 @@ def test_indexes_roundtrip(self, simple_ts_fixture): tables.drop_index() assert not tskit.TableCollection.fromdict(tables.asdict()).has_index() - def test_asdict_lwt_concordence(self, ts_fixture): - def check_concordence(d1, d2): + def test_asdict_lwt_concordance(self, ts_fixture): + def check_concordance(d1, d2): assert set(d1.keys()) == set(d2.keys()) for k1, v1 in d1.items(): v2 = d2[k1] @@ -2632,12 +2634,46 @@ def check_concordence(d1, d2): assert tables.has_index() lwt = _tskit.LightweightTableCollection() lwt.fromdict(tables.asdict()) - check_concordence(lwt.asdict(), tables.asdict()) + check_concordance(lwt.asdict(), tables.asdict()) tables.drop_index() lwt = _tskit.LightweightTableCollection() lwt.fromdict(tables.asdict()) - check_concordence(lwt.asdict(), tables.asdict()) + check_concordance(lwt.asdict(), tables.asdict()) + + def test_dump_pathlib(self, ts_fixture, tmp_path): + path = pathlib.Path(tmp_path) / "tmp.trees" + assert path.exists + assert path.is_file + tc = ts_fixture.dump_tables() + tc.dump(path) + other_tc = tskit.TableCollection.load(path) + assert tc == other_tc + + @pytest.mark.skipif(platform.system() == "Windows", reason="Windows doesn't raise") + def test_dump_load_errors(self, ts_fixture): + tc = ts_fixture.dump_tables() + # Try to dump/load files we don't have access to or don't exist. + for func in [tc.dump, tskit.TableCollection.load]: + for f in ["/", "/test.trees", "/dir_does_not_exist/x.trees"]: + with pytest.raises(OSError): + func(f) + try: + func(f) + except OSError as e: + message = str(e) + assert len(message) > 0 + f = "/" + 4000 * "x" + with pytest.raises(OSError): + func(f) + try: + func(f) + except OSError as e: + message = str(e) + assert "File name too long" in message + for bad_filename in [[], None, {}]: + with pytest.raises(TypeError): + func(bad_filename) class TestEqualityOptions: diff --git a/python/tskit/tables.py b/python/tskit/tables.py index cb4f4b3bda..adf80b628b 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -2258,6 +2258,28 @@ def __eq__(self, other): def __getstate__(self): return self.asdict() + @classmethod + def load(cls, file_or_path): + file, local_file = util.convert_file_like_to_open_file(file_or_path, "rb") + ll_tc = _tskit.TableCollection(1) + ll_tc.load(file) + tc = TableCollection(1) + tc._ll_tables = ll_tc + return tc + + def dump(self, file_or_path): + """ + Writes the table collection to the specified path or file object. + + :param str file_or_path: The file object or path to write the TreeSequence to. + """ + file, local_file = util.convert_file_like_to_open_file(file_or_path, "wb") + try: + self._ll_tables.dump(file) + finally: + if local_file: + file.close() + # Unpickle support def __setstate__(self, state): self.__init__(state["sequence_length"]) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 51717b2a5b..885459f744 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -32,7 +32,6 @@ import itertools import json import math -import os import textwrap import warnings from typing import Any @@ -2430,38 +2429,7 @@ def load(file): stored in the specified file path. :rtype: :class:`tskit.TreeSequence` """ - # Get ourselves a local version of the file. The semantics here are complex - # because need to support a range of inputs and the free behaviour is - # slightly different on each. - _file = None - local_file = True - try: - # First, see if we can interpret the argument as a pathlike object. - path = os.fspath(file) - _file = open(path, "rb") - except TypeError: - pass - if _file is None: - # Now we try to open file. If it's not a pathlike object, it could be - # an integer fd or object with a fileno method. In this case we - # must make sure that close is **not** called on the fd. - try: - _file = open(file, "rb", closefd=False, buffering=0) - except TypeError: - pass - if _file is None: - # Assume that this is a file **but** we haven't opened it, so we must - # not close it. - _file = file - local_file = False - try: - return TreeSequence.load(_file) - except exceptions.FileFormatError as e: - # TODO Fix this for new file semantics - formats.raise_hdf5_format_error(path, e) - finally: - if local_file: - _file.close() + return TreeSequence.load(file) def parse_individuals( @@ -3071,10 +3039,18 @@ def aslist(self, **kwargs): return [tree.copy() for tree in self.trees(**kwargs)] @classmethod - def load(cls, file): - ts = _tskit.TreeSequence() - ts.load(file) - return TreeSequence(ts) + def load(cls, file_or_path): + file, local_file = util.convert_file_like_to_open_file(file_or_path, "rb") + try: + ts = _tskit.TreeSequence() + ts.load(file) + return TreeSequence(ts) + except exceptions.FileFormatError as e: + # TODO Fix this for new file semantics + formats.raise_hdf5_format_error(file_or_path, e) + finally: + if local_file: + file.close() @classmethod def load_tables(cls, tables, *, build_indexes=False): @@ -3082,43 +3058,18 @@ def load_tables(cls, tables, *, build_indexes=False): ts.load_tables(tables._ll_tables, build_indexes=build_indexes) return TreeSequence(ts) - def dump(self, file): + def dump(self, file_or_path): """ - Writes the tree sequence to the specified file object. + Writes the tree sequence to the specified path or file object. - :param str file: The file object or path to write the TreeSequence to. + :param str file_or_path: The file object or path to write the TreeSequence to. """ - # Get ourselves a local version of the file. The semantics here are complex - # because need to support a range of inputs and the free behaviour is - # slightly different on each. - _file = None - local_file = True - try: - # First, see if we can interpret the argument as a pathlike object. - path = os.fspath(file) - _file = open(path, "wb") - except TypeError: - pass - if _file is None: - # Now we try to open file. If it's not a pathlike object, it could be - # an integer fd or object with a fileno method. In this case we - # must make sure that close is **not** called on the fd. - try: - _file = open(file, "wb", closefd=False) - except TypeError: - pass - if _file is None: - # Assume that this is a file **but** we haven't opened it, so we must - # not close it. - if not hasattr(file, "write"): - raise TypeError("file object must have a write method") - _file = file - local_file = False + file, local_file = util.convert_file_like_to_open_file(file_or_path, "wb") try: - self._ll_tree_sequence.dump(_file) + self._ll_tree_sequence.dump(file) finally: if local_file: - _file.close() + file.close() @property def tables_dict(self): diff --git a/python/tskit/util.py b/python/tskit/util.py index 8d055440b2..b1296fc97f 100644 --- a/python/tskit/util.py +++ b/python/tskit/util.py @@ -23,6 +23,7 @@ Module responsible for various utility functions used in other modules. """ import json +import os import numpy as np @@ -435,3 +436,33 @@ def tree_sequence_html(ts): """ # noqa: B950 + + +def convert_file_like_to_open_file(file_like, mode): + # Get ourselves a local version of the file. The semantics here are complex + # because need to support a range of inputs and the free behaviour is + # slightly different on each. + _file = None + local_file = True + try: + # First, see if we can interpret the argument as a pathlike object. + path = os.fspath(file_like) + _file = open(path, mode) + except TypeError: + pass + if _file is None: + # Now we try to open file. If it's not a pathlike object, it could be + # an integer fd or object with a fileno method. In this case we + # must make sure that close is **not** called on the fd. + try: + _file = open(file_like, mode, closefd=False, buffering=0) + except TypeError: + pass + if _file is None: + # Assume that this is a file **but** we haven't opened it, so we must + # not close it. + if mode == "wb" and not hasattr(file_like, "write"): + raise TypeError("file object must have a write method") + _file = file_like + local_file = False + return _file, local_file