diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 3d09d0b1af..f2b6c6d12f 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -4,6 +4,9 @@ **Changes** +- Tables in a table collection can be replaced using the replace_with method + (:user:`hyanwong`, :issue:`1489` :pr:`2389`) + - SVG drawing routines now return a special string object that is automatically rendered in a Jupyter notebook (:user:`hyanwong`, :pr:`2377`) diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 025d280631..1f573b5b94 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -823,6 +823,11 @@ def test_bad_offsets(self): t.append_columns(**input_data) input_data[offset_col.name] = np.copy(original_offset) + def test_replace_with_wrong_class(self): + t = self.table_class() + with pytest.raises(TypeError, match="wrong type"): + t.replace_with(tskit.BaseTable(None, None)) + class MetadataTestsMixin: """ @@ -3834,6 +3839,12 @@ def test_dump_load_errors(self, ts_fixture): with pytest.raises(TypeError): func(bad_filename) + def test_set_table(self): + tc = tskit.TableCollection() + for name, table in tc.table_name_map.items(): + with pytest.raises(AttributeError, match="replace_with"): + setattr(tc, name, table) + class TestEqualityOptions: def test_equals_provenance(self): @@ -4383,6 +4394,20 @@ def test_set_columns_not_implemented(self): with pytest.raises(NotImplementedError): t.set_columns() + def test_replace_with(self, ts_fixture): + # Although replace_with is a BaseTable method, it is simpler to test it + # on the subclasses directly, as some differ e.g. in having metadata schemas + original_tables = ts_fixture.dump_tables() + original_tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json() + new_tables = ts_fixture.dump_tables() + new_tables.clear(clear_provenance=True, clear_metadata_schemas=True) + + # write all the data back in again + for name, table in new_tables.table_name_map.items(): + new_table = getattr(original_tables, name) + table.replace_with(new_table) + new_tables.assert_equals(original_tables) + class TestSubsetTables: """ diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 08bfba0e22..74de73cc89 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -589,6 +589,23 @@ def append(self, row): } ) + def replace_with(self, other): + # Overwrite the contents of this table with a copy of the other table + params = {} + for column in self.column_names: + try: + params[column] = getattr(other, column) + except AttributeError: + raise TypeError( + "Replacement table has wrong type: it lacks a {column} column" + ) + try: + # Not all tables have a metadata_schema: if they do, encode it with repr + params["metadata_schema"] = repr(other.metadata_schema) + except AttributeError: + pass + self.set_columns(**params) + def clear(self): """ Deletes all rows in this table. @@ -2869,6 +2886,10 @@ class TableCollection(metadata.MetadataProvider): method. """ + set_err_text = ( + "Cannot set tables in a table collection: use table.replace_with() instead." + ) + def __init__(self, sequence_length=0): self._ll_tables = _tskit.TableCollection(sequence_length) super().__init__(self._ll_tables) @@ -2880,6 +2901,10 @@ def individuals(self) -> IndividualTable: """ return IndividualTable(ll_table=self._ll_tables.individuals) + @individuals.setter + def individuals(self, value): + raise AttributeError(self.set_err_text) + @property def nodes(self) -> NodeTable: """ @@ -2887,6 +2912,10 @@ def nodes(self) -> NodeTable: """ return NodeTable(ll_table=self._ll_tables.nodes) + @nodes.setter + def nodes(self, value): + raise AttributeError(self.set_err_text) + @property def edges(self) -> EdgeTable: """ @@ -2894,6 +2923,10 @@ def edges(self) -> EdgeTable: """ return EdgeTable(ll_table=self._ll_tables.edges) + @edges.setter + def edges(self, value): + raise AttributeError(self.set_err_text) + @property def migrations(self) -> MigrationTable: """ @@ -2901,6 +2934,10 @@ def migrations(self) -> MigrationTable: """ return MigrationTable(ll_table=self._ll_tables.migrations) + @migrations.setter + def migrations(self, value): + raise AttributeError(self.set_err_text) + @property def sites(self) -> SiteTable: """ @@ -2908,6 +2945,10 @@ def sites(self) -> SiteTable: """ return SiteTable(ll_table=self._ll_tables.sites) + @sites.setter + def sites(self, value): + raise AttributeError(self.set_err_text) + @property def mutations(self) -> MutationTable: """ @@ -2915,6 +2956,10 @@ def mutations(self) -> MutationTable: """ return MutationTable(ll_table=self._ll_tables.mutations) + @mutations.setter + def mutations(self, value): + raise AttributeError(self.set_err_text) + @property def populations(self) -> PopulationTable: """ @@ -2922,6 +2967,10 @@ def populations(self) -> PopulationTable: """ return PopulationTable(ll_table=self._ll_tables.populations) + @populations.setter + def populations(self, value): + raise AttributeError(self.set_err_text) + @property def provenances(self) -> ProvenanceTable: """ @@ -2929,6 +2978,10 @@ def provenances(self) -> ProvenanceTable: """ return ProvenanceTable(ll_table=self._ll_tables.provenances) + @provenances.setter + def provenances(self, value): + raise AttributeError(self.set_err_text) + @property def indexes(self) -> TableCollectionIndexes: """