Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

**Changes**

- Tables in a table collection can be replaced using the replace_with method
(:user:`hyanwong`, :issue:`1489` :pr:`2389`)

- SVG drawing routines now return a special string object that is automatically
rendered in a Jupyter notebook (:user:`hyanwong`, :pr:`2377`)

Expand Down
25 changes: 25 additions & 0 deletions python/tests/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,11 @@ def test_bad_offsets(self):
t.append_columns(**input_data)
input_data[offset_col.name] = np.copy(original_offset)

def test_replace_with_wrong_class(self):
t = self.table_class()
with pytest.raises(TypeError, match="wrong type"):
t.replace_with(tskit.BaseTable(None, None))


class MetadataTestsMixin:
"""
Expand Down Expand Up @@ -3834,6 +3839,12 @@ def test_dump_load_errors(self, ts_fixture):
with pytest.raises(TypeError):
func(bad_filename)

def test_set_table(self):
tc = tskit.TableCollection()
for name, table in tc.table_name_map.items():
with pytest.raises(AttributeError, match="replace_with"):
setattr(tc, name, table)


class TestEqualityOptions:
def test_equals_provenance(self):
Expand Down Expand Up @@ -4383,6 +4394,20 @@ def test_set_columns_not_implemented(self):
with pytest.raises(NotImplementedError):
t.set_columns()

def test_replace_with(self, ts_fixture):
# Although replace_with is a BaseTable method, it is simpler to test it
# on the subclasses directly, as some differ e.g. in having metadata schemas
original_tables = ts_fixture.dump_tables()
original_tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json()
new_tables = ts_fixture.dump_tables()
new_tables.clear(clear_provenance=True, clear_metadata_schemas=True)

# write all the data back in again
for name, table in new_tables.table_name_map.items():
new_table = getattr(original_tables, name)
table.replace_with(new_table)
new_tables.assert_equals(original_tables)


class TestSubsetTables:
"""
Expand Down
53 changes: 53 additions & 0 deletions python/tskit/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,23 @@ def append(self, row):
}
)

def replace_with(self, other):
# Overwrite the contents of this table with a copy of the other table
params = {}
for column in self.column_names:
try:
params[column] = getattr(other, column)
except AttributeError:
raise TypeError(
"Replacement table has wrong type: it lacks a {column} column"
)
try:
# Not all tables have a metadata_schema: if they do, encode it with repr
params["metadata_schema"] = repr(other.metadata_schema)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use repr here and not a direct API? Seem indirect and brittle (we might change repr to something slightly different, perhaps)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repr is the standard way to get the bytes representation of a schema - it's used by (for example) metdata.py:812 and is well tested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine internally maybe, but it's not a great convention. A well named method would be much clearer to the reader

except AttributeError:
pass
self.set_columns(**params)

def clear(self):
"""
Deletes all rows in this table.
Expand Down Expand Up @@ -2869,6 +2886,10 @@ class TableCollection(metadata.MetadataProvider):
method.
"""

set_err_text = (
"Cannot set tables in a table collection: use table.replace_with() instead."
)

def __init__(self, sequence_length=0):
self._ll_tables = _tskit.TableCollection(sequence_length)
super().__init__(self._ll_tables)
Expand All @@ -2880,55 +2901,87 @@ def individuals(self) -> IndividualTable:
"""
return IndividualTable(ll_table=self._ll_tables.individuals)

@individuals.setter
def individuals(self, value):
raise AttributeError(self.set_err_text)

@property
def nodes(self) -> NodeTable:
"""
The :ref:`sec_node_table_definition` in this collection.
"""
return NodeTable(ll_table=self._ll_tables.nodes)

@nodes.setter
def nodes(self, value):
raise AttributeError(self.set_err_text)

@property
def edges(self) -> EdgeTable:
"""
The :ref:`sec_edge_table_definition` in this collection.
"""
return EdgeTable(ll_table=self._ll_tables.edges)

@edges.setter
def edges(self, value):
raise AttributeError(self.set_err_text)

@property
def migrations(self) -> MigrationTable:
"""
The :ref:`sec_migration_table_definition` in this collection
"""
return MigrationTable(ll_table=self._ll_tables.migrations)

@migrations.setter
def migrations(self, value):
raise AttributeError(self.set_err_text)

@property
def sites(self) -> SiteTable:
"""
The :ref:`sec_site_table_definition` in this collection.
"""
return SiteTable(ll_table=self._ll_tables.sites)

@sites.setter
def sites(self, value):
raise AttributeError(self.set_err_text)

@property
def mutations(self) -> MutationTable:
"""
The :ref:`sec_mutation_table_definition` in this collection.
"""
return MutationTable(ll_table=self._ll_tables.mutations)

@mutations.setter
def mutations(self, value):
raise AttributeError(self.set_err_text)

@property
def populations(self) -> PopulationTable:
"""
The :ref:`sec_population_table_definition` in this collection.
"""
return PopulationTable(ll_table=self._ll_tables.populations)

@populations.setter
def populations(self, value):
raise AttributeError(self.set_err_text)

@property
def provenances(self) -> ProvenanceTable:
"""
The :ref:`sec_provenance_table_definition` in this collection.
"""
return ProvenanceTable(ll_table=self._ll_tables.provenances)

@provenances.setter
def provenances(self, value):
raise AttributeError(self.set_err_text)

@property
def indexes(self) -> TableCollectionIndexes:
"""
Expand Down