diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 3ac71ee727..a662f89250 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -184,6 +184,25 @@ typedef struct { tsk_viterbi_matrix_t *viterbi_matrix; } ViterbiMatrix; +/* A named tuple of metadata schemas for a tree sequence */ +static PyTypeObject MetadataSchemas; +static PyStructSequence_Field metadata_schemas_fields[] = { + {"node", "The node metadata schema"}, + {"edge", "The edge metadata schema"}, + {"site", "The site metadata schema"}, + {"mutation", "The mutation metadata schema"}, + {"migration", "The migration metadata schema"}, + {"individual", "The individual metadata schema"}, + {"population", "The node metadata schema"}, + {NULL} +}; +static PyStructSequence_Desc metadata_schemas_desc = { + "MetadataSchemas", + "Namedtuple of metadata schemas for this tree sequence", + metadata_schemas_fields, + 7 +}; + static void handle_library_error(int err) { @@ -238,6 +257,19 @@ make_metadata(const char *metadata, Py_ssize_t length) return PyBytes_FromStringAndSize(m, length); } +static PyObject * +make_Py_Unicode_FromStringAndLength(const char * str, size_t length) { + PyObject * ret; + /* Py_BuildValue returns Py_None for zero length */ + if (length == 0) { + ret = PyUnicode_FromString(""); + } else { + ret = Py_BuildValue("s#", str, length); + } + return ret; +} + + static PyObject * make_mutation(tsk_mutation_t *mutation) { @@ -2454,6 +2486,49 @@ IndividualTable_get_metadata_offset(IndividualTable *self, void *closure) return ret; } +static PyObject * +IndividualTable_get_metadata_schema(IndividualTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (IndividualTable_check_state(self) != 0) { + goto out; + } + ret = make_Py_Unicode_FromStringAndLength(self->table->metadata_schema, self->table->metadata_schema_length); +out: + return ret; +} + +static int +IndividualTable_set_metadata_schema(IndividualTable *self, PyObject *arg, void *closure) +{ + int ret = -1; + int err; + const char *metadata_schema; + Py_ssize_t metadata_schema_length; + + if (arg == NULL) { + PyErr_Format(PyExc_ValueError, "Cannot del metadata_schema"); + goto out; + } + if (IndividualTable_check_state(self) != 0) { + goto out; + } + metadata_schema = PyUnicode_AsUTF8AndSize(arg, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_individual_table_set_metadata_schema( + self->table, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} + static PyGetSetDef IndividualTable_getsetters[] = { {"max_rows_increment", (getter) IndividualTable_get_max_rows_increment, NULL, "The size increment"}, @@ -2468,6 +2543,8 @@ static PyGetSetDef IndividualTable_getsetters[] = { {"metadata", (getter) IndividualTable_get_metadata, NULL, "The metadata array"}, {"metadata_offset", (getter) IndividualTable_get_metadata_offset, NULL, "The metadata offset array"}, + {"metadata_schema", (getter) IndividualTable_get_metadata_schema, + (setter) IndividualTable_set_metadata_schema, "The metadata schema"}, {NULL} /* Sentinel */ }; @@ -2882,6 +2959,49 @@ NodeTable_get_metadata_offset(NodeTable *self, void *closure) return ret; } +static PyObject * +NodeTable_get_metadata_schema(NodeTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (NodeTable_check_state(self) != 0) { + goto out; + } + ret = make_Py_Unicode_FromStringAndLength(self->table->metadata_schema, self->table->metadata_schema_length); +out: + return ret; +} + +static int +NodeTable_set_metadata_schema(NodeTable *self, PyObject *arg, void *closure) +{ + int ret = -1; + int err; + const char *metadata_schema; + Py_ssize_t metadata_schema_length; + + if (arg == NULL) { + PyErr_Format(PyExc_ValueError, "Cannot del metadata_schema"); + goto out; + } + if (NodeTable_check_state(self) != 0) { + goto out; + } + metadata_schema = PyUnicode_AsUTF8AndSize(arg, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_node_table_set_metadata_schema( + self->table, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} + static PyGetSetDef NodeTable_getsetters[] = { {"max_rows_increment", (getter) NodeTable_get_max_rows_increment, NULL, "The size increment"}, @@ -2896,6 +3016,8 @@ static PyGetSetDef NodeTable_getsetters[] = { {"metadata", (getter) NodeTable_get_metadata, NULL, "The metadata array"}, {"metadata_offset", (getter) NodeTable_get_metadata_offset, NULL, "The metadata offset array"}, + {"metadata_schema", (getter) NodeTable_get_metadata_schema, + (setter) NodeTable_set_metadata_schema, "The metadata schema"}, {NULL} /* Sentinel */ }; @@ -3327,6 +3449,49 @@ EdgeTable_get_metadata_offset(EdgeTable *self, void *closure) return ret; } +static PyObject * +EdgeTable_get_metadata_schema(EdgeTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (EdgeTable_check_state(self) != 0) { + goto out; + } + ret = make_Py_Unicode_FromStringAndLength(self->table->metadata_schema, self->table->metadata_schema_length); +out: + return ret; +} + +static int +EdgeTable_set_metadata_schema(EdgeTable *self, PyObject *arg, void *closure) +{ + int ret = -1; + int err; + const char *metadata_schema; + Py_ssize_t metadata_schema_length; + + if (arg == NULL) { + PyErr_Format(PyExc_ValueError, "Cannot del metadata_schema"); + goto out; + } + if (EdgeTable_check_state(self) != 0) { + goto out; + } + metadata_schema = PyUnicode_AsUTF8AndSize(arg, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_edge_table_set_metadata_schema( + self->table, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} + static PyGetSetDef EdgeTable_getsetters[] = { {"max_rows_increment", (getter) EdgeTable_get_max_rows_increment, NULL, @@ -3342,7 +3507,8 @@ static PyGetSetDef EdgeTable_getsetters[] = { {"metadata", (getter) EdgeTable_get_metadata, NULL, "The metadata array"}, {"metadata_offset", (getter) EdgeTable_get_metadata_offset, NULL, "The metadata offset array"}, - + {"metadata_schema", (getter) EdgeTable_get_metadata_schema, + (setter) EdgeTable_set_metadata_schema, "The metadata schema"}, {NULL} /* Sentinel */ }; @@ -3782,6 +3948,49 @@ MigrationTable_get_metadata_offset(MigrationTable *self, void *closure) return ret; } +static PyObject * +MigrationTable_get_metadata_schema(MigrationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MigrationTable_check_state(self) != 0) { + goto out; + } + ret = make_Py_Unicode_FromStringAndLength(self->table->metadata_schema, self->table->metadata_schema_length); +out: + return ret; +} + +static int +MigrationTable_set_metadata_schema(MigrationTable *self, PyObject *arg, void *closure) +{ + int ret = -1; + int err; + const char *metadata_schema; + Py_ssize_t metadata_schema_length; + + if (arg == NULL) { + PyErr_Format(PyExc_ValueError, "Cannot del metadata_schema"); + goto out; + } + if (MigrationTable_check_state(self) != 0) { + goto out; + } + metadata_schema = PyUnicode_AsUTF8AndSize(arg, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_migration_table_set_metadata_schema( + self->table, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} + static PyGetSetDef MigrationTable_getsetters[] = { {"max_rows_increment", (getter) MigrationTable_get_max_rows_increment, NULL, "The size increment"}, @@ -3798,6 +4007,8 @@ static PyGetSetDef MigrationTable_getsetters[] = { {"metadata", (getter) MigrationTable_get_metadata, NULL, "The metadata array"}, {"metadata_offset", (getter) MigrationTable_get_metadata_offset, NULL, "The metadata offset array"}, + {"metadata_schema", (getter) MigrationTable_get_metadata_schema, + (setter) MigrationTable_set_metadata_schema, "The metadata schema"}, {NULL} /* Sentinel */ }; @@ -4198,6 +4409,49 @@ SiteTable_get_metadata_offset(SiteTable *self, void *closure) return ret; } +static PyObject * +SiteTable_get_metadata_schema(SiteTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (SiteTable_check_state(self) != 0) { + goto out; + } + ret = make_Py_Unicode_FromStringAndLength(self->table->metadata_schema, self->table->metadata_schema_length); +out: + return ret; +} + +static int +SiteTable_set_metadata_schema(SiteTable *self, PyObject *arg, void *closure) +{ + int ret = -1; + int err; + const char *metadata_schema; + Py_ssize_t metadata_schema_length; + + if (arg == NULL) { + PyErr_Format(PyExc_ValueError, "Cannot del metadata_schema"); + goto out; + } + if (SiteTable_check_state(self) != 0) { + goto out; + } + metadata_schema = PyUnicode_AsUTF8AndSize(arg, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_site_table_set_metadata_schema( + self->table, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} + static PyGetSetDef SiteTable_getsetters[] = { {"max_rows_increment", (getter) SiteTable_get_max_rows_increment, NULL, @@ -4218,6 +4472,8 @@ static PyGetSetDef SiteTable_getsetters[] = { "The metadata array."}, {"metadata_offset", (getter) SiteTable_get_metadata_offset, NULL, "The metadata offset array."}, + {"metadata_schema", (getter) SiteTable_get_metadata_schema, + (setter) SiteTable_set_metadata_schema, "The metadata schema"}, {NULL} /* Sentinel */ }; @@ -4653,6 +4909,49 @@ MutationTable_get_metadata_offset(MutationTable *self, void *closure) return ret; } +static PyObject * +MutationTable_get_metadata_schema(MutationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (MutationTable_check_state(self) != 0) { + goto out; + } + ret = make_Py_Unicode_FromStringAndLength(self->table->metadata_schema, self->table->metadata_schema_length); +out: + return ret; +} + +static int +MutationTable_set_metadata_schema(MutationTable *self, PyObject *arg, void *closure) +{ + int ret = -1; + int err; + const char *metadata_schema; + Py_ssize_t metadata_schema_length; + + if (arg == NULL) { + PyErr_Format(PyExc_ValueError, "Cannot del metadata_schema"); + goto out; + } + if (MutationTable_check_state(self) != 0) { + goto out; + } + metadata_schema = PyUnicode_AsUTF8AndSize(arg, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_mutation_table_set_metadata_schema( + self->table, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} + static PyGetSetDef MutationTable_getsetters[] = { {"max_rows_increment", (getter) MutationTable_get_max_rows_increment, NULL, @@ -4674,6 +4973,8 @@ static PyGetSetDef MutationTable_getsetters[] = { "The metadata array"}, {"metadata_offset", (getter) MutationTable_get_metadata_offset, NULL, "The metadata_offset array"}, + {"metadata_schema", (getter) MutationTable_get_metadata_schema, + (setter) MutationTable_set_metadata_schema, "The metadata schema"}, {NULL} /* Sentinel */ }; @@ -5024,6 +5325,49 @@ PopulationTable_get_metadata_offset(PopulationTable *self, void *closure) return ret; } +static PyObject * +PopulationTable_get_metadata_schema(PopulationTable *self, void *closure) +{ + PyObject *ret = NULL; + + if (PopulationTable_check_state(self) != 0) { + goto out; + } + ret = make_Py_Unicode_FromStringAndLength(self->table->metadata_schema, self->table->metadata_schema_length); +out: + return ret; +} + +static int +PopulationTable_set_metadata_schema(PopulationTable *self, PyObject *arg, void *closure) +{ + int ret = -1; + int err; + const char *metadata_schema; + Py_ssize_t metadata_schema_length; + + if (arg == NULL) { + PyErr_Format(PyExc_ValueError, "Cannot del metadata_schema"); + goto out; + } + if (PopulationTable_check_state(self) != 0) { + goto out; + } + metadata_schema = PyUnicode_AsUTF8AndSize(arg, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_population_table_set_metadata_schema( + self->table, metadata_schema, metadata_schema_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} + static PyGetSetDef PopulationTable_getsetters[] = { {"max_rows_increment", (getter) PopulationTable_get_max_rows_increment, NULL, "The size increment"}, @@ -5034,6 +5378,8 @@ static PyGetSetDef PopulationTable_getsetters[] = { {"metadata", (getter) PopulationTable_get_metadata, NULL, "The metadata array"}, {"metadata_offset", (getter) PopulationTable_get_metadata_offset, NULL, "The metadata offset array"}, + {"metadata_schema", (getter) PopulationTable_get_metadata_schema, + (setter) PopulationTable_set_metadata_schema, "The metadata schema"}, {NULL} /* Sentinel */ }; @@ -6303,6 +6649,48 @@ TreeSequence_get_site(TreeSequence *self, PyObject *args) return ret; } +static PyObject * +TreeSequence_get_metadata_schemas(TreeSequence *self) { + PyObject *ret = NULL; + PyObject *value = NULL; + PyObject *schema = NULL; + size_t j; + tsk_table_collection_t *tables = self->tree_sequence->tables; + + struct schema_pair { + const char * schema; + tsk_size_t length; + }; + + struct schema_pair schema_pairs[] = { + {tables->nodes.metadata_schema, tables->nodes.metadata_schema_length}, + {tables->edges.metadata_schema, tables->edges.metadata_schema_length}, + {tables->sites.metadata_schema, tables->sites.metadata_schema_length}, + {tables->mutations.metadata_schema, tables->mutations.metadata_schema_length}, + {tables->migrations.metadata_schema, tables->migrations.metadata_schema_length}, + {tables->individuals.metadata_schema, tables->individuals.metadata_schema_length}, + {tables->populations.metadata_schema, tables->populations.metadata_schema_length}, + }; + + value = PyStructSequence_New(&MetadataSchemas); + if (value == NULL) { + goto out; + } + for (j = 0; j < sizeof(schema_pairs) / sizeof(*schema_pairs); j++) { + schema = make_Py_Unicode_FromStringAndLength( + schema_pairs[j].schema, schema_pairs[j].length); + if (schema == NULL) { + goto out; + } + PyStructSequence_SetItem(value, j, schema); + } + ret = value; + value = NULL; +out: + Py_XDECREF(value); + return ret; +} + static PyObject * TreeSequence_get_mutation(TreeSequence *self, PyObject *args) { @@ -7736,6 +8124,8 @@ static PyMethodDef TreeSequence_methods[] = { "Returns the number of unique nodes in the tree sequence." }, {"get_num_samples", (PyCFunction) TreeSequence_get_num_samples, METH_NOARGS, "Returns the sample size" }, + {"get_metadata_schemas", (PyCFunction) TreeSequence_get_metadata_schemas, METH_NOARGS, + "Returns the metadata schemas for the tree sequence tables"}, {"get_samples", (PyCFunction) TreeSequence_get_samples, METH_NOARGS, "Returns the samples." }, {"genealogical_nearest_neighbours", @@ -10259,6 +10649,13 @@ PyInit__tskit(void) Py_INCREF(&LsHmmType); PyModule_AddObject(module, "LsHmm", (PyObject *) &LsHmmType); + /* Metadata schemas namedtuple type*/ + if (PyStructSequence_InitType2(&MetadataSchemas, &metadata_schemas_desc) < 0) { + return NULL; + }; + Py_INCREF(&MetadataSchemas); + PyModule_AddObject(module, "MetadataSchemas", (PyObject*)&MetadataSchemas); + /* Errors and constants */ TskitException = PyErr_NewException("_tskit.TskitException", NULL, NULL); Py_INCREF(TskitException); diff --git a/python/tests/__init__.py b/python/tests/__init__.py index 6134936d74..26e5237891 100644 --- a/python/tests/__init__.py +++ b/python/tests/__init__.py @@ -213,7 +213,10 @@ def make_mutation(id_): node=node, derived_state=derived_state, parent=parent, - metadata=metadata, + encoded_metadata=metadata, + metadata_decoder=tskit.metadata.MetadataSchema.from_str( + ll_ts.get_metadata_schemas().mutation + ).decode_row, ) for j in range(tree_sequence.num_sites): @@ -224,7 +227,10 @@ def make_mutation(id_): position=pos, ancestral_state=ancestral_state, mutations=[make_mutation(ll_mut) for ll_mut in ll_mutations], - metadata=metadata, + encoded_metadata=metadata, + metadata_decoder=tskit.metadata.MetadataSchema.from_str( + ll_ts.get_metadata_schemas().site + ).decode_row, ) ) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 2de877956a..c51f2935b2 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -1332,6 +1332,84 @@ def test_sequence_iteration(self): self.assertEqual(n.id, 0) +class TestTreeSequenceMetadata(unittest.TestCase): + metadata_tables = [ + "node", + "edge", + "site", + "mutation", + "migration", + "individual", + "population", + ] + metadata_schema = tskit.metadata.MetadataSchema( + encoding="json", + schema={ + "title": "Example Metadata", + "type": "object", + "properties": { + "table": {"type": "string"}, + "string_prop": {"type": "string"}, + "num_prop": {"type": "number"}, + }, + "required": ["table", "string_prop", "num_prop"], + "additionalProperties": False, + }, + ) + + def test_metadata_schemas(self): + ts = msprime.simulate(5) + tables = ts.dump_tables() + schemas = { + table: tskit.metadata.MetadataSchema( + encoding="json", schema={"TEST": f"{table}-SCHEMA"} + ) + for table in self.metadata_tables + } + for table in self.metadata_tables: + getattr(tables, f"{table}s").metadata_schema = schemas[table] + ts = tskit.TreeSequence.load_tables(tables) + # Each table should get its own schema back + for table in self.metadata_tables: + self.assertEqual( + getattr(ts.metadata_schemas, table).to_str(), schemas[table].to_str(), + ) + + def test_metadata_round_trip_via_row_getters(self): + ts = msprime.simulate(8, random_seed=3, mutation_rate=1) + self.assertGreater(ts.num_sites, 2) + new_tables = ts.dump_tables() + tables_copy = ts.dump_tables() + for table in self.metadata_tables: + table_obj = getattr(new_tables, f"{table}s") + table_obj.metadata_schema = self.metadata_schema + table_obj.clear() + # Write back the rows, but adding unique metadata + for j, row in enumerate(getattr(tables_copy, f"{table}s")): + row_data = {k: v for k, v in zip(row._fields, row)} + row_data["metadata"] = { + "table": table, + "string_prop": f"Row number{j}", + "num_prop": j, + } + table_obj.add_row(**row_data) + new_ts = new_tables.tree_sequence() + for table in self.metadata_tables: + self.assertEqual( + getattr(new_ts, f"num_{table}s"), getattr(ts, f"num_{table}s") + ) + for table in self.metadata_tables: + for row in getattr(new_ts, f"{table}s")(): + self.assertDictEqual( + row.metadata, + { + "table": table, + "string_prop": f"Row number{row.id}", + "num_prop": row.id, + }, + ) + + class TestPickle(HighLevelTestCase): """ Test pickling of a TreeSequence. @@ -2437,30 +2515,98 @@ def test_repr(self): self.assertGreater(len(repr(c)), 0) -class TestIndividualContainer(unittest.TestCase, SimpleContainersMixin): +class SimpleContainersWithMetadataMixin: + """ + Tests for the SimpleContainerWithMetadata classes. + """ + + def test_metadata(self): + # Test decoding + instances = self.get_instances(5) + for j, inst in enumerate(instances): + self.assertEqual(inst.metadata, ("x" * j) + "decoded") + + # Decoder doesn't effect equality + (inst,) = self.get_instances(1) + (inst2,) = self.get_instances(1) + self.assertTrue(inst == inst2) + inst._metadata_decoder = lambda m: "different decoder" + self.assertTrue(inst == inst2) + + def test_decoder_run_once(self): + # For a given instance, the decoded metadata should be cached + (inst,) = self.get_instances(1) + times_run = 0 + + def decoder(m): + nonlocal times_run + times_run += 1 + return m.decode() + "decoded" + + inst._metadata_decoder = decoder + self.assertEqual(times_run, 0) + _ = inst.metadata + self.assertEqual(times_run, 1) + _ = inst.metadata + self.assertEqual(times_run, 1) + + +class TestIndividualContainer( + unittest.TestCase, SimpleContainersMixin, SimpleContainersWithMetadataMixin +): def get_instances(self, n): return [ - tskit.Individual(id_=j, flags=j, location=[j], nodes=[j], metadata=b"x" * j) + tskit.Individual( + id_=j, + flags=j, + location=[j], + nodes=[j], + encoded_metadata=b"x" * j, + metadata_decoder=lambda m: m.decode() + "decoded", + ) for j in range(n) ] -class TestNodeContainer(unittest.TestCase, SimpleContainersMixin): +class TestNodeContainer( + unittest.TestCase, SimpleContainersMixin, SimpleContainersWithMetadataMixin +): def get_instances(self, n): return [ tskit.Node( - id_=j, flags=j, time=j, population=j, individual=j, metadata=b"x" * j + id_=j, + flags=j, + time=j, + population=j, + individual=j, + encoded_metadata=b"x" * j, + metadata_decoder=lambda m: m.decode() + "decoded", ) for j in range(n) ] -class TestEdgeContainer(unittest.TestCase, SimpleContainersMixin): +class TestEdgeContainer( + unittest.TestCase, SimpleContainersMixin, SimpleContainersWithMetadataMixin +): def get_instances(self, n): - return [tskit.Edge(left=j, right=j, parent=j, child=j, id_=j) for j in range(n)] + return [ + tskit.Edge( + left=j, + right=j, + parent=j, + child=j, + encoded_metadata=b"x" * j, + metadata_decoder=lambda m: m.decode() + "decoded", + id_=j, + ) + for j in range(n) + ] -class TestSiteContainer(unittest.TestCase, SimpleContainersMixin): +class TestSiteContainer( + unittest.TestCase, SimpleContainersMixin, SimpleContainersWithMetadataMixin +): def get_instances(self, n): return [ tskit.Site( @@ -2468,13 +2614,16 @@ def get_instances(self, n): position=j, ancestral_state="A" * j, mutations=TestMutationContainer().get_instances(j), - metadata=b"x" * j, + encoded_metadata=b"x" * j, + metadata_decoder=lambda m: m.decode() + "decoded", ) for j in range(n) ] -class TestMutationContainer(unittest.TestCase, SimpleContainersMixin): +class TestMutationContainer( + unittest.TestCase, SimpleContainersMixin, SimpleContainersWithMetadataMixin +): def get_instances(self, n): return [ tskit.Mutation( @@ -2483,23 +2632,44 @@ def get_instances(self, n): node=j, derived_state="A" * j, parent=j, - metadata=b"x" * j, + encoded_metadata=b"x" * j, + metadata_decoder=lambda m: m.decode() + "decoded", ) for j in range(n) ] -class TestMigrationContainer(unittest.TestCase, SimpleContainersMixin): +class TestMigrationContainer( + unittest.TestCase, SimpleContainersMixin, SimpleContainersWithMetadataMixin +): def get_instances(self, n): return [ - tskit.Migration(left=j, right=j, node=j, source=j, dest=j, time=j) + tskit.Migration( + left=j, + right=j, + node=j, + source=j, + dest=j, + time=j, + encoded_metadata=b"x" * j, + metadata_decoder=lambda m: m.decode() + "decoded", + ) for j in range(n) ] -class TestPopulationContainer(unittest.TestCase, SimpleContainersMixin): +class TestPopulationContainer( + unittest.TestCase, SimpleContainersMixin, SimpleContainersWithMetadataMixin +): def get_instances(self, n): - return [tskit.Population(id_=j, metadata="x" * j) for j in range(n)] + return [ + tskit.Population( + id_=j, + encoded_metadata=b"x" * j, + metadata_decoder=lambda m: m.decode() + "decoded", + ) + for j in range(n) + ] class TestProvenanceContainer(unittest.TestCase, SimpleContainersMixin): diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index aca43b579d..f612151dbf 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -445,6 +445,36 @@ def test_mean_descendants(self): A = ts.mean_descendants([focal[2:], focal[:2]]) self.assertEqual(A.shape, (ts.get_num_nodes(), 2)) + def test_metadata_schemas(self): + tables = _tskit.TableCollection(1.0) + metadata_tables = [ + "node", + "edge", + "site", + "individual", + "mutation", + "migration", + "population", + ] + for table_name in metadata_tables: + table = getattr(tables, f"{table_name}s") + table.metadata_schema = f"{table_name} test metadata schema" + ts = _tskit.TreeSequence() + ts.load_tables(tables) + schemas = ts.get_metadata_schemas() + for table_name in metadata_tables: + self.assertEqual( + getattr(schemas, table_name), f"{table_name} test metadata schema" + ) + for table_name in metadata_tables: + table = getattr(tables, f"{table_name}s") + table.metadata_schema = "" + ts = _tskit.TreeSequence() + ts.load_tables(tables) + schemas = ts.get_metadata_schemas() + for table_name in metadata_tables: + self.assertEqual(getattr(schemas, table_name), "") + class StatsInterfaceMixin: """ @@ -2013,6 +2043,58 @@ def test_map_mutations_errors(self): self.assertRaises(_tskit.LibraryError, tree.map_mutations, genotypes) +class MetadataTestMixin: + tables = [ + "nodes", + "edges", + "sites", + "mutations", + "migrations", + "individuals", + "populations", + ] + + +class TestTableMetadataSchema(unittest.TestCase, MetadataTestMixin): + def test_metadata_schema_attribute(self): + tables = _tskit.TableCollection(1.0) + for table in self.tables: + table = getattr(tables, table) + self.assertEqual(table.metadata_schema, "") + example = "An example of metadata schema with unicode 🎄🌳🌴🌲🎋" + table.metadata_schema = example + self.assertEqual(table.metadata_schema, example) + with self.assertRaises(ValueError): + del table.metadata_schema + table.metadata_schema = "" + self.assertEqual(table.metadata_schema, "") + with self.assertRaises(TypeError): + table.metadata_schema = None + + +class TestMetadataSchemaNamedTuple(unittest.TestCase, MetadataTestMixin): + def test_named_tuple_init(self): + with self.assertRaises(TypeError): + metadata_schemas = _tskit.MetadataSchemas() + with self.assertRaises(TypeError): + metadata_schemas = _tskit.MetadataSchemas([]) + with self.assertRaises(TypeError): + metadata_schemas = _tskit.MetadataSchemas(["test_schema"]) + metadata_schemas = _tskit.MetadataSchemas( + f"{table}_test_schema" for table in self.tables + ) + self.assertEqual( + metadata_schemas, tuple(f"{table}_test_schema" for table in self.tables) + ) + for table in self.tables: + self.assertEqual( + getattr(metadata_schemas, table[:-1]), f"{table}_test_schema" + ) + for table in self.tables: + with self.assertRaises(AttributeError): + setattr(metadata_schemas, table[:-1], "") + + class TestModuleFunctions(unittest.TestCase): """ Tests for the module level functions. diff --git a/python/tests/test_metadata.py b/python/tests/test_metadata.py index 50be388e35..75c7c9b156 100644 --- a/python/tests/test_metadata.py +++ b/python/tests/test_metadata.py @@ -284,3 +284,9 @@ def test_populations(self): expected = ["mno", ")(*&^%$#@!"] for a, b in zip(expected, p): self.assertEqual(a.encode("utf8"), b.metadata) + + +class TestMetadataSchema(unittest.TestCase): + """ + Tests that use the MetadataSchema Class + """ diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index eed28f4c23..3c026f7353 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -37,6 +37,8 @@ import _tskit import tests.tsutil as tsutil import tskit +import tskit.exceptions as exceptions +import tskit.metadata as metadata class Column: @@ -75,6 +77,14 @@ class CommonTestsMixin: we have to make this a mixin. """ + def make_input_data(self, num_rows): + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + return input_data + def test_max_rows_increment(self): for bad_value in [-1, -(2 ** 10)]: self.assertRaises( @@ -147,11 +157,7 @@ def test_set_columns_string_errors(self): self.assertRaises(TypeError, table.set_columns, **kwargs) def test_set_columns_interface(self): - kwargs = {c.name: c.get_input(1) for c in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(1) - kwargs[list_col.name] = value - kwargs[offset_col.name] = [0, 1] + kwargs = self.make_input_data(1) # Make sure this works. table = self.table_class() table.set_columns(**kwargs) @@ -170,11 +176,7 @@ def test_set_columns_interface(self): self.assertRaises(ValueError, table.append_columns, **error_kwargs) def test_set_columns_from_dict(self): - kwargs = {c.name: c.get_input(1) for c in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(1) - kwargs[list_col.name] = value - kwargs[offset_col.name] = [0, 1] + kwargs = self.make_input_data(1) # Make sure this works. t1 = self.table_class() t1.set_columns(**kwargs) @@ -183,11 +185,7 @@ def test_set_columns_from_dict(self): self.assertEqual(t1, t2) def test_set_columns_dimension(self): - kwargs = {c.name: c.get_input(1) for c in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(1) - kwargs[list_col.name] = value - kwargs[offset_col.name] = [0, 1] + kwargs = self.make_input_data(1) table = self.table_class() table.set_columns(**kwargs) table.append_columns(**kwargs) @@ -198,8 +196,7 @@ def test_set_columns_dimension(self): error_kwargs[focal_col.name] = bad_dims self.assertRaises(ValueError, table.set_columns, **error_kwargs) self.assertRaises(ValueError, table.append_columns, **error_kwargs) - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(1) + for _, offset_col in self.ragged_list_columns: error_kwargs = dict(kwargs) for bad_dims in [5, [[1], [1]], np.zeros((2, 2))]: error_kwargs[offset_col.name] = bad_dims @@ -210,13 +207,9 @@ def test_set_columns_dimension(self): self.assertRaises(ValueError, table.set_columns, **error_kwargs) def test_set_columns_input_sizes(self): - num_rows = 100 - input_data = {col.name: col.get_input(num_rows) for col in self.columns} + input_data = self.make_input_data(100) col_map = {col.name: col for col in self.columns} for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) col_map[list_col.name] = list_col col_map[offset_col.name] = offset_col table = self.table_class() @@ -252,11 +245,7 @@ def test_set_column_attributes_empty(self): def test_set_column_attributes_data(self): table = self.table_class() for num_rows in [1, 10, 100]: - input_data = {col.name: col.get_input(num_rows) for col in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + input_data = self.make_input_data(num_rows) table.set_columns(**input_data) for list_col, offset_col in self.ragged_list_columns: @@ -296,11 +285,7 @@ def test_set_column_attributes_data(self): def test_set_column_attributes_errors(self): table = self.table_class() num_rows = 10 - input_data = {col.name: col.get_input(num_rows) for col in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + input_data = self.make_input_data(num_rows) table.set_columns(**input_data) for list_col, offset_col in self.ragged_list_columns: @@ -356,11 +341,7 @@ def test_add_row_data(self): def test_add_row_round_trip(self): for num_rows in [0, 10, 100]: - input_data = {col.name: col.get_input(num_rows) for col in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + input_data = self.make_input_data(num_rows) t1 = self.table_class() t1.set_columns(**input_data) for colname, input_array in input_data.items(): @@ -450,12 +431,9 @@ def test_truncate_errors(self): def test_append_columns_data(self): for num_rows in [0, 10, 100, 1000]: - input_data = {col.name: col.get_input(num_rows) for col in self.columns} + input_data = self.make_input_data(num_rows) offset_cols = set() - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + for _, offset_col in self.ragged_list_columns: offset_cols.add(offset_col.name) table = self.table_class() for j in range(1, 10): @@ -478,11 +456,7 @@ def test_append_columns_data(self): def test_append_columns_max_rows(self): for num_rows in [0, 10, 100, 1000]: - input_data = {col.name: col.get_input(num_rows) for col in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + input_data = self.make_input_data(num_rows) for max_rows in [0, 1, 8192]: table = self.table_class(max_rows_increment=max_rows) for j in range(1, 10): @@ -493,11 +467,7 @@ def test_append_columns_max_rows(self): def test_str(self): for num_rows in [0, 10]: - input_data = {col.name: col.get_input(num_rows) for col in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + input_data = self.make_input_data(num_rows) table = self.table_class() table.set_columns(**input_data) s = str(table) @@ -517,11 +487,7 @@ def test_repr_html(self): def test_copy(self): for num_rows in [0, 10]: - input_data = {col.name: col.get_input(num_rows) for col in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + input_data = self.make_input_data(num_rows) table = self.table_class() table.set_columns(**input_data) for _ in range(10): @@ -533,11 +499,7 @@ def test_copy(self): def test_pickle(self): for num_rows in [0, 10, 100]: - input_data = {col.name: col.get_input(num_rows) for col in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + input_data = self.make_input_data(num_rows) table = self.table_class() table.set_columns(**input_data) pkl = pickle.dumps(table) @@ -550,11 +512,7 @@ def test_pickle(self): def test_equality(self): for num_rows in [1, 10, 100]: - input_data = {col.name: col.get_input(num_rows) for col in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + input_data = self.make_input_data(num_rows) t1 = self.table_class() t2 = self.table_class() self.assertEqual(t1, t1) @@ -605,11 +563,7 @@ def test_equality(self): def test_bad_offsets(self): for num_rows in [10, 100]: - input_data = {col.name: col.get_input(num_rows) for col in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + input_data = self.make_input_data(num_rows) t = self.table_class() t.set_columns(**input_data) @@ -644,13 +598,36 @@ class MetadataTestsMixin: Tests for column that have metadata columns. """ + metadata_schema = metadata.MetadataSchema( + encoding="json", + schema={ + "title": "Example Metadata", + "type": "object", + "properties": {"one": {"type": "string"}, "two": {"type": "number"}}, + "required": ["one", "two"], + "additionalProperties": False, + }, + ) + + def metadata_example_data(self): + try: + self.val += 1 + except AttributeError: + self.val = 0 + return {"one": "val one", "two": self.val} + + def input_data_for_add_row(self): + input_data = {col.name: col.get_input(1) for col in self.columns} + kwargs = {col: data[0] for col, data in input_data.items()} + for col in self.string_colnames: + kwargs[col] = "x" + for col in self.binary_colnames: + kwargs[col] = b"x" + return kwargs + def test_random_metadata(self): for num_rows in [0, 10, 100]: - input_data = {col.name: col.get_input(num_rows) for col in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + input_data = self.make_input_data(num_rows) table = self.table_class() metadatas = [tsutil.random_bytes(10) for _ in range(num_rows)] metadata, metadata_offset = tskit.pack_bytes(metadatas) @@ -663,36 +640,29 @@ def test_random_metadata(self): self.assertEqual(metadatas, unpacked_metadatas) def test_optional_metadata(self): - for num_rows in [0, 10, 100]: - input_data = {col.name: col.get_input(num_rows) for col in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) - table = self.table_class() - del input_data["metadata"] - del input_data["metadata_offset"] - table.set_columns(**input_data) - self.assertEqual(len(list(table.metadata)), 0) - self.assertEqual( - list(table.metadata_offset), [0 for _ in range(num_rows + 1)] - ) - # Supplying None is the same not providing the column. - input_data["metadata"] = None - input_data["metadata_offset"] = None - table.set_columns(**input_data) - self.assertEqual(len(list(table.metadata)), 0) - self.assertEqual( - list(table.metadata_offset), [0 for _ in range(num_rows + 1)] - ) + if not getattr(self, "metadata_mandatory", False): + for num_rows in [0, 10, 100]: + input_data = self.make_input_data(num_rows) + table = self.table_class() + del input_data["metadata"] + del input_data["metadata_offset"] + table.set_columns(**input_data) + self.assertEqual(len(list(table.metadata)), 0) + self.assertEqual( + list(table.metadata_offset), [0 for _ in range(num_rows + 1)] + ) + # Supplying None is the same not providing the column. + input_data["metadata"] = None + input_data["metadata_offset"] = None + table.set_columns(**input_data) + self.assertEqual(len(list(table.metadata)), 0) + self.assertEqual( + list(table.metadata_offset), [0 for _ in range(num_rows + 1)] + ) def test_packset_metadata(self): for num_rows in [0, 10, 100]: - input_data = {col.name: col.get_input(num_rows) for col in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + input_data = self.make_input_data(num_rows) table = self.table_class() table.set_columns(**input_data) metadatas = [tsutil.random_bytes(10) for _ in range(num_rows)] @@ -701,9 +671,85 @@ def test_packset_metadata(self): self.assertTrue(np.array_equal(table.metadata, metadata)) self.assertTrue(np.array_equal(table.metadata_offset, metadata_offset)) + def test_set_metadata_schema(self): + metadata_schema2 = metadata.MetadataSchema("json", {}) + table = self.table_class() + # Set + table.metadata_schema = self.metadata_schema + self.assertEqual(table.metadata_schema.to_str(), self.metadata_schema.to_str()) + # Remove + del table.metadata_schema + self.assertEqual( + table.metadata_schema.to_str(), metadata.NullMetadataSchema().to_str() + ) + # Overwrite + table.metadata_schema = self.metadata_schema + table.metadata_schema = metadata_schema2 + self.assertEqual(table.metadata_schema.to_str(), metadata_schema2.to_str()) + # Delete + del table.metadata_schema + self.assertEqual( + table.metadata_schema.to_str(), metadata.NullMetadataSchema().to_str() + ) + # Empty string results in NullMetadataSchema + table.ll_table.metadata_schema = "" + table._update_metadata_schema_cache_from_ll() + self.assertEqual( + table.metadata_schema.to_str(), metadata.NullMetadataSchema().to_str() + ) + + def test_default_metadata_schema(self): + table = self.table_class() + # Default is no-op metadata codec + self.assertEqual( + table.metadata_schema.to_str(), metadata.NullMetadataSchema().to_str() + ) + table.add_row( + **{**self.input_data_for_add_row(), "metadata": b"acceptable bytes"} + ) + # Adding non-bytes metadata should error + with self.assertRaises(TypeError): + table.add_row( + **{ + **self.input_data_for_add_row(), + "metadata": self.metadata_example_data(), + } + ) + + def test_bad_metadata_schema(self): + table = self.table_class() + table.ll_table.metadata_schema = "I'm not JSON" + with self.assertRaises(ValueError): + table._update_metadata_schema_cache_from_ll() + with self.assertRaises(AttributeError): + table.metadata_schema = "I'm not JSON" + + def test_row_round_trip_metadata_schema(self): + data = self.metadata_example_data() + table = self.table_class() + table.metadata_schema = self.metadata_schema + table.add_row(**{**self.input_data_for_add_row(), "metadata": data}) + self.assertDictEqual(table[0].metadata, data) -class TestIndividualTable(unittest.TestCase, CommonTestsMixin, MetadataTestsMixin): + def test_bad_row_metadata_schema(self): + data = self.metadata_example_data() + data["I really shouldn't be here"] = 6 + table = self.table_class() + table.metadata_schema = self.metadata_schema + with self.assertRaises(exceptions.MetadataValidationError): + table.add_row(**{**self.input_data_for_add_row(), "metadata": data}) + self.assertEqual(len(table), 0) + + def test_absent_metadata_with_required_schema(self): + table = self.table_class() + table.metadata_schema = self.metadata_schema + input_data = self.input_data_for_add_row() + del input_data["metadata"] + with self.assertRaises(exceptions.MetadataValidationError): + table.add_row(**{**input_data}) + +class TestIndividualTable(unittest.TestCase, CommonTestsMixin, MetadataTestsMixin): columns = [UInt32Column("flags")] ragged_list_columns = [ (DoubleColumn("location"), UInt32Column("location_offset")), @@ -928,11 +974,7 @@ def test_add_row_bad_data(self): def test_packset_ancestral_state(self): for num_rows in [0, 10, 100]: - input_data = {col.name: col.get_input(num_rows) for col in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + input_data = self.make_input_data(num_rows) table = self.table_class() table.set_columns(**input_data) ancestral_states = [tsutil.random_strings(10) for _ in range(num_rows)] @@ -988,11 +1030,7 @@ def test_add_row_bad_data(self): def test_packset_derived_state(self): for num_rows in [0, 10, 100]: - input_data = {col.name: col.get_input(num_rows) for col in self.columns} - for list_col, offset_col in self.ragged_list_columns: - value = list_col.get_input(num_rows) - input_data[list_col.name] = value - input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + input_data = self.make_input_data(num_rows) table = self.table_class() table.set_columns(**input_data) derived_states = [tsutil.random_strings(10) for _ in range(num_rows)] @@ -1104,7 +1142,8 @@ def test_packset_record(self): self.assertEqual(t[1].record, "BBBB") -class TestPopulationTable(unittest.TestCase, CommonTestsMixin): +class TestPopulationTable(unittest.TestCase, CommonTestsMixin, MetadataTestsMixin): + metadata_mandatory = True columns = [] ragged_list_columns = [(CharColumn("metadata"), UInt32Column("metadata_offset"))] equal_len_columns = [[]] diff --git a/python/tskit/exceptions.py b/python/tskit/exceptions.py index 1e79f468a2..4b6f50c596 100644 --- a/python/tskit/exceptions.py +++ b/python/tskit/exceptions.py @@ -55,5 +55,11 @@ class DuplicatePositionsError(TskitException): class ProvenanceValidationError(TskitException): """ - A JSON document did non validate against the provenance schema. + A JSON document did not validate against the provenance schema. + """ + + +class MetadataValidationError(TskitException): + """ + A metadata object did not validate against the provenance schema. """ diff --git a/python/tskit/metadata.py b/python/tskit/metadata.py new file mode 100644 index 0000000000..52a55459f3 --- /dev/null +++ b/python/tskit/metadata.py @@ -0,0 +1,115 @@ +# MIT License +# +# Copyright (c) 2020 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Classes for metadata decoding, encoding and validation +""" +import abc +import json +from typing import Any +from typing import Optional + +import jsonschema + +import tskit.exceptions as exceptions + + +class MetadataCodec(abc.ABC): + @abc.abstractmethod + def decode(self, encoded_bytes: bytes) -> Any: + pass + + @abc.abstractmethod + def encode(self, obj: Any) -> bytes: + pass + + +class JSONCodec(MetadataCodec): + name = "json" + + def decode(self, encoded_bytes: bytes) -> Any: + return json.loads(encoded_bytes.decode()) + + def encode(self, obj: Any) -> bytes: + return json.dumps(obj).encode() + + +class AbstractMetadataSchema(abc.ABC): + @abc.abstractmethod + def to_str(self) -> Optional[str]: + pass + + @abc.abstractmethod + def validate_and_encode_row(self, row: Any) -> bytes: + pass + + @abc.abstractmethod + def decode_row(self, row: bytes) -> Any: + pass + + +class MetadataSchema(AbstractMetadataSchema): + @classmethod + def from_str(cls, encoded_schema: Optional[str]) -> AbstractMetadataSchema: + if encoded_schema == "": + return NullMetadataSchema() + else: + try: + decoded = json.loads(encoded_schema) + except json.decoder.JSONDecodeError: + raise ValueError(f"Metadata schema is not JSON, found {encoded_schema}") + return cls(decoded["encoding"], decoded["schema"]) + + def __init__(self, encoding, schema) -> None: + self.encoding: str = encoding + self.schema: str = schema + + def _get_codec(self) -> MetadataCodec: + try: + return {codec().name: codec() for codec in MetadataCodec.__subclasses__()}[ + self.encoding + ] + except KeyError: + raise ValueError(f"Unrecognised metadata encoding:{self.encoding}") + + def to_str(self) -> str: + return json.dumps({"encoding": self.encoding, "schema": self.schema}) + + def validate_and_encode_row(self, row: Any) -> bytes: + try: + jsonschema.validate(row, self.schema) + except jsonschema.exceptions.ValidationError as ve: + raise exceptions.MetadataValidationError from ve + return self._get_codec().encode(row) + + def decode_row(self, row: bytes) -> Any: + return self._get_codec().decode(row) + + +class NullMetadataSchema(AbstractMetadataSchema): + def to_str(self) -> None: + return None + + def validate_and_encode_row(self, row: bytes) -> bytes: + return row + + def decode_row(self, row: bytes) -> bytes: + return row diff --git a/python/tskit/tables.py b/python/tskit/tables.py index ba304fa049..58bf4b843e 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -28,11 +28,14 @@ import datetime import json import warnings +from typing import Any +from typing import Tuple import numpy as np import _tskit import tskit +import tskit.metadata as metadata import tskit.provenance as provenance import tskit.util as util @@ -100,9 +103,10 @@ class BaseTable: # The list of columns in the table. Must be set by subclasses. column_names = [] - def __init__(self, ll_table, row_class): + def __init__(self, ll_table, row_class, **kwargs): self.ll_table = ll_table self.row_class = row_class + super().__init__(**kwargs) def _check_required_args(self, **kwargs): for k, v in kwargs.items(): @@ -154,7 +158,13 @@ def __getitem__(self, index): index += len(self) if index < 0 or index >= len(self): raise IndexError("Index out of bounds") - return self.row_class(*self.ll_table.get_row(index)) + row = self.ll_table.get_row(index) + try: + row = self.decode_row(row) + except AttributeError: + # This means the class returns the low-level row unchanged. + pass + return self.row_class(*row) def clear(self): """ @@ -244,6 +254,10 @@ class MetadataMixin: Mixin class for tables that have a metadata column. """ + def __init__(self): + self.metadata_column_index = self.row_class._fields.index("metadata") + self._update_metadata_schema_cache_from_ll() + def packset_metadata(self, metadatas): """ Packs the specified list of metadata values and updates the ``metadata`` @@ -258,6 +272,32 @@ def packset_metadata(self, metadatas): d["metadata_offset"] = offset self.set_columns(**d) + @property + def metadata_schema(self) -> metadata.MetadataSchema: + return self._metadata_schema_cache + + @metadata_schema.setter + def metadata_schema(self, schema: metadata.MetadataSchema) -> None: + self.ll_table.metadata_schema = schema.to_str() + self._update_metadata_schema_cache_from_ll() + + @metadata_schema.deleter + def metadata_schema(self) -> None: + del self.ll_table.metadata_schema + self._update_metadata_schema_cache_from_ll() + + def decode_row(self, row: Tuple[Any]) -> Tuple: + return ( + row[: self.metadata_column_index] + + (self._metadata_schema_cache.decode_row(row[self.metadata_column_index]),) + + row[self.metadata_column_index + 1 :] + ) + + def _update_metadata_schema_cache_from_ll(self) -> None: + self._metadata_schema_cache = metadata.MetadataSchema.from_str( + self.ll_table.metadata_schema + ) + class IndividualTable(BaseTable, MetadataMixin): """ @@ -330,6 +370,7 @@ def add_row(self, flags=0, location=None, metadata=None): :return: The ID of the newly added node. :rtype: int """ + metadata = self.metadata_schema.validate_and_encode_row(metadata) return self.ll_table.add_row(flags=flags, location=location, metadata=metadata) def set_columns( @@ -518,6 +559,7 @@ def add_row(self, flags=0, time=0, population=-1, individual=-1, metadata=None): :return: The ID of the newly added node. :rtype: int """ + metadata = self.metadata_schema.validate_and_encode_row(metadata) return self.ll_table.add_row(flags, time, population, individual, metadata) def set_columns( @@ -693,6 +735,7 @@ def add_row(self, left, right, parent, child, metadata=None): :return: The ID of the newly added edge. :rtype: int """ + metadata = self.metadata_schema.validate_and_encode_row(metadata) return self.ll_table.add_row(left, right, parent, child, metadata) def set_columns( @@ -888,6 +931,7 @@ def add_row(self, left, right, node, source, dest, time, metadata=None): :return: The ID of the newly added migration. :rtype: int """ + metadata = self.metadata_schema.validate_and_encode_row(metadata) return self.ll_table.add_row(left, right, node, source, dest, time, metadata) def set_columns( @@ -1075,6 +1119,7 @@ def add_row(self, position, ancestral_state, metadata=None): :return: The ID of the newly added site. :rtype: int """ + metadata = self.metadata_schema.validate_and_encode_row(metadata) return self.ll_table.add_row(position, ancestral_state, metadata) def set_columns( @@ -1277,6 +1322,7 @@ def add_row(self, site, node, derived_state, parent=-1, metadata=None): :return: The ID of the newly added mutation. :rtype: int """ + metadata = self.metadata_schema.validate_and_encode_row(metadata) return self.ll_table.add_row(site, node, derived_state, parent, metadata) def set_columns( @@ -1454,6 +1500,7 @@ def add_row(self, metadata=None): :return: The ID of the newly added population. :rtype: int """ + metadata = self.metadata_schema.validate_and_encode_row(metadata) return self.ll_table.add_row(metadata=metadata) def _text_header_and_rows(self): diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 7e4f3792ee..92ca518bfe 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -39,6 +39,7 @@ import tskit.drawing as drawing import tskit.exceptions as exceptions import tskit.formats as formats +import tskit.metadata as metadata import tskit.provenance as provenance import tskit.tables as tables import tskit.util as util @@ -51,6 +52,11 @@ "CoalescenceRecord", ["left", "right", "node", "children", "time", "population"] ) +MetadataSchemas = collections.namedtuple( + "MetadataSchemas", + ["node", "edge", "site", "mutation", "migration", "individual", "population"], +) + # TODO this interface is rubbish. Should have much better printing options. # TODO we should be use __slots__ here probably. @@ -65,7 +71,46 @@ def __repr__(self): return repr(self.__dict__) -class Individual(SimpleContainer): +class SimpleContainerWithMetadata(SimpleContainer): + """ + This class allows metadata to be lazily decoded and cached + """ + + class CachedMetadata: + def __get__(self, container, owner): + decoded = container._metadata_decoder(container._encoded_metadata) + container.__dict__["metadata"] = decoded + return decoded + + metadata = CachedMetadata() + + def __eq__(self, other): + # We need to remove metadata and the decoder so we are just comparing + # the encoded metadata, along with the other attributes + other = {**other.__dict__} + try: + del other["metadata"] + except KeyError: + pass + del other["_metadata_decoder"] + self_ = {**self.__dict__} + try: + del self_["metadata"] + except KeyError: + pass + del self_["_metadata_decoder"] + return self_ == other + + def __repr__(self): + # Make sure we have a decoded metadata + _ = self.metadata + out = {**self.__dict__} + del out["_encoded_metadata"] + del out["_metadata_decoder"] + return repr(out) + + +class Individual(SimpleContainerWithMetadata): """ An :ref:`individual ` in a tree sequence. Since nodes correspond to genomes, individuals are associated with a collection @@ -91,11 +136,20 @@ class Individual(SimpleContainer): :vartype metadata: bytes """ - def __init__(self, id_=None, flags=0, location=None, nodes=None, metadata=""): + def __init__( + self, + id_=None, + flags=0, + location=None, + nodes=None, + encoded_metadata=b"", + metadata_decoder=lambda metadata: metadata, + ): self.id = id_ self.flags = flags self.location = location - self.metadata = metadata + self._encoded_metadata = encoded_metadata + self._metadata_decoder = metadata_decoder self.nodes = nodes def __eq__(self, other): @@ -108,7 +162,7 @@ def __eq__(self, other): ) -class Node(SimpleContainer): +class Node(SimpleContainerWithMetadata): """ A :ref:`node ` in a tree sequence, corresponding to a single genome. The ``time`` and ``population`` are attributes of the @@ -134,13 +188,21 @@ class Node(SimpleContainer): """ def __init__( - self, id_=None, flags=0, time=0, population=NULL, individual=NULL, metadata="" + self, + id_=None, + flags=0, + time=0, + population=NULL, + individual=NULL, + encoded_metadata=b"", + metadata_decoder=lambda metadata: metadata, ): self.id = id_ self.time = time self.population = population self.individual = individual - self.metadata = metadata + self._encoded_metadata = encoded_metadata + self._metadata_decoder = metadata_decoder self.flags = flags def is_sample(self): @@ -153,7 +215,7 @@ def is_sample(self): return self.flags & NODE_IS_SAMPLE -class Edge(SimpleContainer): +class Edge(SimpleContainerWithMetadata): """ An :ref:`edge ` in a tree sequence. @@ -179,13 +241,23 @@ class Edge(SimpleContainer): :vartype metadata: bytes """ - def __init__(self, left, right, parent, child, metadata=b"", id_=None): + def __init__( + self, + left, + right, + parent, + child, + encoded_metadata=b"", + id_=None, + metadata_decoder=lambda metadata: metadata, + ): self.id = id_ self.left = left self.right = right self.parent = parent self.child = child - self.metadata = metadata + self._encoded_metadata = encoded_metadata + self._metadata_decoder = metadata_decoder def __repr__(self): return ( @@ -206,7 +278,7 @@ def span(self): return self.right - self.left -class Site(SimpleContainer): +class Site(SimpleContainerWithMetadata): """ A :ref:`site ` in a tree sequence. @@ -231,15 +303,24 @@ class Site(SimpleContainer): :vartype mutations: list[:class:`Mutation`] """ - def __init__(self, id_, position, ancestral_state, mutations, metadata): + def __init__( + self, + id_, + position, + ancestral_state, + mutations, + encoded_metadata=b"", + metadata_decoder=lambda metadata: metadata, + ): self.id = id_ self.position = position self.ancestral_state = ancestral_state self.mutations = mutations - self.metadata = metadata + self._encoded_metadata = encoded_metadata + self._metadata_decoder = metadata_decoder -class Mutation(SimpleContainer): +class Mutation(SimpleContainerWithMetadata): """ A :ref:`mutation ` in a tree sequence. @@ -279,17 +360,19 @@ def __init__( node=NULL, derived_state=None, parent=NULL, - metadata=None, + encoded_metadata=b"", + metadata_decoder=lambda metadata: metadata, ): self.id = id_ self.site = site self.node = node self.derived_state = derived_state self.parent = parent - self.metadata = metadata + self._encoded_metadata = encoded_metadata + self._metadata_decoder = metadata_decoder -class Migration(SimpleContainer): +class Migration(SimpleContainerWithMetadata): """ A :ref:`migration ` in a tree sequence. @@ -314,7 +397,18 @@ class Migration(SimpleContainer): :vartype time: float """ - def __init__(self, left, right, node, source, dest, time, metadata=b"", id_=None): + def __init__( + self, + left, + right, + node, + source, + dest, + time, + encoded_metadata=b"", + metadata_decoder=lambda metadata: metadata, + id_=None, + ): self.id = id_ self.left = left self.right = right @@ -322,7 +416,8 @@ def __init__(self, left, right, node, source, dest, time, metadata=b"", id_=None self.source = source self.dest = dest self.time = time - self.metadata = metadata + self._encoded_metadata = encoded_metadata + self._metadata_decoder = metadata_decoder def __repr__(self): return ( @@ -340,7 +435,7 @@ def __repr__(self): ) -class Population(SimpleContainer): +class Population(SimpleContainerWithMetadata): """ A :ref:`population ` in a tree sequence. @@ -354,9 +449,12 @@ class Population(SimpleContainer): :vartype metadata: bytes """ - def __init__(self, id_, metadata=""): + def __init__( + self, id_, encoded_metadata=b"", metadata_decoder=lambda metadata: metadata, + ): self.id = id_ - self.metadata = metadata + self._encoded_metadata = encoded_metadata + self._metadata_decoder = metadata_decoder class Variant(SimpleContainer): @@ -2350,6 +2448,13 @@ class TreeSequence: def __init__(self, ll_tree_sequence): self._ll_tree_sequence = ll_tree_sequence + ll_metadata_schemas = self._ll_tree_sequence.get_metadata_schemas() + self._metadata_schemas = MetadataSchemas( + *[ + metadata.MetadataSchema.from_str(getattr(ll_metadata_schemas, name)) + for name in MetadataSchemas._fields + ] + ) # Implement the pickle protocol for TreeSequence def __getstate__(self): @@ -2619,6 +2724,10 @@ def num_samples(self): """ return self._ll_tree_sequence.get_num_samples() + @property + def metadata_schemas(self): + return self._metadata_schemas + @property def sample_size(self): # Deprecated alias for num_samples @@ -2889,9 +2998,10 @@ def edge_diffs(self): :rtype: :class:`collections.abc.Iterable` """ iterator = _tskit.TreeDiffIterator(self._ll_tree_sequence) + metadata_decoder = self.metadata_schemas.edge.decode_row for interval, edge_tuples_out, edge_tuples_in in iterator: - edges_out = [Edge(*e) for e in edge_tuples_out] - edges_in = [Edge(*e) for e in edge_tuples_in] + edges_out = [Edge(*(e + (metadata_decoder,))) for e in edge_tuples_out] + edges_in = [Edge(*(e + (metadata_decoder,))) for e in edge_tuples_in] yield interval, edges_out, edges_in def sites(self): @@ -3291,7 +3401,12 @@ def individual(self, id_): """ flags, location, metadata, nodes = self._ll_tree_sequence.get_individual(id_) return Individual( - id_=id_, flags=flags, location=location, metadata=metadata, nodes=nodes + id_=id_, + flags=flags, + location=location, + encoded_metadata=metadata, + metadata_decoder=self.metadata_schemas.node.decode_row, + nodes=nodes, ) def node(self, id_): @@ -3314,7 +3429,8 @@ def node(self, id_): time=time, population=population, individual=individual, - metadata=metadata, + encoded_metadata=metadata, + metadata_decoder=self.metadata_schemas.node.decode_row, ) def edge(self, id_): @@ -3331,7 +3447,8 @@ def edge(self, id_): right=right, parent=parent, child=child, - metadata=metadata, + encoded_metadata=metadata, + metadata_decoder=self.metadata_schemas.edge.decode_row, ) def migration(self, id_): @@ -3358,7 +3475,8 @@ def migration(self, id_): source=source, dest=dest, time=time, - metadata=metadata, + encoded_metadata=metadata, + metadata_decoder=self.metadata_schemas.migration.decode_row, ) def mutation(self, id_): @@ -3381,7 +3499,8 @@ def mutation(self, id_): node=node, derived_state=derived_state, parent=parent, - metadata=metadata, + encoded_metadata=metadata, + metadata_decoder=self.metadata_schemas.mutation.decode_row, ) def site(self, id_): @@ -3399,7 +3518,8 @@ def site(self, id_): position=pos, ancestral_state=ancestral_state, mutations=mutations, - metadata=metadata, + encoded_metadata=metadata, + metadata_decoder=self.metadata_schemas.site.decode_row, ) def population(self, id_): @@ -3410,7 +3530,11 @@ def population(self, id_): :rtype: :class:`Population` """ (metadata,) = self._ll_tree_sequence.get_population(id_) - return Population(id_=id_, metadata=metadata) + return Population( + id_=id_, + encoded_metadata=metadata, + metadata_decoder=self.metadata_schemas.population.decode_row, + ) def provenance(self, id_): timestamp, record = self._ll_tree_sequence.get_provenance(id_)