diff --git a/CHANGELOG.md b/CHANGELOG.md
index 346dd66aa81..3c5106fdee1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,10 @@ This project adheres to [Semantic Versioning](http://semver.org/).
- Added some rounding to the `make_subplots` function to handle situations where the user-input specs cause the domain to exceed 1 by small amounts [[#4153](https://github.com/plotly/plotly.py/pull/4153)]
- Sanitize JSON output to prevent an XSS vector when graphs are inserted directly into HTML [[#4196](https://github.com/plotly/plotly.py/pull/4196)]
+### Added
+
+ - Added implementation of [UpSet plots](https://en.wikipedia.org/wiki/UpSet_Plot) in `plotly.figure_factory` via the `create_upset` method [[#4204](https://github.com/plotly/plotly.py/pull/4204)]
+
## [5.14.1] - 2023-04-05
### Fixed
diff --git a/doc/python/upset-plots.md b/doc/python/upset-plots.md
new file mode 100644
index 00000000000..0b9b454a346
--- /dev/null
+++ b/doc/python/upset-plots.md
@@ -0,0 +1,142 @@
+---
+jupyter:
+ jupytext:
+ notebook_metadata_filter: all
+ text_representation:
+ extension: .md
+ format_name: markdown
+ format_version: '1.1'
+ jupytext_version: 1.2.3
+ kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+ language_info:
+ codemirror_mode:
+ name: ipython
+ version: 3
+ file_extension: .py
+ mimetype: text/x-python
+ name: python
+ nbconvert_exporter: python
+ pygments_lexer: ipython3
+ version: 3.7.3
+ plotly:
+ description: How to make an UpSet plot in Python, which can be used to display counts of
+arbitrarily complex set intersections.
+ display_as: scientific
+ language: python
+ layout: base
+ name: UpSet Plots
+ order: 10
+ permalink: python/upset-plots/
+---
+
+[UpSet plots](https://en.wikipedia.org/wiki/UpSet_Plot) allow you to visualize data that counts different intersections
+subsets inside a set. This could arise by actual intersections, or counting tag occurrences on data which need not be
+disjoint. Data used in this method must be in one of two forms: wide or condensed. If the latter is provided, then the
+data will be transformed into the wide format before proceeding with the plot generation.
+
+#### A Simple UpSet Plot
+```python
+import plotly.express as px
+import plotly.figure_factory as ff
+
+df = px.data.iris()
+
+# Create categorical non-disjoint tags for "large" features
+df['SL'] = df['sepal_length'].apply(lambda x: int(x > 6))
+df['SW'] = df['sepal_width'].apply(lambda x: int(x > 3))
+df['PL'] = df['petal_length'].apply(lambda x: int(x > 3))
+df['PW'] = df['petal_width'].apply(lambda x: int(x > 1))
+
+df = df[['SL', 'SW', 'PL', 'PW']] # data in "wide" form
+fig = ff.create_upset(df, color_discrete_sequence=['#000000'])
+fig.show()
+```
+
+
+#### Using Condensed Format
+
+Sometimes it's more convenient to have data where one column is given as a list of (possibly) overlapping tags that data
+point has. This can be thought of as dividing our dataset into a family of subsets, one for each tag. UpSet plots can help
+analyze how different combinations of these tags are distributed in the data.
+
+As long as the entries in this column are a list/tuple, this method can handle the preprocessing step of getting the
+data into the "wide" format like above. We simulate this below.
+
+```python
+import plotly.express as px
+import plotly.figure_factory as ff
+
+df = px.data.iris()
+
+# Create categorical non-disjoint tags for "large" features
+df['SL'] = df['sepal_length'].apply(lambda x: int(x > 6))
+df['SW'] = df['sepal_width'].apply(lambda x: int(x > 3))
+df['PL'] = df['petal_length'].apply(lambda x: int(x > 3))
+df['PW'] = df['petal_width'].apply(lambda x: int(x > 1))
+
+# Simulate "tags" column
+df['tags'] = df['sepal_length'].apply(lambda x: ['SL'] if x > 6 else ['']) + df['sepal_width'].apply(lambda x: ['SW'] if x > 3 else ['']) + \
+ df['petal_length'].apply(lambda x: ['PL'] if x > 3 else ['']) + df['petal_width'].apply(lambda x: ['PW'] if x > 1 else [''])
+df['tags'] = df['tags'].apply(lambda x: [y for y in x if y != ''])
+
+# Note we can (optionally) choose the order for how the method unpacks the tags
+fig = ff.create_upset(df, subset_column='tags', subset_order=['PW', 'SW', 'PL', 'SL'], color_discrete_sequence=['#000000'])
+fig.show()
+```
+
+#### Grouping Data by Color
+
+This method supports two ways of grouping data to visualize counts of subset intersections. The first, shown here,
+allows you to see how these counts vary by subset in parallel across categories described by another column.
+
+```python
+import plotly.express as px
+import plotly.figure_factory as ff
+
+df = px.data.iris()
+
+# Create categorical non-disjoint tags for "large" features
+df['SL'] = df['sepal_length'].apply(lambda x: int(x > 6))
+df['SW'] = df['sepal_width'].apply(lambda x: int(x > 3))
+df['PL'] = df['petal_length'].apply(lambda x: int(x > 3))
+df['PW'] = df['petal_width'].apply(lambda x: int(x > 1))
+
+df = df[['species', 'SL', 'SW', 'PL', 'PW']] # data in "wide" form, with extra "species" column
+# Note: ONLY the extra color column was kept, as rest of columns are inferred to make "wide" format subset data
+fig = ff.create_upset(df, color='species', asc=True) # Can toggle in "asc" order
+fig.show()
+```
+
+#### Visualizing Distributions of Counts by Subset
+
+The other way to group data is to provide a column which provides label for different clusters of observations. This
+could be e.g. the day of the observation, different samples in biology, or any other way of dividing up the same
+observations in different situations. This technique lets you see how the different subset counts vary across this
+dimension.
+
+```python
+import plotly.express as px
+import plotly.figure_factory as ff
+import numpy as np
+
+df = px.data.iris()
+
+# Create categorical non-disjoint tags for "large" features
+df['SL'] = df['sepal_length'].apply(lambda x: int(x > 6))
+df['SW'] = df['sepal_width'].apply(lambda x: int(x > 3))
+df['PL'] = df['petal_length'].apply(lambda x: int(x > 3))
+df['PW'] = df['petal_width'].apply(lambda x: int(x > 1))
+df = df[['SL', 'SW', 'PL', 'PW']]
+
+# Simulate random "day" of observation
+np.random.seed(100)
+df['day'] = np.random.randint(0, 5, len(df))
+
+fig = ff.create_upset(df, x='day', plot_type='box', show_yaxis=True, title='Variation of Tags by Day')
+fig.update_layout(yaxis_side="right")
+fig.show()
+```
+
diff --git a/packages/python/plotly/plotly/figure_factory/__init__.py b/packages/python/plotly/plotly/figure_factory/__init__.py
index 0a41dca1ba2..4b82aeda542 100644
--- a/packages/python/plotly/plotly/figure_factory/__init__.py
+++ b/packages/python/plotly/plotly/figure_factory/__init__.py
@@ -25,6 +25,7 @@
from plotly.figure_factory._streamline import create_streamline
from plotly.figure_factory._table import create_table
from plotly.figure_factory._trisurf import create_trisurf
+from plotly.figure_factory._upset import create_upset
from plotly.figure_factory._violin import create_violin
if optional_imports.get_module("pandas") is not None:
@@ -65,5 +66,6 @@ def create_ternary_contour(*args, **kwargs):
"create_table",
"create_ternary_contour",
"create_trisurf",
+ "create_upset",
"create_violin",
]
diff --git a/packages/python/plotly/plotly/figure_factory/_upset.py b/packages/python/plotly/plotly/figure_factory/_upset.py
new file mode 100644
index 00000000000..8f02659e90f
--- /dev/null
+++ b/packages/python/plotly/plotly/figure_factory/_upset.py
@@ -0,0 +1,576 @@
+from __future__ import absolute_import
+
+from plotly import optional_imports
+import plotly.graph_objects as go
+import plotly.express as px
+
+pd = optional_imports.get_module("pandas")
+np = optional_imports.get_module("numpy")
+
+VALID_PLOT_TYPES = ["bar", "box", "violin"]
+
+
+def create_upset(
+ data_frame,
+ x=None,
+ color=None,
+ title=None,
+ plot_type="bar",
+ sort_by="Counts",
+ asc=False,
+ mode="Counts",
+ max_subsets=20,
+ subset_column=None,
+ subset_order=None,
+ subset_bgcolor="#C9C9C9",
+ subset_fgcolor="#000000",
+ category_orders=None,
+ color_discrete_sequence=None,
+ color_discrete_map=None,
+ log_y=False,
+ show_yaxis=False,
+ barmode="group",
+ textangle=0,
+ boxmode="group",
+ points="outliers",
+ notched=False,
+ violinmode="group",
+ box=False,
+):
+ """
+ Creates an UpSet plot, a scalable alternative to Venn diagrams. The interface supports a flexible range of use cases
+ input data formats.
+
+ :param (pandas.DataFrame) data_frame: a DataFrame either in wide format with subset/intersection inclusion data, or
+ with a column in condensed format; see the tutorial for more details
+ :param (str) x: (optional) column name in data_frame for data point labels, e.g. sample name to cluster intersection
+ observations by
+ :param (str) color: (optional) column name in data_frame for grouping intersection counts, similar to plotly.express
+ inputs
+ :param (str) title: (optional) title for plot
+ :param (str) plot_type: (default="bar") type of plot to visualize intersection count data; must be one of "bar", "box", or "violin";
+ the latter two should only be used if x is provided, in which case they represent the distribution of intersection
+ counts (across color groups)
+ :param (str) sort_by: (default="Counts") order in which counts are displayed; must be one of "Counts" or "Intersections";
+ ignored if color is provided
+ :param (bool) asc: (default=False) sort in ascending order
+ :param (str) mode: (default="Counts") how to represent counts; must be one of "Counts" or "Percent"
+ :param (int) max_subsets: (default=20) maximum number of intersection subsets to display
+ :param (str) subset_column: (optional) if data is formatted in condensed form, input column name here with that data;
+ do not use if data is already formatted in wide format
+ :param (list) subset_order: (optional) if subset_column is provided, use this list of entries to specify order of labels
+ :param (str) subset_bgcolor: (default="#C9C9C9") color for background dots on switchboard
+ :param (str) subset_fgcolor: (default="#000000") color for foreground dots on switchboard
+ :param (dict) category_orders: (optional) specify order for groups in color, as in plotly.express inputs
+ :param (list) color_discrete_sequence: (optional) list of colors to use for color input, as in plotly.express inputs
+ :param (dict) color_discrete_map: (optional) map of color categories to color, as in plotly.express inputs
+ :param (bool) log_y: (default=False) use logarithmic y scale
+ :param (bool) show_yaxis: (default=False) show y-axis tickmarks
+ :param (str) barmode: (default="group") argument passed to plotly.express.bar when selected for plotting
+ :param (int) textangle: (default=0) angle to use when displaying counts above bars in a bar chart
+ :param (str) boxmode: (default="group") argument passed to plotly.express.box when selected for plotting
+ :param (str) points: (default="outliers") argument passed to plotly.express.box when selected for plotting
+ :param (bool) notched: (default=False) argument passed to plotly.express.box when selected for plotting
+ :param (str) violinmode: (default="group") argument passed to plotly.express.violin when selected for plotting
+ :param (bool) box: (default=False) argument passed to plotly.express.violin when selected for plotting
+
+ :rtype (plotly.graph_objects.Figure): returns UpSet plot rendered according to input settings.
+
+ Example 1: Simple Counts
+
+ >>> import plotly.express as px
+ >>> import plotly.figure_factory as ff
+
+ >>> df = px.data.iris()
+ >>> # Create 4 subsets defined by qualitative "large" conditions
+ >>> df['SL'] = df['sepal_length'].apply(lambda x: int(x > 6))
+ >>> df['SW'] = df['sepal_width'].apply(lambda x: int(x > 3))
+ >>> df['PL'] = df['petal_length'].apply(lambda x: int(x > 3))
+ >>> df['PW'] = df['petal_width'].apply(lambda x: int(x > 1))
+
+ >>> df = df[['species', 'SL', 'SW', 'PL', 'PW']]
+ >>> # Only use columns with inclusion in subset (0/1) values for this example
+ >>> fig = ff.create_upset(df.drop(columns=['species']), color_discrete_sequence=['#000000'])
+ >>> fig.show()
+
+ Example 2: Counting by Group
+
+ >>> # Continued from Example 1
+ >>> fig = ff.create_upset(df, color='species', asc=True)
+ >>> fig.show()
+
+ Example 3: Tracking Variance of Counts Across a Category
+
+ >>> # Continued from Example 1
+ >>> import numpy as np
+
+ >>> np.random.seed(100)
+ >>> # Add a dummy variable for "day entry was observed" to track variation of subset counts across the days
+ >>> df['day'] = np.random.randint(0, 5, len(df))
+ >>> fig = ff.create_upset(df.drop(columns=['species']), x='day', plot_type='box', show_yaxis=True)
+ >>> fig.update_layout(yaxis_side="right")
+ >>> fig.show()
+ """
+ plot_obj = _Upset(**locals())
+ upset_plot = plot_obj.make_upset_plot()
+ return upset_plot
+
+
+def _expand_subset_column(df, subset_column, subset_order=None):
+ """
+ Takes a column of iterables and expands into binary columns representing inclusion. Also returns subset_names.
+ """
+ subset_names = (
+ subset_order
+ if subset_order is not None
+ else [
+ x for x in df[subset_column].explode().unique() if not pd.isnull(x)
+ ] # Remove empty subset = NaN
+ )
+ new_df = df.copy()
+ for name in subset_names:
+ new_df[name] = new_df[subset_column].apply(lambda x: int(name in x))
+ new_df = new_df[subset_names]
+ return new_df, subset_names
+
+
+def _transform_upset_data(df):
+ """
+ Takes raw data of binary vectors for set inclusion and produces counts over each.
+ """
+ intersect_counts = pd.DataFrame(
+ {
+ "Intersections": list(df.value_counts().to_dict().keys()),
+ "Counts": list(df.value_counts().to_dict().values()),
+ }
+ )
+ return intersect_counts
+
+
+def _make_binary(t):
+ """
+ Converts tuple of 0,1s to binary number. Used in _transform_upset_data for sort order.
+ """
+ return sum([t[i] * 2**i for i in range(len(t))])
+
+
+def _sort_intersect_counts(df, sort_by="Counts", asc=True):
+ """
+ Takes output from _transform_upset_data and sorts by method requested.
+ """
+ key = (
+ None
+ if (sort_by == "Counts")
+ else lambda x: x.apply(lambda y: (sum(y), _make_binary(y)))
+ )
+ df = df.sort_values(by=sort_by, key=key, ascending=asc)
+ return df
+
+
+class _Upset:
+ """
+ Represents builder object for UpSet plot. Refer to figure_factory.create_upset() for full docstring.
+ """
+
+ def __init__(
+ self,
+ data_frame,
+ x=None,
+ color=None,
+ title=None,
+ plot_type="bar",
+ sort_by="Counts",
+ asc=False,
+ mode="Counts",
+ max_subsets=20,
+ subset_column=None,
+ subset_order=None,
+ subset_bgcolor="#C9C9C9",
+ subset_fgcolor="#000000",
+ category_orders=None,
+ color_discrete_sequence=None,
+ color_discrete_map=None,
+ log_y=False,
+ show_yaxis=False,
+ barmode="group",
+ textangle=0,
+ boxmode="group",
+ points="outliers",
+ notched=False,
+ violinmode="group",
+ box=False,
+ ):
+
+ # Plot inputs and settings
+ self.df = data_frame
+ self.x = x
+ self.color = color
+ self.title = title
+ self.plot_type = plot_type
+ self.sort_by = sort_by
+ self.asc = asc
+ self.mode = mode
+ self.max_subsets = max_subsets
+ self.subset_column = subset_column
+ self.subset_order = subset_order
+ self.subset_bgcolor = subset_bgcolor
+ self.subset_fgcolor = subset_fgcolor
+ self.category_orders = category_orders
+ self.color_discrete_sequence = color_discrete_sequence
+ self.color_discrete_map = color_discrete_map
+ self.log_y = log_y
+ self.show_yaxis = show_yaxis
+ self.barmode = barmode
+ self.textangle = textangle
+ self.boxmode = boxmode
+ self.points = points
+ self.notched = notched
+ self.violinmode = violinmode
+ self.box = box
+
+ # Aggregate common plotting args
+ self.common_plot_args = {
+ "color": self.color,
+ "category_orders": self.category_orders,
+ "color_discrete_sequence": self.color_discrete_sequence,
+ "color_discrete_map": self.color_discrete_map,
+ "log_y": self.log_y,
+ }
+
+ # Collect plot specific args
+ self.bar_args = {
+ "barmode": self.barmode,
+ }
+
+ self.box_args = {
+ "boxmode": self.boxmode,
+ "points": self.points,
+ "notched": self.notched,
+ }
+
+ self.violin_args = {
+ "violinmode": self.violinmode,
+ "box": self.box,
+ "points": self.points,
+ }
+
+ # Figure-building specific attributes
+ self.fig = go.Figure()
+ self.intersect_counts = pd.DataFrame()
+ self.subset_names = None
+ self.switchboard_heights = []
+
+ # Validate inputs
+ self.validate_upset_inputs()
+
+ def make_upset_plot(self):
+ # If subset_column provided, expand into standard wider format
+ if self.subset_column is not None:
+ color_column = self.df[self.color] if self.color is not None else None
+ x_column = self.df[self.x] if self.x is not None else None
+ self.df, self.subset_names = _expand_subset_column(
+ self.df, self.subset_column, self.subset_order
+ )
+ if self.color is not None:
+ self.df = pd.concat([self.df, color_column], axis=1)
+ if self.x is not None:
+ self.df = pd.concat([self.df, x_column], axis=1)
+ else:
+ self.subset_names = [
+ c for c in self.df.columns if c != self.x and c != self.color
+ ]
+
+ # Create intersect_counts df depending on if color provided
+ groups = [x for x in [self.color, self.x] if x is not None]
+ if len(groups) > 0:
+ intersect_df = self.df.groupby(groups).apply(
+ lambda df: _transform_upset_data(df.drop(columns=groups)).reset_index(
+ drop=True
+ )
+ )
+
+ # Fill in counts for subsets where count is zero for certain color groups
+ filled_df = (
+ intersect_df.pivot_table(
+ index="Intersections",
+ columns=groups,
+ values="Counts",
+ fill_value=0,
+ )
+ .unstack()
+ .reset_index()
+ .rename(columns={0: "Counts"})
+ )
+
+ # Perform sorting within each color group
+ # WARNING: If sort_by="Counts" it will be ignored here since this won't make sense when using groups
+ # TODO: Make sensible behavior for sort by "Counts" in this case
+ self.intersect_counts = (
+ filled_df.groupby(groups)
+ .apply(
+ lambda df: _sort_intersect_counts(
+ df.drop(columns=groups),
+ sort_by="Intersections",
+ asc=self.asc,
+ ).reset_index()
+ )
+ .reset_index()
+ .drop(columns=["index"])
+ .rename(
+ columns={"level_1": "index", "level_2": "index"}
+ ) # Not sure how to tell the two apart...
+ )
+
+ # Truncate subsets if necessary
+ self.intersect_counts = self.intersect_counts.groupby(groups).head(
+ self.max_subsets
+ )
+
+ else:
+ self.intersect_counts = _transform_upset_data(self.df)
+ self.intersect_counts = _sort_intersect_counts(
+ self.intersect_counts, sort_by=self.sort_by, asc=self.asc
+ )
+ self.intersect_counts = self.intersect_counts.reset_index(
+ drop=True
+ ).reset_index()
+
+ self.intersect_counts = self.intersect_counts.head(self.max_subsets)
+
+ # Rescale for percents if requested
+ mode = self.mode
+ if mode == "Percent":
+ if self.color is not None:
+ denom = self.intersect_counts.groupby(self.color).sum().reset_index()
+ denom_dict = dict(zip(denom[self.color], denom["Counts"]))
+ self.intersect_counts["Counts"] = round(
+ self.intersect_counts["Counts"]
+ / self.intersect_counts[self.color].map(denom_dict),
+ 2,
+ )
+ else:
+ self.intersect_counts["Counts"] = round(
+ self.intersect_counts["Counts"]
+ / self.intersect_counts["Counts"].sum(),
+ 2,
+ )
+
+ # Create 3 main components for figure
+ self.make_primary_plot()
+ self.make_switchboard()
+ self.make_margin_plot()
+
+ # Add title
+ self.fig.update_layout(title=self.title, title_x=0.5)
+
+ return self.fig
+
+ def validate_upset_inputs(self):
+ # Check sorting inputs are valid
+ sort_by = self.sort_by
+ try:
+ assert (sort_by == "Counts") or (sort_by == "Intersections")
+ except AssertionError:
+ raise ValueError(
+ f'Invalid input for "sort_by". Must be either "Counts" or "Intersections" but you provided {sort_by}'
+ )
+
+ # Check mode is either Counts or Percent
+ mode = self.mode
+ try:
+ assert (mode == "Counts") or (mode == "Percent")
+ except AssertionError:
+ raise ValueError(
+ f'Invalid input for "mode". Must be either "Counts" or "Percent" but you provided {mode}'
+ )
+
+ # Check plot_type is valid
+ plot_type = self.plot_type
+ try:
+ assert plot_type in VALID_PLOT_TYPES
+ except AssertionError:
+ raise ValueError(
+ f'Invalid input for "plot_type". Must be one of "bar", "box", or "violin" but you provided {plot_type}'
+ )
+
+ def make_primary_plot(self):
+ plot_function = None
+ args = {}
+ update_traces = {}
+
+ if self.plot_type == "bar":
+ plot_function = px.bar
+ args = {**self.common_plot_args, **self.bar_args, "text": "Counts"}
+ update_traces = {
+ "textposition": "outside",
+ "cliponaxis": False,
+ "textangle": self.textangle,
+ }
+ elif self.plot_type == "box":
+ plot_function = px.box
+ args = {**self.common_plot_args, **self.box_args}
+ elif self.plot_type == "violin":
+ plot_function = px.violin
+ args = {**self.common_plot_args, **self.violin_args}
+
+ self.fig = plot_function(self.intersect_counts, x="index", y="Counts", **args)
+ self.fig.update_traces(**update_traces)
+
+ self.fig.update_layout(
+ plot_bgcolor="#FFFFFF",
+ xaxis_visible=False,
+ xaxis_showticklabels=False,
+ yaxis_visible=self.show_yaxis,
+ yaxis_showticklabels=self.show_yaxis,
+ )
+
+ def make_switchboard(self):
+ """
+ Method to add subset points to input fig px.bar chart in the style of UpSet plot.
+ Returns updated figure, and list of heights of dots for downstream convenience.
+ """
+ # Pull out full list of possible intersection combinations
+ intersections = list(self.intersect_counts["Intersections"].unique())
+
+ # Compute coordinates for bg subset scatter points
+ d = len(self.subset_names)
+ num_bars = len(intersections)
+ x_bg_scatter = np.repeat(range(num_bars), d)
+ y_scatter_offset = (
+ 0.2 # Offsetting ensures bars will hover just above the subset scatterplot
+ )
+ y_max = (1 + y_scatter_offset) * max([max(bar["y"]) for bar in self.fig.data])
+ self.switchboard_heights = [
+ -y_max / d * i - y_scatter_offset * y_max for i in list(range(d))
+ ]
+ y_bg_scatter = num_bars * self.switchboard_heights
+
+ # Add bg subset scatter points to figure below bar chart
+ labels = np.repeat(
+ [
+ "+".join([x for x, y in zip(self.subset_names, s) if y != 0])
+ for s in intersections
+ ],
+ d,
+ )
+ labels = ["None" if x == "" else x for x in labels]
+ self.fig.add_trace(
+ go.Scatter(
+ x=x_bg_scatter,
+ y=y_bg_scatter,
+ mode="markers",
+ showlegend=False,
+ marker=dict(size=16, color=self.subset_bgcolor, showscale=False),
+ text=labels,
+ hovertemplate="%{text}",
+ )
+ )
+ self.fig.update_layout(
+ xaxis=dict(showgrid=False, zeroline=False),
+ yaxis=dict(showgrid=True, zeroline=False),
+ margin=dict(t=40, l=40),
+ )
+
+ # Then fill in subset markers with fg color
+ x = 0
+ for s in intersections:
+ x_subsets = []
+ y_subsets = []
+ y = 0
+ for e in s:
+ if e:
+ x_subsets += [x]
+ y_subsets += [-y_max / d * y - y_scatter_offset * y_max]
+ y += 1
+ x += 1
+ self.fig.add_trace(
+ go.Scatter(
+ x=x_subsets,
+ y=y_subsets,
+ mode="markers+lines",
+ showlegend=False,
+ marker=dict(size=16, color=self.subset_fgcolor, showscale=False),
+ text=["+".join([x for x, y in zip(self.subset_names, s) if y != 0])]
+ * sum(s),
+ hovertemplate="%{text}",
+ )
+ )
+
+ def make_margin_plot(self):
+ """
+ Method to add left margin count px.bar chart in style of UpSet plot.
+ """
+ # Group and count according to inputs
+ color = self.color
+ groups = [x for x in [self.color, self.x] if x is not None]
+ # if len(groups) > 0:
+ # counts_df = self.df.groupby(groups).sum().reset_index()
+ if self.color is not None:
+ counts_df = self.df.groupby(self.color).sum().reset_index()
+ if self.x is not None:
+ counts_df = counts_df.drop(columns=[self.x])
+ else:
+ counts_df = (
+ self.df.sum()
+ .reset_index()
+ .rename(columns={"index": "variable", 0: "value"})
+ )
+
+ # Create counts px.bar chart
+ plot_df = (
+ counts_df.melt(id_vars=[self.color])
+ if self.color is not None
+ else counts_df
+ )
+ if self.mode == "Percent":
+ if color is not None:
+ denom = (
+ self.df.groupby(color)
+ .apply(lambda df: len(df))
+ .reset_index()
+ .rename(columns={0: "value"})
+ )
+ denom_dict = dict(zip(denom[color], denom["value"]))
+ plot_df["value"] = round(
+ plot_df["value"] / plot_df[color].map(denom_dict), 2
+ )
+ else:
+ plot_df["value"] = round(plot_df["value"] / len(self.df), 2)
+
+ plot_function = px.bar
+ update_traces = {"textposition": "outside", "cliponaxis": False}
+ args = {
+ **self.common_plot_args,
+ **self.bar_args,
+ "text": "value",
+ "hover_data": {"variable": False},
+ }
+
+ counts_fig = plot_function(
+ plot_df, x="value", y="variable", orientation="h", **args
+ )
+ counts_fig.update_traces(**update_traces)
+
+ # Add subset names as text into plot
+ max_name_len = max([len(s) for s in self.subset_names])
+ annotation_center = -1 + -0.01 * max_name_len
+ for i, s in enumerate(self.subset_names):
+ self.fig.add_annotation(
+ x=annotation_center,
+ y=self.switchboard_heights[i],
+ text=s,
+ showarrow=False,
+ font=dict(size=12, color="#000000"),
+ align="left",
+ )
+
+ # Reflect horizontally the bars while preserving labels; Shift heights to match input subset scatter heights
+ max_x = max([max(t["x"]) for t in counts_fig.data])
+ for trace in counts_fig.data:
+ trace["x"] = -trace["x"] / max_x
+ trace["y"] = self.switchboard_heights
+ counts_fig.update_traces(base=annotation_center - 1, showlegend=False)
+
+ # Add counts chart traces to main fig
+ for trace in counts_fig.data:
+ self.fig.add_trace(trace)
diff --git a/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py b/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py
index 8783dce1ab4..3a0ca92cd84 100644
--- a/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py
+++ b/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py
@@ -4516,3 +4516,58 @@ def test_build_dataframe(self):
assert len(fig6.frames) == n_frames
assert len(fig7.frames) == n_frames
assert fig6.data[0].geojson == fig1.data[0].geojson
+
+
+class TestUpset(NumpyTestUtilsMixin, TestCaseNoTemplate):
+ # Test compatibilities between using wide format input data vs condensed
+ def test_wide_vs_condensed(self):
+ np.random.seed(0)
+
+ a = np.random.randint(0, 2, 1000)
+ b = np.random.randint(0, 2, 1000)
+ c = np.random.randint(0, 2, 1000)
+ color = np.random.randint(0, 3, 1000).astype(str)
+
+ df = pd.DataFrame({"a": a, "b": b, "c": c, "color": color})
+ fig1 = ff.create_upset(df.drop(columns=["color"]))
+ fig2 = ff.create_upset(
+ df.drop(columns=["color"]), sort_by="Intersections", asc=False
+ )
+ fig3 = ff.create_upset(df, color="color")
+
+ for tag in ["a", "b", "c"]:
+ df[tag] = df[tag].map({1: [tag], 0: [""]})
+
+ df["tags"] = df["a"] + df["b"] + df["c"]
+ df["tags"] = df["tags"].apply(lambda x: [y for y in x if y != ""])
+
+ fig4 = ff.create_upset(
+ df.drop(columns=["a", "b", "c", "color"]),
+ subset_column="tags",
+ subset_order=["a", "b", "c"],
+ )
+ fig5 = ff.create_upset(
+ df.drop(columns=["a", "b", "c"]),
+ subset_column="tags",
+ subset_order=["a", "b", "c"],
+ sort_by="Intersections",
+ asc=False,
+ )
+ fig6 = ff.create_upset(
+ df.drop(columns=["a", "b", "c"]),
+ subset_column="tags",
+ subset_order=["a", "b", "c"],
+ color="color",
+ )
+
+ for data in zip(fig1.data, fig4.data):
+ self.assert_fig_equal(data[0], data[1])
+ self.assert_fig_equal(fig1.layout, fig4.layout)
+
+ for data in zip(fig2.data, fig5.data):
+ self.assert_fig_equal(data[0], data[1])
+ self.assert_fig_equal(fig2.layout, fig5.layout)
+
+ for data in zip(fig3.data, fig6.data):
+ self.assert_fig_equal(data[0], data[1])
+ self.assert_fig_equal(fig3.layout, fig6.layout)