diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0c867c37d46..a175017ddc0 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -3,6 +3,8 @@ All notable changes to this project will be documented in this file.
This project adheres to [Semantic Versioning](http://semver.org/).
## UNRELEASED
+ - added the upset plot type in figure factory for showing the intersection between non-mutually exclusive characteristics in a population (my forum post can be found here which details more of the functionality: https://community.plotly.com/t/plotly-upset-plot/63858)
+
### Updated
- Updated Plotly.js to from version 2.12.1 to version 2.13.3. See the [plotly.js CHANGELOG](https://github.com/plotly/plotly.js/blob/master/CHANGELOG.md#2133----2022-07-25) for more information. Notable changes include:
diff --git a/packages/python/plotly/plotly/figure_factory/__init__.py b/packages/python/plotly/plotly/figure_factory/__init__.py
index 0a41dca1ba2..7dede8026e8 100644
--- a/packages/python/plotly/plotly/figure_factory/__init__.py
+++ b/packages/python/plotly/plotly/figure_factory/__init__.py
@@ -26,6 +26,7 @@
from plotly.figure_factory._table import create_table
from plotly.figure_factory._trisurf import create_trisurf
from plotly.figure_factory._violin import create_violin
+from plotly.figure_factory._upset import create_upset
if optional_imports.get_module("pandas") is not None:
from plotly.figure_factory._county_choropleth import create_choropleth
@@ -66,4 +67,5 @@ def create_ternary_contour(*args, **kwargs):
"create_ternary_contour",
"create_trisurf",
"create_violin",
+ "create_upset"
]
diff --git a/packages/python/plotly/plotly/figure_factory/_upset.py b/packages/python/plotly/plotly/figure_factory/_upset.py
new file mode 100644
index 00000000000..acad116410f
--- /dev/null
+++ b/packages/python/plotly/plotly/figure_factory/_upset.py
@@ -0,0 +1,81 @@
+from __future__ import absolute_import
+
+from numbers import Number
+import itertools
+from plotly import exceptions, optional_imports
+import plotly.colors as clrs
+from plotly.graph_objs import graph_objs
+from plotly.subplots import make_subplots
+
+pd = optional_imports.get_module("pandas")
+np = optional_imports.get_module("numpy")
+
+
+def create_upset(df, include_empty_set = False, max_width = 50):
+ # an array of dimensions d x d*2^d possible subsets where d is the number of columns
+ subsets = []
+ # the sizes of each subset (2^d array)
+ subset_sizes = [ ]
+ d = len(df.columns)
+ for i in range(1, d + 1):
+ subsets = subsets + [list(x) for x in list(itertools.combinations(df.columns, i))]
+ if include_empty_set: subsets = subsets + [[]]
+
+ for s in subsets:
+ curr_bool = [1]*len(df)
+ for col in df.columns:
+ if col in s: curr_bool = [x and y for x, y in zip(curr_bool, list(df.loc[:, col].copy()))]
+ else: curr_bool = [x and not y for x, y in zip(curr_bool, list(df.loc[:, col].copy()))]
+ subset_sizes.append(sum(curr_bool))
+
+
+ plot_df = pd.DataFrame({'Intersection': subsets, 'Size':subset_sizes})
+ plot_df = plot_df.sort_values(by = 'Size', ascending = False)
+ max_y = max(plot_df['Size'])+0.1*max(plot_df['Size'])
+
+ if not max_width is None and len(plot_df) > max_width: plot_df = plot_df.iloc[0:max_width,:]
+
+ subsets = list(plot_df['Intersection'])
+ scatter_x = []
+ scatter_y = []
+ for i, s in enumerate(subsets):
+ for j in range(d):
+ scatter_x.append(i)
+ scatter_y.append(-j*max_y/d-0.1*max_y)
+
+ fig = graph_objs.Figure()
+# fig.add_trace(graph_objs.Scatter(x=[-1.2,len(subsets)],y= [max_y+0.1*max_y,max_y+0.1*max_y],fill='tozeroy'))
+ template = ['' for x in scatter_x]
+ fig.add_trace(graph_objs.Scatter(x = scatter_x, y = scatter_y, mode = 'markers', showlegend=False, marker=dict(size=16,color='#C9C9C9'), hovertemplate = template))
+ fig.update_layout(xaxis=dict(showgrid=False, zeroline=False),
+ yaxis=dict(showgrid=True, zeroline=False),
+ plot_bgcolor = "#FFFFFF", margin=dict(t=40, l=150))
+ for i, s in enumerate(subsets):
+ scatter_x_has = []
+ scatter_y_has = []
+ for j in range(d):
+ if df.columns[j] in s:
+ scatter_x_has.append(i)
+ scatter_y_has.append(-j*max_y/d-0.1*max_y)
+ fig.add_trace(graph_objs.Scatter(x = scatter_x_has, y = scatter_y_has, mode = 'markers+lines', showlegend=False, marker=dict(size=16,color='#000000',showscale=False), hovertemplate = template))
+ fig.update_xaxes(showticklabels=False) # Hide x axis ticks
+ fig.update_yaxes(showticklabels=False) # Hide y axis ticks
+ fig.update_traces(hoverinfo=None)
+
+ plot_df['Intersection'] = ['+'.join(x) for x in plot_df['Intersection']]
+ template = [f'
{lab}
N-Count: {n}' for lab, n in zip(plot_df['Intersection'], plot_df['Size'])]
+ bar = graph_objs.Bar(x = list(range(len(subsets))), y = plot_df['Size'], marker = dict(color='#000000'), text = plot_df['Size'], hovertemplate = template, textposition='outside', hoverinfo='none')
+ fig.add_trace(bar)
+
+ template = ['' for x in range(d)]
+ max_string_len = max([len(x) for x in df.columns])
+ print(max_string_len**2)
+
+ ### the adjusment of the x range to accomodate labels probably needs work
+ fig_lab = graph_objs.Scatter(x = [-0.02*max_string_len]*d, y = scatter_y, text = df.columns, mode = 'text', textposition='middle left',showlegend=False, hovertemplate = template)
+ fig_lab = graph_objs.Scatter(x = [-0.02*max_string_len]*d, y = scatter_y, text = df.columns, mode = 'text', textposition='middle left',showlegend=False, hovertemplate = template)
+ fig.add_trace(fig_lab)
+ fig.update_layout(title = 'Intersections', yaxis_range=[-max_y-0.1*max_y-1, max_y+0.1*max_y], xaxis_range = [-0.015*max_string_len**1.3, len(subsets)], showlegend = False, title_x=0.5)
+
+ return fig
+
\ No newline at end of file
diff --git a/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py b/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py
index 8783dce1ab4..e062b4c0af1 100644
--- a/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py
+++ b/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py
@@ -4516,3 +4516,34 @@ def test_build_dataframe(self):
assert len(fig6.frames) == n_frames
assert len(fig7.frames) == n_frames
assert fig6.data[0].geojson == fig1.data[0].geojson
+
+class TestUpsetPlot():
+
+ def test_defaults(self):
+ df = pd.DataFrame()
+ df['Attr 1'] = list(np.random.binomial(1, 0.3, size=1000))
+ df['Attr 2'] = list(np.random.binomial(1, 0.3, size=1000))
+ df['Attr 3'] = list(np.random.binomial(1, 0.8, size=1000))
+ df['Attr 4'] = list(np.random.binomial(1, 0.05, size=1000))
+ df['Attr 5'] = list(np.random.binomial(1, 0.1, size=1000))
+ fig = ff.create_upset(df)
+
+
+ def test_max_width(self):
+ df = pd.DataFrame()
+ df['Attr 1'] = list(np.random.binomial(1, 0.3, size=1000))
+ df['Attr 2'] = list(np.random.binomial(1, 0.3, size=1000))
+ df['Attr 3'] = list(np.random.binomial(1, 0.8, size=1000))
+ df['Attr 4'] = list(np.random.binomial(1, 0.05, size=1000))
+ df['Attr 5'] = list(np.random.binomial(1, 0.1, size=1000))
+ fig = ff.create_upset(df, max_width = 10)
+
+ def test_empty_set_true(self):
+ df = pd.DataFrame()
+ df['Attr 1'] = list(np.random.binomial(1, 0.3, size=1000))
+ df['Attr 2'] = list(np.random.binomial(1, 0.3, size=1000))
+ df['Attr 3'] = list(np.random.binomial(1, 0.8, size=1000))
+ df['Attr 4'] = list(np.random.binomial(1, 0.05, size=1000))
+ df['Attr 5'] = list(np.random.binomial(1, 0.1, size=1000))
+ fig = ff.create_upset(df, include_empty_set = True, max_width = 10)
+
\ No newline at end of file