diff --git a/.circleci/config.yml b/.circleci/config.yml index 8aeb5779a1..98103185ea 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -114,7 +114,7 @@ jobs: name: Run Python tests command: | cd python - python -m pytest --cov=tskit --cov-report=xml --cov-branch -n `nproc` tests + python -m pytest --cov=tskit --cov-report=xml --cov-branch -n8 tests - run: name: Upload Python coverage diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 313f9e3606..ef59ce0c61 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -8,6 +8,13 @@ collections with the options to ignore top-level metadata/schema or provenance tables. (:user:`mufernando`, :issue:`896`, :pr:`897`). +- ``ts.dump`` and ``tskit.load`` now support reading and writing file objects such as + FIFOs and sockets. (:user:`benjeffery`, :issue:`657`, :pr:`909`) + +**Breaking changes** + +- The argument to ``ts.dump`` and ``tskit.load`` has been renamed `file` from `path`. + -------------------- [0.3.2] - 2020-09-29 -------------------- diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index ef9fd9c451..57dd621941 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -709,6 +709,33 @@ table_get_column_array(size_t num_rows, void *data, int npy_type, size_t element return ret; } +static FILE * +make_file(PyObject *fileobj, const char *mode) +{ + FILE *ret = NULL; + FILE *file = NULL; + int fileobj_fd, new_fd; + + fileobj_fd = PyObject_AsFileDescriptor(fileobj); + if (fileobj_fd == -1) { + goto out; + } + new_fd = dup(fileobj_fd); + if (new_fd == -1) { + PyErr_SetFromErrno(PyExc_OSError); + goto out; + } + file = fdopen(new_fd, mode); + if (file == NULL) { + (void) close(new_fd); + PyErr_SetFromErrno(PyExc_OSError); + goto out; + } + ret = file; +out: + return ret; +} + /*=================================================================== * IndividualTable *=================================================================== @@ -5321,23 +5348,33 @@ static PyObject * TreeSequence_dump(TreeSequence *self, PyObject *args, PyObject *kwds) { int err; - char *path; + FILE *file = NULL; + PyObject *py_file = NULL; PyObject *ret = NULL; - static char *kwlist[] = { "path", NULL }; + static char *kwlist[] = { "file", NULL }; if (TreeSequence_check_tree_sequence(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "s", kwlist, &path)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &py_file)) { + goto out; + } + + file = make_file(py_file, "wb"); + if (file == NULL) { goto out; } - err = tsk_treeseq_dump(self->tree_sequence, path, 0); + + err = tsk_treeseq_dumpf(self->tree_sequence, file, 0); if (err != 0) { handle_library_error(err); goto out; } ret = Py_BuildValue(""); out: + if (file != NULL) { + (void) fclose(file); + } return ret; } @@ -5398,24 +5435,41 @@ static PyObject * TreeSequence_load(TreeSequence *self, PyObject *args, PyObject *kwds) { int err; - char *path; PyObject *ret = NULL; - static char *kwlist[] = { "path", NULL }; + PyObject *py_file; + FILE *file = NULL; + static char *kwlist[] = { "file", NULL }; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "s", kwlist, &path)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &py_file)) { + goto out; + } + file = make_file(py_file, "rb"); + if (file == NULL) { + goto out; + } + /* Set unbuffered mode to ensure no more bytes are read than requested. + * Buffered reads could read beyond the end of the current store in a + * multi-store file or stream. This data would be discarded when we + * fclose() the file below, such that attempts to load the next store + * will fail. */ + if (setvbuf(file, NULL, _IONBF, 0) != 0) { + PyErr_SetFromErrno(PyExc_OSError); goto out; } err = TreeSequence_alloc(self); if (err != 0) { goto out; } - err = tsk_treeseq_load(self->tree_sequence, path, 0); + err = tsk_treeseq_loadf(self->tree_sequence, file, 0); if (err != 0) { handle_library_error(err); goto out; } ret = Py_BuildValue(""); out: + if (file != NULL) { + (void) fclose(file); + } return ret; } diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 56e7a26977..60019faa47 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -1,13 +1,62 @@ +# MIT License +# +# Copyright (c) 2018-2020 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Configuration and fixtures for pytest. Only put test-suite wide fixtures in here. Module +specific fixtures should live in their modules. + +To use a fixture in a test simply refer to it by name as an argument. This is called +dependancy injection. Note that all fixtures should have the suffix "_fixture" to make +it clear in test code. + +For example to use the `ts` fixture (a tree sequence with data in all fields) in a test: + +def test_something(ts): + assert ts.some_method() == expected + +Fixtures can be parameterised etc. see https://docs.pytest.org/en/stable/fixture.html + +Note that fixtures have a "scope" for example `ts` below is only created once per +test session and re-used for subsequent tests. +""" +import msprime import pytest +from pytest import fixture + +import tskit def pytest_addoption(parser): + """ + Add an option to skip tests marked with `@pytest.mark.slow` + """ parser.addoption( "--skip-slow", action="store_true", default=False, help="Skip slow tests" ) def pytest_configure(config): + """ + Add docs on the "slow" marker + """ config.addinivalue_line("markers", "slow: mark test as slow to run") @@ -17,3 +66,64 @@ def pytest_collection_modifyitems(config, items): for item in items: if "slow" in item.keywords: item.add_marker(skip_slow) + + +@fixture(scope="session") +def simple_ts_fixture(): + return msprime.simulate(2, random_seed=42) + + +@fixture(scope="session") +def ts_fixture(): + """ + A tree sequence with data in all fields + """ + n = 10 + t = 1 + population_configurations = [ + msprime.PopulationConfiguration(n // 2), + msprime.PopulationConfiguration(n // 2), + msprime.PopulationConfiguration(0), + ] + demographic_events = [ + msprime.MassMigration(time=t, source=0, destination=2), + msprime.MassMigration(time=t, source=1, destination=2), + ] + ts = msprime.simulate( + population_configurations=population_configurations, + demographic_events=demographic_events, + random_seed=1, + mutation_rate=1, + record_migrations=True, + ) + tables = ts.dump_tables() + for table in [ + "edges", + "individuals", + "migrations", + "mutations", + "nodes", + "populations", + "sites", + ]: + getattr(tables, table).metadata_schema = tskit.MetadataSchema({"codec": "json"}) + metadatas = [f"n_{table}_{u}" for u in range(getattr(ts, f"num_{table}"))] + metadata, metadata_offset = tskit.pack_strings(metadatas) + getattr(tables, table).set_columns( + **{ + **getattr(tables, table).asdict(), + "metadata": metadata, + "metadata_offset": metadata_offset, + } + ) + tables.metadata_schema = tskit.MetadataSchema({"codec": "json"}) + tables.metadata = "Test metadata" + return tables.tree_sequence() + + +@fixture(scope="session") +def replicate_ts_fixture(): + """ + A list of tree sequences + """ + return list(msprime.simulate(10, num_replicates=10, random_seed=42)) diff --git a/python/tests/test_cli.py b/python/tests/test_cli.py index e2453f1e16..7ae295ea4e 100644 --- a/python/tests/test_cli.py +++ b/python/tests/test_cli.py @@ -556,7 +556,7 @@ def verify(self, command): with pytest.raises(TestException): capture_output(cli.tskit_main, ["info", "/no/such/file"]) mocked_exit.assert_called_once_with( - "Load error: [Errno 2] No such file or directory" + "Load error: [Errno 2] No such file or directory: '/no/such/file'" ) def test_info(self): diff --git a/python/tests/test_fileobj.py b/python/tests/test_fileobj.py new file mode 100644 index 0000000000..66139cf6ba --- /dev/null +++ b/python/tests/test_fileobj.py @@ -0,0 +1,307 @@ +# MIT License +# +# Copyright (c) 2018-2020 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for loading and dumping different types of files and streams +""" +import io +import multiprocessing +import os +import pathlib +import platform +import queue +import shutil +import socket +import socketserver +import tempfile +import traceback + +import pytest +from pytest import fixture + +import tskit + + +IS_WINDOWS = platform.system() == "Windows" + + +class TestPath: + @fixture + def tempfile_name(self): + with tempfile.TemporaryDirectory() as tmp_dir: + yield f"{tmp_dir}/plain_path" + + def test_pathlib(self, ts_fixture, tempfile_name): + ts_fixture.dump(tempfile_name) + ts2 = tskit.load(tempfile_name) + assert ts_fixture.tables == ts2.tables + + +class TestPathLib: + @fixture + def pathlib_tempfile(self): + fd, path = tempfile.mkstemp(prefix="tskit_test_pathlib") + os.close(fd) + temp_file = pathlib.Path(path) + yield temp_file + temp_file.unlink() + + def test_pathlib(self, ts_fixture, pathlib_tempfile): + ts_fixture.dump(pathlib_tempfile) + ts2 = tskit.load(pathlib_tempfile) + assert ts_fixture.tables == ts2.tables + + +class TestFileObj: + @fixture + def fileobj(self): + with tempfile.TemporaryDirectory() as tmp_dir: + with open(f"{tmp_dir}/fileobj", "wb") as f: + yield f + + def test_fileobj(self, ts_fixture, fileobj): + ts_fixture.dump(fileobj) + fileobj.close() + ts2 = tskit.load(fileobj.name) + assert ts_fixture.tables == ts2.tables + + def test_fileobj_multi(self, replicate_ts_fixture, fileobj): + file_offsets = [] + for ts in replicate_ts_fixture: + ts.dump(fileobj) + file_offsets.append(fileobj.tell()) + fileobj.close() + with open(fileobj.name, "rb") as f: + for ts, file_offset in zip(replicate_ts_fixture, file_offsets): + ts2 = tskit.load(f) + file_offset2 = f.tell() + assert ts.tables == ts2.tables + assert file_offset == file_offset2 + + +class TestFileObjRW: + @fixture + def fileobj(self): + with tempfile.TemporaryDirectory() as tmp_dir: + pathlib.Path(f"{tmp_dir}/fileobj").touch() + with open(f"{tmp_dir}/fileobj", "r+b") as f: + yield f + + def test_fileobj(self, ts_fixture, fileobj): + ts_fixture.dump(fileobj) + fileobj.seek(0) + ts2 = tskit.load(fileobj) + assert ts_fixture.tables == ts2.tables + + def test_fileobj_multi(self, replicate_ts_fixture, fileobj): + file_offsets = [] + for ts in replicate_ts_fixture: + ts.dump(fileobj) + file_offsets.append(fileobj.tell()) + fileobj.seek(0) + for ts, file_offset in zip(replicate_ts_fixture, file_offsets): + ts2 = tskit.load(fileobj) + file_offset2 = fileobj.tell() + assert ts.tables == ts2.tables + assert file_offset == file_offset2 + + +class TestFD: + @fixture + def fd(self): + with tempfile.TemporaryDirectory() as tmp_dir: + pathlib.Path(f"{tmp_dir}/fd").touch() + with open(f"{tmp_dir}/fd", "r+b") as f: + yield f.fileno() + + def test_fd(self, ts_fixture, fd): + ts_fixture.dump(fd) + os.lseek(fd, 0, os.SEEK_SET) + ts2 = tskit.load(fd) + assert ts_fixture.tables == ts2.tables + + def test_fd_multi(self, replicate_ts_fixture, fd): + for ts in replicate_ts_fixture: + ts.dump(fd) + os.lseek(fd, 0, os.SEEK_SET) + for ts in replicate_ts_fixture: + ts2 = tskit.load(fd) + assert ts.tables == ts2.tables + + +class TestUnsupportedObjects: + def test_string_io(self, ts_fixture): + with pytest.raises(io.UnsupportedOperation, match=r"fileno"): + ts_fixture.dump(io.StringIO()) + with pytest.raises(io.UnsupportedOperation, match=r"fileno"): + tskit.load(io.StringIO()) + with pytest.raises(io.UnsupportedOperation, match=r"fileno"): + ts_fixture.dump(io.BytesIO()) + with pytest.raises(io.UnsupportedOperation, match=r"fileno"): + tskit.load(io.BytesIO()) + + +def dump_to_stream(q_err, q_in, file_out): + """ + Get tree sequences from `q_in` and ts.dump() them to `file_out`. + Uncaught exceptions are placed onto the `q_err` queue. + """ + try: + with open(file_out, "wb") as f: + while True: + ts = q_in.get() + if ts is None: + break + ts.dump(f) + except Exception as exc: + tb = traceback.format_exc() + q_err.put((exc, tb)) + + +def load_from_stream(q_err, q_out, file_in): + """ + tskit.load() tree sequences from `file_in` and put them onto `q_out`. + Uncaught exceptions are placed onto the `q_err` queue. + """ + try: + with open(file_in, "rb") as f: + while True: + try: + ts = tskit.load(f) + except EOFError: + break + q_out.put(ts) + except Exception as exc: + tb = traceback.format_exc() + q_err.put((exc, tb)) + + +def stream(fifo, ts_list): + """ + data -> q_in -> ts.dump(fifo) -> tskit.load(fifo) -> q_out -> data_out + """ + q_err = multiprocessing.Queue() + q_in = multiprocessing.Queue() + q_out = multiprocessing.Queue() + proc1 = multiprocessing.Process(target=dump_to_stream, args=(q_err, q_in, fifo)) + proc2 = multiprocessing.Process(target=load_from_stream, args=(q_err, q_out, fifo)) + proc1.start() + proc2.start() + for data in ts_list: + q_in.put(data) + + q_in.put(None) # signal the process that we're done + proc1.join(timeout=3) + if not q_err.empty(): + # re-raise the first child exception + exc, tb = q_err.get() + print(tb) + raise exc + if proc1.is_alive(): + # prevent hang if proc1 failed to join + proc1.terminate() + proc2.terminate() + raise RuntimeError("proc1 (ts.dump) failed to join") + ts_list_out = [] + for _ in ts_list: + try: + data_out = q_out.get(timeout=3) + except queue.Empty: + # terminate proc2 so we don't hang + proc2.terminate() + raise + ts_list_out.append(data_out) + proc2.join(timeout=3) + if proc2.is_alive(): + # prevent hang if proc2 failed to join + proc2.terminate() + raise RuntimeError("proc2 (tskit.load) failed to join") + + assert len(ts_list) == len(ts_list_out) + for ts, ts_out in zip(ts_list, ts_list_out): + assert ts.tables == ts_out.tables + + +@pytest.mark.skipif(IS_WINDOWS, reason="No FIFOs on Windows") +class TestFIFO: + @fixture + def fifo(self): + temp_dir = tempfile.mkdtemp(prefix="tsk_test_streaming") + temp_fifo = os.path.join(temp_dir, "fifo") + os.mkfifo(temp_fifo) + yield temp_fifo + shutil.rmtree(temp_dir) + + def test_single_stream(self, fifo, ts_fixture): + stream(fifo, [ts_fixture]) + + def test_multi_stream(self, fifo, replicate_ts_fixture): + stream(fifo, replicate_ts_fixture) + + +ADDRESS = ("localhost", 10009) + + +@pytest.mark.skipif(IS_WINDOWS, reason="Errors on Windows") +class TestSocket: + @fixture + def client_fd(self): + class Server(socketserver.ThreadingTCPServer): + allow_reuse_address = True + + class StoreEchoHandler(socketserver.BaseRequestHandler): + def handle(self): + while True: + try: + ts = tskit.load(self.request.fileno()) + except EOFError: + break + ts.dump(self.request.fileno()) + self.server.shutdown() + + def server_process(q): + server = Server(ADDRESS, StoreEchoHandler) + # Tell the client (on the other end of the queue) that it's OK to open + # a connection + q.put(None) + server.serve_forever() + + # Use a queue to synchronise the startup of the server and the client. + q = multiprocessing.Queue() + _server_process = multiprocessing.Process(target=server_process, args=(q,)) + _server_process.start() + q.get() + client = socket.create_connection(ADDRESS) + yield client.fileno() + client.close() + _server_process.join() + + def verify_stream(self, ts_list, client_fd): + for ts in ts_list: + ts.dump(client_fd) + echo_ts = tskit.load(client_fd) + assert ts.tables == echo_ts.tables + + def test_single(self, ts_fixture, client_fd): + self.verify_stream([ts_fixture], client_fd) + + def test_multi(self, replicate_ts_fixture, client_fd): + self.verify_stream(replicate_ts_fixture, client_fd) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index e89ca63170..711b47b20d 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -32,6 +32,7 @@ import os import pathlib import pickle +import platform import random import shutil import tempfile @@ -1261,7 +1262,7 @@ def test_removed_methods(self): ts.newick_trees() def test_dump_pathlib(self): - ts = msprime.simulate(5, random_seed=1) + ts = msprime.simulate(2, random_seed=42) path = pathlib.Path(self.temp_dir) / "tmp.trees" assert path.exists assert path.is_file @@ -1269,15 +1270,27 @@ def test_dump_pathlib(self): other_ts = tskit.load(path) assert ts.tables == other_ts.tables - def test_zlib_compression_warning(self): + @pytest.mark.skipif(platform.system() == "Windows", reason="Windows doesn't raise") + def test_dump_load_errors(self): ts = msprime.simulate(5, random_seed=1) - with warnings.catch_warnings(record=True) as w: - ts.dump(self.temp_file, zlib_compression=True) - assert len(w) == 1 - assert issubclass(w[0].category, RuntimeWarning) - with warnings.catch_warnings(record=True) as w: - ts.dump(self.temp_file, zlib_compression=False) - assert len(w) == 0 + # Try to dump/load files we don't have access to or don't exist. + for func in [ts.dump, tskit.load]: + for f in ["/", "/test.trees", "/dir_does_not_exist/x.trees"]: + with pytest.raises(OSError): + func(f) + try: + func(f) + except OSError as e: + message = str(e) + assert len(message) > 0 + f = "/" + 4000 * "x" + with pytest.raises(OSError): + func(f) + try: + func(f) + except OSError as e: + message = str(e) + assert "File name too long" in message def test_tables_sequence_length_round_trip(self): for sequence_length in [0.1, 1, 10, 100]: diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index e365055355..52c3c63a50 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -393,26 +393,9 @@ def loader(*args): for func in [ts1.dump, loader]: with pytest.raises(TypeError): func() - for bad_type in [1, None, [], {}]: + for bad_type in [None, [], {}]: with pytest.raises(TypeError): func(bad_type) - # Try to dump/load files we don't have access to or don't exist. - for f in ["/", "/test.trees", "/dir_does_not_exist/x.trees"]: - with pytest.raises(OSError): - func(f) - try: - func(f) - except OSError as e: - message = str(e) - assert len(message) > 0 - f = "/" + 4000 * "x" - with pytest.raises(OSError): - func(f) - try: - func(f) - except OSError as e: - message = str(e) - assert message.endswith("File name too long") def test_initial_state(self): # Check the initial state to make sure that it is empty. @@ -451,9 +434,11 @@ def verify_dump_equality(self, ts): Verifies that we can dump a copy of the specified tree sequence to the specified file, and load an identical copy. """ - ts.dump(self.temp_file) - ts2 = _tskit.TreeSequence() - ts2.load(self.temp_file) + with open(self.temp_file, "wb") as f: + ts.dump(f) + with open(self.temp_file, "rb") as f: + ts2 = _tskit.TreeSequence() + ts2.load(f) assert ts.get_num_samples() == ts2.get_num_samples() assert ts.get_sequence_length() == ts2.get_sequence_length() assert ts.get_num_mutations() == ts2.get_num_mutations() diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 8885f9d2e9..2273ccaf40 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -31,6 +31,7 @@ import functools import itertools import math +import os import textwrap import warnings from typing import Any @@ -2338,22 +2339,50 @@ def kc_distance(self, other, lambda_=0.0): return self._ll_tree.get_kc_distance(other._ll_tree, lambda_) -def load(path): +def load(file): """ - Loads a tree sequence from the specified file path. This file must be in the + Loads a tree sequence from the specified file object or path. The file must be in the :ref:`tree sequence file format ` produced by the :meth:`TreeSequence.dump` method. - :param str path: The file path of the ``.trees`` file containing the + :param str file: The file object or path of the ``.trees`` file containing the tree sequence we wish to load. :return: The tree sequence object containing the information stored in the specified file path. :rtype: :class:`tskit.TreeSequence` """ + # Get ourselves a local version of the file. The semantics here are complex + # because need to support a range of inputs and the free behaviour is + # slightly different on each. + _file = None + local_file = True try: - return TreeSequence.load(path) + # First, see if we can interpret the argument as a pathlike object. + path = os.fspath(file) + _file = open(path, "rb") + except TypeError: + pass + if _file is None: + # Now we try to open file. If it's not a pathlike object, it could be + # an integer fd or object with a fileno method. In this case we + # must make sure that close is **not** called on the fd. + try: + _file = open(file, "rb", closefd=False, buffering=0) + except TypeError: + pass + if _file is None: + # Assume that this is a file **but** we haven't opened it, so we must + # not close it. + _file = file + local_file = False + try: + return TreeSequence.load(_file) except exceptions.FileFormatError as e: + # TODO Fix this for new file semantics formats.raise_hdf5_format_error(path, e) + finally: + if local_file: + _file.close() def parse_individuals( @@ -2963,9 +2992,9 @@ def aslist(self, **kwargs): return [tree.copy() for tree in self.trees(**kwargs)] @classmethod - def load(cls, path): + def load(cls, file): ts = _tskit.TreeSequence() - ts.load(str(path)) + ts.load(file) return TreeSequence(ts) @classmethod @@ -2974,20 +3003,43 @@ def load_tables(cls, tables): ts.load_tables(tables._ll_tables) return TreeSequence(ts) - def dump(self, path, zlib_compression=False): + def dump(self, file): """ - Writes the tree sequence to the specified file path. + Writes the tree sequence to the specified file object. - :param str path: The file path to write the TreeSequence to. - :param bool zlib_compression: This parameter is deprecated and ignored. + :param str file: The file object or path to write the TreeSequence to. """ - if zlib_compression: - warnings.warn( - "The zlib_compression option is no longer supported and is ignored", - RuntimeWarning, - ) - # Convert the path to str to allow us use Pathlib inputs - self._ll_tree_sequence.dump(str(path)) + # Get ourselves a local version of the file. The semantics here are complex + # because need to support a range of inputs and the free behaviour is + # slightly different on each. + _file = None + local_file = True + try: + # First, see if we can interpret the argument as a pathlike object. + path = os.fspath(file) + _file = open(path, "wb") + except TypeError: + pass + if _file is None: + # Now we try to open file. If it's not a pathlike object, it could be + # an integer fd or object with a fileno method. In this case we + # must make sure that close is **not** called on the fd. + try: + _file = open(file, "wb", closefd=False) + except TypeError: + pass + if _file is None: + # Assume that this is a file **but** we haven't opened it, so we must + # not close it. + if not hasattr(file, "write"): + raise TypeError("file object must have a write method") + _file = file + local_file = False + try: + self._ll_tree_sequence.dump(_file) + finally: + if local_file: + _file.close() @property def tables_dict(self):