diff --git a/plotly/tools.py b/plotly/tools.py index ebb9a5fbdd8..10ae7838a20 100644 --- a/plotly/tools.py +++ b/plotly/tools.py @@ -6051,7 +6051,8 @@ def create_distplot(hist_data, group_labels, @staticmethod def create_dendrogram(X, orientation="bottom", labels=None, - colorscale=None): + colorscale=None, distfun=scs.distance.pdist, + linkagefun=lambda x: sch.linkage(x, 'complete')): """ BETA function that returns a dendrogram Plotly figure object. @@ -6059,6 +6060,9 @@ def create_dendrogram(X, orientation="bottom", labels=None, :param (str) orientation: 'top', 'right', 'bottom', or 'left' :param (list) labels: List of axis category labels(observation labels) :param (list) colorscale: Optional colorscale for dendrogram tree + :param (function) distfun: Function to compute the pairwise distance from the observations + :param (function) linkagefun: Function to compute the linkage matrix from the pairwise distances + clusters Example 1: Simple bottom oriented dendrogram @@ -6114,7 +6118,8 @@ def create_dendrogram(X, orientation="bottom", labels=None, if len(s) != 2: exceptions.PlotlyError("X should be 2-dimensional array.") - dendrogram = _Dendrogram(X, orientation, labels, colorscale) + dendrogram = _Dendrogram(X, orientation, labels, colorscale, + distfun=distfun, linkagefun=linkagefun) return {'layout': dendrogram.layout, 'data': dendrogram.data} @@ -7041,7 +7046,8 @@ class _Dendrogram(FigureFactory): """Refer to FigureFactory.create_dendrogram() for docstring.""" def __init__(self, X, orientation='bottom', labels=None, colorscale=None, - width="100%", height="100%", xaxis='xaxis', yaxis='yaxis'): + width="100%", height="100%", xaxis='xaxis', yaxis='yaxis', + distfun=scs.distance.pdist, linkagefun=lambda x: sch.linkage(x, 'complete')): # TODO: protected until #282 from plotly.graph_objs import graph_objs self.orientation = orientation @@ -7064,7 +7070,7 @@ def __init__(self, X, orientation='bottom', labels=None, colorscale=None, self.sign[self.yaxis] = -1 (dd_traces, xvals, yvals, - ordered_labels, leaves) = self.get_dendrogram_traces(X, colorscale) + ordered_labels, leaves) = self.get_dendrogram_traces(X, colorscale, distfun, linkagefun) self.labels = ordered_labels self.leaves = leaves @@ -7173,12 +7179,14 @@ def set_figure_layout(self, width, height): return self.layout - def get_dendrogram_traces(self, X, colorscale): + def get_dendrogram_traces(self, X, colorscale, distfun, linkagefun): """ Calculates all the elements needed for plotting a dendrogram. :param (ndarray) X: Matrix of observations as array of arrays :param (list) colorscale: Color scale for dendrogram tree clusters + :param (function) distfun: Function to compute the pairwise distance from the observations + :param (function) linkagefun: Function to compute the linkage matrix from the pairwise distances :rtype (tuple): Contains all the traces in the following order: (a) trace_list: List of Plotly trace objects for dendrogram tree (b) icoord: All X points of the dendrogram tree as array of arrays @@ -7192,8 +7200,8 @@ def get_dendrogram_traces(self, X, colorscale): """ # TODO: protected until #282 from plotly.graph_objs import graph_objs - d = scs.distance.pdist(X) - Z = sch.linkage(d, method='complete') + d = distfun(X) + Z = linkagefun(d) P = sch.dendrogram(Z, orientation=self.orientation, labels=self.labels, no_plot=True)