diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst
index 6b160fbada..f856915ce1 100644
--- a/python/CHANGELOG.rst
+++ b/python/CHANGELOG.rst
@@ -6,8 +6,11 @@ In development
**New features**
+- Add ``_repr_html_`` to tables, so that jupyter notebooks render them as
+ html tables (:user:`benjeffery`, :pr:`514`)
+
- Remove support for ``kc_distance`` on trees with unary nodes
- (:user:`daniel-goldstein`, :`508`)
+ (:user:`daniel-goldstein`, :pr:`508`)
- Improve Kendall-Colijn tree distance algorithm to operate in O(n^2) time
instead of O(n^2 * log(n)) where n is the number of samples
@@ -19,10 +22,10 @@ In development
- Add a metadata column to the edges table. Works similarly to existing
metadata columns on other tables(:user:`benjeffery`, :pr:`496`).
-- Allow sites with missing data to be output by the `haplotypes` method, by
+- Allow sites with missing data to be output by the ``haplotypes`` method, by
default replacing with ``-``. Errors are no longer raised for missing data
- with `impute_missing_data=False`; the error types returned for bad alleles
- (e.g. multiletter or non-ascii) have also changed from `_tskit.LibraryError`
+ with ``impute_missing_data=False``; the error types returned for bad alleles
+ (e.g. multiletter or non-ascii) have also changed from ``_tskit.LibraryError``
to TypeError, or ValueError if the missing data character clashes
(:user:`hyanwong`, :pr:`426`).
@@ -36,7 +39,7 @@ In development
us to efficiently iterate over 'real' roots when we have
missing data (:user:`jeromekelleher`, :pr:`462`).
-- Add pickle support for `TreeSequence` (:user:`terhorst`, :pr:`473`).
+- Add pickle support for ``TreeSequence`` (:user:`terhorst`, :pr:`473`).
**Bugfixes**
@@ -75,7 +78,7 @@ method to manipulate tree sequence data.
- Fix height scaling issues with SVG tree drawing (:user:`jeromekelleher`,
:pr:`407`, :issue:`383`, :pr:`378`).
-- Do not reuse buffers in LdCalculator (:user:`jeromekelleher`). See :pr:`397` and
+- Do not reuse buffers in ``LdCalculator`` (:user:`jeromekelleher`). See :pr:`397` and
:issue:`396`.
--------------------
diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py
index 4a717a4f7f..eed28f4c23 100644
--- a/python/tests/test_tables.py
+++ b/python/tests/test_tables.py
@@ -503,6 +503,18 @@ def test_str(self):
s = str(table)
self.assertEqual(len(s.splitlines()), num_rows + 1)
+ def test_repr_html(self):
+ for num_rows in [0, 10]:
+ input_data = {col.name: col.get_input(num_rows) for col in self.columns}
+ for list_col, offset_col in self.ragged_list_columns:
+ value = list_col.get_input(num_rows)
+ input_data[list_col.name] = value
+ input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32)
+ table = self.table_class()
+ table.set_columns(**input_data)
+ html = table._repr_html_()
+ self.assertEqual(len(html.splitlines()), num_rows + 19)
+
def test_copy(self):
for num_rows in [0, 10]:
input_data = {col.name: col.get_input(num_rows) for col in self.columns}
@@ -1767,9 +1779,8 @@ def test_table_references(self):
mutations = tables.mutations
before_populations = str(tables.populations)
populations = tables.populations
- before_nodes = str(tables.nodes)
- provenances = tables.provenances
before_provenances = str(tables.provenances)
+ provenances = tables.provenances
del tables
self.assertEqual(str(individuals), before_individuals)
self.assertEqual(str(nodes), before_nodes)
diff --git a/python/tskit/tables.py b/python/tskit/tables.py
index ac8232dca4..ba304fa049 100644
--- a/python/tskit/tables.py
+++ b/python/tskit/tables.py
@@ -206,6 +206,38 @@ def set_columns(self, **kwargs):
"""
raise NotImplementedError()
+ def __str__(self):
+ headers, rows = self._text_header_and_rows()
+ return "\n".join("\t".join(row) for row in [headers] + rows)
+
+ def _repr_html_(self):
+ """
+ Called by jupyter notebooks to render tables
+ """
+ headers, rows = self._text_header_and_rows()
+ headers = "".join(f"
{header} | " for header in headers)
+ rows = ("".join(f"{cell} | " for cell in row) for row in rows)
+ rows = "".join(f"{row}
\n" for row in rows)
+ return f"""
+
+
+
+
+
+ {headers}
+
+
+
+ {rows}
+
+
+
+ """
+
class MetadataMixin:
"""
@@ -270,16 +302,19 @@ def __init__(self, max_rows_increment=0, ll_table=None):
ll_table = _tskit.IndividualTable(max_rows_increment=max_rows_increment)
super().__init__(ll_table, IndividualTableRow)
- def __str__(self):
+ def _text_header_and_rows(self):
flags = self.flags
location = util.unpack_arrays(self.location, self.location_offset)
metadata = util.unpack_bytes(self.metadata, self.metadata_offset)
- ret = "id\tflags\tlocation\tmetadata\n"
+ headers = ("id", "flags", "location", "metadata")
+ rows = []
for j in range(self.num_rows):
md = base64.b64encode(metadata[j]).decode("utf8")
location_str = ",".join(map(str, location))
- ret += "{}\t{}\t{}\t{}\n".format(j, flags[j], location_str, md)
- return ret[:-1]
+ rows.append(
+ "{}\t{}\t{}\t{}".format(j, flags[j], location_str, md).split("\t")
+ )
+ return headers, rows
def add_row(self, flags=0, location=None, metadata=None):
"""
@@ -450,19 +485,22 @@ def __init__(self, max_rows_increment=0, ll_table=None):
ll_table = _tskit.NodeTable(max_rows_increment=max_rows_increment)
super().__init__(ll_table, NodeTableRow)
- def __str__(self):
+ def _text_header_and_rows(self):
time = self.time
flags = self.flags
population = self.population
individual = self.individual
metadata = util.unpack_bytes(self.metadata, self.metadata_offset)
- ret = "id\tflags\tpopulation\tindividual\ttime\tmetadata\n"
+ headers = ("id", "flags", "population", "individual", "time", "metadata")
+ rows = []
for j in range(self.num_rows):
md = base64.b64encode(metadata[j]).decode("utf8")
- ret += "{}\t{}\t{}\t{}\t{:.14f}\t{}\n".format(
- j, flags[j], population[j], individual[j], time[j], md
+ rows.append(
+ "{}\t{}\t{}\t{}\t{:.14f}\t{}".format(
+ j, flags[j], population[j], individual[j], time[j], md
+ ).split("\t")
)
- return ret[:-1]
+ return headers, rows
def add_row(self, flags=0, time=0, population=-1, individual=-1, metadata=None):
"""
@@ -624,19 +662,22 @@ def __init__(self, max_rows_increment=0, ll_table=None):
ll_table = _tskit.EdgeTable(max_rows_increment=max_rows_increment)
super().__init__(ll_table, EdgeTableRow)
- def __str__(self):
+ def _text_header_and_rows(self):
left = self.left
right = self.right
parent = self.parent
child = self.child
metadata = util.unpack_bytes(self.metadata, self.metadata_offset)
- ret = "id\tleft\t\tright\t\tparent\tchild\tmetadata\n"
+ headers = ("id", "left\t", "right\t", "parent", "child", "metadata")
+ rows = []
for j in range(self.num_rows):
md = base64.b64encode(metadata[j]).decode("utf8")
- ret += "{}\t{:.8f}\t{:.8f}\t{}\t{}\t{}\n".format(
- j, left[j], right[j], parent[j], child[j], md
+ rows.append(
+ "{}\t{:.8f}\t{:.8f}\t{}\t{}\t{}".format(
+ j, left[j], right[j], parent[j], child[j], md
+ ).split("\t")
)
- return ret[:-1]
+ return headers, rows
def add_row(self, left, right, parent, child, metadata=None):
"""
@@ -812,7 +853,7 @@ def __init__(self, max_rows_increment=0, ll_table=None):
ll_table = _tskit.MigrationTable(max_rows_increment=max_rows_increment)
super().__init__(ll_table, MigrationTableRow)
- def __str__(self):
+ def _text_header_and_rows(self):
left = self.left
right = self.right
node = self.node
@@ -820,13 +861,16 @@ def __str__(self):
dest = self.dest
time = self.time
metadata = util.unpack_bytes(self.metadata, self.metadata_offset)
- ret = "id\tleft\tright\tnode\tsource\tdest\ttime\tmetadata\n"
+ headers = ("id", "left", "right", "node", "source", "dest", "time", "metadata")
+ rows = []
for j in range(self.num_rows):
md = base64.b64encode(metadata[j]).decode("utf8")
- ret += "{}\t{:.8f}\t{:.8f}\t{}\t{}\t{}\t{:.8f}\t{}\n".format(
- j, left[j], right[j], node[j], source[j], dest[j], time[j], md
+ rows.append(
+ "{}\t{:.8f}\t{:.8f}\t{}\t{}\t{}\t{:.8f}\t{}".format(
+ j, left[j], right[j], node[j], source[j], dest[j], time[j], md
+ ).split("\t")
)
- return ret[:-1]
+ return headers, rows
def add_row(self, left, right, node, source, dest, time, metadata=None):
"""
@@ -1002,17 +1046,22 @@ def __init__(self, max_rows_increment=0, ll_table=None):
ll_table = _tskit.SiteTable(max_rows_increment=max_rows_increment)
super().__init__(ll_table, SiteTableRow)
- def __str__(self):
+ def _text_header_and_rows(self):
position = self.position
ancestral_state = util.unpack_strings(
self.ancestral_state, self.ancestral_state_offset
)
metadata = util.unpack_bytes(self.metadata, self.metadata_offset)
- ret = "id\tposition\tancestral_state\tmetadata\n"
+ headers = ("id", "position", "ancestral_state", "metadata")
+ rows = []
for j in range(self.num_rows):
md = base64.b64encode(metadata[j]).decode("utf8")
- ret += "{}\t{:.8f}\t{}\t{}\n".format(j, position[j], ancestral_state[j], md)
- return ret[:-1]
+ rows.append(
+ "{}\t{:.8f}\t{}\t{}".format(
+ j, position[j], ancestral_state[j], md
+ ).split("\t")
+ )
+ return headers, rows
def add_row(self, position, ancestral_state, metadata=None):
"""
@@ -1194,7 +1243,7 @@ def __init__(self, max_rows_increment=0, ll_table=None):
ll_table = _tskit.MutationTable(max_rows_increment=max_rows_increment)
super().__init__(ll_table, MutationTableRow)
- def __str__(self):
+ def _text_header_and_rows(self):
site = self.site
node = self.node
parent = self.parent
@@ -1202,13 +1251,16 @@ def __str__(self):
self.derived_state, self.derived_state_offset
)
metadata = util.unpack_bytes(self.metadata, self.metadata_offset)
- ret = "id\tsite\tnode\tderived_state\tparent\tmetadata\n"
+ headers = ("id", "site", "node", "derived_state", "parent", "metadata")
+ rows = []
for j in range(self.num_rows):
md = base64.b64encode(metadata[j]).decode("utf8")
- ret += "{}\t{}\t{}\t{}\t{}\t{}\n".format(
- j, site[j], node[j], derived_state[j], parent[j], md
+ rows.append(
+ "{}\t{}\t{}\t{}\t{}\t{}".format(
+ j, site[j], node[j], derived_state[j], parent[j], md
+ ).split("\t")
)
- return ret[:-1]
+ return headers, rows
def add_row(self, site, node, derived_state, parent=-1, metadata=None):
"""
@@ -1404,13 +1456,14 @@ def add_row(self, metadata=None):
"""
return self.ll_table.add_row(metadata=metadata)
- def __str__(self):
+ def _text_header_and_rows(self):
metadata = util.unpack_bytes(self.metadata, self.metadata_offset)
- ret = "id\tmetadata\n"
+ headers = ("id", "metadata")
+ rows = []
for j in range(self.num_rows):
md = base64.b64encode(metadata[j]).decode("utf8")
- ret += f"{j}\t{md}\n"
- return ret[:-1]
+ rows.append((str(j), str(md)))
+ return headers, rows
def set_columns(self, metadata=None, metadata_offset=None):
"""
@@ -1577,13 +1630,14 @@ def append_columns(
)
)
- def __str__(self):
+ def _text_header_and_rows(self):
timestamp = util.unpack_strings(self.timestamp, self.timestamp_offset)
record = util.unpack_strings(self.record, self.record_offset)
- ret = "id\ttimestamp\trecord\n"
+ headers = ("id", "timestamp", "record")
+ rows = []
for j in range(self.num_rows):
- ret += "{}\t{}\t{}\n".format(j, timestamp[j], record[j])
- return ret[:-1]
+ rows.append((str(j), str(timestamp[j]), str(record[j])))
+ return headers, rows
def packset_record(self, records):
"""