From 43921cc4ff8696f5fd19f7f04b48222fd4f692d0 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 25 Nov 2021 16:47:34 +0000 Subject: [PATCH] Basic undocumented support for reference sequences. Closes #146 --- c/tests/test_file_format.c | 56 ++ c/tests/test_tables.c | 249 +++++---- c/tests/test_trees.c | 27 + c/tests/testlib.c | 7 + c/tskit/core.h | 2 +- c/tskit/tables.c | 172 +++--- c/tskit/tables.h | 25 +- c/tskit/trees.c | 6 + c/tskit/trees.h | 2 + docs/python-api.md | 13 +- python/_tskitmodule.c | 383 +++++++++++++- python/lwt_interface/dict_encoding_testlib.py | 179 +++++-- python/lwt_interface/tskit_lwt_interface.h | 500 ++++++++++++------ python/requirements/development.txt | 1 + python/tests/conftest.py | 11 +- python/tests/test_file_format.py | 72 ++- python/tests/test_lowlevel.py | 186 +++++++ python/tests/test_reference_sequence.py | 243 +++++++++ python/tests/test_tables.py | 4 +- python/tskit/metadata.py | 70 ++- python/tskit/tables.py | 174 ++++-- python/tskit/trees.py | 11 + 22 files changed, 1931 insertions(+), 462 deletions(-) create mode 100644 python/tests/test_reference_sequence.py diff --git a/c/tests/test_file_format.c b/c/tests/test_file_format.c index cd71782ded..e10cff3821 100644 --- a/c/tests/test_file_format.c +++ b/c/tests/test_file_format.c @@ -636,6 +636,47 @@ test_malformed_indexes(void) free(ts); } +static void +test_missing_reference_sequence(void) +{ + int ret; + tsk_treeseq_t *ts = caterpillar_tree(5, 3, 3); + tsk_table_collection_t t1, t2; + const char *cols[] = { "reference_sequence/data", "reference_sequence/url", + "reference_sequence/metadata_schema", "reference_sequence/metadata" }; + + CU_ASSERT_TRUE(tsk_treeseq_has_reference_sequence(ts)); + + ret = tsk_treeseq_copy_tables(ts, &t1, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + copy_store_drop_columns(ts, 1, cols, _tmp_file_name); + ret = tsk_table_collection_load(&t2, _tmp_file_name, 0); + CU_ASSERT_TRUE(tsk_table_collection_has_reference_sequence(&t2)); + tsk_table_collection_free(&t2); + + copy_store_drop_columns(ts, 2, cols, _tmp_file_name); + ret = tsk_table_collection_load(&t2, _tmp_file_name, 0); + CU_ASSERT_TRUE(tsk_table_collection_has_reference_sequence(&t2)); + tsk_table_collection_free(&t2); + + copy_store_drop_columns(ts, 3, cols, _tmp_file_name); + ret = tsk_table_collection_load(&t2, _tmp_file_name, 0); + CU_ASSERT_TRUE(tsk_table_collection_has_reference_sequence(&t2)); + tsk_table_collection_free(&t2); + + /* Dropping all the columns gives us a NULL reference_sequence, though */ + copy_store_drop_columns(ts, 4, cols, _tmp_file_name); + ret = tsk_table_collection_load(&t2, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FALSE(tsk_table_collection_has_reference_sequence(&t2)); + tsk_table_collection_free(&t2); + + tsk_table_collection_free(&t1); + tsk_treeseq_free(ts); + free(ts); +} + static void test_bad_column_types(void) { @@ -699,6 +740,18 @@ test_bad_column_types(void) CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_COLUMN_TYPE); tsk_table_collection_free(&tables); + cols[0] = "reference_sequence/metadata"; + copy_store_drop_columns(ts, 1, cols, _tmp_file_name); + ret = kastore_open(&store, _tmp_file_name, "a", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = kastore_puts(&store, cols[0], NULL, 0, KAS_FLOAT32, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = kastore_close(&store); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_load(&tables, _tmp_file_name, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_COLUMN_TYPE); + tsk_table_collection_free(&tables); + free(col_memory); tsk_treeseq_free(ts); free(ts); @@ -760,6 +813,8 @@ test_metadata_schemas_optional(void) const char *cols[] = { "metadata", "metadata_schema", + "reference_sequence/metadata", + "reference_sequence/metadata_schema", "individuals/metadata_schema", "populations/metadata_schema", "nodes/metadata_schema", @@ -1290,6 +1345,7 @@ main(int argc, char **argv) { "test_format_data_load_errors", test_format_data_load_errors }, { "test_missing_indexes", test_missing_indexes }, { "test_malformed_indexes", test_malformed_indexes }, + { "test_missing_reference_sequence", test_missing_reference_sequence }, { "test_bad_column_types", test_bad_column_types }, { "test_missing_required_columns", test_missing_required_columns }, { "test_missing_optional_column_pairs", test_missing_optional_column_pairs }, diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index a41af008d4..fe66a3787f 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -359,12 +359,76 @@ test_table_collection_simplify_errors(void) tsk_table_collection_free(&tables); } +static void +test_reference_sequence_state_machine(void) +{ + + tsk_reference_sequence_t r1; + + tsk_reference_sequence_init(&r1, 0); + CU_ASSERT_EQUAL(r1.data, NULL); + CU_ASSERT_EQUAL(r1.url, NULL); + CU_ASSERT_EQUAL(r1.metadata, NULL); + CU_ASSERT_EQUAL(r1.metadata_schema, NULL); + CU_ASSERT_TRUE(tsk_reference_sequence_is_null(&r1)); + + CU_ASSERT_EQUAL(tsk_reference_sequence_set_data(&r1, "x", 1), 0); + CU_ASSERT_FALSE(tsk_reference_sequence_is_null(&r1)); + /* Setting the value back to NULL makes the reference whole object NULL */ + CU_ASSERT_EQUAL(tsk_reference_sequence_set_data(&r1, NULL, 0), 0); + CU_ASSERT_TRUE(tsk_reference_sequence_is_null(&r1)); + tsk_reference_sequence_free(&r1); + CU_ASSERT_TRUE(tsk_reference_sequence_is_null(&r1)); + + /* Any empty string is the same thing. */ + tsk_reference_sequence_init(&r1, 0); + CU_ASSERT_EQUAL(tsk_reference_sequence_set_data(&r1, "", 0), 0); + CU_ASSERT_TRUE(tsk_reference_sequence_is_null(&r1)); + tsk_reference_sequence_free(&r1); + + tsk_reference_sequence_init(&r1, 0); + CU_ASSERT_EQUAL(tsk_reference_sequence_set_url(&r1, "x", 1), 0); + CU_ASSERT_FALSE(tsk_reference_sequence_is_null(&r1)); + tsk_reference_sequence_free(&r1); + + tsk_reference_sequence_init(&r1, 0); + CU_ASSERT_EQUAL(tsk_reference_sequence_set_metadata(&r1, "x", 1), 0); + CU_ASSERT_FALSE(tsk_reference_sequence_is_null(&r1)); + tsk_reference_sequence_free(&r1); + + tsk_reference_sequence_init(&r1, 0); + CU_ASSERT_EQUAL(tsk_reference_sequence_set_metadata_schema(&r1, "x", 1), 0); + CU_ASSERT_FALSE(tsk_reference_sequence_is_null(&r1)); + tsk_reference_sequence_free(&r1); + + tsk_reference_sequence_init(&r1, 0); + CU_ASSERT_EQUAL(tsk_reference_sequence_set_metadata(&r1, "x", 1), 0); + CU_ASSERT_FALSE(tsk_reference_sequence_is_null(&r1)); + CU_ASSERT_EQUAL(tsk_reference_sequence_set_metadata_schema(&r1, "x", 1), 0); + CU_ASSERT_FALSE(tsk_reference_sequence_is_null(&r1)); + CU_ASSERT_EQUAL(tsk_reference_sequence_set_url(&r1, "x", 1), 0); + CU_ASSERT_FALSE(tsk_reference_sequence_is_null(&r1)); + CU_ASSERT_EQUAL(tsk_reference_sequence_set_data(&r1, "x", 1), 0); + CU_ASSERT_FALSE(tsk_reference_sequence_is_null(&r1)); + + CU_ASSERT_EQUAL(tsk_reference_sequence_set_metadata(&r1, "", 0), 0); + CU_ASSERT_FALSE(tsk_reference_sequence_is_null(&r1)); + CU_ASSERT_EQUAL(tsk_reference_sequence_set_metadata_schema(&r1, "", 0), 0); + CU_ASSERT_FALSE(tsk_reference_sequence_is_null(&r1)); + CU_ASSERT_EQUAL(tsk_reference_sequence_set_url(&r1, "", 0), 0); + CU_ASSERT_FALSE(tsk_reference_sequence_is_null(&r1)); + CU_ASSERT_EQUAL(tsk_reference_sequence_set_data(&r1, "", 0), 0); + CU_ASSERT_TRUE(tsk_reference_sequence_is_null(&r1)); + + tsk_reference_sequence_free(&r1); +} + static void test_reference_sequence(void) { int ret; - tsk_reference_sequence_t *r1 = NULL; - tsk_reference_sequence_t *r2 = NULL; + tsk_reference_sequence_t r1; + tsk_reference_sequence_t r2; char example_data[100] = "An example string with unicode 🎄🌳🌴🌲🎋"; tsk_size_t example_data_length = (tsk_size_t) strlen(example_data); @@ -375,99 +439,94 @@ test_reference_sequence(void) char example_schema[100] = "An example schema with unicode 🎄🌳🌴🌲🎋"; tsk_size_t example_schema_length = (tsk_size_t) strlen(example_schema); - // Test equality - CU_ASSERT_TRUE(tsk_reference_sequence_equals(r1, r2, 0)); + tsk_reference_sequence_init(&r1, 0); + tsk_reference_sequence_init(&r2, 0); - r1 = tsk_malloc(sizeof(tsk_reference_sequence_t)); - CU_ASSERT_NOT_EQUAL_FATAL(r1, NULL); - tsk_reference_sequence_init(r1); + /* NULL sequences are initially equal */ + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, 0)); - ret = tsk_reference_sequence_set_data(r1, example_data, example_data_length); + ret = tsk_reference_sequence_set_data(&r1, example_data, example_data_length); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_FALSE(tsk_reference_sequence_equals(r1, r2, 0)); + CU_ASSERT_FALSE(tsk_reference_sequence_equals(&r1, &r2, 0)); - r2 = tsk_malloc(sizeof(tsk_reference_sequence_t)); - CU_ASSERT_NOT_EQUAL_FATAL(r2, NULL); - tsk_reference_sequence_init(r2); + ret = tsk_reference_sequence_set_data(&r1, "", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, 0)); - ret = tsk_reference_sequence_set_data(r1, "", 0); + ret = tsk_reference_sequence_set_data(&r2, "", 0); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_TRUE(tsk_reference_sequence_equals(r1, r2, 0)); + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, 0)); - ret = tsk_reference_sequence_set_data(r1, example_data, example_data_length); + ret = tsk_reference_sequence_set_data(&r1, example_data, example_data_length); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_FALSE(tsk_reference_sequence_equals(r1, r2, 0)); + CU_ASSERT_FALSE(tsk_reference_sequence_equals(&r1, &r2, 0)); - ret = tsk_reference_sequence_set_data(r2, example_data, example_data_length); + ret = tsk_reference_sequence_set_data(&r2, example_data, example_data_length); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_TRUE(tsk_reference_sequence_equals(r1, r2, 0)); + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, 0)); - ret = tsk_reference_sequence_set_url(r1, example_url, example_url_length); + ret = tsk_reference_sequence_set_url(&r1, example_url, example_url_length); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_FALSE(tsk_reference_sequence_equals(r1, r2, 0)); - ret = tsk_reference_sequence_set_url(r2, example_url, example_url_length); + CU_ASSERT_FALSE(tsk_reference_sequence_equals(&r1, &r2, 0)); + ret = tsk_reference_sequence_set_url(&r2, example_url, example_url_length); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_TRUE(tsk_reference_sequence_equals(r1, r2, 0)); + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, 0)); ret = tsk_reference_sequence_set_metadata( - r1, example_metadata, example_metadata_length); + &r1, example_metadata, example_metadata_length); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_FALSE(tsk_reference_sequence_equals(r1, r2, 0)); + CU_ASSERT_FALSE(tsk_reference_sequence_equals(&r1, &r2, 0)); + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, TSK_CMP_IGNORE_METADATA)); ret = tsk_reference_sequence_set_metadata( - r2, example_metadata, example_metadata_length); + &r2, example_metadata, example_metadata_length); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_TRUE(tsk_reference_sequence_equals(r1, r2, 0)); + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, 0)); + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, TSK_CMP_IGNORE_METADATA)); ret = tsk_reference_sequence_set_metadata_schema( - r1, example_schema, example_schema_length); + &r1, example_schema, example_schema_length); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_FALSE(tsk_reference_sequence_equals(r1, r2, 0)); + CU_ASSERT_FALSE(tsk_reference_sequence_equals(&r1, &r2, 0)); + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, TSK_CMP_IGNORE_METADATA)); ret = tsk_reference_sequence_set_metadata_schema( - r2, example_schema, example_schema_length); + &r2, example_schema, example_schema_length); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_TRUE(tsk_reference_sequence_equals(r1, r2, 0)); + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, 0)); + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, TSK_CMP_IGNORE_METADATA)); // Test copy - tsk_reference_sequence_free(r1); - tsk_safe_free(r1); - r1 = NULL; - tsk_reference_sequence_free(r2); - tsk_safe_free(r2); - r2 = NULL; + tsk_reference_sequence_free(&r1); + tsk_reference_sequence_free(&r2); - r1 = tsk_malloc(sizeof(tsk_reference_sequence_t)); - CU_ASSERT_NOT_EQUAL_FATAL(r1, NULL); - tsk_reference_sequence_init(r1); - ret = tsk_reference_sequence_set_data(r1, example_data, example_data_length); + tsk_reference_sequence_init(&r1, 0); + ret = tsk_reference_sequence_set_data(&r1, example_data, example_data_length); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_reference_sequence_copy(r1, &r2, 0); + ret = tsk_reference_sequence_copy(&r1, &r2, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_TRUE(tsk_reference_sequence_equals(r1, r2, 0)); + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, 0)); - ret = tsk_reference_sequence_set_url(r1, example_url, example_url_length); + ret = tsk_reference_sequence_set_url(&r1, example_url, example_url_length); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_reference_sequence_copy(r1, &r2, 0); + ret = tsk_reference_sequence_copy(&r1, &r2, TSK_NO_INIT); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_TRUE(tsk_reference_sequence_equals(r1, r2, 0)); + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, 0)); ret = tsk_reference_sequence_set_metadata( - r1, example_metadata, example_metadata_length); + &r1, example_metadata, example_metadata_length); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_reference_sequence_copy(r1, &r2, 0); + ret = tsk_reference_sequence_copy(&r1, &r2, TSK_NO_INIT); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_TRUE(tsk_reference_sequence_equals(r1, r2, 0)); + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, 0)); ret = tsk_reference_sequence_set_metadata_schema( - r1, example_schema, example_schema_length); + &r1, example_schema, example_schema_length); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_reference_sequence_copy(r1, &r2, 0); + ret = tsk_reference_sequence_copy(&r1, &r2, TSK_NO_INIT); CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_TRUE(tsk_reference_sequence_equals(r1, r2, 0)); + CU_ASSERT_TRUE(tsk_reference_sequence_equals(&r1, &r2, 0)); - tsk_reference_sequence_free(r1); - tsk_safe_free(r1); - tsk_reference_sequence_free(r2); - tsk_safe_free(r2); + tsk_reference_sequence_free(&r1); + tsk_reference_sequence_free(&r2); } static void @@ -492,48 +551,40 @@ test_table_collection_reference_sequence(void) CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_TRUE(tsk_table_collection_equals(&tc1, &tc2, 0)); - tc1.reference_sequence = tsk_malloc(sizeof(tsk_reference_sequence_t)); - CU_ASSERT_NOT_EQUAL_FATAL(tc1.reference_sequence, NULL); - tsk_reference_sequence_init(tc1.reference_sequence); - ret = tsk_reference_sequence_set_data( - tc1.reference_sequence, example_data, example_data_length); + &tc1.reference_sequence, example_data, example_data_length); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_FALSE(tsk_table_collection_equals(&tc1, &tc2, 0)); - tc2.reference_sequence = tsk_malloc(sizeof(tsk_reference_sequence_t)); - CU_ASSERT_NOT_EQUAL_FATAL(tc2.reference_sequence, NULL); - tsk_reference_sequence_init(tc2.reference_sequence); - ret = tsk_reference_sequence_set_data( - tc2.reference_sequence, example_data, example_data_length); + &tc2.reference_sequence, example_data, example_data_length); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_TRUE(tsk_table_collection_equals(&tc1, &tc2, 0)); ret = tsk_reference_sequence_set_url( - tc1.reference_sequence, example_url, example_url_length); + &tc1.reference_sequence, example_url, example_url_length); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_FALSE(tsk_table_collection_equals(&tc1, &tc2, 0)); ret = tsk_reference_sequence_set_url( - tc2.reference_sequence, example_url, example_url_length); + &tc2.reference_sequence, example_url, example_url_length); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_TRUE(tsk_table_collection_equals(&tc1, &tc2, 0)); ret = tsk_reference_sequence_set_metadata( - tc1.reference_sequence, example_metadata, example_metadata_length); + &tc1.reference_sequence, example_metadata, example_metadata_length); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_FALSE(tsk_table_collection_equals(&tc1, &tc2, 0)); ret = tsk_reference_sequence_set_metadata( - tc2.reference_sequence, example_metadata, example_metadata_length); + &tc2.reference_sequence, example_metadata, example_metadata_length); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_TRUE(tsk_table_collection_equals(&tc1, &tc2, 0)); ret = tsk_reference_sequence_set_metadata_schema( - tc1.reference_sequence, example_schema, example_schema_length); + &tc1.reference_sequence, example_schema, example_schema_length); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_FALSE(tsk_table_collection_equals(&tc1, &tc2, 0)); ret = tsk_reference_sequence_set_metadata_schema( - tc2.reference_sequence, example_schema, example_schema_length); + &tc2.reference_sequence, example_schema, example_schema_length); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_TRUE(tsk_table_collection_equals(&tc1, &tc2, 0)); @@ -543,60 +594,51 @@ test_table_collection_reference_sequence(void) ret = tsk_table_collection_init(&tc1, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); - tc1.reference_sequence = tsk_malloc(sizeof(tsk_reference_sequence_t)); - CU_ASSERT_NOT_EQUAL_FATAL(tc1.reference_sequence, NULL); - tsk_reference_sequence_init(tc1.reference_sequence); - ret = tsk_reference_sequence_set_data( - tc1.reference_sequence, example_data, example_data_length); + &tc1.reference_sequence, example_data, example_data_length); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_table_collection_copy(&tc1, &tc2, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_TRUE(tsk_table_collection_equals(&tc1, &tc2, 0)); ret = tsk_reference_sequence_set_url( - tc1.reference_sequence, example_url, example_url_length); + &tc1.reference_sequence, example_url, example_url_length); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_table_collection_copy(&tc1, &tc2, TSK_NO_INIT); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_TRUE(tsk_table_collection_equals(&tc1, &tc2, 0)); ret = tsk_reference_sequence_set_metadata( - tc1.reference_sequence, example_metadata, example_metadata_length); + &tc1.reference_sequence, example_metadata, example_metadata_length); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_table_collection_copy(&tc1, &tc2, TSK_NO_INIT); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_TRUE(tsk_table_collection_equals(&tc1, &tc2, 0)); ret = tsk_reference_sequence_set_metadata_schema( - tc1.reference_sequence, example_schema, example_schema_length); + &tc1.reference_sequence, example_schema, example_schema_length); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_table_collection_copy(&tc1, &tc2, TSK_NO_INIT); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_TRUE(tsk_table_collection_equals(&tc1, &tc2, 0)); - - // Test dump and load tsk_table_collection_free(&tc1); tsk_table_collection_free(&tc2); + + // Test dump and load ret = tsk_table_collection_init(&tc1, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); tc1.sequence_length = 1.0; - - tc1.reference_sequence = tsk_malloc(sizeof(tsk_reference_sequence_t)); - CU_ASSERT_NOT_EQUAL_FATAL(tc1.reference_sequence, NULL); - tsk_reference_sequence_init(tc1.reference_sequence); - ret = tsk_reference_sequence_set_data( - tc1.reference_sequence, example_data, example_data_length); + &tc1.reference_sequence, example_data, example_data_length); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_reference_sequence_set_url( - tc1.reference_sequence, example_url, example_url_length); + &tc1.reference_sequence, example_url, example_url_length); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_reference_sequence_set_metadata( - tc1.reference_sequence, example_metadata, example_metadata_length); + &tc1.reference_sequence, example_metadata, example_metadata_length); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_reference_sequence_set_metadata_schema( - tc1.reference_sequence, example_schema, example_schema_length); + &tc1.reference_sequence, example_schema, example_schema_length); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_table_collection_dump(&tc1, _tmp_file_name, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -607,6 +649,29 @@ test_table_collection_reference_sequence(void) tsk_table_collection_free(&tc2); } +static void +test_table_collection_has_reference_sequence(void) +{ + int ret; + tsk_table_collection_t tc; + + ret = tsk_table_collection_init(&tc, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tc.sequence_length = 1.0; + + CU_ASSERT_FALSE(tsk_table_collection_has_reference_sequence(&tc)); + ret = tsk_reference_sequence_set_data(&tc.reference_sequence, "A", 1); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_table_collection_has_reference_sequence(&tc)); + /* Goes back to NULL by setting a empty string. See + * test_reference_sequence_state_machine for detailed tests. */ + ret = tsk_reference_sequence_set_data(&tc.reference_sequence, "", 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FALSE(tsk_table_collection_has_reference_sequence(&tc)); + + tsk_table_collection_free(&tc); +} + static void test_table_collection_metadata(void) { @@ -9079,7 +9144,11 @@ main(int argc, char **argv) { "test_table_collection_time_units", test_table_collection_time_units }, { "test_table_collection_reference_sequence", test_table_collection_reference_sequence }, + { "test_table_collection_has_reference_sequence", + test_table_collection_has_reference_sequence }, { "test_table_collection_metadata", test_table_collection_metadata }, + { "test_reference_sequence_state_machine", + test_reference_sequence_state_machine }, { "test_reference_sequence", test_reference_sequence }, { "test_simplify_tables_drops_indexes", test_simplify_tables_drops_indexes }, diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 76b346be6a..5fe63adb26 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -6949,6 +6949,32 @@ test_time_uncalibrated(void) tsk_table_collection_free(&tables); } +static void +test_reference_sequence(void) +{ + int ret; + tsk_table_collection_t tables; + tsk_treeseq_t ts; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1; + + ret = tsk_treeseq_init(&ts, &tables, TSK_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FALSE(tsk_treeseq_has_reference_sequence(&ts)); + tsk_treeseq_free(&ts); + + ret = tsk_reference_sequence_set_data(&tables.reference_sequence, "abc", 3); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_init(&ts, &tables, TSK_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_TRUE(tsk_treeseq_has_reference_sequence(&ts)); + tsk_treeseq_free(&ts); + + tsk_table_collection_free(&tables); +} + int main(int argc, char **argv) { @@ -7117,6 +7143,7 @@ main(int argc, char **argv) { "test_zero_edges", test_zero_edges }, { "test_tree_sequence_metadata", test_tree_sequence_metadata }, { "test_time_uncalibrated", test_time_uncalibrated }, + { "test_reference_sequence", test_reference_sequence }, { NULL, NULL }, }; diff --git a/c/tests/testlib.c b/c/tests/testlib.c index 7e4947501a..860116c9d3 100644 --- a/c/tests/testlib.c +++ b/c/tests/testlib.c @@ -823,6 +823,13 @@ caterpillar_tree(tsk_size_t n, tsk_size_t num_sites, tsk_size_t num_mutations) tsk_table_collection_set_metadata(&tables, ts_metadata, strlen(ts_metadata)); tsk_table_collection_set_metadata_schema( &tables, ts_metadata_schema, strlen(ts_metadata_schema)); + tsk_reference_sequence_set_metadata_schema( + &tables.reference_sequence, ts_metadata_schema, strlen(ts_metadata_schema)); + tsk_reference_sequence_set_metadata( + &tables.reference_sequence, ts_metadata, strlen(ts_metadata)); + tsk_reference_sequence_set_data(&tables.reference_sequence, "A", 1); + tsk_reference_sequence_set_url(&tables.reference_sequence, "B", 1); + tsk_population_table_set_metadata_schema( &tables.populations, metadata_schema, strlen(metadata_schema)); tsk_individual_table_set_metadata_schema( diff --git a/c/tskit/core.h b/c/tskit/core.h index 533e8e8654..a0f4927a34 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -178,7 +178,7 @@ to the API or ABI are introduced, i.e., internal refactors of bugfixes. #define TSK_FILE_FORMAT_NAME "tskit.trees" #define TSK_FILE_FORMAT_NAME_LENGTH 11 #define TSK_FILE_FORMAT_VERSION_MAJOR 12 -#define TSK_FILE_FORMAT_VERSION_MINOR 6 +#define TSK_FILE_FORMAT_VERSION_MINOR 7 /** @defgroup GENERAL_ERROR_GROUP General errors. diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 7006f59659..391d22ec1d 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -631,7 +631,8 @@ write_metadata_schema_header( *************************/ int -tsk_reference_sequence_init(tsk_reference_sequence_t *self) +tsk_reference_sequence_init( + tsk_reference_sequence_t *self, tsk_flags_t TSK_UNUSED(options)) { tsk_memset(self, 0, sizeof(*self)); return 0; @@ -647,70 +648,78 @@ tsk_reference_sequence_free(tsk_reference_sequence_t *self) return 0; } +bool +tsk_reference_sequence_is_null(const tsk_reference_sequence_t *self) +{ + return self->data_length == 0 && self->url_length == 0 && self->metadata_length == 0 + && self->metadata_schema_length == 0; +} + bool tsk_reference_sequence_equals(const tsk_reference_sequence_t *self, const tsk_reference_sequence_t *other, tsk_flags_t options) { - if (self == NULL && other == NULL) { + int ret; + bool self_null = tsk_reference_sequence_is_null(self); + bool other_null = tsk_reference_sequence_is_null(other); + + if (self_null && other_null) { return true; } /* If one or the other is NULL they are not equal */ - if ((self == NULL) != (other == NULL)) { + if (self_null != other_null) { return false; } - return ( - (self->data_length == other->data_length && self->url_length == other->url_length - && ((options & TSK_CMP_IGNORE_TS_METADATA) - || self->metadata_length == other->metadata_length) - && ((options & TSK_CMP_IGNORE_TS_METADATA) - || self->metadata_schema_length == other->metadata_schema_length) - && tsk_memcmp(self->data, other->data, self->data_length * sizeof(char)) == 0 - && tsk_memcmp(self->url, other->url, self->url_length * sizeof(char)) == 0 - && ((options & TSK_CMP_IGNORE_TS_METADATA) - || tsk_memcmp(self->metadata, other->metadata, - self->metadata_length * sizeof(char)) - == 0) - && ((options & TSK_CMP_IGNORE_TS_METADATA) - || tsk_memcmp(self->metadata_schema, other->metadata_schema, - self->metadata_schema_length * sizeof(char)) - == 0))); + ret = self->data_length == other->data_length + && self->url_length == other->url_length + && tsk_memcmp(self->data, other->data, self->data_length * sizeof(char)) == 0 + && tsk_memcmp(self->url, other->url, self->url_length * sizeof(char)) == 0; + + if (!(options & TSK_CMP_IGNORE_METADATA)) { + ret = ret && self->metadata_length == other->metadata_length + && self->metadata_schema_length == other->metadata_schema_length + && tsk_memcmp(self->metadata, other->metadata, + self->metadata_length * sizeof(char)) + == 0 + && tsk_memcmp(self->metadata_schema, other->metadata_schema, + self->metadata_schema_length * sizeof(char)) + == 0; + } + return ret; } int tsk_reference_sequence_copy(const tsk_reference_sequence_t *self, - tsk_reference_sequence_t **dest, tsk_flags_t TSK_UNUSED(options)) + tsk_reference_sequence_t *dest, tsk_flags_t options) { int ret = 0; - if (*dest != NULL) { - tsk_reference_sequence_free(*dest); - tsk_safe_free(*dest); - *dest = NULL; - } - - if (self != NULL) { - *dest = tsk_malloc(sizeof(tsk_reference_sequence_t)); - if (*dest == NULL) { - ret = TSK_ERR_NO_MEMORY; + if (!(options & TSK_NO_INIT)) { + ret = tsk_reference_sequence_init(dest, 0); + if (ret != 0) { goto out; } - tsk_reference_sequence_init(*dest); + } - ret = tsk_reference_sequence_set_data(*dest, self->data, self->data_length); + if (tsk_reference_sequence_is_null(self)) { + /* This is a simple way to get any input into the NULL state */ + tsk_reference_sequence_free(dest); + } else { + ret = tsk_reference_sequence_set_data(dest, self->data, self->data_length); if (ret != 0) { goto out; } - ret = tsk_reference_sequence_set_url(*dest, self->url, self->url_length); + ret = tsk_reference_sequence_set_url(dest, self->url, self->url_length); if (ret != 0) { goto out; } ret = tsk_reference_sequence_set_metadata( - *dest, self->metadata, self->metadata_length); + dest, self->metadata, self->metadata_length); if (ret != 0) { goto out; } ret = tsk_reference_sequence_set_metadata_schema( - *dest, self->metadata_schema, self->metadata_schema_length); + dest, self->metadata_schema, self->metadata_schema_length); if (ret != 0) { goto out; } @@ -720,37 +729,33 @@ tsk_reference_sequence_copy(const tsk_reference_sequence_t *self, } int -tsk_reference_sequence_set_data(tsk_reference_sequence_t *self, - const char *reference_sequence, tsk_size_t reference_sequence_length) +tsk_reference_sequence_set_data( + tsk_reference_sequence_t *self, const char *data, tsk_size_t data_length) { - return replace_string( - &self->data, &self->data_length, reference_sequence, reference_sequence_length); + return replace_string(&self->data, &self->data_length, data, data_length); } int -tsk_reference_sequence_set_url(tsk_reference_sequence_t *self, - const char *reference_sequence_url, tsk_size_t reference_sequence_url_length) +tsk_reference_sequence_set_url( + tsk_reference_sequence_t *self, const char *url, tsk_size_t url_length) { - return replace_string(&self->url, &self->url_length, reference_sequence_url, - reference_sequence_url_length); + return replace_string(&self->url, &self->url_length, url, url_length); } int -tsk_reference_sequence_set_metadata(tsk_reference_sequence_t *self, - const char *reference_sequence_metadata, - tsk_size_t reference_sequence_metadata_length) +tsk_reference_sequence_set_metadata( + tsk_reference_sequence_t *self, const char *metadata, tsk_size_t metadata_length) { - return replace_string(&self->metadata, &self->metadata_length, - reference_sequence_metadata, reference_sequence_metadata_length); + return replace_string( + &self->metadata, &self->metadata_length, metadata, metadata_length); } int tsk_reference_sequence_set_metadata_schema(tsk_reference_sequence_t *self, - const char *reference_sequence_metadata_schema, - tsk_size_t reference_sequence_metadata_schema_length) + const char *metadata_schema, tsk_size_t metadata_schema_length) { return replace_string(&self->metadata_schema, &self->metadata_schema_length, - reference_sequence_metadata_schema, reference_sequence_metadata_schema_length); + metadata_schema, metadata_schema_length); } /************************* @@ -9939,6 +9944,10 @@ tsk_table_collection_init(tsk_table_collection_t *self, tsk_flags_t options) if (ret != 0) { goto out; } + ret = tsk_reference_sequence_init(&self->reference_sequence, 0); + if (ret != 0) { + goto out; + } out: return ret; } @@ -9954,10 +9963,7 @@ tsk_table_collection_free(tsk_table_collection_t *self) tsk_mutation_table_free(&self->mutations); tsk_population_table_free(&self->populations); tsk_provenance_table_free(&self->provenances); - if (self->reference_sequence != NULL) { - tsk_reference_sequence_free(self->reference_sequence); - } - tsk_safe_free(self->reference_sequence); + tsk_reference_sequence_free(&self->reference_sequence); tsk_safe_free(self->indexes.edge_insertion_order); tsk_safe_free(self->indexes.edge_removal_order); tsk_safe_free(self->file_uuid); @@ -10012,7 +10018,7 @@ tsk_table_collection_equals(const tsk_table_collection_t *self, } ret = ret && tsk_reference_sequence_equals( - self->reference_sequence, other->reference_sequence, options); + &self->reference_sequence, &other->reference_sequence, options); return ret; } @@ -10071,6 +10077,12 @@ tsk_table_collection_has_index( && self->indexes.num_edges == self->edges.num_rows; } +bool +tsk_table_collection_has_reference_sequence(const tsk_table_collection_t *self) +{ + return !tsk_reference_sequence_is_null(&self->reference_sequence); +} + int tsk_table_collection_drop_index( tsk_table_collection_t *self, tsk_flags_t TSK_UNUSED(options)) @@ -10218,7 +10230,7 @@ tsk_table_collection_copy(const tsk_table_collection_t *self, goto out; } ret = tsk_reference_sequence_copy( - self->reference_sequence, &dest->reference_sequence, options); + &self->reference_sequence, &dest->reference_sequence, options); if (ret != 0) { goto out; } @@ -10235,11 +10247,13 @@ tsk_table_collection_read_format_data(tsk_table_collection_t *self, kastore_t *s uint32_t *version; int8_t *format_name, *uuid; double *L; - char *time_units = NULL; char *metadata = NULL; char *metadata_schema = NULL; size_t time_units_length, metadata_length, metadata_schema_length; + /* TODO we could simplify this function quite a bit if we use the + * read_table_properties infrastructure. We would need to add the + * ability to have non-optional columns to that though. */ ret = kastore_gets_int8(store, "format/name", &format_name, &len); if (ret != 0) { @@ -10447,7 +10461,6 @@ tsk_table_collection_load_reference_sequence( char *metadata = NULL; char *metadata_schema = NULL; tsk_size_t data_length = 0, url_length, metadata_length, metadata_schema_length; - bool reference_sequence_loaded; read_table_property_t properties[] = { { "reference_sequence/data", (void **) &data, &data_length, KAS_UINT8, @@ -10465,36 +10478,22 @@ tsk_table_collection_load_reference_sequence( if (ret != 0) { goto out; } - reference_sequence_loaded - = data != NULL || url != NULL || metadata != NULL || metadata_schema != NULL; - if (self->reference_sequence != NULL) { - tsk_reference_sequence_free(self->reference_sequence); - tsk_safe_free(self->reference_sequence); - } - if (reference_sequence_loaded) { - self->reference_sequence = tsk_malloc(sizeof(tsk_reference_sequence_t)); - if (self->reference_sequence == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - tsk_reference_sequence_init(self->reference_sequence); - } if (data != NULL) { ret = tsk_reference_sequence_set_data( - self->reference_sequence, data, (tsk_size_t) data_length); + &self->reference_sequence, data, (tsk_size_t) data_length); if (ret != 0) { goto out; } } if (metadata != NULL) { ret = tsk_reference_sequence_set_metadata( - self->reference_sequence, metadata, (tsk_size_t) metadata_length); + &self->reference_sequence, metadata, (tsk_size_t) metadata_length); if (ret != 0) { goto out; } } if (metadata_schema != NULL) { - ret = tsk_reference_sequence_set_metadata_schema(self->reference_sequence, + ret = tsk_reference_sequence_set_metadata_schema(&self->reference_sequence, metadata_schema, (tsk_size_t) metadata_schema_length); if (ret != 0) { goto out; @@ -10502,14 +10501,13 @@ tsk_table_collection_load_reference_sequence( } if (url != NULL) { ret = tsk_reference_sequence_set_url( - self->reference_sequence, url, (tsk_size_t) url_length); + &self->reference_sequence, url, (tsk_size_t) url_length); if (ret != 0) { goto out; } } out: - return ret; } @@ -10687,10 +10685,11 @@ tsk_table_collection_write_format_data(const tsk_table_collection_t *self, } static int TSK_WARN_UNUSED -tsk_table_collection_reference_sequence_dump(const tsk_table_collection_t *self, +tsk_table_collection_dump_reference_sequence(const tsk_table_collection_t *self, kastore_t *store, tsk_flags_t TSK_UNUSED(options)) { - const tsk_reference_sequence_t *ref = self->reference_sequence; + int ret = 0; + const tsk_reference_sequence_t *ref = &self->reference_sequence; write_table_col_t write_cols[] = { { "reference_sequence/data", (void *) ref->data, ref->data_length, KAS_UINT8 }, { "reference_sequence/url", (void *) ref->url, ref->url_length, KAS_UINT8 }, @@ -10700,7 +10699,10 @@ tsk_table_collection_reference_sequence_dump(const tsk_table_collection_t *self, ref->metadata_schema_length, KAS_UINT8 }, { .name = NULL }, }; - return write_table_cols(store, write_cols, 0); + if (tsk_table_collection_has_reference_sequence(self)) { + ret = write_table_cols(store, write_cols, 0); + } + return ret; } int TSK_WARN_UNUSED @@ -10791,11 +10793,9 @@ tsk_table_collection_dumpf( if (ret != 0) { goto out; } - if (self->reference_sequence != NULL) { - ret = tsk_table_collection_reference_sequence_dump(self, &store, options); - if (ret != 0) { - goto out; - } + ret = tsk_table_collection_dump_reference_sequence(self, &store, options); + if (ret != 0) { + goto out; } ret = kastore_close(&store); diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 35e615e3e9..2e1cd0b024 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -565,7 +565,7 @@ typedef struct { /** @brief The metadata schema */ char *metadata_schema; tsk_size_t metadata_schema_length; - tsk_reference_sequence_t *reference_sequence; + tsk_reference_sequence_t reference_sequence; /** @brief The individual table */ tsk_individual_table_t individuals; /** @brief The node table */ @@ -4004,6 +4004,8 @@ life-cycle. bool tsk_table_collection_has_index( const tsk_table_collection_t *self, tsk_flags_t options); +bool tsk_table_collection_has_reference_sequence(const tsk_table_collection_t *self); + /** @brief Deletes the indexes for this table collection. @@ -4140,22 +4142,21 @@ int tsk_table_collection_compute_mutation_parents( int tsk_table_collection_compute_mutation_times( tsk_table_collection_t *self, double *random, tsk_flags_t TSK_UNUSED(options)); -int tsk_reference_sequence_init(tsk_reference_sequence_t *self); +int tsk_reference_sequence_init(tsk_reference_sequence_t *self, tsk_flags_t options); int tsk_reference_sequence_free(tsk_reference_sequence_t *self); +bool tsk_reference_sequence_is_null(const tsk_reference_sequence_t *self); bool tsk_reference_sequence_equals(const tsk_reference_sequence_t *self, const tsk_reference_sequence_t *other, tsk_flags_t options); int tsk_reference_sequence_copy(const tsk_reference_sequence_t *self, - tsk_reference_sequence_t **dest, tsk_flags_t options); -int tsk_reference_sequence_set_data(tsk_reference_sequence_t *self, - const char *reference_sequence, tsk_size_t reference_sequence_length); -int tsk_reference_sequence_set_url(tsk_reference_sequence_t *self, - const char *reference_sequence_url, tsk_size_t reference_sequence_url_length); -int tsk_reference_sequence_set_metadata(tsk_reference_sequence_t *self, - const char *reference_sequence_metadata, - tsk_size_t reference_sequence_metadata_length); + tsk_reference_sequence_t *dest, tsk_flags_t options); +int tsk_reference_sequence_set_data( + tsk_reference_sequence_t *self, const char *data, tsk_size_t data_length); +int tsk_reference_sequence_set_url( + tsk_reference_sequence_t *self, const char *url, tsk_size_t url_length); +int tsk_reference_sequence_set_metadata( + tsk_reference_sequence_t *self, const char *metadata, tsk_size_t metadata_length); int tsk_reference_sequence_set_metadata_schema(tsk_reference_sequence_t *self, - const char *reference_sequence_metadata_schema, - tsk_size_t reference_sequence_metadata_schema_length); + const char *metadata_schema, tsk_size_t metadata_schema_length); /** @defgroup TABLE_SORTER_API_GROUP Low-level table sorter API. diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 2e15ebe063..90623320d6 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -704,6 +704,12 @@ tsk_treeseq_get_discrete_time(const tsk_treeseq_t *self) return self->discrete_time; } +bool +tsk_treeseq_has_reference_sequence(const tsk_treeseq_t *self) +{ + return tsk_table_collection_has_reference_sequence(self->tables); +} + /* Stats functions */ #define GET_2D_ROW(array, row_len, row) (array + (((size_t)(row_len)) * (size_t) row)) diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 8c6abc2979..896820e268 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -243,6 +243,8 @@ void tsk_treeseq_print_state(const tsk_treeseq_t *self, FILE *out); /** @} */ +bool tsk_treeseq_has_reference_sequence(const tsk_treeseq_t *self); + tsk_size_t tsk_treeseq_get_num_nodes(const tsk_treeseq_t *self); tsk_size_t tsk_treeseq_get_num_edges(const tsk_treeseq_t *self); tsk_size_t tsk_treeseq_get_num_migrations(const tsk_treeseq_t *self); diff --git a/docs/python-api.md b/docs/python-api.md index e30fd19299..f5bc4fcd66 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -22,7 +22,7 @@ This page documents the full tskit Python API. Brief thematic summaries of commo classes and methods are presented first. The {ref}`sec_reference_api` at the end then contains full details which aim to be concise, precise and exhaustive. Note that this may not therefore be the best place to start if you are new -to a particular piece of functionality. +to a particular piece of functionality. (sec_python_api_trees_and_tree_sequences)= @@ -224,7 +224,7 @@ which perform the same actions but modify the {class}`TableCollection` in place. #### Tables -The underlying data in a tree sequence is stored in a +The underlying data in a tree sequence is stored in a {ref}`collection of tables`. The following methods give access to tables and associated functionality. Since tables can be modified, this allows tree sequences to be edited: see the {ref}`sec_tables` tutorial for @@ -285,7 +285,7 @@ efficient methods sometimes exist for entire tree sequences: ```{eval-rst} .. autosummary:: TreeSequence.count_topologies - + ``` (sec_python_api_tree_sequences_display)= @@ -365,7 +365,7 @@ It is sometimes useful to create an entirely new tree sequence consisting of just a single tree (a "one-tree sequence"). The follow methods create such an object and return a {class}`Tree` instance corresponding to that tree. The new tree sequence to which the tree belongs is available through the -{attr}`~Tree.tree_sequence` property. +{attr}`~Tree.tree_sequence` property. ```{eval-rst} Creating a new tree @@ -406,7 +406,7 @@ available via simple and high performance {class}`Tree` methods ##### Simple measures These return a simple number, or (usually) short list of numbers relevant to a specific -node or limited set of nodes. +node or limited set of nodes. ```{eval-rst} Node information @@ -434,7 +434,7 @@ Descendant nodes Tree.samples Tree.num_samples Tree.num_tracked_samples - + Multiple nodes .. autosummary:: Tree.is_descendant @@ -1152,6 +1152,7 @@ The following constants are used throughout the `tskit` API. ```{eval-rst} .. autoclass:: TableCollection + :inherited-members: :members: ``` diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 6b06c60443..1ba13a35ed 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -142,12 +142,6 @@ typedef struct { tsk_tree_t *tree; } Tree; -typedef struct { - PyObject_HEAD - Tree *tree; - int first; -} TreeIterator; - typedef struct { PyObject_HEAD TreeSequence *tree_sequence; @@ -184,6 +178,13 @@ typedef struct { tsk_viterbi_matrix_t *viterbi_matrix; } ViterbiMatrix; +typedef struct { + PyObject_HEAD + PyObject *owner; + bool read_only; + tsk_reference_sequence_t *reference_sequence; +} ReferenceSequence; + typedef struct { PyObject_HEAD tsk_identity_segments_t *identity_segments; @@ -5835,6 +5836,294 @@ static PyTypeObject IdentitySegmentsType = { // clang-format on }; +/*=================================================================== + * ReferenceSequence + *=================================================================== + */ + +static int +ReferenceSequence_check_read(ReferenceSequence *self) +{ + int ret = -1; + if (self->reference_sequence == NULL) { + PyErr_SetString(PyExc_SystemError, "ReferenceSequence not initialised"); + goto out; + } + ret = 0; +out: + return ret; +} + +static int +ReferenceSequence_check_write(ReferenceSequence *self) +{ + int ret = ReferenceSequence_check_read(self); + + if (ret != 0) { + goto out; + } + if (self->read_only) { + PyErr_SetString(PyExc_AttributeError, + "ReferenceSequence is read-only and can only be modified " + "in a TableCollection"); + ret = -1; + goto out; + } + ret = 0; +out: + return ret; +} + +static void +ReferenceSequence_dealloc(ReferenceSequence *self) +{ + self->reference_sequence = NULL; + Py_XDECREF(self->owner); + Py_TYPE(self)->tp_free((PyObject *) self); +} + +static int +ReferenceSequence_init(ReferenceSequence *self, PyObject *args, PyObject *kwds) +{ + self->reference_sequence = NULL; + self->owner = NULL; + self->read_only = true; + return 0; +} + +static PyObject * +ReferenceSequence_get_data(ReferenceSequence *self, void *closure) +{ + PyObject *ret = NULL; + + if (ReferenceSequence_check_read(self) != 0) { + goto out; + } + /* This isn't zero-copy, so we'll possible want to return a + * numpy array wrapping this at some point */ + ret = make_Py_Unicode_FromStringAndLength( + self->reference_sequence->data, self->reference_sequence->data_length); +out: + return ret; +} + +typedef int(refseq_string_setter_func)( + tsk_reference_sequence_t *obj, const char *str, tsk_size_t len); + +static int +ReferenceSequence_set_string_attr(ReferenceSequence *self, PyObject *arg, + const char *attr_name, refseq_string_setter_func setter_func) +{ + int ret = -1; + int err; + const char *str; + Py_ssize_t length; + + if (ReferenceSequence_check_write(self) != 0) { + goto out; + } + if (arg == NULL) { + PyErr_Format( + PyExc_AttributeError, "Cannot del %s, set to None to clear.", attr_name); + goto out; + } + if (!PyUnicode_Check(arg)) { + PyErr_Format(PyExc_TypeError, "%s must be a string", attr_name); + goto out; + } + str = PyUnicode_AsUTF8AndSize(arg, &length); + if (str == NULL) { + goto out; + } + err = setter_func(self->reference_sequence, str, (tsk_size_t) length); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} + +static int +ReferenceSequence_set_data(ReferenceSequence *self, PyObject *arg, void *closure) +{ + return ReferenceSequence_set_string_attr( + self, arg, "data", tsk_reference_sequence_set_data); +} + +static PyObject * +ReferenceSequence_get_url(ReferenceSequence *self, void *closure) +{ + PyObject *ret = NULL; + + if (ReferenceSequence_check_read(self) != 0) { + goto out; + } + ret = make_Py_Unicode_FromStringAndLength( + self->reference_sequence->url, self->reference_sequence->url_length); +out: + return ret; +} + +static int +ReferenceSequence_set_url(ReferenceSequence *self, PyObject *arg, void *closure) +{ + return ReferenceSequence_set_string_attr( + self, arg, "url", tsk_reference_sequence_set_url); +} + +static PyObject * +ReferenceSequence_get_metadata_schema(ReferenceSequence *self, void *closure) +{ + PyObject *ret = NULL; + + if (ReferenceSequence_check_read(self) != 0) { + goto out; + } + ret = make_Py_Unicode_FromStringAndLength(self->reference_sequence->metadata_schema, + self->reference_sequence->metadata_schema_length); +out: + return ret; +} + +static int +ReferenceSequence_set_metadata_schema( + ReferenceSequence *self, PyObject *arg, void *closure) +{ + return ReferenceSequence_set_string_attr( + self, arg, "metadata_schema", tsk_reference_sequence_set_metadata_schema); +} + +static PyObject * +ReferenceSequence_get_metadata(ReferenceSequence *self, void *closure) +{ + PyObject *ret = NULL; + + if (ReferenceSequence_check_read(self) != 0) { + goto out; + } + + ret = PyBytes_FromStringAndSize( + self->reference_sequence->metadata, self->reference_sequence->metadata_length); +out: + return ret; +} + +static int +ReferenceSequence_set_metadata(ReferenceSequence *self, PyObject *arg, void *closure) +{ + int ret = -1; + int err; + char *metadata; + Py_ssize_t metadata_length; + + if (ReferenceSequence_check_write(self) != 0) { + goto out; + } + if (arg == NULL) { + PyErr_Format(PyExc_AttributeError, + "Cannot del metadata, set to empty string (b\"\") to clear."); + goto out; + } + err = PyBytes_AsStringAndSize(arg, &metadata, &metadata_length); + if (err != 0) { + goto out; + } + err = tsk_reference_sequence_set_metadata( + self->reference_sequence, metadata, metadata_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} + +static PyObject * +ReferenceSequence_is_null(ReferenceSequence *self) +{ + PyObject *ret = NULL; + + if (ReferenceSequence_check_read(self) != 0) { + goto out; + } + ret = Py_BuildValue( + "i", (int) tsk_reference_sequence_is_null(self->reference_sequence)); +out: + return ret; +} + +static PyMethodDef ReferenceSequence_methods[] = { + { .ml_name = "is_null", + .ml_meth = (PyCFunction) ReferenceSequence_is_null, + .ml_flags = METH_NOARGS, + .ml_doc = "Returns True if this is the null reference sequence ." }, + { NULL } /* Sentinel */ +}; + +static PyGetSetDef ReferenceSequence_getsetters[] = { + { .name = "data", + .set = (setter) ReferenceSequence_set_data, + .get = (getter) ReferenceSequence_get_data, + .doc = "The data string for this reference sequence. " }, + { .name = "url", + .set = (setter) ReferenceSequence_set_url, + .get = (getter) ReferenceSequence_get_url, + .doc = "The url string for this reference sequence. " }, + { .name = "metadata_schema", + .set = (setter) ReferenceSequence_set_metadata_schema, + .get = (getter) ReferenceSequence_get_metadata_schema, + .doc = "The metadata_schema string for this reference sequence. " }, + { .name = "metadata", + .set = (setter) ReferenceSequence_set_metadata, + .get = (getter) ReferenceSequence_get_metadata, + .doc = "The metadata string for this reference sequence. " }, + { NULL } /* Sentinel */ +}; + +static PyTypeObject ReferenceSequenceType = { + // clang-format off + PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "_tskit.ReferenceSequence", + .tp_basicsize = sizeof(ReferenceSequence), + .tp_dealloc = (destructor) ReferenceSequence_dealloc, + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_doc = "A thin Python translation layer over the C tsk_reference_sequence_t struct", + .tp_methods = ReferenceSequence_methods, + .tp_getset = ReferenceSequence_getsetters, + .tp_init = (initproc) ReferenceSequence_init, + .tp_new = PyType_GenericNew, + // clang-format on +}; + +static PyObject * +ReferenceSequence_get_new( + tsk_reference_sequence_t *refseq, PyObject *owner, bool read_only) +{ + + PyObject *ret = NULL; + ReferenceSequence *py_refseq = NULL; + + py_refseq = (ReferenceSequence *) PyObject_CallObject( + (PyObject *) &ReferenceSequenceType, NULL); + if (py_refseq == NULL) { + goto out; + } + py_refseq->reference_sequence = refseq; + py_refseq->owner = owner; + py_refseq->read_only = read_only; + /* We increment the reference on the owner */ + Py_INCREF(owner); + + ret = (PyObject *) py_refseq; + py_refseq = NULL; +out: + Py_XDECREF(py_refseq); + return ret; +} + /*=================================================================== * TableCollection *=================================================================== @@ -6252,6 +6541,20 @@ TableCollection_set_metadata_schema(TableCollection *self, PyObject *arg, void * return ret; } +static PyObject * +TableCollection_get_reference_sequence(TableCollection *self, void *closure) +{ + PyObject *ret = NULL; + + if (TableCollection_check_state(self) != 0) { + goto out; + } + ret = ReferenceSequence_get_new( + &self->tables->reference_sequence, (PyObject *) self, false); +out: + return ret; +} + static PyObject * TableCollection_simplify(TableCollection *self, PyObject *args, PyObject *kwds) { @@ -6863,6 +7166,20 @@ TableCollection_set_indexes(TableCollection *self, PyObject *arg, void *closure) return ret; } +static PyObject * +TableCollection_has_reference_sequence(TableCollection *self) +{ + PyObject *ret = NULL; + + if (TableCollection_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue( + "i", (int) tsk_table_collection_has_reference_sequence(self->tables)); +out: + return ret; +} + static PyObject * TableCollection_has_index(TableCollection *self) { @@ -7132,6 +7449,9 @@ static PyGetSetDef TableCollection_getsetters[] = { .get = (getter) TableCollection_get_metadata_schema, .set = (setter) TableCollection_set_metadata_schema, .doc = "The metadata schema." }, + { .name = "reference_sequence", + .get = (getter) TableCollection_get_reference_sequence, + .doc = "The reference sequence." }, { NULL } /* Sentinel */ }; @@ -7200,6 +7520,10 @@ static PyMethodDef TableCollection_methods[] = { .ml_meth = (PyCFunction) TableCollection_drop_index, .ml_flags = METH_NOARGS, .ml_doc = "Drops indexes." }, + { .ml_name = "has_reference_sequence", + .ml_meth = (PyCFunction) TableCollection_has_reference_sequence, + .ml_flags = METH_NOARGS, + .ml_doc = "Returns True if the TableCollection has a reference sequence." }, { .ml_name = "has_index", .ml_meth = (PyCFunction) TableCollection_has_index, .ml_flags = METH_NOARGS, @@ -9018,6 +9342,34 @@ TreeSequence_get_genotype_matrix(TreeSequence *self, PyObject *args, PyObject *k return ret; } +static PyObject * +TreeSequence_has_reference_sequence(TreeSequence *self) +{ + PyObject *ret = NULL; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue( + "i", (int) tsk_treeseq_has_reference_sequence(self->tree_sequence)); +out: + return ret; +} + +static PyObject * +TreeSequence_get_reference_sequence(TreeSequence *self, void *closure) +{ + PyObject *ret = NULL; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + ret = ReferenceSequence_get_new( + &self->tree_sequence->tables->reference_sequence, (PyObject *) self, true); +out: + return ret; +} + static PyMethodDef TreeSequence_methods[] = { { .ml_name = "dump", .ml_meth = (PyCFunction) TreeSequence_dump, @@ -9223,6 +9575,17 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_get_genotype_matrix, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Returns the genotypes matrix." }, + { .ml_name = "has_reference_sequence", + .ml_meth = (PyCFunction) TreeSequence_has_reference_sequence, + .ml_flags = METH_NOARGS, + .ml_doc = "Returns True if the TreeSequence has a reference sequence." }, + { NULL } /* Sentinel */ +}; + +static PyGetSetDef TreeSequence_getsetters[] = { + { .name = "reference_sequence", + .get = (getter) TreeSequence_get_reference_sequence, + .doc = "The reference sequence." }, { NULL } /* Sentinel */ }; @@ -9235,6 +9598,7 @@ static PyTypeObject TreeSequenceType = { .tp_flags = Py_TPFLAGS_DEFAULT, .tp_doc = "TreeSequence objects", .tp_methods = TreeSequence_methods, + .tp_getset = TreeSequence_getsetters, .tp_init = (initproc) TreeSequence_init, .tp_new = PyType_GenericNew, // clang-format on @@ -11951,6 +12315,13 @@ PyInit__tskit(void) PyModule_AddObject( module, "IdentitySegmentList", (PyObject *) &IdentitySegmentListType); + /* ReferenceSequence type */ + if (PyType_Ready(&ReferenceSequenceType) < 0) { + return NULL; + } + Py_INCREF(&ReferenceSequenceType); + PyModule_AddObject(module, "ReferenceSequence", (PyObject *) &ReferenceSequenceType); + /* Metadata schemas namedtuple type*/ if (PyStructSequence_InitType2(&MetadataSchemas, &metadata_schemas_desc) < 0) { return NULL; diff --git a/python/lwt_interface/dict_encoding_testlib.py b/python/lwt_interface/dict_encoding_testlib.py index 2fdbc21453..8e249739ee 100644 --- a/python/lwt_interface/dict_encoding_testlib.py +++ b/python/lwt_interface/dict_encoding_testlib.py @@ -26,6 +26,7 @@ compiled module exporting the LightweightTableCollection class. See the test_example_c_module file for an example. """ +import copy import json import kastore @@ -44,35 +45,31 @@ @pytest.fixture(scope="session") def full_ts(): """ - Return a tree sequence that has data in all fields. - """ - """ - A tree sequence with data in all fields - duplcated from tskit's conftest.py + A tree sequence with data in all fields - duplicated from tskit's conftest.py as other test suites using this file will not have that fixture defined. """ - n = 10 - t = 1 - population_configurations = [ - msprime.PopulationConfiguration(n // 2), - msprime.PopulationConfiguration(n // 2), - msprime.PopulationConfiguration(0), - ] - demographic_events = [ - msprime.MassMigration(time=t, source=0, destination=2), - msprime.MassMigration(time=t, source=1, destination=2), - ] - ts = msprime.simulate( - population_configurations=population_configurations, - demographic_events=demographic_events, + demography = msprime.Demography() + demography.add_population(initial_size=100, name="A") + demography.add_population(initial_size=100, name="B") + demography.add_population(initial_size=100, name="C") + demography.add_population_split(time=10, ancestral="C", derived=["A", "B"]) + + ts = msprime.sim_ancestry( + {"A": 5, "B": 5}, + demography=demography, random_seed=1, - mutation_rate=1, + sequence_length=10, record_migrations=True, ) + assert ts.num_migrations > 0 + assert ts.num_individuals > 0 + ts = msprime.sim_mutations(ts, rate=0.1, random_seed=2) + assert ts.num_mutations > 0 tables = ts.dump_tables() - # TODO replace this with properly linked up individuals using sim_ancestry - # once 1.0 is released. - for j in range(n): - tables.individuals.add_row(flags=j, location=(j, j), parents=(j - 1, j - 1)) + tables.individuals.clear() + + for ind in ts.individuals(): + tables.individuals.add_row(flags=0, location=[ind.id, ind.id], parents=[-1, -1]) for name, table in tables.name_map.items(): if name != "provenances": @@ -87,7 +84,12 @@ def full_ts(): } ) tables.metadata_schema = tskit.MetadataSchema({"codec": "json"}) - tables.metadata = "Test metadata" + tables.metadata = {"A": "Test metadata"} + + tables.reference_sequence.data = "A" * int(tables.sequence_length) + tables.reference_sequence.url = "https://example.com/sequence" + tables.reference_sequence.metadata_schema = tskit.MetadataSchema.permissive_json() + tables.reference_sequence.metadata = {"A": "Test metadata"} # Add some more provenance so we have enough rows for the offset deletion test. for j in range(10): @@ -109,14 +111,13 @@ def test_check_ts_full(tmp_path, full_ts): full_ts.dump(tmp_path / "tables") store = kastore.load(tmp_path / "tables") for v in store.values(): - # Check we really have data in every field assert v.nbytes > 0 class TestEncodingVersion: def test_version(self): lwt = lwt_module.LightweightTableCollection() - assert lwt.asdict()["encoding_version"] == (1, 5) + assert lwt.asdict()["encoding_version"] == (1, 6) class TestRoundTrip: @@ -128,7 +129,7 @@ def verify(self, tables): lwt = lwt_module.LightweightTableCollection() lwt.fromdict(tables.asdict()) other_tables = tskit.TableCollection.fromdict(lwt.asdict()) - assert tables == other_tables + tables.assert_equals(other_tables) def test_simple(self): ts = msprime.simulate(10, mutation_rate=1, random_seed=2) @@ -242,6 +243,7 @@ def test_missing_tables(self, tables): "metadata_schema", "encoding_version", "indexes", + "reference_sequence", } for table_name in table_names: d = tables.asdict() @@ -265,15 +267,16 @@ def verify_columns(self, value, tables): "metadata_schema", "encoding_version", "indexes", + "reference_sequence", } for table_name in table_names: table_dict = d[table_name] for colname in set(table_dict.keys()) - {"metadata_schema"}: - copy = dict(table_dict) - copy[colname] = value + d_copy = dict(table_dict) + d_copy[colname] = value lwt = lwt_module.LightweightTableCollection() d = tables.asdict() - d[table_name] = copy + d[table_name] = d_copy with pytest.raises(ValueError): lwt.fromdict(d) @@ -308,15 +311,16 @@ def verify(self, num_rows, tables): "metadata_schema", "encoding_version", "indexes", + "reference_sequence", } for table_name in sorted(table_names): table_dict = d[table_name] for colname in set(table_dict.keys()) - {"metadata_schema"}: - copy = dict(table_dict) - copy[colname] = table_dict[colname][:num_rows].copy() + d_copy = dict(table_dict) + d_copy[colname] = table_dict[colname][:num_rows].copy() lwt = lwt_module.LightweightTableCollection() d = tables.asdict() - d[table_name] = copy + d[table_name] = d_copy with pytest.raises(ValueError): lwt.fromdict(d) @@ -350,6 +354,64 @@ def test_bad_index_length(self, tables): lwt.fromdict(d) +class TestParsingUtilities: + def test_missing_required(self, tables): + d = tables.asdict() + del d["sequence_length"] + lwt = lwt_module.LightweightTableCollection() + with pytest.raises(TypeError, match="'sequence_length' is required"): + lwt.fromdict(d) + + def test_string_bad_type(self, tables): + d = tables.asdict() + d["time_units"] = b"sdf" + lwt = lwt_module.LightweightTableCollection() + with pytest.raises(TypeError, match="'time_units' is not a string"): + lwt.fromdict(d) + + def test_bytes_bad_type(self, tables): + d = tables.asdict() + d["metadata"] = 1234 + lwt = lwt_module.LightweightTableCollection() + with pytest.raises(TypeError, match="'metadata' is not bytes"): + lwt.fromdict(d) + + def test_dict_bad_type(self, tables): + d = tables.asdict() + d["nodes"] = b"sdf" + lwt = lwt_module.LightweightTableCollection() + with pytest.raises(TypeError, match="'nodes' is not a dict"): + lwt.fromdict(d) + + def test_bad_strings(self, tables): + def verify_unicode_error(d): + lwt = lwt_module.LightweightTableCollection() + with pytest.raises(UnicodeEncodeError): + lwt.fromdict(d) + + def verify_bad_string_type(d): + lwt = lwt_module.LightweightTableCollection() + with pytest.raises(TypeError): + lwt.fromdict(d) + + d = tables.asdict() + for k, v in d.items(): + if isinstance(v, str): + d_copy = copy.deepcopy(d) + d_copy[k] = NON_UTF8_STRING + verify_unicode_error(d_copy) + d_copy[k] = 12345 + verify_bad_string_type(d_copy) + if isinstance(v, dict): + for kp, vp in v.items(): + if isinstance(vp, str): + d_copy = copy.deepcopy(d) + d_copy[k][kp] = NON_UTF8_STRING + verify_unicode_error(d_copy) + d_copy[k][kp] = 12345 + verify_bad_string_type(d_copy) + + class TestRequiredAndOptionalColumns: """ Tests that specifying None for some columns will give the intended @@ -371,9 +433,9 @@ def verify_required_columns(self, tables, table_name, required_cols): # Any one of these required columns as None gives an error. for col in required_cols: d = tables.asdict() - copy = dict(table_dict) - copy[col] = None - d[table_name] = copy + d_copy = copy.deepcopy(table_dict) + d_copy[col] = None + d[table_name] = d_copy lwt = lwt_module.LightweightTableCollection() with pytest.raises(TypeError): lwt.fromdict(d) @@ -381,9 +443,9 @@ def verify_required_columns(self, tables, table_name, required_cols): # Removing any one of these required columns gives an error. for col in required_cols: d = tables.asdict() - copy = dict(table_dict) - del copy[col] - d[table_name] = copy + d_copy = copy.deepcopy(table_dict) + del d_copy[col] + d[table_name] = d_copy lwt = lwt_module.LightweightTableCollection() with pytest.raises(TypeError): lwt.fromdict(d) @@ -595,6 +657,37 @@ def test_index(self, tables): ): lwt.fromdict(d) + def test_index_bad_type(self, tables): + d = tables.asdict() + lwt = lwt_module.LightweightTableCollection() + d["indexes"] = "asdf" + with pytest.raises(TypeError): + lwt.fromdict(d) + + def test_reference_sequence(self, tables): + self.verify_metadata_schema(tables, "reference_sequence") + + def get_refseq(d): + tables = tskit.TableCollection.fromdict(d) + return tables.reference_sequence + + d = tables.asdict() + refseq_dict = d.pop("reference_sequence") + assert get_refseq(d).is_null() + + # All empty strings is the same thing + d["reference_sequence"] = dict( + data="", url="", metadata_schema="", metadata=b"" + ) + assert get_refseq(d).is_null() + + del refseq_dict["metadata_schema"] # handled above + for key, value in refseq_dict.items(): + d["reference_sequence"] = {key: value} + refseq = get_refseq(d) + assert not refseq.is_null() + assert getattr(refseq, key) == value + def test_top_level_time_units(self, tables): d = tables.asdict() # None should give default value @@ -606,7 +699,6 @@ def test_top_level_time_units(self, tables): assert tables.time_units == tskit.TIME_UNITS_UNKNOWN # Missing is tested in TestMissingData above d = tables.asdict() - # None should give default value d["time_units"] = NON_UTF8_STRING lwt = lwt_module.LightweightTableCollection() with pytest.raises(UnicodeEncodeError): @@ -713,3 +805,10 @@ def test_values_equal(self, tables): col_64 = offsets_64[col_name] assert col_64.shape == col_32.shape assert np.all(col_64 == col_32) + + +@pytest.mark.parametrize("bad_type", [None, "", []]) +def test_fromdict_bad_type(bad_type): + lwt = lwt_module.LightweightTableCollection() + with pytest.raises(TypeError): + lwt.fromdict(bad_type) diff --git a/python/lwt_interface/tskit_lwt_interface.h b/python/lwt_interface/tskit_lwt_interface.h index 24b353ec87..ab611ad618 100644 --- a/python/lwt_interface/tskit_lwt_interface.h +++ b/python/lwt_interface/tskit_lwt_interface.h @@ -66,7 +66,7 @@ make_Py_Unicode_FromStringAndLength(const char *str, size_t length) * NB This returns a *borrowed reference*, so don't DECREF it! */ static PyObject * -get_table_dict_value(PyObject *dict, const char *key_str, bool required) +get_dict_value(PyObject *dict, const char *key_str, bool required) { PyObject *ret = NULL; @@ -81,6 +81,62 @@ get_table_dict_value(PyObject *dict, const char *key_str, bool required) return ret; } +/* Specialised version of get_dict_value that checks if the + * value is a dictionary. */ +static PyObject * +get_dict_value_dict(PyObject *dict, const char *key_str, bool required) +{ + PyObject *ret = NULL; + PyObject *value = get_dict_value(dict, key_str, required); + + if (value == NULL) { + goto out; + } + if (value != Py_None && !PyDict_Check(value)) { + PyErr_Format(PyExc_TypeError, "'%s' is not a dict", key_str); + goto out; + } + ret = value; +out: + return ret; +} + +static PyObject * +get_dict_value_string(PyObject *dict, const char *key_str, bool required) +{ + PyObject *ret = NULL; + PyObject *value = get_dict_value(dict, key_str, required); + + if (value == NULL) { + goto out; + } + if (value != Py_None && !PyUnicode_Check(value)) { + PyErr_Format(PyExc_TypeError, "'%s' is not a string", key_str); + goto out; + } + ret = value; +out: + return ret; +} + +static PyObject * +get_dict_value_bytes(PyObject *dict, const char *key_str, bool required) +{ + PyObject *ret = NULL; + PyObject *value = get_dict_value(dict, key_str, required); + + if (value == NULL) { + goto out; + } + if (value != Py_None && !PyBytes_Check(value)) { + PyErr_Format(PyExc_TypeError, "'%s' is not bytes", key_str); + goto out; + } + ret = value; +out: + return ret; +} + static PyArrayObject * table_read_column_array( PyObject *input, int npy_type, size_t *num_rows, bool check_num_rows) @@ -199,35 +255,35 @@ parse_individual_table_dict( Py_ssize_t metadata_schema_length = 0; /* Get the input values */ - flags_input = get_table_dict_value(dict, "flags", true); + flags_input = get_dict_value(dict, "flags", true); if (flags_input == NULL) { goto out; } - location_input = get_table_dict_value(dict, "location", false); + location_input = get_dict_value(dict, "location", false); if (location_input == NULL) { goto out; } - location_offset_input = get_table_dict_value(dict, "location_offset", false); + location_offset_input = get_dict_value(dict, "location_offset", false); if (location_offset_input == NULL) { goto out; } - parents_input = get_table_dict_value(dict, "parents", false); + parents_input = get_dict_value(dict, "parents", false); if (parents_input == NULL) { goto out; } - parents_offset_input = get_table_dict_value(dict, "parents_offset", false); + parents_offset_input = get_dict_value(dict, "parents_offset", false); if (parents_offset_input == NULL) { goto out; } - metadata_input = get_table_dict_value(dict, "metadata", false); + metadata_input = get_dict_value(dict, "metadata", false); if (metadata_input == NULL) { goto out; } - metadata_offset_input = get_table_dict_value(dict, "metadata_offset", false); + metadata_offset_input = get_dict_value(dict, "metadata_offset", false); if (metadata_offset_input == NULL) { goto out; } - metadata_schema_input = get_table_dict_value(dict, "metadata_schema", false); + metadata_schema_input = get_dict_value(dict, "metadata_schema", false); if (metadata_schema_input == NULL) { goto out; } @@ -362,31 +418,31 @@ parse_node_table_dict(tsk_node_table_t *table, PyObject *dict, bool clear_table) Py_ssize_t metadata_schema_length = 0; /* Get the input values */ - flags_input = get_table_dict_value(dict, "flags", true); + flags_input = get_dict_value(dict, "flags", true); if (flags_input == NULL) { goto out; } - time_input = get_table_dict_value(dict, "time", true); + time_input = get_dict_value(dict, "time", true); if (time_input == NULL) { goto out; } - population_input = get_table_dict_value(dict, "population", false); + population_input = get_dict_value(dict, "population", false); if (population_input == NULL) { goto out; } - individual_input = get_table_dict_value(dict, "individual", false); + individual_input = get_dict_value(dict, "individual", false); if (individual_input == NULL) { goto out; } - metadata_input = get_table_dict_value(dict, "metadata", false); + metadata_input = get_dict_value(dict, "metadata", false); if (metadata_input == NULL) { goto out; } - metadata_offset_input = get_table_dict_value(dict, "metadata_offset", false); + metadata_offset_input = get_dict_value(dict, "metadata_offset", false); if (metadata_offset_input == NULL) { goto out; } - metadata_schema_input = get_table_dict_value(dict, "metadata_schema", false); + metadata_schema_input = get_dict_value(dict, "metadata_schema", false); if (metadata_schema_input == NULL) { goto out; } @@ -500,31 +556,31 @@ parse_edge_table_dict(tsk_edge_table_t *table, PyObject *dict, bool clear_table) Py_ssize_t metadata_schema_length = 0; /* Get the input values */ - left_input = get_table_dict_value(dict, "left", true); + left_input = get_dict_value(dict, "left", true); if (left_input == NULL) { goto out; } - right_input = get_table_dict_value(dict, "right", true); + right_input = get_dict_value(dict, "right", true); if (right_input == NULL) { goto out; } - parent_input = get_table_dict_value(dict, "parent", true); + parent_input = get_dict_value(dict, "parent", true); if (parent_input == NULL) { goto out; } - child_input = get_table_dict_value(dict, "child", true); + child_input = get_dict_value(dict, "child", true); if (child_input == NULL) { goto out; } - metadata_input = get_table_dict_value(dict, "metadata", false); + metadata_input = get_dict_value(dict, "metadata", false); if (metadata_input == NULL) { goto out; } - metadata_offset_input = get_table_dict_value(dict, "metadata_offset", false); + metadata_offset_input = get_dict_value(dict, "metadata_offset", false); if (metadata_offset_input == NULL) { goto out; } - metadata_schema_input = get_table_dict_value(dict, "metadata_schema", false); + metadata_schema_input = get_dict_value(dict, "metadata_schema", false); if (metadata_schema_input == NULL) { goto out; } @@ -635,39 +691,39 @@ parse_migration_table_dict( Py_ssize_t metadata_schema_length = 0; /* Get the input values */ - left_input = get_table_dict_value(dict, "left", true); + left_input = get_dict_value(dict, "left", true); if (left_input == NULL) { goto out; } - right_input = get_table_dict_value(dict, "right", true); + right_input = get_dict_value(dict, "right", true); if (right_input == NULL) { goto out; } - node_input = get_table_dict_value(dict, "node", true); + node_input = get_dict_value(dict, "node", true); if (node_input == NULL) { goto out; } - source_input = get_table_dict_value(dict, "source", true); + source_input = get_dict_value(dict, "source", true); if (source_input == NULL) { goto out; } - dest_input = get_table_dict_value(dict, "dest", true); + dest_input = get_dict_value(dict, "dest", true); if (dest_input == NULL) { goto out; } - time_input = get_table_dict_value(dict, "time", true); + time_input = get_dict_value(dict, "time", true); if (time_input == NULL) { goto out; } - metadata_input = get_table_dict_value(dict, "metadata", false); + metadata_input = get_dict_value(dict, "metadata", false); if (metadata_input == NULL) { goto out; } - metadata_offset_input = get_table_dict_value(dict, "metadata_offset", false); + metadata_offset_input = get_dict_value(dict, "metadata_offset", false); if (metadata_offset_input == NULL) { goto out; } - metadata_schema_input = get_table_dict_value(dict, "metadata_schema", false); + metadata_schema_input = get_dict_value(dict, "metadata_schema", false); if (metadata_schema_input == NULL) { goto out; } @@ -782,28 +838,27 @@ parse_site_table_dict(tsk_site_table_t *table, PyObject *dict, bool clear_table) Py_ssize_t metadata_schema_length = 0; /* Get the input values */ - position_input = get_table_dict_value(dict, "position", true); + position_input = get_dict_value(dict, "position", true); if (position_input == NULL) { goto out; } - ancestral_state_input = get_table_dict_value(dict, "ancestral_state", true); + ancestral_state_input = get_dict_value(dict, "ancestral_state", true); if (ancestral_state_input == NULL) { goto out; } - ancestral_state_offset_input - = get_table_dict_value(dict, "ancestral_state_offset", true); + ancestral_state_offset_input = get_dict_value(dict, "ancestral_state_offset", true); if (ancestral_state_offset_input == NULL) { goto out; } - metadata_input = get_table_dict_value(dict, "metadata", false); + metadata_input = get_dict_value(dict, "metadata", false); if (metadata_input == NULL) { goto out; } - metadata_offset_input = get_table_dict_value(dict, "metadata_offset", false); + metadata_offset_input = get_dict_value(dict, "metadata_offset", false); if (metadata_offset_input == NULL) { goto out; } - metadata_schema_input = get_table_dict_value(dict, "metadata_schema", false); + metadata_schema_input = get_dict_value(dict, "metadata_schema", false); if (metadata_schema_input == NULL) { goto out; } @@ -917,40 +972,39 @@ parse_mutation_table_dict(tsk_mutation_table_t *table, PyObject *dict, bool clea Py_ssize_t metadata_schema_length = 0; /* Get the input values */ - site_input = get_table_dict_value(dict, "site", true); + site_input = get_dict_value(dict, "site", true); if (site_input == NULL) { goto out; } - node_input = get_table_dict_value(dict, "node", true); + node_input = get_dict_value(dict, "node", true); if (node_input == NULL) { goto out; } - parent_input = get_table_dict_value(dict, "parent", false); + parent_input = get_dict_value(dict, "parent", false); if (parent_input == NULL) { goto out; } - time_input = get_table_dict_value(dict, "time", false); + time_input = get_dict_value(dict, "time", false); if (time_input == NULL) { goto out; } - derived_state_input = get_table_dict_value(dict, "derived_state", true); + derived_state_input = get_dict_value(dict, "derived_state", true); if (derived_state_input == NULL) { goto out; } - derived_state_offset_input - = get_table_dict_value(dict, "derived_state_offset", true); + derived_state_offset_input = get_dict_value(dict, "derived_state_offset", true); if (derived_state_offset_input == NULL) { goto out; } - metadata_input = get_table_dict_value(dict, "metadata", false); + metadata_input = get_dict_value(dict, "metadata", false); if (metadata_input == NULL) { goto out; } - metadata_offset_input = get_table_dict_value(dict, "metadata_offset", false); + metadata_offset_input = get_dict_value(dict, "metadata_offset", false); if (metadata_offset_input == NULL) { goto out; } - metadata_schema_input = get_table_dict_value(dict, "metadata_schema", false); + metadata_schema_input = get_dict_value(dict, "metadata_schema", false); if (metadata_schema_input == NULL) { goto out; } @@ -1072,15 +1126,15 @@ parse_population_table_dict( Py_ssize_t metadata_schema_length = 0; /* Get the inputs */ - metadata_input = get_table_dict_value(dict, "metadata", true); + metadata_input = get_dict_value(dict, "metadata", true); if (metadata_input == NULL) { goto out; } - metadata_offset_input = get_table_dict_value(dict, "metadata_offset", true); + metadata_offset_input = get_dict_value(dict, "metadata_offset", true); if (metadata_offset_input == NULL) { goto out; } - metadata_schema_input = get_table_dict_value(dict, "metadata_schema", false); + metadata_schema_input = get_dict_value(dict, "metadata_schema", false); if (metadata_schema_input == NULL) { goto out; } @@ -1147,19 +1201,19 @@ parse_provenance_table_dict( PyArrayObject *record_offset_array = NULL; /* Get the inputs */ - timestamp_input = get_table_dict_value(dict, "timestamp", true); + timestamp_input = get_dict_value(dict, "timestamp", true); if (timestamp_input == NULL) { goto out; } - timestamp_offset_input = get_table_dict_value(dict, "timestamp_offset", true); + timestamp_offset_input = get_dict_value(dict, "timestamp_offset", true); if (timestamp_offset_input == NULL) { goto out; } - record_input = get_table_dict_value(dict, "record", true); + record_input = get_dict_value(dict, "record", true); if (record_input == NULL) { goto out; } - record_offset_input = get_table_dict_value(dict, "record_offset", true); + record_offset_input = get_dict_value(dict, "record_offset", true); if (record_offset_input == NULL) { goto out; } @@ -1220,11 +1274,11 @@ parse_indexes_dict(tsk_table_collection_t *tables, PyObject *dict) PyArrayObject *removal_array = NULL; /* Get the inputs */ - insertion_input = get_table_dict_value(dict, "edge_insertion_order", false); + insertion_input = get_dict_value(dict, "edge_insertion_order", false); if (insertion_input == NULL) { goto out; } - removal_input = get_table_dict_value(dict, "edge_removal_order", false); + removal_input = get_dict_value(dict, "edge_removal_order", false); if (removal_input == NULL) { goto out; } @@ -1271,6 +1325,89 @@ parse_indexes_dict(tsk_table_collection_t *tables, PyObject *dict) return ret; } +static int +parse_reference_sequence_dict(tsk_reference_sequence_t *ref, PyObject *dict) +{ + int err; + int ret = -1; + PyObject *value = NULL; + const char *metadata_schema, *data, *url; + char *metadata; + Py_ssize_t metadata_schema_length, metadata_length, data_length, url_length; + + /* metadata_schema */ + value = get_dict_value_string(dict, "metadata_schema", false); + if (value == NULL) { + goto out; + } + if (value != Py_None) { + metadata_schema = parse_unicode_arg(value, &metadata_schema_length); + if (metadata_schema == NULL) { + goto out; + } + err = tsk_reference_sequence_set_metadata_schema( + ref, metadata_schema, (tsk_size_t) metadata_schema_length); + if (err != 0) { + handle_tskit_error(err); + goto out; + } + } + + /* metadata */ + value = get_dict_value_bytes(dict, "metadata", false); + if (value == NULL) { + goto out; + } + if (value != Py_None) { + err = PyBytes_AsStringAndSize(value, &metadata, &metadata_length); + if (err != 0) { + goto out; + } + err = tsk_reference_sequence_set_metadata(ref, metadata, metadata_length); + if (err != 0) { + handle_tskit_error(err); + goto out; + } + } + + /* data */ + value = get_dict_value_string(dict, "data", false); + if (value == NULL) { + goto out; + } + if (value != Py_None) { + data = parse_unicode_arg(value, &data_length); + if (data == NULL) { + goto out; + } + err = tsk_reference_sequence_set_data(ref, data, (tsk_size_t) data_length); + if (err != 0) { + handle_tskit_error(err); + goto out; + } + } + + /* url */ + value = get_dict_value_string(dict, "url", false); + if (value == NULL) { + goto out; + } + if (value != Py_None) { + url = parse_unicode_arg(value, &url_length); + if (url == NULL) { + goto out; + } + err = tsk_reference_sequence_set_url(ref, url, (tsk_size_t) url_length); + if (err != 0) { + handle_tskit_error(err); + goto out; + } + } + ret = 0; +out: + return ret; +} + static int parse_table_collection_dict(tsk_table_collection_t *tables, PyObject *tables_dict) { @@ -1282,7 +1419,7 @@ parse_table_collection_dict(tsk_table_collection_t *tables, PyObject *tables_dic const char *metadata_schema = NULL; Py_ssize_t time_units_length, metadata_length, metadata_schema_length; - value = get_table_dict_value(tables_dict, "sequence_length", true); + value = get_dict_value(tables_dict, "sequence_length", true); if (value == NULL) { goto out; } @@ -1293,15 +1430,11 @@ parse_table_collection_dict(tsk_table_collection_t *tables, PyObject *tables_dic tables->sequence_length = PyFloat_AsDouble(value); /* metadata_schema */ - value = get_table_dict_value(tables_dict, "metadata_schema", false); + value = get_dict_value_string(tables_dict, "metadata_schema", false); if (value == NULL) { goto out; } if (value != Py_None) { - if (!PyUnicode_Check(value)) { - PyErr_Format(PyExc_TypeError, "'metadata_schema' is not a string"); - goto out; - } metadata_schema = parse_unicode_arg(value, &metadata_schema_length); if (metadata_schema == NULL) { goto out; @@ -1315,15 +1448,11 @@ parse_table_collection_dict(tsk_table_collection_t *tables, PyObject *tables_dic } /* metadata */ - value = get_table_dict_value(tables_dict, "metadata", false); + value = get_dict_value_bytes(tables_dict, "metadata", false); if (value == NULL) { goto out; } if (value != Py_None) { - if (!PyBytes_Check(value)) { - PyErr_Format(PyExc_TypeError, "'metadata' is not bytes"); - goto out; - } err = PyBytes_AsStringAndSize(value, &metadata, &metadata_length); if (err != 0) { goto out; @@ -1336,15 +1465,11 @@ parse_table_collection_dict(tsk_table_collection_t *tables, PyObject *tables_dic } /* time_units */ - value = get_table_dict_value(tables_dict, "time_units", false); + value = get_dict_value_string(tables_dict, "time_units", false); if (value == NULL) { goto out; } if (value != Py_None) { - if (!PyUnicode_Check(value)) { - PyErr_Format(PyExc_TypeError, "'time_units' is not a string"); - goto out; - } time_units = parse_unicode_arg(value, &time_units_length); if (time_units == NULL) { goto out; @@ -1357,124 +1482,98 @@ parse_table_collection_dict(tsk_table_collection_t *tables, PyObject *tables_dic } /* individuals */ - value = get_table_dict_value(tables_dict, "individuals", true); + value = get_dict_value_dict(tables_dict, "individuals", true); if (value == NULL) { goto out; } - if (!PyDict_Check(value)) { - PyErr_SetString(PyExc_TypeError, "not a dictionary"); - goto out; - } if (parse_individual_table_dict(&tables->individuals, value, true) != 0) { goto out; } /* nodes */ - value = get_table_dict_value(tables_dict, "nodes", true); + value = get_dict_value_dict(tables_dict, "nodes", true); if (value == NULL) { goto out; } - if (!PyDict_Check(value)) { - PyErr_SetString(PyExc_TypeError, "not a dictionary"); - goto out; - } if (parse_node_table_dict(&tables->nodes, value, true) != 0) { goto out; } /* edges */ - value = get_table_dict_value(tables_dict, "edges", true); + value = get_dict_value_dict(tables_dict, "edges", true); if (value == NULL) { goto out; } - if (!PyDict_Check(value)) { - PyErr_SetString(PyExc_TypeError, "not a dictionary"); - goto out; - } if (parse_edge_table_dict(&tables->edges, value, true) != 0) { goto out; } /* migrations */ - value = get_table_dict_value(tables_dict, "migrations", true); + value = get_dict_value_dict(tables_dict, "migrations", true); if (value == NULL) { goto out; } - if (!PyDict_Check(value)) { - PyErr_SetString(PyExc_TypeError, "not a dictionary"); - goto out; - } if (parse_migration_table_dict(&tables->migrations, value, true) != 0) { goto out; } /* sites */ - value = get_table_dict_value(tables_dict, "sites", true); + value = get_dict_value_dict(tables_dict, "sites", true); if (value == NULL) { goto out; } - if (!PyDict_Check(value)) { - PyErr_SetString(PyExc_TypeError, "not a dictionary"); - goto out; - } if (parse_site_table_dict(&tables->sites, value, true) != 0) { goto out; } /* mutations */ - value = get_table_dict_value(tables_dict, "mutations", true); + value = get_dict_value_dict(tables_dict, "mutations", true); if (value == NULL) { goto out; } - if (!PyDict_Check(value)) { - PyErr_SetString(PyExc_TypeError, "not a dictionary"); - goto out; - } if (parse_mutation_table_dict(&tables->mutations, value, true) != 0) { goto out; } /* populations */ - value = get_table_dict_value(tables_dict, "populations", true); + value = get_dict_value_dict(tables_dict, "populations", true); if (value == NULL) { goto out; } - if (!PyDict_Check(value)) { - PyErr_SetString(PyExc_TypeError, "not a dictionary"); - goto out; - } if (parse_population_table_dict(&tables->populations, value, true) != 0) { goto out; } /* provenances */ - value = get_table_dict_value(tables_dict, "provenances", true); + value = get_dict_value_dict(tables_dict, "provenances", true); if (value == NULL) { goto out; } - if (!PyDict_Check(value)) { - PyErr_SetString(PyExc_TypeError, "not a dictionary"); - goto out; - } if (parse_provenance_table_dict(&tables->provenances, value, true) != 0) { goto out; } /* indexes */ - value = get_table_dict_value(tables_dict, "indexes", false); + value = get_dict_value_dict(tables_dict, "indexes", false); if (value == NULL) { goto out; } if (value != Py_None) { - if (!PyDict_Check(value)) { - PyErr_SetString(PyExc_TypeError, "not a dictionary"); - goto out; - } if (parse_indexes_dict(tables, value) != 0) { goto out; } } + /* reference_sequence */ + value = get_dict_value_dict(tables_dict, "reference_sequence", false); + if (value == NULL) { + goto out; + } + if (value != Py_None) { + if (parse_reference_sequence_dict(&tables->reference_sequence, value) != 0) { + goto out; + } + } ret = 0; out: return ret; @@ -1573,11 +1672,47 @@ write_ragged_col(tsklwt_ragged_col_t *col, PyObject *table_dict, bool force_offs return ret; } +static int +write_string_to_dict(PyObject *dict, const char *key, const char *str, tsk_size_t length) +{ + int ret = -1; + PyObject *val = make_Py_Unicode_FromStringAndLength(str, length); + + if (val == NULL) { + goto out; + } + if (PyDict_SetItemString(dict, key, val) != 0) { + goto out; + } + ret = 0; +out: + Py_XDECREF(val); + return ret; +} + +static int +write_bytes_to_dict( + PyObject *dict, const char *key, const char *bytes, tsk_size_t length) +{ + int ret = -1; + PyObject *val = PyBytes_FromStringAndSize(bytes, length); + + if (val == NULL) { + goto out; + } + if (PyDict_SetItemString(dict, key, val) != 0) { + goto out; + } + ret = 0; +out: + Py_XDECREF(val); + return ret; +} + static PyObject * write_table_dict(const tsklwt_table_desc_t *table_desc, bool force_offset_64) { PyObject *ret = NULL; - PyObject *str = NULL; PyObject *table_dict = NULL; tsklwt_table_col_t *col; tsklwt_ragged_col_t *ragged_col; @@ -1602,19 +1737,15 @@ write_table_dict(const tsklwt_table_desc_t *table_desc, bool force_offset_64) } } if (table_desc->metadata_schema_length > 0) { - str = make_Py_Unicode_FromStringAndLength( - table_desc->metadata_schema, table_desc->metadata_schema_length); - if (str == NULL) { - goto out; - } - if (PyDict_SetItemString(table_dict, "metadata_schema", str) != 0) { + if (write_string_to_dict(table_dict, "metadata_schema", + table_desc->metadata_schema, table_desc->metadata_schema_length) + != 0) { goto out; } } ret = table_dict; table_dict = NULL; out: - Py_XDECREF(str); Py_XDECREF(table_dict); return ret; } @@ -1808,22 +1939,15 @@ write_table_arrays( return ret; } -/* Returns a dictionary encoding of the specified table collection */ -static PyObject * -dump_tables_dict(tsk_table_collection_t *tables, bool force_offset_64) +static int +write_top_level_data( + const tsk_table_collection_t *tables, PyObject *dict, bool force_offset_64) { - PyObject *ret = NULL; - PyObject *dict = NULL; + int ret = -1; PyObject *val = NULL; - int err; - - dict = PyDict_New(); - if (dict == NULL) { - goto out; - } /* Dict representation version */ - val = Py_BuildValue("ll", 1, 5); + val = Py_BuildValue("ll", 1, 6); if (val == NULL) { goto out; } @@ -1843,42 +1967,100 @@ dump_tables_dict(tsk_table_collection_t *tables, bool force_offset_64) Py_DECREF(val); val = NULL; + if (write_string_to_dict( + dict, "time_units", tables->time_units, tables->time_units_length) + != 0) { + goto out; + } if (tables->metadata_schema_length > 0) { - val = make_Py_Unicode_FromStringAndLength( - tables->metadata_schema, tables->metadata_schema_length); - if (val == NULL) { + if (write_string_to_dict(dict, "metadata_schema", tables->metadata_schema, + tables->metadata_schema_length) + != 0) { goto out; } - if (PyDict_SetItemString(dict, "metadata_schema", val) != 0) { + } + if (tables->metadata_length > 0) { + if (write_bytes_to_dict( + dict, "metadata", tables->metadata, tables->metadata_length) + != 0) { goto out; } - Py_DECREF(val); - val = NULL; } - if (tables->metadata_length > 0) { - val = PyBytes_FromStringAndSize(tables->metadata, tables->metadata_length); - if (val == NULL) { + ret = 0; +out: + Py_XDECREF(val); + return ret; +} + +static PyObject * +write_reference_sequence_dict(const tsk_reference_sequence_t *ref, bool force_offset_64) +{ + PyObject *ret = NULL; + PyObject *dict = NULL; + + dict = PyDict_New(); + if (dict == NULL) { + goto out; + } + + if (ref->metadata_schema_length > 0) { + if (write_string_to_dict(dict, "metadata_schema", ref->metadata_schema, + ref->metadata_schema_length) + != 0) { goto out; } - if (PyDict_SetItemString(dict, "metadata", val) != 0) { + } + if (ref->metadata_length > 0) { + if (write_bytes_to_dict(dict, "metadata", ref->metadata, ref->metadata_length) + != 0) { goto out; } - Py_DECREF(val); - val = NULL; } - - val = make_Py_Unicode_FromStringAndLength( - tables->time_units, tables->time_units_length); - if (val == NULL) { + if (write_string_to_dict(dict, "data", ref->data, ref->data_length) != 0) { goto out; } - if (PyDict_SetItemString(dict, "time_units", val) != 0) { + if (write_string_to_dict(dict, "url", ref->url, ref->url_length) != 0) { goto out; } - Py_DECREF(val); - val = NULL; + ret = dict; + dict = NULL; +out: + Py_XDECREF(dict); + return ret; +} + +/* Returns a dictionary encoding of the specified table collection */ +static PyObject * +dump_tables_dict(tsk_table_collection_t *tables, bool force_offset_64) +{ + PyObject *ret = NULL; + PyObject *dict = NULL; + PyObject *ref_dict = NULL; + int err; + + dict = PyDict_New(); + if (dict == NULL) { + goto out; + } + + err = write_top_level_data(tables, dict, force_offset_64); + if (err != 0) { + goto out; + } + if (tsk_table_collection_has_reference_sequence(tables)) { + ref_dict = write_reference_sequence_dict( + &tables->reference_sequence, force_offset_64); + if (ref_dict == NULL) { + goto out; + } + if (PyDict_SetItemString(dict, "reference_sequence", ref_dict) != 0) { + goto out; + } + Py_DECREF(ref_dict); + ref_dict = NULL; + } err = write_table_arrays(tables, dict, force_offset_64); if (err != 0) { goto out; @@ -1887,7 +2069,7 @@ dump_tables_dict(tsk_table_collection_t *tables, bool force_offset_64) dict = NULL; out: Py_XDECREF(dict); - Py_XDECREF(val); + Py_XDECREF(ref_dict); return ret; } diff --git a/python/requirements/development.txt b/python/requirements/development.txt index bd5ac945ea..b4568f7b3d 100644 --- a/python/requirements/development.txt +++ b/python/requirements/development.txt @@ -30,6 +30,7 @@ sphinx>=4.3 sphinx-argparse sphinx-issues sphinx-jupyterbook-latex +sphinxcontrib-prettyspecialmethods pydata_sphinx_theme>=0.7.2 svgwrite>=1.1.10 xmlunittest diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 741dc3881c..8d09031ce1 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -150,7 +150,7 @@ def ts_fixture(): # Add metadata for name, table in tables.name_map.items(): if name != "provenances": - table.metadata_schema = tskit.MetadataSchema({"codec": "json"}) + table.metadata_schema = tskit.MetadataSchema.permissive_json() metadatas = [f'{{"foo":"n_{name}_{u}"}}' for u in range(len(table))] metadata, metadata_offset = tskit.pack_strings(metadatas) table.set_columns( @@ -160,10 +160,17 @@ def ts_fixture(): "metadata_offset": metadata_offset, } ) - tables.metadata_schema = tskit.MetadataSchema({"codec": "json"}) + tables.metadata_schema = tskit.MetadataSchema.permissive_json() tables.metadata = "Test metadata" tables.time_units = "Test time units" + tables.reference_sequence.metadata_schema = tskit.MetadataSchema.permissive_json() + tables.reference_sequence.metadata = "Test reference metadata" + tables.reference_sequence.data = "A" * int(ts.sequence_length) + # NOTE: it's unclear whether we'll want to have this set at the same time as + # 'data', but it's useful to have something in all columns for now. + tables.reference_sequence.url = "http://example.com/a_reference" + # Add some more rows to provenance to have enough for testing. for _ in range(3): tables.provenances.add_row(record="A") diff --git a/python/tests/test_file_format.py b/python/tests/test_file_format.py index 4e361f908b..79b1195241 100644 --- a/python/tests/test_file_format.py +++ b/python/tests/test_file_format.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2020 Tskit Developers +# Copyright (c) 2018-2021 Tskit Developers # Copyright (c) 2016-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -43,7 +43,7 @@ CURRENT_FILE_MAJOR = 12 -CURRENT_FILE_MINOR = 6 +CURRENT_FILE_MINOR = 7 test_data_dir = os.path.join(os.path.dirname(__file__), "data") @@ -887,6 +887,74 @@ def test_empty_individual_parents(self): tables.assert_equals(ts3.tables) +class TestReferenceSequence: + def test_fixture_has_reference_sequence(self, ts_fixture): + assert ts_fixture.has_reference_sequence() + + def test_round_trip(self, ts_fixture, tmp_path): + ts1 = ts_fixture + temp_file = tmp_path / "tmp.trees" + ts1.dump(temp_file) + ts2 = tskit.load(temp_file) + ts1.tables.assert_equals(ts2.tables) + + def test_no_reference_sequence(self, ts_fixture, tmp_path): + ts1 = ts_fixture + temp_file = tmp_path / "tmp.trees" + ts1.dump(temp_file) + with kastore.load(temp_file) as store: + all_data = dict(store) + del all_data["reference_sequence/metadata_schema"] + del all_data["reference_sequence/metadata"] + del all_data["reference_sequence/data"] + del all_data["reference_sequence/url"] + for key in all_data.keys(): + assert not key.startswith("reference_sequence") + kastore.dump(all_data, temp_file) + ts2 = tskit.load(temp_file) + assert not ts2.has_reference_sequence() + tables = ts2.dump_tables() + tables.reference_sequence = ts1.reference_sequence + tables.assert_equals(ts1.tables) + + @pytest.mark.parametrize("attr", ["data", "url"]) + def test_missing_attr(self, ts_fixture, tmp_path, attr): + ts1 = ts_fixture + temp_file = tmp_path / "tmp.trees" + ts1.dump(temp_file) + with kastore.load(temp_file) as store: + all_data = dict(store) + del all_data[f"reference_sequence/{attr}"] + kastore.dump(all_data, temp_file) + ts2 = tskit.load(temp_file) + assert ts2.has_reference_sequence + assert getattr(ts2.reference_sequence, attr) == "" + + def test_missing_metadata(self, ts_fixture, tmp_path): + ts1 = ts_fixture + temp_file = tmp_path / "tmp.trees" + ts1.dump(temp_file) + with kastore.load(temp_file) as store: + all_data = dict(store) + del all_data["reference_sequence/metadata"] + kastore.dump(all_data, temp_file) + ts2 = tskit.load(temp_file) + assert ts2.has_reference_sequence + assert ts2.reference_sequence.metadata_bytes == b"" + + def test_missing_metadata_schema(self, ts_fixture, tmp_path): + ts1 = ts_fixture + temp_file = tmp_path / "tmp.trees" + ts1.dump(temp_file) + with kastore.load(temp_file) as store: + all_data = dict(store) + del all_data["reference_sequence/metadata_schema"] + kastore.dump(all_data, temp_file) + ts2 = tskit.load(temp_file) + assert ts2.has_reference_sequence + assert repr(ts2.reference_sequence.metadata_schema) == "" + + class TestFileFormatErrors(TestFileFormat): """ Tests for errors in the HDF5 format. diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 71af7d60c7..b3c4ba048e 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -39,6 +39,9 @@ import tskit +NON_UTF8_STRING = "\ud861\udd37" + + def get_tracked_sample_counts(ts, st, tracked_samples): """ Returns a list giving the number of samples in the specified list @@ -3308,6 +3311,189 @@ def test_named_tuple_init(self): assert metadata_schemas != metadata_schemas3 +class TestReferenceSequenceInputErrors: + @pytest.mark.parametrize("bad_type", [1234, b"bytes", None, {}]) + @pytest.mark.parametrize("attr", ["data", "url", "metadata_schema"]) + def test_string_bad_type(self, attr, bad_type): + refseq = _tskit.TableCollection().reference_sequence + with pytest.raises(TypeError, match=f"{attr} must be a string"): + setattr(refseq, attr, bad_type) + + @pytest.mark.parametrize("bad_type", [1234, "unicode", None, {}]) + def test_metadata_bad_type(self, bad_type): + refseq = _tskit.TableCollection().reference_sequence + with pytest.raises(TypeError): + refseq.metadata = bad_type + + @pytest.mark.parametrize("attr", ["data", "url", "metadata_schema"]) + def test_unicode_error(self, attr): + refseq = _tskit.TableCollection().reference_sequence + with pytest.raises(UnicodeEncodeError): + setattr(refseq, attr, NON_UTF8_STRING) + + @pytest.mark.parametrize("attr", ["data", "url", "metadata", "metadata_schema"]) + def test_del_attr(self, attr): + refseq = _tskit.TableCollection().reference_sequence + with pytest.raises(AttributeError, match=f"Cannot del {attr}"): + delattr(refseq, attr) + + +class TestReferenceSequenceUpdates: + @pytest.mark.parametrize("value", ["abc", "🎄🌳🌴🌲🎋"]) + @pytest.mark.parametrize("attr", ["data", "url", "metadata_schema"]) + def test_set_string(self, attr, value): + refseq = _tskit.TableCollection().reference_sequence + assert refseq.is_null() + setattr(refseq, attr, value) + assert getattr(refseq, attr) == value + assert not refseq.is_null() + + @pytest.mark.parametrize("attr", ["data", "url", "metadata_schema"]) + def test_set_string_null_none(self, attr): + refseq = _tskit.TableCollection().reference_sequence + assert refseq.is_null() + setattr(refseq, attr, "a") + assert not refseq.is_null() + setattr(refseq, attr, "") + assert refseq.is_null() + + @pytest.mark.parametrize("value", [b"x", b"{}", b"abc\0defg"]) + def test_set_metadata(self, value): + refseq = _tskit.TableCollection().reference_sequence + assert refseq.is_null() + refseq.metadata = value + assert not refseq.is_null() + refseq.metadata = b"" + assert refseq.is_null() + + +class TestReferenceSequenceTableCollection: + def test_references(self): + tables = _tskit.TableCollection() + refseq = tables.reference_sequence + assert refseq is not tables.reference_sequence + + def test_state(self): + tables = _tskit.TableCollection() + refseq = tables.reference_sequence + assert refseq.is_null() + assert not tables.has_reference_sequence() + # Setting any non empty string changes the state to "non-null" + refseq.data = "x" + assert tables.has_reference_sequence() + assert not refseq.is_null() + + @pytest.mark.parametrize("ref_data", ["abc", "A" * 10, "🎄🌳🌴🌲🎋"]) + def test_data(self, ref_data): + tables = _tskit.TableCollection() + refseq = tables.reference_sequence + assert refseq.data == "" + refseq.data = ref_data + assert refseq.data == ref_data + assert tables.reference_sequence.data == ref_data + + @pytest.mark.parametrize("url", ["", "abc", "A" * 10, "🎄🌳🌴🌲🎋"]) + def test_url(self, url): + tables = _tskit.TableCollection() + refseq = tables.reference_sequence + assert refseq.url == "" + refseq.url = url + assert refseq.url == url + assert tables.reference_sequence.url == url + + def test_metadata_default_none(self): + tables = _tskit.TableCollection() + assert tables.reference_sequence.metadata_schema == "" + assert tables.reference_sequence.metadata == b"" + + # we don't actually check the form here, just pass in and out strings + @pytest.mark.parametrize("schema", ["", "{}", "abcdefg"]) + def test_metadata_schema(self, schema): + tables = _tskit.TableCollection() + tables.reference_sequence.metadata_schema = schema + assert tables.has_reference_sequence + assert tables.reference_sequence.metadata_schema == schema + + @pytest.mark.parametrize("metadata", [b"", b"{}", b"abcdefg"]) + def test_metadata(self, metadata): + tables = _tskit.TableCollection() + tables.reference_sequence.metadata = metadata + assert tables.has_reference_sequence + assert tables.reference_sequence.metadata == metadata + + +class TestReferenceSequenceTreeSequence: + def test_references(self): + tc = _tskit.TableCollection() + tc.sequence_length = 1 + ts = _tskit.TreeSequence() + ts.load_tables(tc, build_indexes=True) + refseq = ts.reference_sequence + assert refseq is not ts.reference_sequence + assert refseq is not tc + + def test_state(self): + tc = _tskit.TableCollection() + tc.sequence_length = 1 + ts = _tskit.TreeSequence() + ts.load_tables(tc, build_indexes=True) + assert not ts.has_reference_sequence() + + def test_write(self): + tc = _tskit.TableCollection() + tc.sequence_length = 1 + ts = _tskit.TreeSequence() + ts.load_tables(tc, build_indexes=True) + refseq = ts.reference_sequence + with pytest.raises(AttributeError, match="read-only"): + refseq.data = "asdf" + with pytest.raises(AttributeError, match="read-only"): + refseq.url = "asdf" + with pytest.raises(AttributeError, match="read-only"): + refseq.metadata_schema = "asdf" + with pytest.raises(AttributeError, match="read-only"): + refseq.metadata = "asdf" + + @pytest.mark.parametrize("ref_data", ["", "ACTG" * 10, "🎄🌳🌴🌲🎋"]) + def test_data(self, ref_data): + tc = _tskit.TableCollection() + tc.sequence_length = 1 + tc.reference_sequence.data = ref_data + ts = _tskit.TreeSequence() + ts.load_tables(tc, build_indexes=True) + assert ts.reference_sequence.data == ref_data + + @pytest.mark.parametrize("url", ["", "ACTG" * 10, "🎄🌳🌴🌲🎋"]) + def test_url(self, url): + tc = _tskit.TableCollection() + tc.sequence_length = 1 + tc.reference_sequence.url = url + ts = _tskit.TreeSequence() + ts.load_tables(tc, build_indexes=True) + assert ts.reference_sequence.url == url + + # we don't actually check the form here, just pass in and out strings + @pytest.mark.parametrize("schema", ["", "{}", "abcdefg"]) + def test_metadata_schema(self, schema): + tc = _tskit.TableCollection() + tc.sequence_length = 1 + tc.reference_sequence.metadata_schema = schema + ts = _tskit.TreeSequence() + ts.load_tables(tc, build_indexes=True) + assert ts.has_reference_sequence + assert ts.reference_sequence.metadata_schema == schema + + @pytest.mark.parametrize("metadata", [b"", b"{}", b"abcdefg"]) + def test_metadata(self, metadata): + tc = _tskit.TableCollection() + tc.sequence_length = 1 + tc.reference_sequence.metadata = metadata + ts = _tskit.TreeSequence() + ts.load_tables(tc, build_indexes=True) + assert ts.has_reference_sequence + assert ts.reference_sequence.metadata == metadata + + class TestModuleFunctions: """ Tests for the module level functions. diff --git a/python/tests/test_reference_sequence.py b/python/tests/test_reference_sequence.py new file mode 100644 index 0000000000..ec3f3fce59 --- /dev/null +++ b/python/tests/test_reference_sequence.py @@ -0,0 +1,243 @@ +# MIT License +# +# Copyright (c) 2021 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Tests for reference sequence support. +""" +import pytest + +import tskit + + +class TestTablesProperties: + def test_initially_not_set(self): + tables = tskit.TableCollection(1) + assert not tables.has_reference_sequence() + tables.reference_sequence.data = "ABCDEF" + assert tables.reference_sequence.data == "ABCDEF" + assert tables.has_reference_sequence() + + def test_does_not_have_reference_sequence_if_empty(self): + tables = tskit.TableCollection(1) + assert not tables.has_reference_sequence() + tables.reference_sequence.data = "" + assert not tables.has_reference_sequence() + + def test_same_object(self): + tables = tskit.TableCollection(1) + refseq = tables.reference_sequence + tables.reference_sequence.data = "asdf" + assert refseq.data == "asdf" + # Not clear we want to do this, but keeping the same pattern as the + # tables for now. + assert tables.reference_sequence is not refseq + + def test_clear(self, ts_fixture): + tables = ts_fixture.dump_tables() + tables.reference_sequence.clear() + assert not tables.has_reference_sequence() + + def test_write_object_fails_bad_type(self): + tables = tskit.TableCollection(1) + with pytest.raises(AttributeError): + tables.reference_sequence = None + + def test_write_object(self, ts_fixture): + tables = tskit.TableCollection(1) + tables.reference_sequence = ts_fixture.reference_sequence + tables.reference_sequence.assert_equals(ts_fixture.reference_sequence) + + def test_asdict_no_reference(self): + tables = tskit.TableCollection(1) + d = tables.asdict() + assert "reference_sequence" not in d + + def test_asdict_reference_no_metadata(self): + tables = tskit.TableCollection(1) + tables.reference_sequence.data = "ABCDEF" + d = tables.asdict()["reference_sequence"] + assert d["data"] == "ABCDEF" + assert d["url"] == "" + assert "metadata" not in d + assert "metadata_schema" not in d + + def test_asdict_reference_metadata(self): + tables = tskit.TableCollection(1) + tables.reference_sequence.metadata_schema = ( + tskit.MetadataSchema.permissive_json() + ) + tables.reference_sequence.metadata = {"a": "ABCDEF"} + d = tables.asdict()["reference_sequence"] + assert d["data"] == "" + assert d["url"] == "" + assert d["metadata_schema"] == '{"codec":"json"}' + assert d["metadata"] == b'{"a":"ABCDEF"}' + + def test_fromdict_reference_data(self): + d = tskit.TableCollection(1).asdict() + d["reference_sequence"] = {"data": "XYZ"} + tables = tskit.TableCollection.fromdict(d) + assert tables.has_reference_sequence() + assert tables.reference_sequence.data == "XYZ" + assert tables.reference_sequence.url == "" + assert repr(tables.reference_sequence.metadata_schema) == "" + assert tables.reference_sequence.metadata == b"" + + def test_fromdict_reference_url(self): + d = tskit.TableCollection(1).asdict() + d["reference_sequence"] = {"url": "file://file.fasta"} + tables = tskit.TableCollection.fromdict(d) + assert tables.has_reference_sequence() + assert tables.reference_sequence.data == "" + assert tables.reference_sequence.url == "file://file.fasta" + assert repr(tables.reference_sequence.metadata_schema) == "" + assert tables.reference_sequence.metadata == b"" + + def test_fromdict_reference_metadata(self): + tables = tskit.TableCollection(1) + tables.reference_sequence.metadata_schema = ( + tskit.MetadataSchema.permissive_json() + ) + tables.reference_sequence.metadata = {"a": "ABCDEF"} + tables = tskit.TableCollection.fromdict(tables.asdict()) + assert tables.has_reference_sequence() + assert tables.reference_sequence.data == "" + assert ( + tables.reference_sequence.metadata_schema + == tskit.MetadataSchema.permissive_json() + ) + assert tables.reference_sequence.metadata == {"a": "ABCDEF"} + + def test_fromdict_no_reference(self): + d = tskit.TableCollection(1).asdict() + tables = tskit.TableCollection.fromdict(d) + assert not tables.has_reference_sequence() + + def test_fromdict_all_values_empty(self): + d = tskit.TableCollection(1).asdict() + d["reference_sequence"] = dict( + data="", url="", metadata_schema="", metadata=b"" + ) + tables = tskit.TableCollection.fromdict(d) + assert not tables.has_reference_sequence() + + +class TestSummaries: + def test_repr(self): + tables = tskit.TableCollection(1) + refseq = tables.reference_sequence + # TODO add better tests when summaries are updated + assert repr(refseq).startswith("ReferenceSequence") + + +class TestAssertEquals: + def test_success_self(self, ts_fixture): + ts_fixture.reference_sequence.assert_equals(ts_fixture.reference_sequence) + + def test_success_empty(self): + tables = tskit.TableCollection(1) + tables.reference_sequence.assert_equals(tables.reference_sequence) + + @pytest.mark.parametrize("attr", ["url", "data"]) + def test_fails_attr_missing(self, ts_fixture, attr): + t1 = ts_fixture.tables + d = t1.asdict() + del d["reference_sequence"][attr] + t2 = tskit.TableCollection.fromdict(d) + with pytest.raises(AssertionError, match=attr): + t1.reference_sequence.assert_equals(t2.reference_sequence) + with pytest.raises(AssertionError, match=attr): + t2.reference_sequence.assert_equals(t1.reference_sequence) + + def test_fails_metadata_different(self, ts_fixture): + t1 = ts_fixture.dump_tables() + t2 = t1.copy() + t1.reference_sequence.metadata = {"different": "metadata"} + with pytest.raises(AssertionError, match="metadata"): + t1.reference_sequence.assert_equals(t2.reference_sequence) + with pytest.raises(AssertionError, match="metadata"): + t2.reference_sequence.assert_equals(t1.reference_sequence) + + def test_fails_metadata_schema_different(self, ts_fixture): + t1 = ts_fixture.dump_tables() + t2 = t1.copy() + t1.reference_sequence.metadata_schema = tskit.MetadataSchema(None) + with pytest.raises(AssertionError, match="schemas"): + t1.reference_sequence.assert_equals(t2.reference_sequence) + with pytest.raises(AssertionError, match="schemas"): + t2.reference_sequence.assert_equals(t1.reference_sequence) + + +class TestTreeSequenceProperties: + @pytest.mark.parametrize("data", ["abcd", "🎄🌳🌴"]) + def test_data_inherited_from_tables(self, data): + tables = tskit.TableCollection(1) + tables.reference_sequence.data = data + ts = tables.tree_sequence() + assert ts.reference_sequence.data == data + assert ts.has_reference_sequence() + + @pytest.mark.parametrize("url", ["http://xyx.z", "file://"]) + def test_url_inherited_from_tables(self, url): + tables = tskit.TableCollection(1) + tables.reference_sequence.url = url + ts = tables.tree_sequence() + assert ts.reference_sequence.url == url + assert ts.has_reference_sequence() + + def test_no_reference_sequence(self): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + assert not ts.has_reference_sequence() + + def test_write_data_fails(self): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + with pytest.raises(AttributeError, match="read-only"): + ts.reference_sequence.data = "xyz" + + def test_write_url_fails(self): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + with pytest.raises(AttributeError, match="read-only"): + ts.reference_sequence.url = "xyz" + + def test_write_metadata_fails(self): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + with pytest.raises(AttributeError, match="read-only"): + # NOTE: it can be slightly confusing here because we try to encode + # first, and so we don't get an AttributeError for all inputs. + ts.reference_sequence.metadata = b"xyz" + + def test_write_metadata_schema_fails(self): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + with pytest.raises(AttributeError, match="read-only"): + ts.reference_sequence.metadata_schema = ( + tskit.MetadataSchema.permissive_json() + ) + + def test_write_object_fails(self, ts_fixture): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + with pytest.raises(AttributeError): + ts.reference_sequence = ts_fixture.reference_sequence diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index ae3e709a8d..ee753e43a8 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -3346,7 +3346,7 @@ def test_nbytes(self, tmp_path, ts_fixture): def test_asdict(self, ts_fixture): t = ts_fixture.dump_tables() d1 = { - "encoding_version": (1, 5), + "encoding_version": (1, 6), "sequence_length": t.sequence_length, "metadata_schema": repr(t.metadata_schema), "metadata": t.metadata_schema.encode_row(t.metadata), @@ -3360,6 +3360,7 @@ def test_asdict(self, ts_fixture): "migrations": t.migrations.asdict(), "provenances": t.provenances.asdict(), "indexes": t.indexes.asdict(), + "reference_sequence": t.reference_sequence.asdict(), } d2 = t.asdict() assert set(d1.keys()) == set(d2.keys()) @@ -3414,6 +3415,7 @@ def test_from_dict(self, ts_fixture): "migrations": t1.migrations.asdict(), "provenances": t1.provenances.asdict(), "indexes": t1.indexes.asdict(), + "reference_sequence": t1.reference_sequence.asdict(), } t2 = tskit.TableCollection.fromdict(d) t1.assert_equals(t2) diff --git a/python/tskit/metadata.py b/python/tskit/metadata.py index f172be86ee..187b439127 100644 --- a/python/tskit/metadata.py +++ b/python/tskit/metadata.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020 Tskit Developers +# Copyright (c) 2020-2021 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,8 @@ """ Classes for metadata decoding, encoding and validation """ +from __future__ import annotations + import abc import builtins import collections @@ -36,6 +38,7 @@ from typing import Mapping from typing import Optional from typing import Type +from typing import Union import jsonschema @@ -697,6 +700,14 @@ def encode_row(self, row: Any) -> bytes: # Set by __init__ pass # pragma: no cover + # Utility to make a simple permission JSON schema. Probably should be + # part of the documented API. See + # https://github.com/tskit-dev/tskit/issues/1956 for more details. + + @staticmethod + def permissive_json(): + return MetadataSchema({"codec": "json"}) + # Often many replicate tree sequences are processed with identical schemas, so cache them @functools.lru_cache(maxsize=128) @@ -771,3 +782,60 @@ def new_init(self, *args, metadata_decoder=None, **kwargs): for k, v in sloted_members.items(): setattr(new_cls, k, v) return new_cls + + +class MetadataProvider: + """ + Abstract superclass of container objects that provide metadata. + """ + + def __init__(self, ll_object): + self._ll_object = ll_object + + @property + def metadata_schema(self) -> MetadataSchema: + """ + The :class:`tskit.MetadataSchema` for this object. + """ + return parse_metadata_schema(self._ll_object.metadata_schema) + + @metadata_schema.setter + def metadata_schema(self, schema: MetadataSchema) -> None: + # Check the schema is a valid schema instance by roundtripping it. + text_version = repr(schema) + parse_metadata_schema(text_version) + self._ll_object.metadata_schema = text_version + + @property + def metadata(self) -> Any: + """ + The decoded metadata for this object. + """ + return self.metadata_schema.decode_row(self.metadata_bytes) + + @metadata.setter + def metadata(self, metadata: Optional[Union[bytes, dict]]) -> None: + encoded = self.metadata_schema.validate_and_encode_row(metadata) + self._ll_object.metadata = encoded + + @property + def metadata_bytes(self) -> Any: + """ + The raw bytes of metadata for this TableCollection + """ + return self._ll_object.metadata + + @property + def nbytes(self) -> int: + return len(self._ll_object.metadata) + len(self._ll_object.metadata_schema) + + def assert_equals(self, other: MetadataProvider): + if self.metadata_schema != other.metadata_schema: + raise AssertionError( + f"Metadata schemas differ: self={self.metadata_schema} " + f"other={other.metadata_schema}" + ) + if self.metadata != other.metadata: + raise AssertionError( + f"Metadata differs: self={self.metadata} " f"other={other.metadata}" + ) diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 15d3c7a172..dc982a565d 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -34,7 +34,6 @@ from collections.abc import Mapping from dataclasses import dataclass from functools import reduce -from typing import Any from typing import Dict from typing import Optional from typing import Union @@ -547,11 +546,17 @@ def _repr_html_(self): """ -class MetadataMixin: +class MetadataColumnMixin: """ Mixin class for tables that have a metadata column. """ + # TODO this class has some overlap with the MetadataProvider base class + # and also the TreeSequence class. These all have methods to deal with + # schemas and essentially do the same thing (provide a facade for the + # low-level get/set metadata schemas functionality). We should refactor + # this so we're only doing it in one place. + # https://github.com/tskit-dev/tskit/issues/1957 def __init__(self): base_row_class = self.row_class @@ -647,7 +652,7 @@ def getter(d, k): return out -class IndividualTable(BaseTable, MetadataMixin): +class IndividualTable(BaseTable, MetadataColumnMixin): """ A table defining the individuals in a tree sequence. Note that although each Individual has associated nodes, reference to these is not stored in @@ -905,7 +910,7 @@ def packset_parents(self, parents): self.set_columns(**d) -class NodeTable(BaseTable, MetadataMixin): +class NodeTable(BaseTable, MetadataColumnMixin): """ A table defining the nodes in a tree sequence. See the :ref:`definitions ` for details on the columns @@ -1103,7 +1108,7 @@ def append_columns( ) -class EdgeTable(BaseTable, MetadataMixin): +class EdgeTable(BaseTable, MetadataColumnMixin): """ A table defining the edges in a tree sequence. See the :ref:`definitions ` for details on the columns @@ -1308,7 +1313,7 @@ def squash(self): self.ll_table.squash() -class MigrationTable(BaseTable, MetadataMixin): +class MigrationTable(BaseTable, MetadataColumnMixin): """ A table defining the migrations in a tree sequence. See the :ref:`definitions ` for details on the columns @@ -1530,7 +1535,7 @@ def append_columns( ) -class SiteTable(BaseTable, MetadataMixin): +class SiteTable(BaseTable, MetadataColumnMixin): """ A table defining the sites in a tree sequence. See the :ref:`definitions ` for details on the columns @@ -1742,7 +1747,7 @@ def packset_ancestral_state(self, ancestral_states): self.set_columns(**d) -class MutationTable(BaseTable, MetadataMixin): +class MutationTable(BaseTable, MetadataColumnMixin): """ A table defining the mutations in a tree sequence. See the :ref:`definitions ` for details on the columns @@ -2004,7 +2009,7 @@ def packset_derived_state(self, derived_states): self.set_columns(**d) -class PopulationTable(BaseTable, MetadataMixin): +class PopulationTable(BaseTable, MetadataColumnMixin): """ A table defining the populations referred to in a tree sequence. The PopulationTable stores metadata for populations that may be referred to @@ -2493,7 +2498,75 @@ def __len__(self): return self.num_pairs -class TableCollection: +# TODO move to reference_sequence.py when we start adding more functionality. +class ReferenceSequence(metadata.MetadataProvider): + def __init__(self, ll_reference_sequence): + super().__init__(ll_reference_sequence) + self._ll_reference_sequence = ll_reference_sequence + + def is_null(self) -> bool: + return bool(self._ll_reference_sequence.is_null()) + + def clear(self): + self.data = "" + self.url = "" + self.metadata_schema = tskit.MetadataSchema(None) + self.metadata = b"" + + # https://github.com/tskit-dev/tskit/issues/1984 + # TODO add a __str__ method + # TODO add a _repr_html_ + # FIXME This is a shortcut, we want to put the values in explicitly + # here to get more control over how they are displayed. + def __repr__(self): + return f"ReferenceSequence({repr(self.asdict())})" + + @property + def data(self) -> str: + return self._ll_reference_sequence.data + + @data.setter + def data(self, value): + self._ll_reference_sequence.data = value + + @property + def url(self) -> str: + return self._ll_reference_sequence.url + + @url.setter + def url(self, value): + self._ll_reference_sequence.url = value + + def asdict(self) -> dict: + return { + "metadata_schema": repr(self.metadata_schema), + "metadata": self.metadata_bytes, + "data": self.data, + "url": self.url, + } + + def assert_equals(self, other, ignore_metadata=False): + if not ignore_metadata: + super().assert_equals(other) + + if self.data != other.data: + raise AssertionError( + f"Reference sequence data differs: self={self.data} " + f"other={other.data}" + ) + if self.url != other.url: + raise AssertionError( + f"Reference sequence url differs: self={self.url} " f"other={other.url}" + ) + + @property + def nbytes(self): + # TODO this will be inefficient when we work with large references. + # Make a dedicated low-level method for getting the length of data. + return super().nbytes + len(self.url) + len(self.data) + + +class TableCollection(metadata.MetadataProvider): """ A collection of mutable tables defining a tree sequence. See the :ref:`sec_data_model` section for definition on the various tables @@ -2509,6 +2582,7 @@ class TableCollection: def __init__(self, sequence_length=0): self._ll_tables = _tskit.TableCollection(sequence_length) + super().__init__(self._ll_tables) @property def individuals(self) -> IndividualTable: @@ -2597,39 +2671,6 @@ def file_uuid(self) -> str: """ return self._ll_tables.file_uuid - @property - def metadata_schema(self) -> metadata.MetadataSchema: - """ - The :class:`tskit.MetadataSchema` for this TableCollection. - """ - return metadata.parse_metadata_schema(self._ll_tables.metadata_schema) - - @metadata_schema.setter - def metadata_schema(self, schema: metadata.MetadataSchema) -> None: - # Check the schema is a valid schema instance by roundtripping it. - metadata.parse_metadata_schema(repr(schema)) - self._ll_tables.metadata_schema = repr(schema) - - @property - def metadata(self) -> Any: - """ - The decoded metadata for this TableCollection. - """ - return self.metadata_schema.decode_row(self._ll_tables.metadata) - - @metadata.setter - def metadata(self, metadata: Optional[Union[bytes, dict]]) -> None: - self._ll_tables.metadata = self.metadata_schema.validate_and_encode_row( - metadata - ) - - @property - def metadata_bytes(self) -> Any: - """ - The raw bytes of metadata for this TableCollection - """ - return self._ll_tables.metadata - @property def time_units(self) -> str: """ @@ -2641,6 +2682,28 @@ def time_units(self) -> str: def time_units(self, time_units: str) -> None: self._ll_tables.time_units = time_units + def has_reference_sequence(self): + """ + Returns True if this TableCollection has an associated reference + sequence. + """ + return bool(self._ll_tables.has_reference_sequence()) + + @property + def reference_sequence(self): + # NOTE: arguably we should cache the reference to this object + # during init, rather than creating a new instance each time. + # However, following the pattern of the Table classes for now + # for consistency. + return ReferenceSequence(self._ll_tables.reference_sequence) + + @reference_sequence.setter + def reference_sequence(self, value: ReferenceSequence): + self.reference_sequence.metadata_schema = value.metadata_schema + self.reference_sequence.metadata = value.metadata + self.reference_sequence.data = value.data + self.reference_sequence.url = value.url + def asdict(self, force_offset_64=False): """ Returns the nested dictionary representation of this TableCollection @@ -2658,6 +2721,9 @@ def asdict(self, force_offset_64=False): """ return self._ll_tables.asdict(force_offset_64) + # TODO rename this to "table_name_map" to resolve the issue with whether + # we should regard ReferenceSequence as being in it or not. + # https://github.com/tskit-dev/tskit/issues/1981 @property def name_map(self) -> Dict: """ @@ -2686,10 +2752,10 @@ def nbytes(self) -> int: return sum( ( 8, # sequence_length takes 8 bytes - len(self.metadata_bytes), + super().nbytes, # metadata len(self.time_units.encode()), - len(repr(self.metadata_schema).encode()), self.indexes.nbytes, + self.reference_sequence.nbytes, sum(table.nbytes for table in self.name_map.values()), ) ) @@ -2824,23 +2890,19 @@ def assert_equals( ): return + if not ignore_metadata or ignore_ts_metadata: + super().assert_equals(other) + + self.reference_sequence.assert_equals( + other.reference_sequence, ignore_metadata=ignore_metadata + ) + if self.time_units != other.time_units: raise AssertionError( f"Time units differs: self={self.time_units} " f"other={other.time_units}" ) - if not ignore_metadata or ignore_ts_metadata: - if self.metadata_schema != other.metadata_schema: - raise AssertionError( - f"Metadata schemas differ: self={self.metadata_schema} " - f"other={other.metadata_schema}" - ) - if self.metadata != other.metadata: - raise AssertionError( - f"Metadata differs: self={self.metadata} " f"other={other.metadata}" - ) - if self.sequence_length != other.sequence_length: raise AssertionError( f"Sequence Length" diff --git a/python/tskit/trees.py b/python/tskit/trees.py index cabb312adb..69d78ef2c1 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -3690,6 +3690,17 @@ def dump(self, file_or_path, zlib_compression=False): if local_file: file.close() + @property + def reference_sequence(self): + return tables.ReferenceSequence(self._ll_tree_sequence.reference_sequence) + + def has_reference_sequence(self): + """ + Returns True if this TreeSequence has an associated reference + sequence. + """ + return bool(self._ll_tree_sequence.has_reference_sequence()) + @property def tables_dict(self): """