From 0ae45237b0cc51b95a53eab2d680a47f4e9ed239 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Tue, 7 Apr 2020 01:02:47 +0100 Subject: [PATCH] Add _repr_html_ to tables --- python/CHANGELOG.rst | 15 +++-- python/tests/test_tables.py | 15 ++++- python/tskit/tables.py | 126 +++++++++++++++++++++++++----------- 3 files changed, 112 insertions(+), 44 deletions(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 6b160fbada..f856915ce1 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -6,8 +6,11 @@ In development **New features** +- Add ``_repr_html_`` to tables, so that jupyter notebooks render them as + html tables (:user:`benjeffery`, :pr:`514`) + - Remove support for ``kc_distance`` on trees with unary nodes - (:user:`daniel-goldstein`, :`508`) + (:user:`daniel-goldstein`, :pr:`508`) - Improve Kendall-Colijn tree distance algorithm to operate in O(n^2) time instead of O(n^2 * log(n)) where n is the number of samples @@ -19,10 +22,10 @@ In development - Add a metadata column to the edges table. Works similarly to existing metadata columns on other tables(:user:`benjeffery`, :pr:`496`). -- Allow sites with missing data to be output by the `haplotypes` method, by +- Allow sites with missing data to be output by the ``haplotypes`` method, by default replacing with ``-``. Errors are no longer raised for missing data - with `impute_missing_data=False`; the error types returned for bad alleles - (e.g. multiletter or non-ascii) have also changed from `_tskit.LibraryError` + with ``impute_missing_data=False``; the error types returned for bad alleles + (e.g. multiletter or non-ascii) have also changed from ``_tskit.LibraryError`` to TypeError, or ValueError if the missing data character clashes (:user:`hyanwong`, :pr:`426`). @@ -36,7 +39,7 @@ In development us to efficiently iterate over 'real' roots when we have missing data (:user:`jeromekelleher`, :pr:`462`). -- Add pickle support for `TreeSequence` (:user:`terhorst`, :pr:`473`). +- Add pickle support for ``TreeSequence`` (:user:`terhorst`, :pr:`473`). **Bugfixes** @@ -75,7 +78,7 @@ method to manipulate tree sequence data. - Fix height scaling issues with SVG tree drawing (:user:`jeromekelleher`, :pr:`407`, :issue:`383`, :pr:`378`). -- Do not reuse buffers in LdCalculator (:user:`jeromekelleher`). See :pr:`397` and +- Do not reuse buffers in ``LdCalculator`` (:user:`jeromekelleher`). See :pr:`397` and :issue:`396`. -------------------- diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 4a717a4f7f..eed28f4c23 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -503,6 +503,18 @@ def test_str(self): s = str(table) self.assertEqual(len(s.splitlines()), num_rows + 1) + def test_repr_html(self): + for num_rows in [0, 10]: + input_data = {col.name: col.get_input(num_rows) for col in self.columns} + for list_col, offset_col in self.ragged_list_columns: + value = list_col.get_input(num_rows) + input_data[list_col.name] = value + input_data[offset_col.name] = np.arange(num_rows + 1, dtype=np.uint32) + table = self.table_class() + table.set_columns(**input_data) + html = table._repr_html_() + self.assertEqual(len(html.splitlines()), num_rows + 19) + def test_copy(self): for num_rows in [0, 10]: input_data = {col.name: col.get_input(num_rows) for col in self.columns} @@ -1767,9 +1779,8 @@ def test_table_references(self): mutations = tables.mutations before_populations = str(tables.populations) populations = tables.populations - before_nodes = str(tables.nodes) - provenances = tables.provenances before_provenances = str(tables.provenances) + provenances = tables.provenances del tables self.assertEqual(str(individuals), before_individuals) self.assertEqual(str(nodes), before_nodes) diff --git a/python/tskit/tables.py b/python/tskit/tables.py index ac8232dca4..ba304fa049 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -206,6 +206,38 @@ def set_columns(self, **kwargs): """ raise NotImplementedError() + def __str__(self): + headers, rows = self._text_header_and_rows() + return "\n".join("\t".join(row) for row in [headers] + rows) + + def _repr_html_(self): + """ + Called by jupyter notebooks to render tables + """ + headers, rows = self._text_header_and_rows() + headers = "".join(f"{header}" for header in headers) + rows = ("".join(f"{cell}" for cell in row) for row in rows) + rows = "".join(f"{row}\n" for row in rows) + return f""" +
+ + + + + {headers} + + + + {rows} + +
+
+ """ + class MetadataMixin: """ @@ -270,16 +302,19 @@ def __init__(self, max_rows_increment=0, ll_table=None): ll_table = _tskit.IndividualTable(max_rows_increment=max_rows_increment) super().__init__(ll_table, IndividualTableRow) - def __str__(self): + def _text_header_and_rows(self): flags = self.flags location = util.unpack_arrays(self.location, self.location_offset) metadata = util.unpack_bytes(self.metadata, self.metadata_offset) - ret = "id\tflags\tlocation\tmetadata\n" + headers = ("id", "flags", "location", "metadata") + rows = [] for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode("utf8") location_str = ",".join(map(str, location)) - ret += "{}\t{}\t{}\t{}\n".format(j, flags[j], location_str, md) - return ret[:-1] + rows.append( + "{}\t{}\t{}\t{}".format(j, flags[j], location_str, md).split("\t") + ) + return headers, rows def add_row(self, flags=0, location=None, metadata=None): """ @@ -450,19 +485,22 @@ def __init__(self, max_rows_increment=0, ll_table=None): ll_table = _tskit.NodeTable(max_rows_increment=max_rows_increment) super().__init__(ll_table, NodeTableRow) - def __str__(self): + def _text_header_and_rows(self): time = self.time flags = self.flags population = self.population individual = self.individual metadata = util.unpack_bytes(self.metadata, self.metadata_offset) - ret = "id\tflags\tpopulation\tindividual\ttime\tmetadata\n" + headers = ("id", "flags", "population", "individual", "time", "metadata") + rows = [] for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode("utf8") - ret += "{}\t{}\t{}\t{}\t{:.14f}\t{}\n".format( - j, flags[j], population[j], individual[j], time[j], md + rows.append( + "{}\t{}\t{}\t{}\t{:.14f}\t{}".format( + j, flags[j], population[j], individual[j], time[j], md + ).split("\t") ) - return ret[:-1] + return headers, rows def add_row(self, flags=0, time=0, population=-1, individual=-1, metadata=None): """ @@ -624,19 +662,22 @@ def __init__(self, max_rows_increment=0, ll_table=None): ll_table = _tskit.EdgeTable(max_rows_increment=max_rows_increment) super().__init__(ll_table, EdgeTableRow) - def __str__(self): + def _text_header_and_rows(self): left = self.left right = self.right parent = self.parent child = self.child metadata = util.unpack_bytes(self.metadata, self.metadata_offset) - ret = "id\tleft\t\tright\t\tparent\tchild\tmetadata\n" + headers = ("id", "left\t", "right\t", "parent", "child", "metadata") + rows = [] for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode("utf8") - ret += "{}\t{:.8f}\t{:.8f}\t{}\t{}\t{}\n".format( - j, left[j], right[j], parent[j], child[j], md + rows.append( + "{}\t{:.8f}\t{:.8f}\t{}\t{}\t{}".format( + j, left[j], right[j], parent[j], child[j], md + ).split("\t") ) - return ret[:-1] + return headers, rows def add_row(self, left, right, parent, child, metadata=None): """ @@ -812,7 +853,7 @@ def __init__(self, max_rows_increment=0, ll_table=None): ll_table = _tskit.MigrationTable(max_rows_increment=max_rows_increment) super().__init__(ll_table, MigrationTableRow) - def __str__(self): + def _text_header_and_rows(self): left = self.left right = self.right node = self.node @@ -820,13 +861,16 @@ def __str__(self): dest = self.dest time = self.time metadata = util.unpack_bytes(self.metadata, self.metadata_offset) - ret = "id\tleft\tright\tnode\tsource\tdest\ttime\tmetadata\n" + headers = ("id", "left", "right", "node", "source", "dest", "time", "metadata") + rows = [] for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode("utf8") - ret += "{}\t{:.8f}\t{:.8f}\t{}\t{}\t{}\t{:.8f}\t{}\n".format( - j, left[j], right[j], node[j], source[j], dest[j], time[j], md + rows.append( + "{}\t{:.8f}\t{:.8f}\t{}\t{}\t{}\t{:.8f}\t{}".format( + j, left[j], right[j], node[j], source[j], dest[j], time[j], md + ).split("\t") ) - return ret[:-1] + return headers, rows def add_row(self, left, right, node, source, dest, time, metadata=None): """ @@ -1002,17 +1046,22 @@ def __init__(self, max_rows_increment=0, ll_table=None): ll_table = _tskit.SiteTable(max_rows_increment=max_rows_increment) super().__init__(ll_table, SiteTableRow) - def __str__(self): + def _text_header_and_rows(self): position = self.position ancestral_state = util.unpack_strings( self.ancestral_state, self.ancestral_state_offset ) metadata = util.unpack_bytes(self.metadata, self.metadata_offset) - ret = "id\tposition\tancestral_state\tmetadata\n" + headers = ("id", "position", "ancestral_state", "metadata") + rows = [] for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode("utf8") - ret += "{}\t{:.8f}\t{}\t{}\n".format(j, position[j], ancestral_state[j], md) - return ret[:-1] + rows.append( + "{}\t{:.8f}\t{}\t{}".format( + j, position[j], ancestral_state[j], md + ).split("\t") + ) + return headers, rows def add_row(self, position, ancestral_state, metadata=None): """ @@ -1194,7 +1243,7 @@ def __init__(self, max_rows_increment=0, ll_table=None): ll_table = _tskit.MutationTable(max_rows_increment=max_rows_increment) super().__init__(ll_table, MutationTableRow) - def __str__(self): + def _text_header_and_rows(self): site = self.site node = self.node parent = self.parent @@ -1202,13 +1251,16 @@ def __str__(self): self.derived_state, self.derived_state_offset ) metadata = util.unpack_bytes(self.metadata, self.metadata_offset) - ret = "id\tsite\tnode\tderived_state\tparent\tmetadata\n" + headers = ("id", "site", "node", "derived_state", "parent", "metadata") + rows = [] for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode("utf8") - ret += "{}\t{}\t{}\t{}\t{}\t{}\n".format( - j, site[j], node[j], derived_state[j], parent[j], md + rows.append( + "{}\t{}\t{}\t{}\t{}\t{}".format( + j, site[j], node[j], derived_state[j], parent[j], md + ).split("\t") ) - return ret[:-1] + return headers, rows def add_row(self, site, node, derived_state, parent=-1, metadata=None): """ @@ -1404,13 +1456,14 @@ def add_row(self, metadata=None): """ return self.ll_table.add_row(metadata=metadata) - def __str__(self): + def _text_header_and_rows(self): metadata = util.unpack_bytes(self.metadata, self.metadata_offset) - ret = "id\tmetadata\n" + headers = ("id", "metadata") + rows = [] for j in range(self.num_rows): md = base64.b64encode(metadata[j]).decode("utf8") - ret += f"{j}\t{md}\n" - return ret[:-1] + rows.append((str(j), str(md))) + return headers, rows def set_columns(self, metadata=None, metadata_offset=None): """ @@ -1577,13 +1630,14 @@ def append_columns( ) ) - def __str__(self): + def _text_header_and_rows(self): timestamp = util.unpack_strings(self.timestamp, self.timestamp_offset) record = util.unpack_strings(self.record, self.record_offset) - ret = "id\ttimestamp\trecord\n" + headers = ("id", "timestamp", "record") + rows = [] for j in range(self.num_rows): - ret += "{}\t{}\t{}\n".format(j, timestamp[j], record[j]) - return ret[:-1] + rows.append((str(j), str(timestamp[j]), str(record[j]))) + return headers, rows def packset_record(self, records): """