diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c867c37d46..a175017ddc0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,8 @@ All notable changes to this project will be documented in this file. This project adheres to [Semantic Versioning](http://semver.org/). ## UNRELEASED + - added the upset plot type in figure factory for showing the intersection between non-mutually exclusive characteristics in a population (my forum post can be found here which details more of the functionality: https://community.plotly.com/t/plotly-upset-plot/63858) + ### Updated - Updated Plotly.js to from version 2.12.1 to version 2.13.3. See the [plotly.js CHANGELOG](https://github.com/plotly/plotly.js/blob/master/CHANGELOG.md#2133----2022-07-25) for more information. Notable changes include: diff --git a/packages/python/plotly/plotly/figure_factory/__init__.py b/packages/python/plotly/plotly/figure_factory/__init__.py index 0a41dca1ba2..7dede8026e8 100644 --- a/packages/python/plotly/plotly/figure_factory/__init__.py +++ b/packages/python/plotly/plotly/figure_factory/__init__.py @@ -26,6 +26,7 @@ from plotly.figure_factory._table import create_table from plotly.figure_factory._trisurf import create_trisurf from plotly.figure_factory._violin import create_violin +from plotly.figure_factory._upset import create_upset if optional_imports.get_module("pandas") is not None: from plotly.figure_factory._county_choropleth import create_choropleth @@ -66,4 +67,5 @@ def create_ternary_contour(*args, **kwargs): "create_ternary_contour", "create_trisurf", "create_violin", + "create_upset" ] diff --git a/packages/python/plotly/plotly/figure_factory/_upset.py b/packages/python/plotly/plotly/figure_factory/_upset.py new file mode 100644 index 00000000000..acad116410f --- /dev/null +++ b/packages/python/plotly/plotly/figure_factory/_upset.py @@ -0,0 +1,81 @@ +from __future__ import absolute_import + +from numbers import Number +import itertools +from plotly import exceptions, optional_imports +import plotly.colors as clrs +from plotly.graph_objs import graph_objs +from plotly.subplots import make_subplots + +pd = optional_imports.get_module("pandas") +np = optional_imports.get_module("numpy") + + +def create_upset(df, include_empty_set = False, max_width = 50): + # an array of dimensions d x d*2^d possible subsets where d is the number of columns + subsets = [] + # the sizes of each subset (2^d array) + subset_sizes = [ ] + d = len(df.columns) + for i in range(1, d + 1): + subsets = subsets + [list(x) for x in list(itertools.combinations(df.columns, i))] + if include_empty_set: subsets = subsets + [[]] + + for s in subsets: + curr_bool = [1]*len(df) + for col in df.columns: + if col in s: curr_bool = [x and y for x, y in zip(curr_bool, list(df.loc[:, col].copy()))] + else: curr_bool = [x and not y for x, y in zip(curr_bool, list(df.loc[:, col].copy()))] + subset_sizes.append(sum(curr_bool)) + + + plot_df = pd.DataFrame({'Intersection': subsets, 'Size':subset_sizes}) + plot_df = plot_df.sort_values(by = 'Size', ascending = False) + max_y = max(plot_df['Size'])+0.1*max(plot_df['Size']) + + if not max_width is None and len(plot_df) > max_width: plot_df = plot_df.iloc[0:max_width,:] + + subsets = list(plot_df['Intersection']) + scatter_x = [] + scatter_y = [] + for i, s in enumerate(subsets): + for j in range(d): + scatter_x.append(i) + scatter_y.append(-j*max_y/d-0.1*max_y) + + fig = graph_objs.Figure() +# fig.add_trace(graph_objs.Scatter(x=[-1.2,len(subsets)],y= [max_y+0.1*max_y,max_y+0.1*max_y],fill='tozeroy')) + template = ['' for x in scatter_x] + fig.add_trace(graph_objs.Scatter(x = scatter_x, y = scatter_y, mode = 'markers', showlegend=False, marker=dict(size=16,color='#C9C9C9'), hovertemplate = template)) + fig.update_layout(xaxis=dict(showgrid=False, zeroline=False), + yaxis=dict(showgrid=True, zeroline=False), + plot_bgcolor = "#FFFFFF", margin=dict(t=40, l=150)) + for i, s in enumerate(subsets): + scatter_x_has = [] + scatter_y_has = [] + for j in range(d): + if df.columns[j] in s: + scatter_x_has.append(i) + scatter_y_has.append(-j*max_y/d-0.1*max_y) + fig.add_trace(graph_objs.Scatter(x = scatter_x_has, y = scatter_y_has, mode = 'markers+lines', showlegend=False, marker=dict(size=16,color='#000000',showscale=False), hovertemplate = template)) + fig.update_xaxes(showticklabels=False) # Hide x axis ticks + fig.update_yaxes(showticklabels=False) # Hide y axis ticks + fig.update_traces(hoverinfo=None) + + plot_df['Intersection'] = ['+'.join(x) for x in plot_df['Intersection']] + template = [f'
{lab}
N-Count: {n}
' for lab, n in zip(plot_df['Intersection'], plot_df['Size'])] + bar = graph_objs.Bar(x = list(range(len(subsets))), y = plot_df['Size'], marker = dict(color='#000000'), text = plot_df['Size'], hovertemplate = template, textposition='outside', hoverinfo='none') + fig.add_trace(bar) + + template = ['' for x in range(d)] + max_string_len = max([len(x) for x in df.columns]) + print(max_string_len**2) + + ### the adjusment of the x range to accomodate labels probably needs work + fig_lab = graph_objs.Scatter(x = [-0.02*max_string_len]*d, y = scatter_y, text = df.columns, mode = 'text', textposition='middle left',showlegend=False, hovertemplate = template) + fig_lab = graph_objs.Scatter(x = [-0.02*max_string_len]*d, y = scatter_y, text = df.columns, mode = 'text', textposition='middle left',showlegend=False, hovertemplate = template) + fig.add_trace(fig_lab) + fig.update_layout(title = 'Intersections', yaxis_range=[-max_y-0.1*max_y-1, max_y+0.1*max_y], xaxis_range = [-0.015*max_string_len**1.3, len(subsets)], showlegend = False, title_x=0.5) + + return fig + \ No newline at end of file diff --git a/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py b/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py index 8783dce1ab4..e062b4c0af1 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py @@ -4516,3 +4516,34 @@ def test_build_dataframe(self): assert len(fig6.frames) == n_frames assert len(fig7.frames) == n_frames assert fig6.data[0].geojson == fig1.data[0].geojson + +class TestUpsetPlot(): + + def test_defaults(self): + df = pd.DataFrame() + df['Attr 1'] = list(np.random.binomial(1, 0.3, size=1000)) + df['Attr 2'] = list(np.random.binomial(1, 0.3, size=1000)) + df['Attr 3'] = list(np.random.binomial(1, 0.8, size=1000)) + df['Attr 4'] = list(np.random.binomial(1, 0.05, size=1000)) + df['Attr 5'] = list(np.random.binomial(1, 0.1, size=1000)) + fig = ff.create_upset(df) + + + def test_max_width(self): + df = pd.DataFrame() + df['Attr 1'] = list(np.random.binomial(1, 0.3, size=1000)) + df['Attr 2'] = list(np.random.binomial(1, 0.3, size=1000)) + df['Attr 3'] = list(np.random.binomial(1, 0.8, size=1000)) + df['Attr 4'] = list(np.random.binomial(1, 0.05, size=1000)) + df['Attr 5'] = list(np.random.binomial(1, 0.1, size=1000)) + fig = ff.create_upset(df, max_width = 10) + + def test_empty_set_true(self): + df = pd.DataFrame() + df['Attr 1'] = list(np.random.binomial(1, 0.3, size=1000)) + df['Attr 2'] = list(np.random.binomial(1, 0.3, size=1000)) + df['Attr 3'] = list(np.random.binomial(1, 0.8, size=1000)) + df['Attr 4'] = list(np.random.binomial(1, 0.05, size=1000)) + df['Attr 5'] = list(np.random.binomial(1, 0.1, size=1000)) + fig = ff.create_upset(df, include_empty_set = True, max_width = 10) + \ No newline at end of file