diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index df096aa219..2e77fd653c 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -714,7 +714,7 @@ parse_allele_list(PyObject *allele_tuple) /* * Retrieves the PyObject* corresponding the specified key in the * specified dictionary. If required is true, raise a TypeError if the - * value is None. + * value is None or absent. * * NB This returns a *borrowed reference*, so don't DECREF it! */ @@ -725,7 +725,7 @@ get_table_dict_value(PyObject *dict, const char *key_str, bool required) ret = PyDict_GetItemString(dict, key_str); if (ret == NULL) { - PyErr_Format(PyExc_ValueError, "'%s' not specified", key_str); + ret = Py_None; } if (required && ret == Py_None) { PyErr_Format(PyExc_TypeError, "'%s' is required", key_str); @@ -817,6 +817,24 @@ table_read_offset_array(PyObject *input, size_t *num_rows, size_t length, bool c return ret; } +static const char * +parse_metadata_schema_arg(PyObject *arg, Py_ssize_t* metadata_schema_length) +{ + const char *ret = NULL; + if (arg == NULL) { + PyErr_Format( + PyExc_AttributeError, + "Cannot del metadata_schema, set to empty string (\"\") to clear."); + goto out; + } + ret = PyUnicode_AsUTF8AndSize(arg, metadata_schema_length); + if (ret == NULL) { + goto out; + } +out: + return ret; +} + static int parse_individual_table_dict(tsk_individual_table_t *table, PyObject *dict, bool clear_table) { @@ -837,6 +855,9 @@ parse_individual_table_dict(tsk_individual_table_t *table, PyObject *dict, bool PyArrayObject *metadata_array = NULL; PyObject *metadata_offset_input = NULL; PyArrayObject *metadata_offset_array = NULL; + PyObject *metadata_schema_input = NULL; + const char *metadata_schema = NULL; + Py_ssize_t metadata_schema_length = 0; /* Get the input values */ flags_input = get_table_dict_value(dict, "flags", true); @@ -859,6 +880,10 @@ parse_individual_table_dict(tsk_individual_table_t *table, PyObject *dict, bool if (metadata_offset_input == NULL) { goto out; } + metadata_schema_input = get_table_dict_value(dict, "metadata_schema", false); + if (metadata_schema_input == NULL) { + goto out; + } /* Pull out the arrays */ flags_array = table_read_column_array(flags_input, NPY_UINT32, &num_rows, false); @@ -904,6 +929,20 @@ parse_individual_table_dict(tsk_individual_table_t *table, PyObject *dict, bool metadata_offset_data = PyArray_DATA(metadata_offset_array); } + if (metadata_schema_input != Py_None) { + metadata_schema = parse_metadata_schema_arg( + metadata_schema_input, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_individual_table_set_metadata_schema( + table, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + } + if (clear_table) { err = tsk_individual_table_clear(table); if (err != 0) { @@ -951,6 +990,9 @@ parse_node_table_dict(tsk_node_table_t *table, PyObject *dict, bool clear_table) PyArrayObject *metadata_array = NULL; PyObject *metadata_offset_input = NULL; PyArrayObject *metadata_offset_array = NULL; + PyObject *metadata_schema_input = NULL; + const char *metadata_schema = NULL; + Py_ssize_t metadata_schema_length = 0; /* Get the input values */ flags_input = get_table_dict_value(dict, "flags", true); @@ -977,6 +1019,10 @@ parse_node_table_dict(tsk_node_table_t *table, PyObject *dict, bool clear_table) if (metadata_offset_input == NULL) { goto out; } + metadata_schema_input = get_table_dict_value(dict, "metadata_schema", false); + if (metadata_schema_input == NULL) { + goto out; + } /* Create the arrays */ flags_array = table_read_column_array(flags_input, NPY_UINT32, &num_rows, false); @@ -1022,6 +1068,20 @@ parse_node_table_dict(tsk_node_table_t *table, PyObject *dict, bool clear_table) } metadata_offset_data = PyArray_DATA(metadata_offset_array); } + if (metadata_schema_input != Py_None) { + metadata_schema = parse_metadata_schema_arg( + metadata_schema_input, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_node_table_set_metadata_schema( + table, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + } + if (clear_table) { err = tsk_node_table_clear(table); if (err != 0) { @@ -1068,6 +1128,9 @@ parse_edge_table_dict(tsk_edge_table_t *table, PyObject *dict, bool clear_table) PyArrayObject *metadata_array = NULL; PyObject *metadata_offset_input = NULL; PyArrayObject *metadata_offset_array = NULL; + PyObject *metadata_schema_input = NULL; + const char *metadata_schema = NULL; + Py_ssize_t metadata_schema_length = 0; /* Get the input values */ left_input = get_table_dict_value(dict, "left", true); @@ -1094,6 +1157,10 @@ parse_edge_table_dict(tsk_edge_table_t *table, PyObject *dict, bool clear_table) if (metadata_offset_input == NULL) { goto out; } + metadata_schema_input = get_table_dict_value(dict, "metadata_schema", false); + if (metadata_schema_input == NULL) { + goto out; + } /* Create the arrays */ @@ -1132,7 +1199,19 @@ parse_edge_table_dict(tsk_edge_table_t *table, PyObject *dict, bool clear_table) } metadata_offset_data = PyArray_DATA(metadata_offset_array); } - + if (metadata_schema_input != Py_None) { + metadata_schema = parse_metadata_schema_arg( + metadata_schema_input, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_edge_table_set_metadata_schema( + table, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + } if (clear_table) { err = tsk_edge_table_clear(table); @@ -1185,6 +1264,9 @@ parse_migration_table_dict(tsk_migration_table_t *table, PyObject *dict, bool cl PyArrayObject *metadata_array = NULL; PyObject *metadata_offset_input = NULL; PyArrayObject *metadata_offset_array = NULL; + PyObject *metadata_schema_input = NULL; + const char *metadata_schema = NULL; + Py_ssize_t metadata_schema_length = 0; /* Get the input values */ left_input = get_table_dict_value(dict, "left", true); @@ -1219,6 +1301,10 @@ parse_migration_table_dict(tsk_migration_table_t *table, PyObject *dict, bool cl if (metadata_offset_input == NULL) { goto out; } + metadata_schema_input = get_table_dict_value(dict, "metadata_schema", false); + if (metadata_schema_input == NULL) { + goto out; + } /* Build the arrays */ left_array = table_read_column_array(left_input, NPY_FLOAT64, &num_rows, false); @@ -1264,6 +1350,20 @@ parse_migration_table_dict(tsk_migration_table_t *table, PyObject *dict, bool cl } metadata_offset_data = PyArray_DATA(metadata_offset_array); } + if (metadata_schema_input != Py_None) { + metadata_schema = parse_metadata_schema_arg( + metadata_schema_input, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_migration_table_set_metadata_schema( + table, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + } + if (clear_table) { err = tsk_migration_table_clear(table); if (err != 0) { @@ -1311,6 +1411,10 @@ parse_site_table_dict(tsk_site_table_t *table, PyObject *dict, bool clear_table) PyArrayObject *metadata_offset_array = NULL; char *metadata_data; uint32_t *metadata_offset_data; + PyObject *metadata_schema_input = NULL; + const char *metadata_schema = NULL; + Py_ssize_t metadata_schema_length = 0; + /* Get the input values */ position_input = get_table_dict_value(dict, "position", true); @@ -1333,6 +1437,11 @@ parse_site_table_dict(tsk_site_table_t *table, PyObject *dict, bool clear_table) if (metadata_offset_input == NULL) { goto out; } + metadata_schema_input = get_table_dict_value(dict, "metadata_schema", false); + if (metadata_schema_input == NULL) { + goto out; + } + /* Get the arrays */ position_array = table_read_column_array(position_input, NPY_FLOAT64, &num_rows, false); @@ -1371,6 +1480,19 @@ parse_site_table_dict(tsk_site_table_t *table, PyObject *dict, bool clear_table) } metadata_offset_data = PyArray_DATA(metadata_offset_array); } + if (metadata_schema_input != Py_None) { + metadata_schema = parse_metadata_schema_arg( + metadata_schema_input, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_site_table_set_metadata_schema( + table, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + } if (clear_table) { err = tsk_site_table_clear(table); @@ -1421,6 +1543,9 @@ parse_mutation_table_dict(tsk_mutation_table_t *table, PyObject *dict, bool clea PyArrayObject *metadata_offset_array = NULL; char *metadata_data; uint32_t *metadata_offset_data; + PyObject *metadata_schema_input = NULL; + const char *metadata_schema = NULL; + Py_ssize_t metadata_schema_length = 0; /* Get the input values */ site_input = get_table_dict_value(dict, "site", true); @@ -1451,6 +1576,10 @@ parse_mutation_table_dict(tsk_mutation_table_t *table, PyObject *dict, bool clea if (metadata_offset_input == NULL) { goto out; } + metadata_schema_input = get_table_dict_value(dict, "metadata_schema", false); + if (metadata_schema_input == NULL) { + goto out; + } /* Get the arrays */ site_array = table_read_column_array(site_input, NPY_INT32, &num_rows, false); @@ -1502,6 +1631,19 @@ parse_mutation_table_dict(tsk_mutation_table_t *table, PyObject *dict, bool clea } metadata_offset_data = PyArray_DATA(metadata_offset_array); } + if (metadata_schema_input != Py_None) { + metadata_schema = parse_metadata_schema_arg( + metadata_schema_input, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_mutation_table_set_metadata_schema( + table, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + } if (clear_table) { err = tsk_mutation_table_clear(table); @@ -1541,6 +1683,9 @@ parse_population_table_dict(tsk_population_table_t *table, PyObject *dict, bool PyArrayObject *metadata_array = NULL; PyObject *metadata_offset_input = NULL; PyArrayObject *metadata_offset_array = NULL; + PyObject *metadata_schema_input = NULL; + const char *metadata_schema = NULL; + Py_ssize_t metadata_schema_length = 0; /* Get the inputs */ metadata_input = get_table_dict_value(dict, "metadata", true); @@ -1551,6 +1696,10 @@ parse_population_table_dict(tsk_population_table_t *table, PyObject *dict, bool if (metadata_offset_input == NULL) { goto out; } + metadata_schema_input = get_table_dict_value(dict, "metadata_schema", false); + if (metadata_schema_input == NULL) { + goto out; + } /* Get the arrays */ metadata_array = table_read_column_array(metadata_input, NPY_INT8, @@ -1563,6 +1712,19 @@ parse_population_table_dict(tsk_population_table_t *table, PyObject *dict, bool if (metadata_offset_array == NULL) { goto out; } + if (metadata_schema_input != Py_None) { + metadata_schema = parse_metadata_schema_arg( + metadata_schema_input, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_population_table_set_metadata_schema( + table, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + } if (clear_table) { err = tsk_population_table_clear(table); @@ -1666,6 +1828,10 @@ parse_table_collection_dict(tsk_table_collection_t *tables, PyObject *tables_dic { int ret = -1; PyObject *value = NULL; + int err; + char *metadata = NULL; + const char *metadata_schema = NULL; + Py_ssize_t metadata_length, metadata_schema_length; value = get_table_dict_value(tables_dict, "sequence_length", true); if (value == NULL) { @@ -1677,6 +1843,50 @@ parse_table_collection_dict(tsk_table_collection_t *tables, PyObject *tables_dic } tables->sequence_length = PyFloat_AsDouble(value); + /* metadata_schema */ + value = get_table_dict_value(tables_dict, "metadata_schema", false); + if (value == NULL) { + goto out; + } + if (value != Py_None) { + if (!PyUnicode_Check(value)) { + PyErr_Format(PyExc_TypeError, "'metadata_schema' is not a string"); + goto out; + } + metadata_schema = parse_metadata_schema_arg(value, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_table_collection_set_metadata_schema( + tables, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + } + + /* metadata */ + value = get_table_dict_value(tables_dict, "metadata", false); + if (value == NULL) { + goto out; + } + if (value != Py_None) { + if (!PyBytes_Check(value)) { + PyErr_Format(PyExc_TypeError, "'metadata' is not bytes"); + goto out; + } + err = PyBytes_AsStringAndSize(value, &metadata, &metadata_length); + if (err != 0) { + goto out; + } + err = tsk_table_collection_set_metadata( + tables, metadata, metadata_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + } + /* individuals */ value = get_table_dict_value(tables_dict, "individuals", true); if (value == NULL) { @@ -1786,24 +1996,6 @@ parse_table_collection_dict(tsk_table_collection_t *tables, PyObject *tables_dic return ret; } -static const char * -parse_metadata_schema_arg(PyObject *arg, Py_ssize_t* metadata_schema_length) -{ - const char *ret = NULL; - if (arg == NULL) { - PyErr_Format( - PyExc_AttributeError, - "Cannot del metadata_schema, set to empty string (\"\") to clear."); - goto out; - } - ret = PyUnicode_AsUTF8AndSize(arg, metadata_schema_length); - if (ret == NULL) { - goto out; - } -out: - return ret; -} - static int write_table_arrays(tsk_table_collection_t *tables, PyObject *dict) { @@ -1816,6 +2008,8 @@ write_table_arrays(tsk_table_collection_t *tables, PyObject *dict) struct table_desc { const char *name; struct table_col *cols; + char *metadata_schema; + tsk_size_t metadata_schema_length; }; int ret = -1; PyObject *array = NULL; @@ -1949,14 +2143,21 @@ write_table_arrays(tsk_table_collection_t *tables, PyObject *dict) }; struct table_desc table_descs[] = { - {"individuals", individual_cols}, - {"nodes", node_cols}, - {"edges", edge_cols}, - {"migrations", migration_cols}, - {"sites", site_cols}, - {"mutations", mutation_cols}, - {"populations", population_cols}, - {"provenances", provenance_cols}, + {"individuals", individual_cols, + tables->individuals.metadata_schema, tables->individuals.metadata_schema_length}, + {"nodes", node_cols, + tables->nodes.metadata_schema, tables->nodes.metadata_schema_length}, + {"edges", edge_cols, + tables->edges.metadata_schema, tables->edges.metadata_schema_length}, + {"migrations", migration_cols, + tables->migrations.metadata_schema, tables->migrations.metadata_schema_length}, + {"sites", site_cols, + tables->sites.metadata_schema, tables->sites.metadata_schema_length}, + {"mutations", mutation_cols, + tables->mutations.metadata_schema, tables->mutations.metadata_schema_length}, + {"populations", population_cols, + tables->populations.metadata_schema, tables->populations.metadata_schema_length}, + {"provenances", provenance_cols, NULL, 0}, }; for (j = 0; j < sizeof(table_descs) / sizeof(*table_descs); j++) { @@ -1977,6 +2178,19 @@ write_table_arrays(tsk_table_collection_t *tables, PyObject *dict) array = NULL; col++; } + if (table_descs[j].metadata_schema_length > 0) { + array = make_Py_Unicode_FromStringAndLength(table_descs[j].metadata_schema, + table_descs[j].metadata_schema_length); + if (array == NULL) { + goto out; + } + if (PyDict_SetItemString(table_dict, "metadata_schema", array) != 0) { + goto out; + } + Py_DECREF(array); + array = NULL; + } + if (PyDict_SetItemString(dict, table_descs[j].name, table_dict) != 0) { goto out; } @@ -2003,6 +2217,18 @@ dump_tables_dict(tsk_table_collection_t *tables) if (dict == NULL) { goto out; } + + /* Dict representation version */ + val = Py_BuildValue("ll", 1, 1); + if (val == NULL) { + goto out; + } + if (PyDict_SetItemString(dict, "encoding_version", val) != 0) { + goto out; + } + Py_DECREF(val); + val = NULL; + val = Py_BuildValue("d", tables->sequence_length); if (val == NULL) { goto out; @@ -2013,6 +2239,31 @@ dump_tables_dict(tsk_table_collection_t *tables) Py_DECREF(val); val = NULL; + if (tables->metadata_schema_length > 0) { + val = make_Py_Unicode_FromStringAndLength( + tables->metadata_schema, tables->metadata_schema_length); + if (val == NULL) { + goto out; + } + if (PyDict_SetItemString(dict, "metadata_schema", val) != 0) { + goto out; + } + Py_DECREF(val); + val = NULL; + } + + if (tables->metadata_length > 0) { + val = PyBytes_FromStringAndSize(tables->metadata, tables->metadata_length); + if (val == NULL) { + goto out; + } + if (PyDict_SetItemString(dict, "metadata", val) != 0) { + goto out; + } + Py_DECREF(val); + val = NULL; + } + err = write_table_arrays(tables, dict); if (err != 0) { goto out; diff --git a/python/tests/data/dict-encodings/generate_msprime.py b/python/tests/data/dict-encodings/generate_msprime.py new file mode 100644 index 0000000000..ba2990ef5c --- /dev/null +++ b/python/tests/data/dict-encodings/generate_msprime.py @@ -0,0 +1,21 @@ +import pathlib +import pickle + +import _msprime +import msprime + +pop_configs = [msprime.PopulationConfiguration(5) for _ in range(2)] +migration_matrix = [[0, 1], [1, 0]] +ts = msprime.simulate( + population_configurations=pop_configs, + migration_matrix=migration_matrix, + mutation_rate=1, + record_migrations=True, + random_seed=1, +) +lwt = _msprime.LightweightTableCollection() +lwt.fromdict(ts.tables.asdict()) + +test_dir = pathlib.Path(__file__).parent +with open(test_dir / f"msprime-{msprime.__version__}.pkl", "wb") as f: + pickle.dump(lwt.asdict(), f) diff --git a/python/tests/data/dict-encodings/msprime-0.7.4.pkl b/python/tests/data/dict-encodings/msprime-0.7.4.pkl new file mode 100644 index 0000000000..1c8d6d2bde Binary files /dev/null and b/python/tests/data/dict-encodings/msprime-0.7.4.pkl differ diff --git a/python/tests/test_dict_encoding.py b/python/tests/test_dict_encoding.py index 3a6dfe4b0e..e1c29d7cfa 100644 --- a/python/tests/test_dict_encoding.py +++ b/python/tests/test_dict_encoding.py @@ -23,6 +23,8 @@ Test cases for the low-level dictionary encoding used to move data around in C. """ +import pathlib +import pickle import unittest import msprime @@ -105,9 +107,46 @@ def get_example_tables(): for j in range(10): tables.populations.add_row(metadata=b"p" * j) tables.provenances.add_row(timestamp="x" * j, record="y" * j) + tables.metadata_schema = tskit.MetadataSchema( + { + "codec": "struct", + "type": "object", + "properties": { + "top-level": { + "type": "array", + "items": {"type": "integer", "binaryFormat": "B"}, + "noLengthEncodingExhaustBuffer": True, + } + }, + } + ) + tables.metadata = {"top-level": [1, 2, 3, 4]} + for table in [ + "individuals", + "nodes", + "edges", + "migrations", + "sites", + "mutations", + "populations", + ]: + t = getattr(tables, table) + t.metadata_schema = tskit.MetadataSchema( + { + "codec": "struct", + "type": "object", + "properties": {table: {"type": "string", "binaryFormat": "50p"}}, + } + ) return tables +class TestEncodingVersion(unittest.TestCase): + def test_version(self): + lwt = c_module.LightweightTableCollection() + self.assertEqual(lwt.asdict()["encoding_version"], (1, 1)) + + class TestRoundTrip(unittest.TestCase): """ Tests if we can do a simple round trip on simulated data. @@ -154,7 +193,35 @@ def test_migration(self): self.verify(ts.tables) def test_example(self): - self.verify(get_example_tables()) + tables = get_example_tables() + tables.metadata_schema = tskit.MetadataSchema( + { + "codec": "struct", + "type": "object", + "properties": {"top-level": {"type": "string", "binaryFormat": "50p"}}, + } + ) + tables.metadata = {"top-level": "top-level-metadata"} + for table in [ + "individuals", + "nodes", + "edges", + "migrations", + "sites", + "mutations", + "populations", + ]: + t = getattr(tables, table) + t.packset_metadata([f"{table}-{i}".encode() for i in range(t.num_rows)]) + t.metadata_schema = tskit.MetadataSchema( + { + "codec": "struct", + "type": "object", + "properties": {table: {"type": "string", "binaryFormat": "50p"}}, + } + ) + + self.verify(tables) class TestMissingData(unittest.TestCase): @@ -167,35 +234,46 @@ def test_missing_sequence_length(self): d = tables.asdict() del d["sequence_length"] lwt = c_module.LightweightTableCollection() - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): lwt.fromdict(d) + def test_missing_metadata(self): + tables = get_example_tables() + self.assertNotEqual(tables.metadata, b"") + d = tables.asdict() + del d["metadata"] + lwt = c_module.LightweightTableCollection() + lwt.fromdict(d) + tables = tskit.TableCollection.fromdict(lwt.asdict()) + # Empty byte field still gets interpreted by schema + self.assertEqual(tables.metadata, {"top-level": []}) + + def test_missing_metadata_schema(self): + tables = get_example_tables() + self.assertNotEqual(str(tables.metadata_schema), "") + d = tables.asdict() + del d["metadata_schema"] + lwt = c_module.LightweightTableCollection() + lwt.fromdict(d) + tables = tskit.TableCollection.fromdict(lwt.asdict()) + self.assertEqual(str(tables.metadata_schema), "") + def test_missing_tables(self): tables = get_example_tables() d = tables.asdict() - table_names = set(d.keys()) - {"sequence_length"} + table_names = d.keys() - { + "sequence_length", + "metadata", + "metadata_schema", + "encoding_version", + } for table_name in table_names: d = tables.asdict() del d[table_name] lwt = c_module.LightweightTableCollection() - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): lwt.fromdict(d) - def test_missing_columns(self): - tables = get_example_tables() - d = tables.asdict() - table_names = set(d.keys()) - {"sequence_length"} - for table_name in table_names: - table_dict = d[table_name] - for colname in table_dict.keys(): - copy = dict(table_dict) - del copy[colname] - lwt = c_module.LightweightTableCollection() - d = tables.asdict() - d[table_name] = copy - with self.assertRaises(ValueError): - lwt.fromdict(d) - class TestBadTypes(unittest.TestCase): """ @@ -205,10 +283,15 @@ class TestBadTypes(unittest.TestCase): def verify_columns(self, value): tables = get_example_tables() d = tables.asdict() - table_names = set(d.keys()) - {"sequence_length"} + table_names = set(d.keys()) - { + "sequence_length", + "metadata", + "metadata_schema", + "encoding_version", + } for table_name in table_names: table_dict = d[table_name] - for colname in table_dict.keys(): + for colname in set(table_dict.keys()) - {"metadata_schema"}: copy = dict(table_dict) copy[colname] = value lwt = c_module.LightweightTableCollection() @@ -226,7 +309,7 @@ def test_str(self): def test_bad_top_level_types(self): tables = get_example_tables() d = tables.asdict() - for key in d.keys(): + for key in set(d.keys()) - {"encoding_version"}: bad_type_dict = tables.asdict() # A list should be a ValueError for both the tables and sequence_length bad_type_dict[key] = ["12345"] @@ -244,10 +327,15 @@ def verify(self, num_rows): tables = get_example_tables() d = tables.asdict() - table_names = set(d.keys()) - {"sequence_length"} + table_names = set(d.keys()) - { + "sequence_length", + "metadata", + "metadata_schema", + "encoding_version", + } for table_name in sorted(table_names): table_dict = d[table_name] - for colname in sorted(table_dict.keys()): + for colname in set(table_dict.keys()) - {"metadata_schema"}: copy = dict(table_dict) copy[colname] = table_dict[colname][:num_rows].copy() lwt = c_module.LightweightTableCollection() @@ -281,7 +369,7 @@ def verify_required_columns(self, tables, table_name, required_cols): for col in required_cols: self.assertTrue(np.array_equal(other[table_name][col], table_dict[col])) - # Removing any one of these required columns gives an error. + # Any one of these required columns as None gives an error. for col in required_cols: d = tables.asdict() copy = dict(table_dict) @@ -291,6 +379,16 @@ def verify_required_columns(self, tables, table_name, required_cols): with self.assertRaises(TypeError): lwt.fromdict(d) + # Removing any one of these required columns gives an error. + for col in required_cols: + d = tables.asdict() + copy = dict(table_dict) + del copy[col] + d[table_name] = copy + lwt = c_module.LightweightTableCollection() + with self.assertRaises(TypeError): + lwt.fromdict(d) + def verify_optional_column(self, tables, table_len, table_name, col_name): d = tables.asdict() table_dict = d[table_name] @@ -304,27 +402,52 @@ def verify_optional_column(self, tables, table_len, table_name, col_name): ) ) - def verify_offset_pair(self, tables, table_len, table_name, col_name): + def verify_offset_pair( + self, tables, table_len, table_name, col_name, required=False + ): offset_col = col_name + "_offset" + if not required: + d = tables.asdict() + table_dict = d[table_name] + table_dict[col_name] = None + table_dict[offset_col] = None + lwt = c_module.LightweightTableCollection() + lwt.fromdict(d) + out = lwt.asdict() + self.assertEqual(out[table_name][col_name].shape, (0,)) + self.assertTrue( + np.array_equal( + out[table_name][offset_col], + np.zeros(table_len + 1, dtype=np.uint32), + ) + ) + d = tables.asdict() + table_dict = d[table_name] + del table_dict[col_name] + del table_dict[offset_col] + lwt = c_module.LightweightTableCollection() + lwt.fromdict(d) + out = lwt.asdict() + self.assertEqual(out[table_name][col_name].shape, (0,)) + self.assertTrue( + np.array_equal( + out[table_name][offset_col], + np.zeros(table_len + 1, dtype=np.uint32), + ) + ) + + # Setting one or the other raises a TypeError d = tables.asdict() table_dict = d[table_name] table_dict[col_name] = None - table_dict[offset_col] = None lwt = c_module.LightweightTableCollection() - lwt.fromdict(d) - out = lwt.asdict() - self.assertEqual(out[table_name][col_name].shape, (0,)) - self.assertTrue( - np.array_equal( - out[table_name][offset_col], np.zeros(table_len + 1, dtype=np.uint32) - ) - ) + with self.assertRaises(TypeError): + lwt.fromdict(d) - # Setting one or the other raises a ValueError d = tables.asdict() table_dict = d[table_name] - table_dict[col_name] = None + del table_dict[col_name] lwt = c_module.LightweightTableCollection() with self.assertRaises(TypeError): lwt.fromdict(d) @@ -336,6 +459,13 @@ def verify_offset_pair(self, tables, table_len, table_name, col_name): with self.assertRaises(TypeError): lwt.fromdict(d) + d = tables.asdict() + table_dict = d[table_name] + del table_dict[offset_col] + lwt = c_module.LightweightTableCollection() + with self.assertRaises(TypeError): + lwt.fromdict(d) + d = tables.asdict() table_dict = d[table_name] bad_offset = np.zeros_like(table_dict[offset_col]) @@ -346,6 +476,16 @@ def verify_offset_pair(self, tables, table_len, table_name, col_name): with self.assertRaises(c_module.LibraryError): lwt.fromdict(d) + def verify_metadata_schema(self, tables, table_name): + d = tables.asdict() + d[table_name]["metadata_schema"] = None + lwt = c_module.LightweightTableCollection() + lwt.fromdict(d) + out = lwt.asdict() + self.assertNotIn("metadata_schema", out[table_name]) + tables = tskit.TableCollection.fromdict(out) + self.assertEqual(str(getattr(tables, table_name).metadata_schema), "") + def test_individuals(self): tables = get_example_tables() self.verify_required_columns(tables, "individuals", ["flags"]) @@ -355,6 +495,7 @@ def test_individuals(self): self.verify_offset_pair( tables, len(tables.individuals), "individuals", "metadata" ) + self.verify_metadata_schema(tables, "individuals") def test_nodes(self): tables = get_example_tables() @@ -362,18 +503,25 @@ def test_nodes(self): self.verify_optional_column(tables, len(tables.nodes), "nodes", "population") self.verify_optional_column(tables, len(tables.nodes), "nodes", "individual") self.verify_required_columns(tables, "nodes", ["flags", "time"]) + self.verify_metadata_schema(tables, "nodes") def test_edges(self): tables = get_example_tables() self.verify_required_columns( tables, "edges", ["left", "right", "parent", "child"] ) + self.verify_offset_pair(tables, len(tables.edges), "edges", "metadata") + self.verify_metadata_schema(tables, "edges") def test_migrations(self): tables = get_example_tables() self.verify_required_columns( tables, "migrations", ["left", "right", "node", "source", "dest", "time"] ) + self.verify_offset_pair( + tables, len(tables.migrations), "migrations", "metadata" + ) + self.verify_metadata_schema(tables, "migrations") def test_sites(self): tables = get_example_tables() @@ -381,6 +529,7 @@ def test_sites(self): tables, "sites", ["position", "ancestral_state", "ancestral_state_offset"] ) self.verify_offset_pair(tables, len(tables.sites), "sites", "metadata") + self.verify_metadata_schema(tables, "sites") def test_mutations(self): tables = get_example_tables() @@ -390,12 +539,15 @@ def test_mutations(self): ["site", "node", "derived_state", "derived_state_offset"], ) self.verify_offset_pair(tables, len(tables.mutations), "mutations", "metadata") + self.verify_metadata_schema(tables, "mutations") def test_populations(self): tables = get_example_tables() self.verify_required_columns( tables, "populations", ["metadata", "metadata_offset"] ) + self.verify_metadata_schema(tables, "populations") + self.verify_offset_pair(tables, len(tables.nodes), "nodes", "metadata", True) def test_provenances(self): tables = get_example_tables() @@ -404,3 +556,47 @@ def test_provenances(self): "provenances", ["record", "record_offset", "timestamp", "timestamp_offset"], ) + + def test_top_level_metadata(self): + tables = get_example_tables() + d = tables.asdict() + # None should give default value + d["metadata"] = None + lwt = c_module.LightweightTableCollection() + lwt.fromdict(d) + out = lwt.asdict() + self.assertNotIn("metadata", out) + tables = tskit.TableCollection.fromdict(out) + # We only removed the metadata, not the schema. So empty bytefield + # still gets interpreted + self.assertEqual(tables.metadata, {"top-level": []}) + # Missing is tested in TestMissingData above + + def test_top_level_metadata_schema(self): + tables = get_example_tables() + d = tables.asdict() + # None should give default value + d["metadata_schema"] = None + lwt = c_module.LightweightTableCollection() + lwt.fromdict(d) + out = lwt.asdict() + self.assertNotIn("metadata_schema", out) + tables = tskit.TableCollection.fromdict(out) + self.assertEqual(str(tables.metadata_schema), "") + # Missing is tested in TestMissingData above + + +class TestExamples(unittest.TestCase): + def test_pickled_examples(self): + seen_msprime = False + test_dir = pathlib.Path(__file__).parent / "data/dict-encodings" + for filename in test_dir.glob("*.pkl"): + if "msprime" in str(filename): + seen_msprime = True + with open(test_dir / filename, "rb") as f: + d = pickle.load(f) + lwt = c_module.LightweightTableCollection() + lwt.fromdict(d) + tskit.TableCollection.fromdict(d) + # Check we've done something + self.assertTrue(seen_msprime) diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 8a26eef2af..fb6280f822 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -1845,6 +1845,34 @@ class TestTableCollection(unittest.TestCase): Tests for the convenience wrapper around a collection of related tables. """ + def add_metadata(self, tc): + tc.metadata_schema = tskit.MetadataSchema( + { + "codec": "struct", + "type": "object", + "properties": {"top-level": {"type": "string", "binaryFormat": "50p"}}, + } + ) + tc.metadata = {"top-level": "top-level-metadata"} + for table in [ + "individuals", + "nodes", + "edges", + "migrations", + "sites", + "mutations", + "populations", + ]: + t = getattr(tc, table) + t.packset_metadata([f"{table}-{i}".encode() for i in range(t.num_rows)]) + t.metadata_schema = tskit.MetadataSchema( + { + "codec": "struct", + "type": "object", + "properties": {table: {"type": "string", "binaryFormat": "50p"}}, + } + ) + def test_table_references(self): ts = msprime.simulate(10, mutation_rate=2, random_seed=1) tables = ts.tables @@ -1883,8 +1911,12 @@ def test_str(self): def test_asdict(self): ts = msprime.simulate(10, mutation_rate=1, random_seed=1) t = ts.tables + self.add_metadata(t) d1 = { + "encoding_version": (1, 1), "sequence_length": t.sequence_length, + "metadata_schema": str(t.metadata_schema), + "metadata": t.metadata_schema.encode_row(t.metadata), "individuals": t.individuals.asdict(), "populations": t.populations.asdict(), "nodes": t.nodes.asdict(), @@ -1896,12 +1928,19 @@ def test_asdict(self): } d2 = t.asdict() self.assertEqual(set(d1.keys()), set(d2.keys())) + t1 = tskit.TableCollection.fromdict(d1) + t2 = tskit.TableCollection.fromdict(d2) + self.assertEqual(t1, t2) def test_from_dict(self): ts = msprime.simulate(10, mutation_rate=1, random_seed=1) t1 = ts.tables + self.add_metadata(t1) d = { + "encoding_version": (1, 1), "sequence_length": t1.sequence_length, + "metadata_schema": str(t1.metadata_schema), + "metadata": t1.metadata_schema.encode_row(t1.metadata), "individuals": t1.individuals.asdict(), "populations": t1.populations.asdict(), "nodes": t1.nodes.asdict(), @@ -1914,6 +1953,16 @@ def test_from_dict(self): t2 = tskit.TableCollection.fromdict(d) self.assertEquals(t1, t2) + def test_roundtrip_dict(self): + ts = msprime.simulate(10, mutation_rate=1, random_seed=1) + t1 = ts.tables + t2 = tskit.TableCollection.fromdict(t1.asdict()) + self.assertEqual(t1, t2) + + self.add_metadata(t1) + t2 = tskit.TableCollection.fromdict(t1.asdict()) + self.assertEqual(t1, t2) + def test_iter(self): def test_iter(table_collection): table_names = [ @@ -2207,12 +2256,13 @@ def test_bad_metadata(self): self.assertEqual(tc.ll_tables.metadata, b"") -class TestTableCollectionPickle(unittest.TestCase): +class TestTableCollectionPickle(TestTableCollection): """ Tests that we can round-trip table collections through pickle. """ def verify(self, tables): + self.add_metadata(tables) other_tables = pickle.loads(pickle.dumps(tables)) self.assertEqual(tables, other_tables) diff --git a/python/tskit/tables.py b/python/tskit/tables.py index b97e01ac33..ece9d88478 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -255,7 +255,13 @@ def asdict(self): Returns a dictionary mapping the names of the columns in this table to the corresponding numpy arrays. """ - return {col: getattr(self, col) for col in self.column_names} + ret = {col: getattr(self, col) for col in self.column_names} + # Not all tables have metadata + try: + ret["metadata_schema"] = str(self.metadata_schema) + except AttributeError: + pass + return ret def set_columns(self, **kwargs): """ @@ -432,6 +438,7 @@ def set_columns( location_offset=None, metadata=None, metadata_offset=None, + metadata_schema=None, ): """ Sets the values for each column in this :class:`IndividualTable` using the @@ -461,6 +468,7 @@ def set_columns( :type metadata: numpy.ndarray, dtype=np.int8 :param metadata_offset: The offsets into the ``metadata`` array. :type metadata_offset: numpy.ndarray, dtype=np.uint32. + :param metadata_schema: The encoded metadata schema. """ self._check_required_args(flags=flags) self.ll_table.set_columns( @@ -470,6 +478,7 @@ def set_columns( location_offset=location_offset, metadata=metadata, metadata_offset=metadata_offset, + metadata_schema=metadata_schema, ) ) @@ -627,6 +636,7 @@ def set_columns( individual=None, metadata=None, metadata_offset=None, + metadata_schema=None, ): """ Sets the values for each column in this :class:`NodeTable` using the values in @@ -655,6 +665,7 @@ def set_columns( :type metadata: numpy.ndarray, dtype=np.int8 :param metadata_offset: The offsets into the ``metadata`` array. :type metadata_offset: numpy.ndarray, dtype=np.uint32. + :param metadata_schema: The encoded metadata schema. """ self._check_required_args(flags=flags, time=time) self.ll_table.set_columns( @@ -665,6 +676,7 @@ def set_columns( individual=individual, metadata=metadata, metadata_offset=metadata_offset, + metadata_schema=metadata_schema, ) ) @@ -714,6 +726,7 @@ def append_columns( individual=individual, metadata=metadata, metadata_offset=metadata_offset, + metadata_schema=None, ) ) @@ -807,6 +820,7 @@ def set_columns( child=None, metadata=None, metadata_offset=None, + metadata_schema=None, ): """ Sets the values for each column in this :class:`EdgeTable` using the values @@ -835,7 +849,7 @@ def set_columns( :type metadata: numpy.ndarray, dtype=np.int8 :param metadata_offset: The offsets into the ``metadata`` array. :type metadata_offset: numpy.ndarray, dtype=np.uint32. - + :param metadata_schema: The encoded metadata schema. """ self._check_required_args(left=left, right=right, parent=parent, child=child) self.ll_table.set_columns( @@ -846,6 +860,7 @@ def set_columns( child=child, metadata=metadata, metadata_offset=metadata_offset, + metadata_schema=metadata_schema, ) ) @@ -1010,6 +1025,7 @@ def set_columns( time=None, metadata=None, metadata_offset=None, + metadata_schema=None, ): """ Sets the values for each column in this :class:`MigrationTable` using the values @@ -1041,6 +1057,7 @@ def set_columns( :type metadata: numpy.ndarray, dtype=np.int8 :param metadata_offset: The offsets into the ``metadata`` array. :type metadata_offset: numpy.ndarray, dtype=np.uint32. + :param metadata_schema: The encoded metadata schema. """ self._check_required_args( left=left, right=right, node=node, source=source, dest=dest, time=time @@ -1055,6 +1072,7 @@ def set_columns( time=time, metadata=metadata, metadata_offset=metadata_offset, + metadata_schema=metadata_schema, ) ) @@ -1200,6 +1218,7 @@ def set_columns( ancestral_state_offset=None, metadata=None, metadata_offset=None, + metadata_schema=None, ): """ Sets the values for each column in this :class:`SiteTable` using the values @@ -1230,6 +1249,7 @@ def set_columns( :type metadata: numpy.ndarray, dtype=np.int8 :param metadata_offset: The offsets into the ``metadata`` array. :type metadata_offset: numpy.ndarray, dtype=np.uint32. + :param metadata_schema: The encoded metadata schema. """ self._check_required_args( position=position, @@ -1243,6 +1263,7 @@ def set_columns( ancestral_state_offset=ancestral_state_offset, metadata=metadata, metadata_offset=metadata_offset, + metadata_schema=metadata_schema, ) ) @@ -1410,6 +1431,7 @@ def set_columns( parent=None, metadata=None, metadata_offset=None, + metadata_schema=None, ): """ Sets the values for each column in this :class:`MutationTable` using the values @@ -1445,6 +1467,7 @@ def set_columns( :type metadata: numpy.ndarray, dtype=np.int8 :param metadata_offset: The offsets into the ``metadata`` array. :type metadata_offset: numpy.ndarray, dtype=np.uint32. + :param metadata_schema: The encoded metadata schema. """ self._check_required_args( site=site, @@ -1461,6 +1484,7 @@ def set_columns( derived_state_offset=derived_state_offset, metadata=metadata, metadata_offset=metadata_offset, + metadata_schema=metadata_schema, ) ) @@ -1593,7 +1617,7 @@ def _text_header_and_rows(self): rows.append((str(j), str(md))) return headers, rows - def set_columns(self, metadata=None, metadata_offset=None): + def set_columns(self, metadata=None, metadata_offset=None, metadata_schema=None): """ Sets the values for each column in this :class:`PopulationTable` using the values in the specified arrays. Overwrites any data currently stored in the @@ -1611,9 +1635,14 @@ def set_columns(self, metadata=None, metadata_offset=None): :type metadata: numpy.ndarray, dtype=np.int8 :param metadata_offset: The offsets into the ``metadata`` array. :type metadata_offset: numpy.ndarray, dtype=np.uint32. + :param metadata_schema: The encoded metadata schema. """ self.ll_table.set_columns( - dict(metadata=metadata, metadata_offset=metadata_offset) + dict( + metadata=metadata, + metadata_offset=metadata_offset, + metadata_schema=metadata_schema, + ) ) def append_columns(self, metadata=None, metadata_offset=None): @@ -1916,7 +1945,10 @@ def asdict(self): map of table names to the tables themselves was returned. """ return { + "encoding_version": (1, 1), "sequence_length": self.sequence_length, + "metadata_schema": str(self.metadata_schema), + "metadata": self.metadata_schema.encode_row(self.metadata), "individuals": self.individuals.asdict(), "nodes": self.nodes.asdict(), "edges": self.edges.asdict(), @@ -1983,6 +2015,8 @@ def __getstate__(self): # Unpickle support def __setstate__(self, state): self.__init__(state["sequence_length"]) + self.metadata_schema = tskit.parse_metadata_schema(state["metadata_schema"]) + self.metadata = self.metadata_schema.decode_row(state["metadata"]) self.individuals.set_columns(**state["individuals"]) self.nodes.set_columns(**state["nodes"]) self.edges.set_columns(**state["edges"]) @@ -1995,6 +2029,17 @@ def __setstate__(self, state): @classmethod def fromdict(self, tables_dict): tables = TableCollection(tables_dict["sequence_length"]) + try: + tables.metadata_schema = tskit.parse_metadata_schema( + tables_dict["metadata_schema"] + ) + except KeyError: + pass + try: + tables.metadata = tables.metadata_schema.decode_row(tables_dict["metadata"]) + except KeyError: + pass + tables.individuals.set_columns(**tables_dict["individuals"]) tables.nodes.set_columns(**tables_dict["nodes"]) tables.edges.set_columns(**tables_dict["edges"])