Skip to content

Commit

Permalink
Merge pull request #492 from plotly/clustergram-row-labels
Browse files Browse the repository at this point in the history
Fix Clustergram when passing labels.
  • Loading branch information
mkcor committed Mar 16, 2020
2 parents 5468be6 + fe739cd commit c920884
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 33 deletions.
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Changelog

## [Unreleased]
## [0.4.8] - 2020-03-16
### Changed
* [#489](https://github.com/plotly/dash-bio/pull/489) Renamed async modules with hyphen `-` instead of tilde `~`
- [#494](https://github.com/plotly/dash-bio/pull/494) Update from React 16.8.6 to 16.13.0
* [#494](https://github.com/plotly/dash-bio/pull/494) Update from React 16.8.6 to 16.13.0

### Added
* [#492](https://github.com/plotly/dash-bio/pull/492) Added working support of
labels in Clustergram.

## [0.4.7] - 2020-02-21
### Added
Expand Down
2 changes: 1 addition & 1 deletion dash_bio/bundle.js

Large diffs are not rendered by default.

59 changes: 32 additions & 27 deletions dash_bio/component_factory/_clustergram.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def Clustergram(
(observation labels).
- column_labels (list; optional): List of column category labels
(observation labels).
- hidden_labels (list; optional): List of labels not to display on the
final plot.
- hidden_labels (list; optional): List containing strings 'row' and/or 'col'
if row and/or column labels should be hidden on the final plot.
- standardize (string; default 'none'): The dimension for standardizing
values, so that the mean is 0 and the standard deviation is 1,
along the specified dimension: 'row', 'column', or 'none'.
Expand Down Expand Up @@ -162,8 +162,6 @@ def Clustergram(
- width (number; default 500): The width of the graph, in px.
"""
if hidden_labels is None:
hidden_labels = []
if color_threshold is None:
color_threshold = dict(row=0, col=0)

Expand Down Expand Up @@ -243,16 +241,16 @@ def __init__(
hidden_labels = []
if color_threshold is None:
color_threshold = dict(row=0, col=0)
if row_labels is None:
row_labels = [str(i) for i in range(data.shape[0])]
hidden_labels.append("row")
if column_labels is None:
column_labels = [str(i) for i in range(data.shape[1])]
hidden_labels.append("col")
# Always keep unique identifiers for rows
row_ids = list(range(data.shape[0]))
# Always keep unique identifiers for columns
column_ids = list(range(data.shape[1]))

self._data = data
self._row_labels = row_labels
self._row_ids = row_ids
self._column_labels = column_labels
self._column_ids = column_ids
self._cluster = cluster
self._row_dist = row_dist
self._col_dist = col_dist
Expand Down Expand Up @@ -359,16 +357,23 @@ def figure(self, computed_traces=None):
(
dt,
self._data,
self._row_labels,
self._column_labels,
self._row_ids,
self._column_ids,
) = self._compute_clustered_data()
else:
# use, if available, the precomputed dendrogram and heatmap
# traces (as well as the row and column labels)
dt = computed_traces["dendro_traces"]
heatmap = computed_traces["heatmap"]
self._row_labels = computed_traces["row_labels"]
self._column_labels = computed_traces["column_labels"]
self._row_ids = computed_traces["row_ids"]
self._column_ids = computed_traces["column_ids"]

# Match reordered rows and columns with their respective labels
if self._row_labels:
self._row_labels = [self._row_labels[r] for r in self._row_ids]
if self._column_labels:
self._column_labels = [self._column_labels[r]
for r in self._column_ids]

# this dictionary relates curve numbers (accessible from the
# hoverData/clickData props) to cluster numbers
Expand Down Expand Up @@ -501,7 +506,7 @@ def figure(self, computed_traces=None):
xaxis2.update(scaleanchor="x5")

if len(tickvals_col) == 0:
tickvals_col = [10 * i + 5 for i in range(len(self._column_labels))]
tickvals_col = [10 * i + 5 for i in range(len(self._column_ids))]

# add in all of the labels
fig["layout"]["xaxis5"].update( # pylint: disable=invalid-sequence-index
Expand All @@ -518,7 +523,7 @@ def figure(self, computed_traces=None):
)

if len(tickvals_row) == 0:
tickvals_row = [10 * i + 5 for i in range(len(self._row_labels))]
tickvals_row = [10 * i + 5 for i in range(len(self._row_ids))]

fig["layout"]["yaxis5"].update( # pylint: disable=invalid-sequence-index
tickmode="array",
Expand Down Expand Up @@ -676,8 +681,8 @@ def figure(self, computed_traces=None):
computed_traces = {
"dendro_traces": dt,
"heatmap": heatmap,
"row_labels": self._row_labels,
"column_labels": self._column_labels,
"row_ids": self._row_ids,
"column_ids": self._column_ids,
}

return (fig, computed_traces, cluster_curve_numbers)
Expand Down Expand Up @@ -746,8 +751,8 @@ def _compute_clustered_data(self):
# first, compute the clusters
(Zcol, Zrow) = self._get_clusters()

clustered_column_labels = self._column_labels
clustered_row_labels = self._row_labels
clustered_column_ids = self._column_ids
clustered_row_ids = self._row_ids

# calculate dendrogram from clusters; sch.dendrogram returns sets
# of four coordinates that make up the 'u' shapes in the dendrogram
Expand All @@ -756,18 +761,18 @@ def _compute_clustered_data(self):
Zcol,
orientation="top",
color_threshold=self._color_threshold["col"],
labels=self._column_labels,
labels=self._column_ids,
no_plot=True,
)
clustered_column_labels = Pcol["ivl"]
clustered_column_ids = Pcol["ivl"]
trace_list["col"] = self._color_dendro_clusters(Pcol, "col")

if Zrow is not None:
Prow = sch.dendrogram(
Zrow,
orientation="left",
color_threshold=self._color_threshold["row"],
labels=self._row_labels,
labels=self._row_ids,
no_plot=True,
)
# need to flip the coordinates for the row dendrogram
Expand All @@ -776,21 +781,21 @@ def _compute_clustered_data(self):
"dcoord": Prow["icoord"],
"color_list": Prow["color_list"],
}
clustered_row_labels = Prow["ivl"]
clustered_row_ids = Prow["ivl"]
trace_list["row"] = self._color_dendro_clusters(Prow_tmp, "row")

# now, we need to rearrange the data array to fit the labels

# first get reordered indices
rl_indices = [self._row_labels.index(r) for r in clustered_row_labels]
cl_indices = [self._column_labels.index(c) for c in clustered_column_labels]
rl_indices = [self._row_ids.index(r) for r in clustered_row_ids]
cl_indices = [self._column_ids.index(c) for c in clustered_column_ids]

# modify the data here; first shuffle rows,
# then transpose and shuffle columns,
# then transpose again
clustered_data = self._data[rl_indices].T[cl_indices].T

return trace_list, clustered_data, clustered_row_labels, clustered_column_labels
return trace_list, clustered_data, clustered_row_ids, clustered_column_ids

def _color_dendro_clusters(self, P, dim):
"""Color each cluster below the color threshold separately.
Expand Down
2 changes: 1 addition & 1 deletion dash_bio/package-info.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"name": "dash_bio", "version": "0.4.7", "author": "The Plotly Team <dashbio@plot.ly>"}
{"name": "dash_bio", "version": "0.4.8", "author": "The Plotly Team <dashbio@plot.ly>"}
2 changes: 1 addition & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "dash-bio",
"version": "0.4.7",
"version": "0.4.8",
"description": "Dash components for bioinformatics",
"repository": {
"type": "git",
Expand Down
61 changes: 61 additions & 0 deletions tests/integration/test_clustergram.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,64 @@ def test_dbcl006_df_input_row_cluster(dash_duo):

assert len(dash_duo.find_elements('g.subplot.x2y2')) == 0
assert len(dash_duo.find_elements('g.subplot.x4y4')) == 1


def test_dbcl007_hidden_labels(dash_duo):

app = dash.Dash(__name__)

data = _mtcars_data
row_labels = list(_mtcars_data.index)
col_labels = list(_mtcars_data.columns)

app.layout = html.Div(nested_component_layout(
dash_bio.Clustergram(
data=data,
row_labels=row_labels,
column_labels=col_labels
)
))

nested_component_app_callback(
app,
dash_duo,
component=dash_bio.Clustergram,
component_data=data,
test_prop_name='hidden_labels',
test_prop_value='row',
prop_value_type='string'
)

# ensure that row labels are hidden
assert len(dash_duo.find_elements('g.yaxislayer-above g.y5tick')) == 0
# ensure that column labels are displayed
assert len(dash_duo.find_elements('g.xaxislayer-above g.x5tick')) == \
len(col_labels)

# create a new instance of the app to test hiding of column labels

app = dash.Dash(__name__)

app.layout = html.Div(nested_component_layout(
dash_bio.Clustergram(
data=data,
row_labels=row_labels,
column_labels=col_labels
)
))

nested_component_app_callback(
app,
dash_duo,
component=dash_bio.Clustergram,
component_data=data,
test_prop_name='hidden_labels',
test_prop_value='col',
prop_value_type='string'
)

# ensure that column labels are hidden
assert len(dash_duo.find_elements('g.xaxislayer-above g.x5tick')) == 0
# ensure that row labels are displayed
assert len(dash_duo.find_elements('g.yaxislayer-above g.y5tick')) == \
len(row_labels)
34 changes: 34 additions & 0 deletions tests/unit/test_clustergram.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,37 @@ def test_read_dataframe():
clustered_data = CLUSTERED_DATA

assert np.array_equal(curves_dict['heatmap']['z'], clustered_data)


def test_row_labels():
"""Test that specifying row labels preserves clustering."""

data = DATA
row_labels = ['a', 'b', 'b', 'b', 'b', 'b']
_, _, curves_dict = Clustergram(
data,
generate_curves_dict=True,
return_computed_traces=True,
row_labels=row_labels,
center_values=False
)
clustered_data = CLUSTERED_DATA

assert np.array_equal(curves_dict['heatmap']['z'], clustered_data)


def test_column_labels():
"""Test that specifying column labels preserves clustering."""

data = DATA.T
column_labels = ['a', 'b', 'b', 'b', 'b', 'b']
_, _, curves_dict = Clustergram(
data,
generate_curves_dict=True,
return_computed_traces=True,
column_labels=column_labels,
center_values=False
)
clustered_data = CLUSTERED_DATA.T

assert np.array_equal(curves_dict['heatmap']['z'], clustered_data)

0 comments on commit c920884

Please sign in to comment.