diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 2988a4693a..0c082cb53f 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -4122,6 +4122,9 @@ test_sort_tables_canonical_errors(void) ret = tsk_table_collection_init(&tables, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); tables.sequence_length = 1; + tsk_id_t null_p[] = { -1 }; + tsk_id_t zero_p[] = { 0 }; + tsk_id_t one_p[] = { 1 }; ret = tsk_node_table_add_row(&tables.nodes, 0, 0.0, TSK_NULL, TSK_NULL, NULL, 0); CU_ASSERT_FATAL(ret >= 0); @@ -4140,7 +4143,7 @@ test_sort_tables_canonical_errors(void) CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MUTATION_PARENT_INCONSISTENT); ret = tsk_mutation_table_clear(&tables.mutations); - CU_ASSERT_FATAL(ret >= 0); + CU_ASSERT_FATAL(ret == 0); ret = tsk_mutation_table_add_row(&tables.mutations, 0, 0, 2, 0.0, "a", 1, NULL, 0); CU_ASSERT_FATAL(ret >= 0); ret = tsk_mutation_table_add_row(&tables.mutations, 0, 0, 3, 0.0, "b", 1, NULL, 0); @@ -4153,6 +4156,44 @@ test_sort_tables_canonical_errors(void) ret = tsk_table_collection_canonicalise(&tables, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_node_table_add_row(&tables.nodes, 0, 0.0, TSK_NULL, 0, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_node_table_add_row(&tables.nodes, 0, 0.0, TSK_NULL, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_individual_table_add_row( + &tables.individuals, 0, NULL, 0, one_p, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_individual_table_add_row( + &tables.individuals, 0, NULL, 0, zero_p, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + + ret = tsk_table_collection_canonicalise(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INDIVIDUAL_PARENT_CYCLE); + + ret = tsk_individual_table_clear(&tables.individuals); + CU_ASSERT_FATAL(ret == 0); + ret = tsk_individual_table_add_row( + &tables.individuals, 0, NULL, 0, zero_p, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_individual_table_add_row( + &tables.individuals, 0, NULL, 0, zero_p, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + + ret = tsk_table_collection_canonicalise(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INDIVIDUAL_SELF_PARENT); + + ret = tsk_individual_table_clear(&tables.individuals); + CU_ASSERT_FATAL(ret == 0); + ret = tsk_individual_table_add_row( + &tables.individuals, 0, NULL, 0, null_p, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + ret = tsk_individual_table_add_row( + &tables.individuals, 0, NULL, 0, zero_p, 1, NULL, 0); + CU_ASSERT_FATAL(ret >= 0); + + ret = tsk_table_collection_canonicalise(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_table_collection_free(&tables); } @@ -4169,10 +4210,10 @@ test_sort_tables_canonical(void) "0 1 2 -1\n" "0 2 -1 2\n" "0 3 -1 -1\n"; - const char *individuals = "0 0.0\n" - "0 1.0\n" - "0 2.0\n" - "0 3.0\n"; + const char *individuals = "0 0.0 1\n" + "0 1.0 -1\n" + "0 2.0 1,3\n" + "0 3.0 -1,1\n"; const char *sites = "0 0\n" "0.2 0\n" "0.1 0\n"; @@ -4192,9 +4233,9 @@ test_sort_tables_canonical(void) "0 1 0 -1\n" "0 2 -1 2\n" "0 3 -1 -1\n"; - const char *individuals_sorted = "0 1.0\n" - "0 3.0\n" - "0 2.0\n"; + const char *individuals_sorted = "0 1.0 -1\n" + "0 3.0 -1,0\n" + "0 2.0 0,1\n"; const char *sites_sorted = "0 0\n" "0.1 0\n" "0.2 0\n"; @@ -4207,10 +4248,10 @@ test_sort_tables_canonical(void) "2 1 4 4 0.5\n" "2 1 5 6 0.5\n" "2 1 6 6 0.5\n"; - const char *individuals_sorted_kept = "0 1.0\n" - "0 3.0\n" - "0 2.0\n" - "0 0.0\n"; + const char *individuals_sorted_kept = "0 1.0 -1\n" + "0 3.0 -1,0\n" + "0 2.0 0,1\n" + "0 0.0 0\n"; ret = tsk_table_collection_init(&t1, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -5338,6 +5379,8 @@ test_table_collection_subset_with_options(tsk_flags_t options) tsk_table_collection_t tables_copy; int k; tsk_id_t nodes[4]; + tsk_id_t zero_p[] = { 0 }; + tsk_id_t one_p[] = { 1 }; ret = tsk_table_collection_init(&tables, options); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -5360,15 +5403,18 @@ test_table_collection_subset_with_options(tsk_flags_t options) ret = tsk_node_table_add_row( &tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, 1, NULL, 0); CU_ASSERT_FATAL(ret >= 0); + // unused individual who is the parent of others ret = tsk_individual_table_add_row( &tables.individuals, 0, NULL, 0, NULL, 0, NULL, 0); + ret = tsk_individual_table_add_row( + &tables.individuals, 0, NULL, 0, zero_p, 1, NULL, 0); CU_ASSERT_FATAL(ret >= 0); ret = tsk_individual_table_add_row( - &tables.individuals, 0, NULL, 0, NULL, 0, NULL, 0); + &tables.individuals, 0, NULL, 0, one_p, 1, NULL, 0); CU_ASSERT_FATAL(ret >= 0); // unused individual ret = tsk_individual_table_add_row( - &tables.individuals, 0, NULL, 0, NULL, 0, NULL, 0); + &tables.individuals, 0, NULL, 0, one_p, 1, NULL, 0); CU_ASSERT_FATAL(ret >= 0); ret = tsk_population_table_add_row(&tables.populations, NULL, 0); CU_ASSERT_FATAL(ret >= 0); @@ -5425,7 +5471,8 @@ test_table_collection_subset_with_options(tsk_flags_t options) ret = tsk_table_collection_subset(&tables_copy, NULL, 0, TSK_KEEP_UNREFERENCED); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(tables_copy.nodes.num_rows, 0); - CU_ASSERT_EQUAL_FATAL(tables_copy.individuals.num_rows, 3); + CU_ASSERT_FATAL( + tsk_individual_table_equals(&tables.individuals, &tables_copy.individuals, 0)); CU_ASSERT_EQUAL_FATAL(tables_copy.populations.num_rows, 2); CU_ASSERT_EQUAL_FATAL(tables_copy.mutations.num_rows, 0); CU_ASSERT_FATAL(tsk_site_table_equals(&tables.sites, &tables_copy.sites, 0)); @@ -5437,13 +5484,14 @@ test_table_collection_subset_with_options(tsk_flags_t options) &tables_copy, NULL, 0, TSK_KEEP_UNREFERENCED | TSK_NO_CHANGE_POPULATIONS); CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL_FATAL(tables_copy.nodes.num_rows, 0); - CU_ASSERT_EQUAL_FATAL(tables_copy.individuals.num_rows, 3); + CU_ASSERT_FATAL( + tsk_individual_table_equals(&tables.individuals, &tables_copy.individuals, 0)); CU_ASSERT_EQUAL_FATAL(tables_copy.mutations.num_rows, 0); CU_ASSERT_FATAL( tsk_population_table_equals(&tables.populations, &tables_copy.populations, 0)); CU_ASSERT_FATAL(tsk_site_table_equals(&tables.sites, &tables_copy.sites, 0)); - // the identity transformation, since unused inds/pops are at the end + // the identity transformation, since unused pops are at the end for (k = 0; k < 4; k++) { nodes[k] = k; } @@ -5458,6 +5506,8 @@ test_table_collection_subset_with_options(tsk_flags_t options) CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_table_collection_subset(&tables_copy, nodes, 4, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_check_integrity(&tables_copy, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_FATAL(tsk_node_table_equals(&tables.nodes, &tables_copy.nodes, 0)); CU_ASSERT_EQUAL_FATAL(tables_copy.individuals.num_rows, 2); CU_ASSERT_EQUAL_FATAL(tables_copy.populations.num_rows, 1); @@ -5465,7 +5515,7 @@ test_table_collection_subset_with_options(tsk_flags_t options) CU_ASSERT_FATAL( tsk_mutation_table_equals(&tables.mutations, &tables_copy.mutations, 0)); - // reverse twice should get back to the start, since unused inds/pops are at the end + // reverse twice should get back to the start, since unused pops are at the end for (k = 0; k < 4; k++) { nodes[k] = 3 - k; } @@ -5475,6 +5525,8 @@ test_table_collection_subset_with_options(tsk_flags_t options) CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_table_collection_subset(&tables_copy, nodes, 4, TSK_KEEP_UNREFERENCED); CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_table_collection_check_integrity(&tables_copy, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_FATAL(tsk_table_collection_equals(&tables, &tables_copy, 0)); tsk_table_collection_free(&tables_copy); @@ -5496,6 +5548,7 @@ test_table_collection_subset_unsorted(void) tsk_table_collection_t tables_copy; int k; tsk_id_t nodes[3]; + tsk_id_t one_p[] = { 1 }; ret = tsk_table_collection_init(&tables, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -5508,10 +5561,14 @@ test_table_collection_subset_unsorted(void) &tables.nodes, TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL, NULL, 0); CU_ASSERT_FATAL(ret >= 0); ret = tsk_node_table_add_row( - &tables.nodes, TSK_NODE_IS_SAMPLE, 0.5, TSK_NULL, TSK_NULL, NULL, 0); + &tables.nodes, TSK_NODE_IS_SAMPLE, 0.5, TSK_NULL, 1, NULL, 0); CU_ASSERT_FATAL(ret >= 0); - ret = tsk_node_table_add_row(&tables.nodes, 0, 1.0, TSK_NULL, TSK_NULL, NULL, 0); + ret = tsk_node_table_add_row(&tables.nodes, 0, 1.0, TSK_NULL, 0, NULL, 0); CU_ASSERT_FATAL(ret >= 0); + ret = tsk_individual_table_add_row( + &tables.individuals, 0, NULL, 0, one_p, 1, NULL, 0); + ret = tsk_individual_table_add_row( + &tables.individuals, 0, NULL, 0, NULL, 0, NULL, 0); ret = tsk_edge_table_add_row(&tables.edges, 0.0, 0.5, 2, 1, NULL, 0); CU_ASSERT_FATAL(ret >= 0); ret = tsk_edge_table_add_row(&tables.edges, 0.0, 1.0, 1, 0, NULL, 0); diff --git a/c/tskit/tables.c b/c/tskit/tables.c index cd42f6f41e..0fd544af79 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -4548,6 +4548,12 @@ typedef struct { int num_descendants; } mutation_canonical_sort_t; +typedef struct { + tsk_individual_t ind; + tsk_id_t first_node; + tsk_size_t num_descendants; +} individual_canonical_sort_t; + typedef struct { double left; double right; @@ -4620,6 +4626,22 @@ cmp_mutation_canonical(const void *a, const void *b) return ret; } +static int +cmp_individual_canonical(const void *a, const void *b) +{ + const individual_canonical_sort_t *ia = (const individual_canonical_sort_t *) a; + const individual_canonical_sort_t *ib = (const individual_canonical_sort_t *) b; + int ret = (ia->num_descendants < ib->num_descendants) + - (ia->num_descendants > ib->num_descendants); + if (ret == 0) { + ret = (ia->first_node > ib->first_node) - (ia->first_node < ib->first_node); + } + if (ret == 0) { + ret = (ia->ind.id > ib->ind.id) - (ia->ind.id < ib->ind.id); + } + return ret; +} + static int cmp_edge(const void *a, const void *b) { @@ -4960,66 +4982,63 @@ tsk_table_sorter_sort_mutations_canonical(tsk_table_sorter_t *self) } static int -tsk_table_sorter_sort_individuals(tsk_table_sorter_t *self) +tsk_individual_table_topological_sort( + tsk_individual_table_t *self, tsk_id_t *traversal_order, tsk_size_t *num_descendants) { int ret = 0; - tsk_id_t i; - tsk_individual_table_t copy; + tsk_id_t i, j, p; tsk_individual_t individual; - tsk_individual_table_t *individuals = &self->tables->individuals; - tsk_node_table_t *nodes = &self->tables->nodes; - tsk_size_t num_individuals = individuals->num_rows; + tsk_size_t num_individuals = self->num_rows; tsk_size_t current_todo = 0; tsk_size_t todo_insertion_point = 0; tsk_size_t *incoming_edge_count = malloc(num_individuals * sizeof(*incoming_edge_count)); - tsk_id_t *individuals_todo - = malloc((num_individuals + 1) * sizeof(*individuals_todo)); - tsk_id_t *new_id_map = malloc(num_individuals * sizeof(*new_id_map)); + bool count_descendants = (num_descendants != NULL); - ret = tsk_individual_table_copy(individuals, ©, 0); - if (ret != 0) { - goto out; - } - - if (incoming_edge_count == NULL || individuals_todo == NULL || new_id_map == NULL) { + if (incoming_edge_count == NULL) { ret = TSK_ERR_NO_MEMORY; goto out; } for (i = 0; i < (tsk_id_t) num_individuals; i++) { incoming_edge_count[i] = 0; - individuals_todo[i] = TSK_NULL; - new_id_map[i] = TSK_NULL; + traversal_order[i] = TSK_NULL; + if (count_descendants) { + num_descendants[i] = 0; + } } - individuals_todo[num_individuals] = TSK_NULL; /* Sentinel value */ /* First find the set of individuals that have no children by creating * an array of incoming edge counts */ - for (i = 0; i < (tsk_id_t) individuals->parents_length; i++) { - if (individuals->parents[i] != TSK_NULL) { - incoming_edge_count[individuals->parents[i]]++; + for (i = 0; i < (tsk_id_t) self->parents_length; i++) { + if (self->parents[i] != TSK_NULL) { + incoming_edge_count[self->parents[i]]++; } } /* Use these as the starting points for checking all individuals, * doing this in reverse makes the sort stable */ for (i = (tsk_id_t) num_individuals - 1; i >= 0; i--) { if (incoming_edge_count[i] == 0) { - individuals_todo[todo_insertion_point] = i; + traversal_order[todo_insertion_point] = i; todo_insertion_point++; } } - /* Now emit individuals from the set that have no children, removing their edges - * as we go adding new individuals to the no children set. */ - while (individuals_todo[current_todo] != TSK_NULL) { - tsk_individual_table_get_row_unsafe( - individuals, individuals_todo[current_todo], &individual); + /* Now process individuals from the set that have no children, updating their + * parents' information as we go, and adding their parents to the list if + * this was their last child */ + while (current_todo < todo_insertion_point) { + j = traversal_order[current_todo]; + tsk_individual_table_get_row_unsafe(self, j, &individual); for (i = 0; i < (tsk_id_t) individual.parents_length; i++) { - if (individual.parents[i] != TSK_NULL) { - incoming_edge_count[individual.parents[i]]--; - if (incoming_edge_count[individual.parents[i]] == 0) { - individuals_todo[todo_insertion_point] = individual.parents[i]; + p = individual.parents[i]; + if (p != TSK_NULL) { + incoming_edge_count[p]--; + if (count_descendants) { + num_descendants[p] += 1 + num_descendants[j]; + } + if (incoming_edge_count[p] == 0) { + traversal_order[todo_insertion_point] = p; todo_insertion_point++; } } @@ -5035,21 +5054,55 @@ tsk_table_sorter_sort_individuals(tsk_table_sorter_t *self) } } +out: + tsk_safe_free(incoming_edge_count); + return ret; +} + +static int +tsk_table_sorter_sort_individuals(tsk_table_sorter_t *self) +{ + int ret = 0; + tsk_id_t i; + tsk_individual_table_t copy; + tsk_individual_t individual; + tsk_individual_table_t *individuals = &self->tables->individuals; + tsk_node_table_t *nodes = &self->tables->nodes; + tsk_size_t num_individuals = individuals->num_rows; + tsk_id_t *traversal_order = malloc(num_individuals * sizeof(*traversal_order)); + tsk_id_t *new_id_map = malloc(num_individuals * sizeof(*new_id_map)); + + if (new_id_map == NULL || traversal_order == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + memset(new_id_map, 0xff, num_individuals * sizeof(*new_id_map)); + + ret = tsk_individual_table_copy(individuals, ©, 0); + if (ret != 0) { + goto out; + } + ret = tsk_individual_table_clear(individuals); if (ret != 0) { goto out; } + ret = tsk_individual_table_topological_sort(©, traversal_order, NULL); + if (ret != 0) { + goto out; + } + /* The sorted individuals are in reverse order */ for (i = (tsk_id_t) num_individuals - 1; i >= 0; i--) { - tsk_individual_table_get_row_unsafe(©, individuals_todo[i], &individual); + tsk_individual_table_get_row_unsafe(©, traversal_order[i], &individual); ret = tsk_individual_table_add_row(individuals, individual.flags, individual.location, individual.location_length, individual.parents, individual.parents_length, individual.metadata, individual.metadata_length); if (ret < 0) { goto out; } - new_id_map[individuals_todo[i]] = ret; + new_id_map[traversal_order[i]] = ret; } /* Rewrite the parent ids */ @@ -5067,13 +5120,109 @@ tsk_table_sorter_sort_individuals(tsk_table_sorter_t *self) ret = 0; out: - tsk_safe_free(incoming_edge_count); - tsk_safe_free(individuals_todo); + tsk_safe_free(traversal_order); tsk_safe_free(new_id_map); tsk_individual_table_free(©); return ret; } +static int +tsk_table_sorter_sort_individuals_canonical(tsk_table_sorter_t *self) +{ + int ret = 0; + tsk_id_t i, j, parent, mapped_parent; + tsk_individual_table_t *individuals = &self->tables->individuals; + tsk_node_table_t *nodes = &self->tables->nodes; + tsk_individual_table_t copy; + tsk_size_t num_individuals = individuals->num_rows; + individual_canonical_sort_t *sorted_individuals + = malloc(num_individuals * sizeof(*sorted_individuals)); + tsk_id_t *individual_id_map = malloc(num_individuals * sizeof(*individual_id_map)); + tsk_size_t *num_descendants = malloc(num_individuals * sizeof(*num_descendants)); + tsk_id_t *traversal_order = malloc(num_individuals * sizeof(*traversal_order)); + + if (individual_id_map == NULL || sorted_individuals == NULL + || traversal_order == NULL || num_descendants == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + ret = tsk_individual_table_copy(individuals, ©, 0); + if (ret != 0) { + goto out; + } + ret = tsk_individual_table_clear(individuals); + if (ret != 0) { + goto out; + } + + ret = tsk_individual_table_topological_sort(©, traversal_order, num_descendants); + if (ret != 0) { + goto out; + } + + for (i = 0; i < (tsk_id_t) num_individuals; i++) { + sorted_individuals[i].num_descendants = num_descendants[i]; + sorted_individuals[i].first_node = (tsk_id_t) nodes->num_rows; + } + + /* find first referring node */ + for (j = 0; j < (tsk_id_t) nodes->num_rows; j++) { + if (nodes->individual[j] != TSK_NULL) { + sorted_individuals[nodes->individual[j]].first_node + = TSK_MIN(j, sorted_individuals[nodes->individual[j]].first_node); + } + } + + for (j = 0; j < (tsk_id_t) num_individuals; j++) { + tsk_individual_table_get_row_unsafe( + ©, (tsk_id_t) j, &sorted_individuals[j].ind); + } + + qsort(sorted_individuals, num_individuals, sizeof(*sorted_individuals), + cmp_individual_canonical); + + /* Make a first pass through the sorted individuals to build the ID map. */ + for (j = 0; j < (tsk_id_t) num_individuals; j++) { + individual_id_map[sorted_individuals[j].ind.id] = (tsk_id_t) j; + } + + for (i = 0; i < (tsk_id_t) num_individuals; i++) { + for (j = 0; j < (tsk_id_t) sorted_individuals[i].ind.parents_length; j++) { + parent = sorted_individuals[i].ind.parents[j]; + if (parent != TSK_NULL) { + mapped_parent = individual_id_map[parent]; + sorted_individuals[i].ind.parents[j] = mapped_parent; + } + } + ret = tsk_individual_table_add_row(individuals, sorted_individuals[i].ind.flags, + sorted_individuals[i].ind.location, + sorted_individuals[i].ind.location_length, sorted_individuals[i].ind.parents, + sorted_individuals[i].ind.parents_length, sorted_individuals[i].ind.metadata, + sorted_individuals[i].ind.metadata_length); + if (ret < 0) { + goto out; + } + } + ret = 0; + + /* remap individuals in the node table */ + for (i = 0; i < (tsk_id_t) nodes->num_rows; i++) { + j = nodes->individual[i]; + if (j != TSK_NULL) { + nodes->individual[i] = individual_id_map[j]; + } + } + +out: + tsk_safe_free(sorted_individuals); + tsk_safe_free(individual_id_map); + tsk_safe_free(traversal_order); + tsk_safe_free(num_descendants); + tsk_individual_table_free(©); + return ret; +} + int tsk_table_sorter_run(tsk_table_sorter_t *self, const tsk_bookmark_t *start) { @@ -5144,7 +5293,7 @@ tsk_table_sorter_run(tsk_table_sorter_t *self, const tsk_bookmark_t *start) } } if (!skip_individuals) { - ret = tsk_table_sorter_sort_individuals(self); + ret = self->sort_individuals(self); if (ret != 0) { goto out; } @@ -5177,6 +5326,7 @@ tsk_table_sorter_init( /* Set the sort_edges and sort_mutations methods to the default. */ self->sort_edges = tsk_table_sorter_sort_edges; self->sort_mutations = tsk_table_sorter_sort_mutations; + self->sort_individuals = tsk_table_sorter_sort_individuals; out: return ret; } @@ -9191,6 +9341,7 @@ tsk_table_collection_canonicalise(tsk_table_collection_t *self, tsk_flags_t opti goto out; } sorter.sort_mutations = tsk_table_sorter_sort_mutations_canonical; + sorter.sort_individuals = tsk_table_sorter_sort_individuals_canonical; nodes = malloc(self->nodes.num_rows * sizeof(*nodes)); if (nodes == NULL) { @@ -9719,7 +9870,9 @@ tsk_table_collection_subset(tsk_table_collection_t *self, const tsk_id_t *nodes, tsk_size_t num_nodes, tsk_flags_t options) { int ret = 0; - tsk_id_t j, k, new_parent, new_child, new_node, site_id; + tsk_id_t j, k, parent_ind, new_parent, new_child, new_node, site_id; + tsk_size_t num_parents; + tsk_individual_t ind; tsk_edge_t edge; tsk_id_t *node_map = NULL; tsk_id_t *individual_map = NULL; @@ -9727,7 +9880,6 @@ tsk_table_collection_subset(tsk_table_collection_t *self, const tsk_id_t *nodes, tsk_id_t *site_map = NULL; tsk_id_t *mutation_map = NULL; tsk_table_collection_t tables; - tsk_individual_t ind; tsk_population_t pop; tsk_site_t site; tsk_mutation_t mut; @@ -9774,7 +9926,62 @@ tsk_table_collection_subset(tsk_table_collection_t *self, const tsk_id_t *nodes, } } - // Nodes, individuals, populations + // First do individuals so they stay in the same order. + // So we can remap individual parents and not rely on sortedness, + // we first check who to keep; then build the individual map, and + // finally populate the tables. + if (keep_unreferenced) { + for (k = 0; k < (tsk_id_t) tables.individuals.num_rows; k++) { + // put a non-NULL value here; fill in the actual order next + individual_map[k] = 0; + } + } else { + for (k = 0; k < (tsk_id_t) num_nodes; k++) { + if (nodes[k] < 0 || nodes[k] >= (tsk_id_t) tables.nodes.num_rows) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + j = tables.nodes.individual[nodes[k]]; + if (j != TSK_NULL) { + individual_map[j] = 0; + } + } + } + j = 0; + for (k = 0; k < (tsk_id_t) tables.individuals.num_rows; k++) { + if (individual_map[k] != TSK_NULL) { + individual_map[k] = j; + j++; + } + } + for (k = 0; k < (tsk_id_t) tables.individuals.num_rows; k++) { + if (individual_map[k] != TSK_NULL) { + tsk_individual_table_get_row_unsafe(&tables.individuals, k, &ind); + num_parents = 0; + for (j = 0; j < (tsk_id_t) ind.parents_length; j++) { + parent_ind = ind.parents[j]; + new_parent = parent_ind; + if (parent_ind != TSK_NULL) { + new_parent = individual_map[parent_ind]; + } + if ((parent_ind == TSK_NULL) || (new_parent != TSK_NULL)) { + /* Beware: this modifies the parents column of tables.individuals + * in-place! But it's OK as we don't use it again. */ + ind.parents[num_parents] = new_parent; + num_parents++; + } + } + ret = tsk_individual_table_add_row(&self->individuals, ind.flags, + ind.location, ind.location_length, ind.parents, num_parents, + ind.metadata, ind.metadata_length); + if (ret < 0) { + goto out; + } + tsk_bug_assert(individual_map[k] == ret); + } + } + + // Nodes and populations for (k = 0; k < (tsk_id_t) num_nodes; k++) { ret = tsk_table_collection_add_and_remap_node( self, &tables, nodes[k], individual_map, population_map, node_map, true); @@ -9792,18 +9999,7 @@ tsk_table_collection_subset(tsk_table_collection_t *self, const tsk_id_t *nodes, } if (keep_unreferenced) { - // Keep unused individuals and populations - for (k = 0; k < (tsk_id_t) tables.individuals.num_rows; k++) { - if (individual_map[k] == TSK_NULL) { - tsk_individual_table_get_row_unsafe(&tables.individuals, k, &ind); - ret = tsk_individual_table_add_row(&self->individuals, ind.flags, - ind.location, ind.location_length, ind.parents, ind.parents_length, - ind.metadata, ind.metadata_length); - if (ret < 0) { - goto out; - } - } - } + // Keep unused populations for (k = 0; k < (tsk_id_t) tables.populations.num_rows; k++) { if (population_map[k] == TSK_NULL) { tsk_population_table_get_row_unsafe(&tables.populations, k, &pop); @@ -9881,13 +10077,6 @@ tsk_table_collection_subset(tsk_table_collection_t *self, const tsk_id_t *nodes, } } - /* Rewrite the individual parent ids */ - for (k = 0; k < (tsk_id_t) self->individuals.parents_length; k++) { - if (self->individuals.parents[k] != TSK_NULL) { - self->individuals.parents[k] = individual_map[self->individuals.parents[k]]; - } - } - ret = 0; out: tsk_safe_free(node_map); @@ -10115,9 +10304,7 @@ tsk_table_collection_union(tsk_table_collection_t *self, goto out; } - // provenance (new record is added in python) - - // deduplicating, sorting, and computing parents + // sorting, deduplicating, and computing parents ret = tsk_table_collection_sort(self, 0, 0); if (ret < 0) { goto out; diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 89e4831dd2..34d9f555ea 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -637,6 +637,8 @@ typedef struct _tsk_table_sorter_t { int (*sort_edges)(struct _tsk_table_sorter_t *self, tsk_size_t start); /** @brief The mutation sorting function. */ int (*sort_mutations)(struct _tsk_table_sorter_t *self); + /** @brief The individual sorting function. */ + int (*sort_individuals)(struct _tsk_table_sorter_t *self); /** @brief An opaque pointer for use by client code */ void *user_data; /** @brief Mapping from input site IDs to output site IDs */ @@ -2765,14 +2767,18 @@ they appear in the ``nodes`` argument. Specifically, this subsets and reorders each of the tables as follows (but see options, below): 1. Nodes: if in the list of nodes, and in the order provided. -2. Individuals and Populations: if referred to by a retained node, and in the - order first seen when traversing the list of retained nodes. -3. Edges: if both parent and child are retained nodes. -4. Mutations: if the mutation's node is a retained node. -5. Sites: if any mutations remain at the site after removing mutations. - -Retained edges, mutations, and sites appear in the same -order as in the original tables. +2. Individuals: if referred to by a retained node. +3. Populations: if referred to by a retained node, and in the order first seen + when traversing the list of retained nodes. +4. Edges: if both parent and child are retained nodes. +5. Mutations: if the mutation's node is a retained node. +6. Sites: if any mutations remain at the site after removing mutations. + +Retained individuals, edges, mutations, and sites appear in the same +order as in the original tables. Note that only the information *directly* +associated with the provided nodes is retained - for instance, +subsetting to nodes=[A, B] does not retain nodes ancestral to A and B, +and only retains the individuals A and B are in, and not their parents. This function does *not* require the tables to be sorted. @@ -2787,10 +2793,9 @@ TSK_NO_CHANGE_POPULATIONS TSK_KEEP_UNREFERENCED If this flag is provided, then unreferenced sites, individuals, and populations - will not be removed. If so, the site table will not be changed, - unreferenced individuals will be placed last, in their original order, and - (unless TSK_NO_CHANGE_POPULATIONS is also provided), unreferenced - populations will also be placed last, in their original order. + will not be removed. If so, the site and individual tables will not be changed, + and (unless TSK_NO_CHANGE_POPULATIONS is also provided) unreferenced + populations will be placed last, in their original order. .. note:: Migrations are currently not supported by susbset, and an error will be raised if we attempt call subset on a table collection with greater diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 860a8f9bcf..dd999fe65d 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -1420,15 +1420,15 @@ def verify_canonical_equality(self, tables, seed): tables.migrations.clear() for ru in [True, False]: - tables1 = tables.copy() + tsk_tables = tables.copy() tsutil.shuffle_tables( - tables1, + tsk_tables, seed, ) - tables2 = tables1.copy() - tables1.canonicalise(remove_unreferenced=ru) - tsutil.py_canonicalise(tables2, remove_unreferenced=ru) - tsutil.assert_table_collections_equal(tables1, tables2) + py_tables = tsk_tables.copy() + tsk_tables.canonicalise(remove_unreferenced=ru) + tsutil.py_canonicalise(py_tables, remove_unreferenced=ru) + tsutil.assert_table_collections_equal(tsk_tables, py_tables) def verify_sort_mutation_consistency(self, orig_tables, seed): tables = orig_tables.copy() @@ -1669,6 +1669,15 @@ def test_no_mutation_parents(self): self.verify_sort_equality(t, 985) self.verify_sort_mutation_consistency(t, 985) + def test_stable_individual_order(self): + # canonical should retain individual order lacking any other information + tables = tskit.TableCollection(sequence_length=100) + for a in "arbol": + tables.individuals.add_row(metadata=a.encode()) + tables2 = tables.copy() + tables2.canonicalise(remove_unreferenced=False) + tsutil.assert_table_collections_equal(tables, tables2) + def test_discrete_times(self): ts = self.get_wf_example(seed=623) ts = tsutil.insert_discrete_time_mutations(ts) @@ -3446,7 +3455,6 @@ def get_wf_example(self, N=5, ngens=2, seed=1249): tables = wf.wf_sim(N, N, num_pops=2, seed=seed) tables.sort() ts = tables.tree_sequence() - # adding muts ts = tsutil.jukes_cantor(ts, 1, 10, seed=seed) ts = tsutil.add_random_metadata(ts, seed) ts = tsutil.insert_random_ploidy_individuals(ts, max_ploidy=2) @@ -3459,22 +3467,22 @@ def get_examples(self, seed): def verify_subset_equality(self, tables, nodes): for rp in [True, False]: for ru in [True, False]: - sub1 = tables.copy() - sub2 = tables.copy() + py_sub = tables.copy() + tsk_sub = tables.copy() tsutil.py_subset( - sub1, + py_sub, nodes, record_provenance=False, reorder_populations=rp, remove_unreferenced=ru, ) - sub2.subset( + tsk_sub.subset( nodes, record_provenance=False, reorder_populations=rp, remove_unreferenced=ru, ) - tsutil.assert_table_collections_equal(sub1, sub2) + tsutil.assert_table_collections_equal(py_sub, tsk_sub) def verify_subset(self, tables, nodes): self.verify_subset_equality(tables, nodes) @@ -3492,6 +3500,7 @@ def verify_subset(self, tables, nodes): indivs.append(ind) if pop not in pops and pop != tskit.NULL: pops.append(pop) + indivs.sort() # keep individuals in the same order ind_map = np.repeat(tskit.NULL, tables.individuals.num_rows + 1) ind_map[indivs] = np.arange(len(indivs), dtype="int32") pop_map = np.repeat(tskit.NULL, tables.populations.num_rows + 1) @@ -3507,7 +3516,13 @@ def verify_subset(self, tables, nodes): assert subset.individuals.num_rows == len(indivs) for k, i in zip(indivs, subset.individuals): ii = tables.individuals[k] - assert ii == i + assert np.all(np.equal(ii.location, i.location)) + assert ii.metadata == i.metadata + sub_parents = [] + for p in ii.parents: + if p == tskit.NULL or ind_map[p] != tskit.NULL: + sub_parents.append(ind_map[p]) + assert np.all(np.equal(sub_parents, i.parents)) assert subset.populations.num_rows == len(pops) for k, p in zip(pops, subset.populations): pp = tables.populations[k] @@ -3603,6 +3618,13 @@ def test_shuffled_tables(self): assert tables2.sites.num_rows == 0 assert tables2.mutations.num_rows == 0 + def test_doesnt_reorder_individuals(self): + tables = wf.wf_sim(N=5, ngens=5, num_pops=2, seed=123) + tsutil.shuffle_tables(tables, 7000) + tables2 = tables.copy() + tables2.subset(np.arange(tables.nodes.num_rows)) + assert tables.individuals == tables2.individuals + def test_random_subsets(self): rng = np.random.default_rng(1542) for tables in self.get_examples(9412): @@ -3706,14 +3728,19 @@ def split_example(self, ts, T): tables1.metadata = {"hello": "world"} return tables1, tables2, node_mapping + def verify_union(self, tables, other, node_mapping, add_populations=True): + self.verify_union_consistency(tables, other, node_mapping) + self.verify_union_equality( + tables, other, node_mapping, add_populations=add_populations + ) + def verify_union_equality(self, tables, other, node_mapping, add_populations=True): - # verifying against py impl uni1 = tables.copy() uni2 = tables.copy() uni1.union( other, node_mapping, - record_provenance=True, + record_provenance=False, add_populations=add_populations, ) tsutil.py_union( @@ -3731,13 +3758,147 @@ def verify_union_equality(self, tables, other, node_mapping, add_populations=Tru tables.subset(orig_nodes) tsutil.assert_table_collections_equal(uni1, tables, ignore_provenance=True) + def verify_union_consistency(self, tables, other, node_mapping): + ts1 = tsutil.insert_unique_metadata(tables) + ts2 = tsutil.insert_unique_metadata(other, offset=1000000) + tsu = ts1.union(ts2, node_mapping, check_shared_equality=False) + mapu = tsutil.metadata_map(tsu) + for j, n1 in enumerate(ts1.nodes()): + # nodes in ts1 should be preserved, in the same order + nu = tsu.node(j) + assert n1.metadata == nu.metadata + if n1.individual == tskit.NULL: + assert nu.individual == tskit.NULL + else: + assert ( + ts1.individual(n1.individual).metadata + == tsu.individual(nu.individual).metadata + ) + for j, k in enumerate(node_mapping): + # nodes in ts2 should match if they are not in node mapping + if k == tskit.NULL: + n2 = ts2.node(j) + md2 = n2.metadata + assert md2 in mapu["nodes"] + nu = tsu.node(mapu["nodes"][md2]) + if n2.individual == tskit.NULL: + assert nu.individual == tskit.NULL + else: + assert ( + ts2.individual(n2.individual).metadata + == tsu.individual(nu.individual).metadata + ) + for e1 in ts1.edges(): + # relationships between nodes in ts1 should be preserved + p1, c1 = e1.parent, e1.child + assert e1.metadata in mapu["edges"] + eu = tsu.edge(mapu["edges"][e1.metadata]) + pu, cu = eu.parent, eu.child + assert ts1.node(p1).metadata == tsu.node(pu).metadata + assert ts1.node(c1).metadata == tsu.node(cu).metadata + for e2 in ts2.edges(): + # relationships between nodes in ts2 should be preserved + # if both are new nodes + p2, c2 = e2.parent, e2.child + if node_mapping[p2] == tskit.NULL and node_mapping[c2] == tskit.NULL: + assert e2.metadata in mapu["edges"] + eu = tsu.edge(mapu["edges"][e2.metadata]) + pu, cu = eu.parent, eu.child + assert ts2.node(p2).metadata == tsu.node(pu).metadata + assert ts2.node(c2).metadata == tsu.node(cu).metadata + + for i1 in ts1.individuals(): + # individuals in ts1 should be preserved + assert i1.metadata in mapu["individuals"] + iu = tsu.individual(mapu["individuals"][i1.metadata]) + assert len(i1.parents) == len(iu.parents) + for p1, pu in zip(i1.parents, iu.parents): + if p1 == tskit.NULL: + assert pu == tskit.NULL + else: + assert ts1.individual(p1).metadata == tsu.individual(pu).metadata + # how should individual metadata from ts2 map to ts1 + # and only individuals without shared nodes should be added + indivs21 = {} + new_indivs2 = [True for _ in ts2.individuals()] + for j, k in enumerate(node_mapping): + n = ts2.node(j) + if n.individual != tskit.NULL: + i2 = ts2.individual(n.individual) + if k == tskit.NULL: + indivs21[i2.metadata] = i2.metadata + else: + new_indivs2[n.individual] = False + i1 = ts1.individual(ts1.node(k).individual) + if i2.metadata in indivs21: + assert indivs21[i2.metadata] == i1.metadata + else: + indivs21[i2.metadata] = i1.metadata + for i2 in ts2.individuals(): + if new_indivs2[i2.id]: + assert i2.metadata in mapu["individuals"] + iu = tsu.individual(mapu["individuals"][i2.metadata]) + assert np.sum(i2.parents == tskit.NULL) == np.sum( + iu.parents == tskit.NULL + ) + md2 = [ + ts2.individual(p).metadata for p in i2.parents if p != tskit.NULL + ] + md2u = [indivs21[md] for md in md2] + mdu = [ + tsu.individual(p).metadata for p in iu.parents if p != tskit.NULL + ] + assert set(md2u) == set(mdu) + else: + # the individual *should* be there, but by a different name + assert i2.metadata not in mapu["individuals"] + assert indivs21[i2.metadata] in mapu["individuals"] + for m1 in ts1.mutations(): + # all mutations in ts1 should be present + assert m1.metadata in mapu["mutations"] + mu = tsu.mutation(mapu["mutations"][m1.metadata]) + assert m1.derived_state == mu.derived_state + assert m1.node == mu.node + if tskit.is_unknown_time(m1.time): + assert tskit.is_unknown_time(mu.time) + else: + assert m1.time == mu.time + assert ts1.site(m1.site).position == tsu.site(mu.site).position + for m2 in ts2.mutations(): + # and those in ts2 if their node has been added + if node_mapping[m2.node] == tskit.NULL: + assert m2.metadata in mapu["mutations"] + mu = tsu.mutation(mapu["mutations"][m2.metadata]) + assert m2.derived_state == mu.derived_state + assert ts2.node(m2.node).metadata == tsu.node(mu.node).metadata + if tskit.is_unknown_time(m2.time): + assert tskit.is_unknown_time(mu.time) + else: + assert m2.time == mu.time + assert ts2.site(m2.site).position == tsu.site(mu.site).position + for s1 in ts1.sites(): + assert s1.metadata in mapu["sites"] + su = tsu.site(mapu["sites"][s1.metadata]) + assert s1.position == su.position + assert s1.ancestral_state == su.ancestral_state + for s2 in ts2.sites(): + if s2.position not in ts1.tables.sites.position: + assert s2.metadata in mapu["sites"] + su = tsu.site(mapu["sites"][s2.metadata]) + assert s2.position == su.position + assert s2.ancestral_state == su.ancestral_state + # check mutation parents + tables_union = tsu.tables + tables_union.compute_mutation_parents() + assert tables_union.mutations == tsu.tables.mutations + def test_union_empty(self): - ts1 = self.get_msprime_example(sample_size=3, T=2, seed=9328) - ts2 = tskit.TableCollection(sequence_length=ts1.sequence_length).tree_sequence() - uni = ts1.union(ts2, []) - tsutil.assert_table_collections_equal( - ts1.tables, uni.tables, ignore_provenance=True - ) + tables = self.get_msprime_example(sample_size=3, T=2, seed=9328).dump_tables() + tables.sort() + empty_tables = tskit.TableCollection(sequence_length=tables.sequence_length) + uni = tables.copy() + uni.union(empty_tables, []) + tsutil.assert_table_collections_equal(tables, uni, ignore_provenance=True) def test_noshared_example(self): ts1 = self.get_msprime_example(sample_size=3, T=2, seed=9328) @@ -3750,17 +3911,18 @@ def test_noshared_example(self): def test_all_shared_example(self): tables = self.get_wf_example(N=5, T=5, seed=11349).dump_tables() + tables.sort() uni = tables.copy() node_mapping = np.arange(tables.nodes.num_rows) uni.union(tables, node_mapping, record_provenance=False) - assert uni == tables + tsutil.assert_table_collections_equal(uni, tables) def test_no_add_pop(self): - self.verify_union_equality( + self.verify_union( *self.split_example(self.get_msprime_example(10, 10, seed=135), 10), add_populations=False, ) - self.verify_union_equality( + self.verify_union( *self.split_example(self.get_wf_example(10, 10, seed=157), 10), add_populations=False, ) @@ -3794,13 +3956,13 @@ def test_examples(self): tables = ts.tables tables.compute_mutation_times() ts = tables.tree_sequence() - self.verify_union_equality(*self.split_example(ts, T)) + self.verify_union(*self.split_example(ts, T)) ts = self.get_wf_example(N=N, T=T, seed=827) if mut_times: tables = ts.tables tables.compute_mutation_times() ts = tables.tree_sequence() - self.verify_union_equality(*self.split_example(ts, T)) + self.verify_union(*self.split_example(ts, T)) class TestSubsetUnion: diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 40a395bfc9..d940fda7cf 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -5790,7 +5790,7 @@ def verify(self, ts): self.verify_keep_input_roots(ts, samples) def verify_keep_input_roots(self, ts, samples): - ts = tsutil.insert_unique_metadata(ts, "individuals") + ts = tsutil.insert_unique_metadata(ts, ["individuals"]) ts_with_roots, node_map = self.do_simplify( ts, samples, keep_input_roots=True, filter_sites=False, compare_lib=True ) diff --git a/python/tests/test_vcf.py b/python/tests/test_vcf.py index 3dfd0a0d95..725cd8cc1b 100644 --- a/python/tests/test_vcf.py +++ b/python/tests/test_vcf.py @@ -199,7 +199,7 @@ class ExamplesMixin: def test_simple_infinite_sites_random_ploidy(self): ts = msprime.simulate(10, mutation_rate=1, random_seed=2) - ts = tsutil.insert_random_ploidy_individuals(ts) + ts = tsutil.insert_random_ploidy_individuals(ts, min_ploidy=1) assert ts.num_sites > 2 self.verify(ts) @@ -226,7 +226,7 @@ def test_simple_infinite_sites_ploidy_2_even_samples(self): def test_simple_jukes_cantor_random_ploidy(self): ts = msprime.simulate(10, random_seed=2) ts = tsutil.jukes_cantor(ts, num_sites=10, mu=1, seed=2) - ts = tsutil.insert_random_ploidy_individuals(ts) + ts = tsutil.insert_random_ploidy_individuals(ts, min_ploidy=1) self.verify(ts) def test_single_tree_multichar_mutations(self): diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 1f6cc3786f..e24852bd79 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -219,10 +219,13 @@ def insert_multichar_mutations(ts, seed=1, max_len=10): return tables.tree_sequence() -def insert_random_ploidy_individuals(ts, max_ploidy=5, max_dimension=3, seed=1): +def insert_random_ploidy_individuals( + ts, min_ploidy=0, max_ploidy=5, max_dimension=3, seed=1 +): """ Takes random contiguous subsets of the samples an assigns them to individuals. - Also creates random locations in variable dimensions in the unit interval. + Also creates random locations in variable dimensions in the unit interval, + and assigns random parents (including NULL parents). """ rng = random.Random(seed) samples = np.array(ts.samples(), dtype=int) @@ -231,12 +234,14 @@ def insert_random_ploidy_individuals(ts, max_ploidy=5, max_dimension=3, seed=1): tables.individuals.clear() individual = tables.nodes.individual[:] individual[:] = tskit.NULL + ind_id = -1 while j < len(samples): - ploidy = rng.randint(0, max_ploidy) + ploidy = rng.randint(min_ploidy, max_ploidy) nodes = samples[j : min(j + ploidy, len(samples))] dimension = rng.randint(0, max_dimension) location = [rng.random() for _ in range(dimension)] - ind_id = tables.individuals.add_row(location=location) + parents = rng.sample(range(-1, 1 + ind_id), min(1 + ind_id, rng.randint(0, 3))) + ind_id = tables.individuals.add_row(location=location, parents=parents) individual[nodes] = ind_id j += ploidy tables.nodes.individual = individual @@ -649,14 +654,34 @@ def py_subset( for j, pop in enumerate(full.populations): pop_map[j] = j tables.populations.add_row(metadata=pop.metadata) - for old_id in nodes: - node = full.nodes[old_id] - if node.individual not in ind_map and node.individual != tskit.NULL: - ind = full.individuals[node.individual] + # first build individual map + if not remove_unreferenced: + keep_ind = [True for _ in full.individuals] + else: + keep_ind = [False for _ in full.individuals] + for old_id in nodes: + i = full.nodes[old_id].individual + if i != tskit.NULL: + keep_ind[i] = True + new_ind_id = 0 + for j, k in enumerate(keep_ind): + if k: + ind_map[j] = new_ind_id + new_ind_id += 1 + # now the individual table + for j, k in enumerate(keep_ind): + if k: + ind = full.individuals[j] new_ind_id = tables.individuals.add_row( - ind.flags, ind.location, ind.parents, ind.metadata + ind.flags, + ind.location, + [ind_map[i] for i in ind.parents if i in ind_map], + ind.metadata, ) - ind_map[node.individual] = new_ind_id + assert new_ind_id == ind_map[j] + + for old_id in nodes: + node = full.nodes[old_id] if node.population not in pop_map and node.population != tskit.NULL: pop = full.populations[node.population] new_pop_id = tables.populations.add_row(pop.metadata) @@ -670,14 +695,6 @@ def py_subset( ) node_map[old_id] = new_id if not remove_unreferenced: - for j, ind in enumerate(full.individuals): - if j not in ind_map: - ind_map[j] = tables.individuals.add_row( - ind.flags, - location=ind.location, - parents=ind.parents, - metadata=ind.metadata, - ) for j, ind in enumerate(full.populations): if j not in pop_map: pop_map[j] = tables.populations.add_row(ind.metadata) @@ -763,10 +780,12 @@ def py_union(tables, other, nodes, record_provenance=True, add_populations=True) ) node_map[other_id] = node_id individuals = tables.individuals + new_parents = individuals.parents for i in range( individuals.parents_offset[original_num_individuals], len(individuals.parents) ): - individuals.parents[i] = ind_map[individuals.parents[i]] + new_parents[i] = ind_map[individuals.parents[i]] + individuals.parents = new_parents for edge in other.edges: if (nodes[edge.parent] == tskit.NULL) or (nodes[edge.child] == tskit.NULL): tables.edges.add_row( @@ -1041,6 +1060,22 @@ def cmp_migration(i, j, tables): return ret +def cmp_individual_canonical(i, j, tables, num_descendants): + ret = num_descendants[j] - num_descendants[i] + if ret == 0: + node_i = node_j = tables.nodes.num_rows + ni = np.where(tables.nodes.individual == i)[0] + if len(ni) > 0: + node_i = np.min(ni) + nj = np.where(tables.nodes.individual == j)[0] + if len(nj) > 0: + node_j = np.min(nj) + ret = node_i - node_j + if ret == 0: + ret = i - j + return ret + + def compute_mutation_num_descendants(tables): mutations = tables.mutations num_descendants = np.zeros(mutations.num_rows) @@ -1051,16 +1086,61 @@ def compute_mutation_num_descendants(tables): return num_descendants +def compute_individual_num_descendants(tables): + # adapted from sort_individual_table + individuals = tables.individuals + num_individuals = individuals.num_rows + num_descendants = np.zeros((num_individuals,), np.int64) + + # First find the set of individuals that have no children + # by creating an array of incoming edge counts + incoming_edge_count = np.zeros((num_individuals,), np.int64) + for parent in individuals.parents: + if parent != tskit.NULL: + incoming_edge_count[parent] += 1 + todo = np.full((num_individuals + 1,), -1, np.int64) + current_todo = 0 + todo_insertion_point = 0 + for individual, num_edges in enumerate(incoming_edge_count): + if num_edges == 0: + todo[todo_insertion_point] = individual + todo_insertion_point += 1 + + # Now process individuals from the set that have no children, updating their + # parents' information as we go, and adding their parents to the list if + # this was their last child + while todo[current_todo] != -1: + individual = todo[current_todo] + current_todo += 1 + for parent in individuals.parents[ + individuals.parents_offset[individual] : individuals.parents_offset[ + individual + 1 + ] + ]: + if parent != tskit.NULL: + incoming_edge_count[parent] -= 1 + num_descendants[parent] += 1 + num_descendants[individual] + if incoming_edge_count[parent] == 0: + todo[todo_insertion_point] = parent + todo_insertion_point += 1 + + if num_individuals > 0: + assert np.min(incoming_edge_count) >= 0 + if np.max(incoming_edge_count) > 0: + raise ValueError("Individual pedigree has cycles") + return num_descendants + + def py_canonicalise(tables, remove_unreferenced=True): tables.subset( np.arange(tables.nodes.num_rows), record_provenance=False, remove_unreferenced=remove_unreferenced, ) - py_sort(tables, use_num_descendants=True) + py_sort(tables, canonical=True) -def py_sort(tables, use_num_descendants=False): +def py_sort(tables, canonical=False): copy = tables.copy() tables.edges.clear() tables.sites.clear() @@ -1072,15 +1152,15 @@ def py_sort(tables, use_num_descendants=False): sorted_sites = sorted(range(copy.sites.num_rows), key=site_key) site_id_map = {k: j for j, k in enumerate(sorted_sites)} site_order = np.argsort(sorted_sites) - if use_num_descendants: - num_descendants = compute_mutation_num_descendants(copy) + if canonical: + mut_num_descendants = compute_mutation_num_descendants(copy) mut_key = functools.cmp_to_key( lambda a, b: cmp_mutation_canonical( a, b, tables=copy, site_order=site_order, - num_descendants=num_descendants, + num_descendants=mut_num_descendants, ) ) else: @@ -1130,7 +1210,31 @@ def py_sort(tables, use_num_descendants=False): copy.migrations[mig_id].metadata, ) - sort_individual_table(tables) + # individuals + if canonical: + tables.individuals.clear() + ind_num_descendants = compute_individual_num_descendants(copy) + ind_key = functools.cmp_to_key( + lambda a, b: cmp_individual_canonical( + a, + b, + tables=copy, + num_descendants=ind_num_descendants, + ) + ) + sorted_inds = sorted(range(copy.individuals.num_rows), key=ind_key) + ind_id_map = {k: j for j, k in enumerate(sorted_inds)} + ind_id_map[tskit.NULL] = tskit.NULL + for ind_id in sorted_inds: + tables.individuals.add_row( + flags=copy.individuals[ind_id].flags, + location=copy.individuals[ind_id].location, + parents=[ind_id_map[p] for p in copy.individuals[ind_id].parents], + metadata=copy.individuals[ind_id].metadata, + ) + tables.nodes.individual = [ind_id_map[i] for i in tables.nodes.individual] + else: + sort_individual_table(tables) def algorithm_T(ts): @@ -1689,15 +1793,15 @@ def assert_tables_equal(t1, t2, label=""): if hasattr(t1, "metadata_schema"): if t1.metadata_schema != t2.metadata_schema: msg = ( - f"{label} :::::::::: t1 ::::::::::::\n{t1.metadata_schema}" - f"{label} :::::::::: t2 ::::::::::::\n{t1.metadata_schema}" + f"\n{label} :::::::::: t1 ::::::::::::\n{t1.metadata_schema}" + f"\n{label} :::::::::: t2 ::::::::::::\n{t2.metadata_schema}" ) raise AssertionError(msg) for k, (e1, e2) in enumerate(zip(t1, t2)): if e1 != e2: msg = ( - f"{label} :::::::::: t1 (row {k}) ::::::::::::\n{e1}" - f"{label} :::::::::: t2 (row {k}) ::::::::::::\n{e2}" + f"\n{label} :::::::::: t1 (row {k}) ::::::::::::\n{e1}" + f"\n{label} :::::::::: t2 (row {k}) ::::::::::::\n{e2}" ) raise AssertionError(msg) if t1.num_rows != t2.num_rows: @@ -1759,9 +1863,44 @@ def sort_individual_table(tables): return tables -def insert_unique_metadata(ts, table): - tables = ts.dump_tables() - getattr(tables, table).packset_metadata( - [struct.pack("I", i) for i in range(getattr(tables, table).num_rows)] - ) +def insert_unique_metadata(tables, table=None, offset=0): + if isinstance(tables, tskit.TreeSequence): + tables = tables.dump_tables() + else: + tables = tables.copy() + if table is None: + table = [ + "populations", + "individuals", + "nodes", + "edges", + "sites", + "mutations", + "migrations", + ] + for t in table: + getattr(tables, t).packset_metadata( + [struct.pack("I", offset + i) for i in range(getattr(tables, t).num_rows)] + ) return tables.tree_sequence() + + +def metadata_map(tables): + # builds a mapping from metadata (as produced by insert_unique_metadata) + # to ID for all the tables (except provenance) + if isinstance(tables, tskit.TreeSequence): + tables = tables.dump_tables() + out = {} + for t in [ + "populations", + "individuals", + "nodes", + "edges", + "sites", + "mutations", + "migrations", + ]: + out[t] = {} + for j, x in enumerate(getattr(tables, t)): + out[t][x.metadata] = j + return out diff --git a/python/tskit/metadata.py b/python/tskit/metadata.py index 46bc297b83..73de443872 100644 --- a/python/tskit/metadata.py +++ b/python/tskit/metadata.py @@ -528,7 +528,7 @@ def decode(self, encoded: bytes) -> Any: def validate_bytes(data: Optional[bytes]) -> None: if data is not None and not isinstance(data, bytes): raise TypeError( - f"If no encoding is set metadata should be bytes, found {type(bytes)}" + f"If no encoding is set metadata should be bytes, found {type(data)}" )