Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
- Added ``TreeSequence.__repr__`` to display a summary for terminal usage.
(:user:`benjeffery`, :issue:`938`, :pr:`985`)

- Added ``TableCollection.dump`` and ``TableCollection.load``. This allows table
collections that are not valid tree sequences to be manipulated.
(:user:`benjeffery`, :issue:`14`, :pr:`986`)

**Breaking changes**

- The argument to ``ts.dump`` and ``tskit.load`` has been renamed `file` from `path`.
Expand Down
108 changes: 106 additions & 2 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -4495,6 +4495,26 @@ TableCollection_check_state(TableCollection *self)
return ret;
}

static int
TableCollection_alloc(TableCollection *self)
{
int ret = -1;

if (self->tables != NULL) {
tsk_table_collection_free(self->tables);
PyMem_Free(self->tables);
}
self->tables = PyMem_Malloc(sizeof(tsk_table_collection_t));
if (self->tables == NULL) {
PyErr_NoMemory();
goto out;
}
memset(self->tables, 0, sizeof(*self->tables));
ret = 0;
out:
return ret;
}

static void
TableCollection_dealloc(TableCollection *self)
{
Expand Down Expand Up @@ -5449,6 +5469,82 @@ TableCollection_equals(TableCollection *self, PyObject *args, PyObject *kwds)
return ret;
}

static PyObject *
TableCollection_dump(TableCollection *self, PyObject *args, PyObject *kwds)
{
int err;
FILE *file = NULL;
PyObject *py_file = NULL;
PyObject *ret = NULL;
static char *kwlist[] = { "file", NULL };

if (TableCollection_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &py_file)) {
goto out;
}

file = make_file(py_file, "wb");
if (file == NULL) {
goto out;
}

err = tsk_table_collection_dumpf(self->tables, file, 0);
if (err != 0) {
handle_library_error(err);
goto out;
}
ret = Py_BuildValue("");
out:
if (file != NULL) {
(void) fclose(file);
}
return ret;
}

static PyObject *
TableCollection_load(TableCollection *self, PyObject *args, PyObject *kwds)
{
int err;
PyObject *ret = NULL;
PyObject *py_file;
FILE *file = NULL;
static char *kwlist[] = { "file", NULL };

if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &py_file)) {
goto out;
}
file = make_file(py_file, "rb");
if (file == NULL) {
goto out;
}
/* Set unbuffered mode to ensure no more bytes are read than requested.
* Buffered reads could read beyond the end of the current store in a
* multi-store file or stream. This data would be discarded when we
* fclose() the file below, such that attempts to load the next store
* will fail. */
if (setvbuf(file, NULL, _IONBF, 0) != 0) {
PyErr_SetFromErrno(PyExc_OSError);
goto out;
}
err = TableCollection_alloc(self);
if (err != 0) {
goto out;
}
err = tsk_table_collection_loadf(self->tables, file, 0);
if (err != 0) {
handle_library_error(err);
goto out;
}
ret = Py_BuildValue("");
out:
if (file != NULL) {
(void) fclose(file);
}
return ret;
}

