From 3b1098bc1f56ed593e4c4c1d8d3e34ff765b6081 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Wed, 10 Mar 2021 15:00:09 +0000 Subject: [PATCH 1/2] Remap individuals in union --- c/tests/test_tables.c | 9 ++++++--- c/tskit/tables.c | 22 ++++++++++++++++++++++ python/tests/tsutil.py | 11 +++++++++++ 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 6d1f9abe58..2988a4693a 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -5629,6 +5629,7 @@ test_table_collection_union(void) tsk_table_collection_t tables_empty; tsk_table_collection_t tables_copy; tsk_id_t node_mapping[3]; + tsk_id_t parents[2] = { -1, -1 }; char example_metadata[100] = "An example of metadata with unicode 🎄🌳🌴🌲🎋"; tsk_size_t example_metadata_length = (tsk_size_t) strlen(example_metadata); @@ -5662,13 +5663,15 @@ test_table_collection_union(void) ret = tsk_node_table_add_row(&tables.nodes, TSK_NODE_IS_SAMPLE, 0.5, 1, 2, NULL, 0); CU_ASSERT_FATAL(ret >= 0); ret = tsk_individual_table_add_row( - &tables.individuals, 0, NULL, 0, NULL, 0, NULL, 0); + &tables.individuals, 0, NULL, 0, parents, 2, NULL, 0); CU_ASSERT_FATAL(ret >= 0); + parents[0] = 0; ret = tsk_individual_table_add_row( - &tables.individuals, 0, NULL, 0, NULL, 0, NULL, 0); + &tables.individuals, 0, NULL, 0, parents, 2, NULL, 0); CU_ASSERT_FATAL(ret >= 0); + parents[1] = 1; ret = tsk_individual_table_add_row( - &tables.individuals, 0, NULL, 0, NULL, 0, NULL, 0); + &tables.individuals, 0, NULL, 0, parents, 2, NULL, 0); CU_ASSERT_FATAL(ret >= 0); ret = tsk_population_table_add_row(&tables.populations, NULL, 0); CU_ASSERT_FATAL(ret >= 0); diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 81666ad515..cd42f6f41e 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -9975,6 +9975,7 @@ tsk_table_collection_union(tsk_table_collection_t *self, int ret = 0; tsk_id_t k, i, new_parent, new_child; tsk_size_t num_shared_nodes = 0; + tsk_size_t num_individuals_self = self->individuals.num_rows; tsk_edge_t edge; tsk_mutation_t mut; tsk_site_t site; @@ -10027,6 +10028,18 @@ tsk_table_collection_union(tsk_table_collection_t *self, memset(population_map, 0xff, other->populations.num_rows * sizeof(*population_map)); memset(site_map, 0xff, other->sites.num_rows * sizeof(*site_map)); + /* We have to map the individuals who are linked to nodes in the intersection first + as otherwise an individual linked to one node in the intersection and one in + `other` would be duplicated. We assume that the individual in `self` takes + priority. + */ + for (k = 0; k < (tsk_id_t) other->nodes.num_rows; k++) { + if (other_node_mapping[k] != TSK_NULL + && other->nodes.individual[k] != TSK_NULL) { + individual_map[other->nodes.individual[k]] + = self->nodes.individual[other_node_mapping[k]]; + } + } // nodes, individuals, populations for (k = 0; k < (tsk_id_t) other->nodes.num_rows; k++) { if (other_node_mapping[k] != TSK_NULL) { @@ -10040,6 +10053,15 @@ tsk_table_collection_union(tsk_table_collection_t *self, } } + /* Now we know the full individual map we can remap the parents of the new + * individuals*/ + for (k = (tsk_id_t) self->individuals.parents_offset[num_individuals_self]; + k < (tsk_id_t) self->individuals.parents_length; k++) { + if (self->individuals.parents[k] != TSK_NULL) { + self->individuals.parents[k] = individual_map[self->individuals.parents[k]]; + } + } + // edges for (k = 0; k < (tsk_id_t) other->edges.num_rows; k++) { tsk_edge_table_get_row_unsafe(&other->edges, k, &edge); diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index f78463b1b6..5c4d99b5c5 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -727,6 +727,12 @@ def py_union(tables, other, nodes, record_provenance=True, add_populations=True) node_map = [tskit.NULL for _ in range(other.nodes.num_rows + 1)] site_map = [tskit.NULL for _ in range(other.sites.num_rows + 1)] mut_map = [tskit.NULL for _ in range(other.mutations.num_rows + 1)] + original_num_individuals = tables.individuals.num_rows + + for other_id, node in enumerate(other.nodes): + if nodes[other_id] != tskit.NULL and node.individual != tskit.NULL: + ind_map[node.individual] = tables.nodes[nodes[other_id]].individual + for other_id, node in enumerate(other.nodes): if nodes[other_id] != tskit.NULL: node_map[other_id] = nodes[other_id] @@ -755,6 +761,11 @@ def py_union(tables, other, nodes, record_provenance=True, add_populations=True) flags=node.flags, ) node_map[other_id] = node_id + individuals = tables.individuals + for i in range( + individuals.parents_offset[original_num_individuals], len(individuals.parents) + ): + individuals.parents[i] = ind_map[individuals.parents[i]] for edge in other.edges: if (nodes[edge.parent] == tskit.NULL) or (nodes[edge.child] == tskit.NULL): tables.edges.add_row( From 701112a32bc64cbbb87f59f98a3a71abec53f051 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 12 Mar 2021 13:58:20 +0000 Subject: [PATCH 2/2] Default to always record individuals in WF sim --- python/tests/data/svg/ts_multiroot.svg | 144 ++++++++++++------------- python/tests/test_tables.py | 4 - python/tests/test_topology.py | 9 +- python/tests/test_wright_fisher.py | 23 ++-- python/tests/tsutil.py | 9 ++ 5 files changed, 97 insertions(+), 92 deletions(-) diff --git a/python/tests/data/svg/ts_multiroot.svg b/python/tests/data/svg/ts_multiroot.svg index 37b0aa01c4..1140fa2b57 100644 --- a/python/tests/data/svg/ts_multiroot.svg +++ b/python/tests/data/svg/ts_multiroot.svg @@ -120,23 +120,23 @@ - + 0 - - - + + + 2 - + 4 - + 5 @@ -150,13 +150,13 @@ 6 - - + + 1 - + 3 @@ -172,18 +172,18 @@ - - + + 2 - + 4 - + 5 @@ -191,19 +191,19 @@ 6 - - + + 0 - - + + 1 - + 3 @@ -219,23 +219,23 @@ - + 0 - - - + + + 2 - + 4 - + @@ -254,13 +254,13 @@ 6 - - + + 1 - + 3 @@ -276,13 +276,13 @@ - + 0 - - - + + + @@ -292,7 +292,7 @@ 2 - + @@ -302,7 +302,7 @@ 4 - + 5 @@ -311,13 +311,13 @@ 6 - - + + 1 - + 3 @@ -333,13 +333,13 @@ - - + + 1 - + 3 @@ -347,25 +347,25 @@ 7 - - + + 0 - - + + 5 - - + + 2 - + 4 @@ -385,8 +385,8 @@ - - + + @@ -396,7 +396,7 @@ 2 - + 4 @@ -404,13 +404,13 @@ 6 - - + + 1 - + 3 @@ -418,13 +418,13 @@ 7 - - + + 0 - + 5 @@ -436,18 +436,18 @@ - - + + 2 - + 4 - + 3 @@ -455,17 +455,17 @@ 6 - + 1 - - + + 0 - + 5 @@ -477,18 +477,18 @@ - - + + 2 - + 4 - + 3 @@ -496,19 +496,19 @@ 6 - - + + 1 - - + + 0 - + 5 diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 1ada3ac795..860a8f9bcf 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -1532,7 +1532,6 @@ def get_wf_example(self, seed): seed=seed, num_loci=3, record_migrations=True, - record_individuals=True, ) tables.sort() ts = tables.tree_sequence() @@ -1547,7 +1546,6 @@ def test_wf_example(self): deep_history=False, seed=42, record_migrations=True, - record_individuals=True, ) self.verify_sort(tables, 42) @@ -2405,7 +2403,6 @@ def wf_sim_with_individual_metadata(self): deep_history=False, initial_generation_samples=False, num_loci=5, - record_individuals=True, ) assert tables.individuals.num_rows > 50 assert np.all(tables.nodes.individual >= 0) @@ -2502,7 +2499,6 @@ def test_individual_mapping(self): deep_history=True, initial_generation_samples=False, num_loci=5, - record_individuals=True, ) assert tables.individuals.num_rows > 50 node_md = [] diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 1d036b5735..40a395bfc9 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -5790,6 +5790,7 @@ def verify(self, ts): self.verify_keep_input_roots(ts, samples) def verify_keep_input_roots(self, ts, samples): + ts = tsutil.insert_unique_metadata(ts, "individuals") ts_with_roots, node_map = self.do_simplify( ts, samples, keep_input_roots=True, filter_sites=False, compare_lib=True ) @@ -5807,7 +5808,13 @@ def verify_keep_input_roots(self, ts, samples): new_node = ts_with_roots.node(root) assert new_node.time == input_node.time assert new_node.population == input_node.population - assert new_node.individual == input_node.individual + if new_node.individual == tskit.NULL: + assert new_node.individual == input_node.individual + else: + assert ( + ts_with_roots.individual(new_node.individual).metadata + == ts.individual(input_node.individual).metadata + ) assert new_node.metadata == input_node.metadata # This should only be marked as a sample if it's an # element of the samples list. diff --git a/python/tests/test_wright_fisher.py b/python/tests/test_wright_fisher.py index 72686401d3..bcaf5ab68b 100644 --- a/python/tests/test_wright_fisher.py +++ b/python/tests/test_wright_fisher.py @@ -62,7 +62,7 @@ def __init__( num_pops=1, mig_rate=0.0, record_migrations=False, - record_individuals=False, + record_individuals=True, ): self.N = N self.num_pops = num_pops @@ -214,7 +214,7 @@ def wf_sim( num_pops=1, mig_rate=0.0, record_migrations=False, - record_individuals=False, + record_individuals=True, ): sim = WrightFisherSimulator( N, @@ -248,7 +248,6 @@ def test_one_gen_multipop_mig_no_deep(self): deep_history=False, seed=self.random_seed, record_migrations=True, - record_individuals=True, ) assert tables.nodes.num_rows == 5 * 4 * (1 + 1) assert tables.edges.num_rows > 0 @@ -266,7 +265,6 @@ def test_multipop_mig_deep(self): mig_rate=1.0, seed=self.random_seed, record_migrations=True, - record_individuals=True, ) assert tables.nodes.num_rows > (num_pops * N * ngens) + N assert tables.edges.num_rows > 0 @@ -297,7 +295,6 @@ def test_multipop_mig_no_deep(self): deep_history=False, seed=self.random_seed, record_migrations=True, - record_individuals=True, ) assert tables.nodes.num_rows == num_pops * N * (ngens + 1) assert tables.edges.num_rows > 0 @@ -323,7 +320,7 @@ def test_non_overlapping_generations(self): assert tables.sites.num_rows == 0 assert tables.mutations.num_rows == 0 assert tables.migrations.num_rows == 0 - assert tables.individuals.num_rows == 0 + assert tables.individuals.num_rows > 0 tables.sort() tables.simplify() ts = tables.tree_sequence() @@ -344,7 +341,7 @@ def test_overlapping_generations(self): assert tables.sites.num_rows == 0 assert tables.mutations.num_rows == 0 assert tables.migrations.num_rows == 0 - assert tables.individuals.num_rows == 0 + assert tables.individuals.num_rows > 0 tables.sort() tables.simplify() ts = tables.tree_sequence() @@ -359,7 +356,7 @@ def test_one_generation_no_deep_history(self): assert tables.sites.num_rows == 0 assert tables.mutations.num_rows == 0 assert tables.migrations.num_rows == 0 - assert tables.individuals.num_rows == 0 + assert tables.individuals.num_rows > 0 tables.sort() tables.simplify() ts = tables.tree_sequence() @@ -383,7 +380,7 @@ def test_many_generations_no_deep_history(self): assert tables.sites.num_rows == 0 assert tables.mutations.num_rows == 0 assert tables.migrations.num_rows == 0 - assert tables.individuals.num_rows == 0 + assert tables.individuals.num_rows > 0 tables.sort() tables.simplify() ts = tables.tree_sequence() @@ -448,9 +445,7 @@ def test_with_recurrent_mutations(self): def test_record_individuals_initial_state(self): N = 10 - tables = wf_sim( - N=N, ngens=0, seed=12345, record_individuals=True, deep_history=False - ) + tables = wf_sim(N=N, ngens=0, seed=12345, deep_history=False) tables.sort() assert len(tables.individuals) == N assert len(tables.nodes) == N @@ -461,9 +456,7 @@ def test_record_individuals_initial_state(self): def test_record_individuals(self): N = 10 - tables = wf_sim( - N=N, ngens=10, seed=12345, record_individuals=True, deep_history=False - ) + tables = wf_sim(N=N, ngens=10, seed=12345, deep_history=False) assert len(tables.individuals) == len(tables.nodes) for node_id, individual in enumerate(tables.nodes.individual): assert node_id == individual diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 5c4d99b5c5..1f6cc3786f 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -28,6 +28,7 @@ import json import random import string +import struct import numpy as np @@ -1756,3 +1757,11 @@ def sort_individual_table(tables): tables.nodes.individual = [ind_id_map[i] for i in tables.nodes.individual] return tables + + +def insert_unique_metadata(ts, table): + tables = ts.dump_tables() + getattr(tables, table).packset_metadata( + [struct.pack("I", i) for i in range(getattr(tables, table).num_rows)] + ) + return tables.tree_sequence()