static PyGetSetDef TableCollection_getsetters[] = {
{ .name = "individuals",
.get = (getter) TableCollection_get_individuals,
Expand Down Expand Up @@ -5553,6 +5649,14 @@ static PyMethodDef TableCollection_methods[] = {
.ml_meth = (PyCFunction) TableCollection_has_index,
.ml_flags = METH_NOARGS,
.ml_doc = "Returns True if the TableCollection is indexed." },
{ .ml_name = "dump",
.ml_meth = (PyCFunction) TableCollection_dump,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Writes the table collection out to the specified file." },
{ .ml_name = "load",
.ml_meth = (PyCFunction) TableCollection_load,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Loads the table collection out to the specified file." },
{ NULL } /* Sentinel */
};

Expand Down Expand Up @@ -7340,11 +7444,11 @@ static PyMethodDef TreeSequence_methods[] = {
{ .ml_name = "dump",
.ml_meth = (PyCFunction) TreeSequence_dump,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Writes the tree sequence out to the specified path." },
.ml_doc = "Writes the tree sequence out to the specified file." },
{ .ml_name = "load",
.ml_meth = (PyCFunction) TreeSequence_load,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
.ml_doc = "Loads a tree sequence from the specified path." },
.ml_doc = "Loads a tree sequence from the specified file." },
{ .ml_name = "load_tables",
.ml_meth = (PyCFunction) TreeSequence_load_tables,
.ml_flags = METH_VARARGS | METH_KEYWORDS,
Expand Down
3 changes: 3 additions & 0 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,6 +1292,9 @@ def test_dump_load_errors(self):
except OSError as e:
message = str(e)
assert "File name too long" in message
for bad_filename in [[], None, {}]:
with pytest.raises(TypeError):
func(bad_filename)

def test_tables_sequence_length_round_trip(self):
for sequence_length in [0.1, 1, 10, 100]:
Expand Down
80 changes: 45 additions & 35 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,15 @@
import inspect
import itertools
import os
import platform
import random
import tempfile
import unittest

import msprime
import numpy as np
import pytest

import _tskit

IS_WINDOWS = platform.system() == "Windows"


def get_tracked_sample_counts(st, tracked_samples):
"""
Expand Down Expand Up @@ -70,7 +66,7 @@ def get_sample_counts(tree_sequence, st):
return nu


class LowLevelTestCase(unittest.TestCase):
class LowLevelTestCase:
"""
Superclass of tests for the low-level interface.
"""
Expand Down Expand Up @@ -166,6 +162,32 @@ class TestTableCollection(LowLevelTestCase):
Tests for the low-level TableCollection class
"""

def test_file_errors(self):
tc1 = _tskit.TableCollection(1)
self.get_example_tree_sequence().dump_tables(tc1)

def loader(*args):
tc = _tskit.TableCollection(1)
tc.load(*args)

for func in [tc1.dump, loader]:
with pytest.raises(TypeError):
func()
for bad_type in [None, [], {}]:
with pytest.raises(TypeError):
func(bad_type)

def test_dump_equality(self, tmp_path):
for ts in self.get_example_tree_sequences():
tc = _tskit.TableCollection(sequence_length=ts.get_sequence_length())
ts.dump_tables(tc)
with open(tmp_path / "tmp.trees", "wb") as f:
tc.dump(f)
with open(tmp_path / "tmp.trees", "rb") as f:
tc2 = _tskit.TableCollection()
tc2.load(f)
assert tc.equals(tc2)

def test_reference_deletion(self):
ts = msprime.simulate(10, mutation_rate=1, random_seed=1)
tc = ts.tables._ll_tables
Expand Down Expand Up @@ -518,7 +540,6 @@ def setUp(self):
def tearDown(self):
os.unlink(self.temp_file)

@pytest.mark.skipif(IS_WINDOWS, reason="File permissions on Windows")
def test_file_errors(self):
ts1 = self.get_example_tree_sequence()

Expand Down Expand Up @@ -565,38 +586,23 @@ def test_num_nodes(self):
max_node = node
assert max_node + 1 == ts.get_num_nodes()

def verify_dump_equality(self, ts):
"""
Verifies that we can dump a copy of the specified tree sequence
to the specified file, and load an identical copy.
"""
with open(self.temp_file, "wb") as f:
ts.dump(f)
with open(self.temp_file, "rb") as f:
ts2 = _tskit.TreeSequence()
ts2.load(f)
assert ts.get_num_samples() == ts2.get_num_samples()
assert ts.get_sequence_length() == ts2.get_sequence_length()
assert ts.get_num_mutations() == ts2.get_num_mutations()
assert ts.get_num_nodes() == ts2.get_num_nodes()
records1 = [ts.get_edge(j) for j in range(ts.get_num_edges())]
records2 = [ts2.get_edge(j) for j in range(ts2.get_num_edges())]
assert records1 == records2
mutations1 = [ts.get_mutation(j) for j in range(ts.get_num_mutations())]
mutations2 = [ts2.get_mutation(j) for j in range(ts2.get_num_mutations())]
assert mutations1 == mutations2
provenances1 = [ts.get_provenance(j) for j in range(ts.get_num_provenances())]
provenances2 = [ts2.get_provenance(j) for j in range(ts2.get_num_provenances())]
assert provenances1 == provenances2

def test_dump_equality(self):
def test_dump_equality(self, tmp_path):
for ts in self.get_example_tree_sequences():
tables = _tskit.TableCollection(sequence_length=ts.get_sequence_length())
ts.dump_tables(tables)
tables.compute_mutation_times()
ts = _tskit.TreeSequence()
ts.load_tables(tables)
self.verify_dump_equality(ts)
with open(tmp_path / "temp.trees", "wb") as f:
ts.dump(f)
with open(tmp_path / "temp.trees", "rb") as f:
ts2 = _tskit.TreeSequence()
ts2.load(f)
tc = _tskit.TableCollection(ts.get_sequence_length())
ts.dump_tables(tc)
tc2 = _tskit.TableCollection(ts2.get_sequence_length())
ts2.dump_tables(tc2)
assert tc.equals(tc2)

def verify_mutations(self, ts):
mutations = [ts.get_mutation(j) for j in range(ts.get_num_mutations())]
Expand Down Expand Up @@ -829,7 +835,7 @@ def test_kc_distance(self):
for lambda_ in [-1, 0, 1, 1000, -1e300]:
x1 = ts1.get_kc_distance(ts2, lambda_)
x2 = ts2.get_kc_distance(ts1, lambda_)
self.assertAlmostEqual(x1, x2)
assert x1 == x2

def test_load_tables_build_indexes(self):
for ts in self.get_example_tree_sequences():
Expand Down Expand Up @@ -2423,7 +2429,7 @@ def test_kc_distance(self):
for lambda_ in [-1, 0, 1, 1000, -1e300]:
x1 = t1.get_kc_distance(t2, lambda_)
x2 = t2.get_kc_distance(t1, lambda_)
self.assertAlmostEqual(x1, x2)
assert x1 == x2

def test_copy(self):
for ts in self.get_example_tree_sequences():
Expand Down Expand Up @@ -2561,7 +2567,11 @@ def test_tskit_version(self):

def test_uninitialised():
# These methods work from an instance that has a NULL ref so don't check
skip_list = ["TreeSequence_load", "TreeSequence_load_tables"]
skip_list = [
"TableCollection_load",
"TreeSequence_load",
"TreeSequence_load_tables",
]
for cls_name, cls in inspect.getmembers(_tskit):
if (
type(cls) == type
Expand Down
44 changes: 40 additions & 4 deletions python/tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import io
import json
import math
import pathlib
import pickle
import platform
import random
import time
import unittest
Expand Down Expand Up @@ -2604,8 +2606,8 @@ def test_indexes_roundtrip(self, simple_ts_fixture):
tables.drop_index()
assert not tskit.TableCollection.fromdict(tables.asdict()).has_index()

def test_asdict_lwt_concordence(self, ts_fixture):
def check_concordence(d1, d2):
def test_asdict_lwt_concordance(self, ts_fixture):
def check_concordance(d1, d2):
assert set(d1.keys()) == set(d2.keys())
for k1, v1 in d1.items():
v2 = d2[k1]
Expand All @@ -2632,12 +2634,46 @@ def check_concordence(d1, d2):
assert tables.has_index()
lwt = _tskit.LightweightTableCollection()
lwt.fromdict(tables.asdict())
check_concordence(lwt.asdict(), tables.asdict())
check_concordance(lwt.asdict(), tables.asdict())

tables.drop_index()
lwt = _tskit.LightweightTableCollection()
lwt.fromdict(tables.asdict())
check_concordence(lwt.asdict(), tables.asdict())
check_concordance(lwt.asdict(), tables.asdict())

def test_dump_pathlib(self, ts_fixture, tmp_path):
path = pathlib.Path(tmp_path) / "tmp.trees"
assert path.exists
assert path.is_file
tc = ts_fixture.dump_tables()
tc.dump(path)
other_tc = tskit.TableCollection.load(path)
assert tc == other_tc

@pytest.mark.skipif(platform.system() == "Windows", reason="Windows doesn't raise")
def test_dump_load_errors(self, ts_fixture):
tc = ts_fixture.dump_tables()
# Try to dump/load files we don't have access to or don't exist.
for func in [tc.dump, tskit.TableCollection.load]:
for f in ["/", "/test.trees", "/dir_does_not_exist/x.trees"]:
with pytest.raises(OSError):
func(f)
try:
func(f)
except OSError as e:
message = str(e)
assert len(message) > 0
f = "/" + 4000 * "x"
with pytest.raises(OSError):
func(f)
try:
func(f)
except OSError as e:
message = str(e)
assert "File name too long" in message
for bad_filename in [[], None, {}]:
with pytest.raises(TypeError):
func(bad_filename)


class TestEqualityOptions:
Expand Down
Loading