Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable faceting for geo, geojson everywhere possible, text/symbols fo… #2923

Merged
merged 5 commits into from
Nov 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).
## [4.13.0] - UNRELEASED

### Added

- `px.choropleth`, `px.scatter_geo` and `px.line_geo` now support faceting as well as `fitbounds` and `basemap_visible` [2923](https://github.com/plotly/plotly.py/pull/2923)
- `px.scatter_geo` and `px.line_geo` now support `geojson`/`featureidkey` input [2923](https://github.com/plotly/plotly.py/pull/2923)
- `px.scatter_geo` now supports `symbol` [2923](https://github.com/plotly/plotly.py/pull/2923)
- `go.Figure` now has a `set_subplots` method to set subplots on an already
existing figure. [2866](https://github.com/plotly/plotly.py/pull/2866)
- Added `Turbo` colorscale and fancier swatch display functions
Expand All @@ -37,6 +39,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).

### Fixed

- `px.scatter_geo` support for `text` is fixed [2923](https://github.com/plotly/plotly.py/pull/2923)
- the `x` and `y` parameters of `px.imshow` are now used also in the case where
an Image trace is used (for RGB data or with `binary_string=True`). However,
only numerical values are accepted (while the Heatmap trace allows date or
Expand Down
23 changes: 22 additions & 1 deletion doc/python/facet-plots.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,26 @@ fig = px.histogram(df, x="total_bill", y="tip", color="sex", facet_row="time", f
fig.show()
```

### Choropleth Column Facets

*new in version 4.13*

```python
import plotly.express as px

df = px.data.election()
df = df.melt(id_vars="district", value_vars=["Coderre", "Bergeron", "Joly"],
var_name="candidate", value_name="votes")
geojson = px.data.election_geojson()

fig = px.choropleth(df, geojson=geojson, color="votes", facet_col="candidate",
locations="district", featureidkey="properties.district",
projection="mercator"
)
fig.update_geos(fitbounds="locations", visible=False)
fig.show()
```

### Adding Lines and Rectangles to Facet Plots

*introduced in plotly 4.12*
Expand Down Expand Up @@ -133,7 +153,8 @@ trace.update(legendgroup="trendline", showlegend=False)
fig.add_trace(trace, row="all", col="all", exclude_empty_subplots=True)

# set only the last trace added to appear in the legend
fig.data[-1].update(showlegend=True)
# `selector=-1` introduced in plotly v4.13
fig.update_traces(selector=-1, showlegend=True)
fig.show()
```

Expand Down
47 changes: 30 additions & 17 deletions packages/python/plotly/plotly/express/_chart_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,11 @@ def choropleth(
geojson=None,
featureidkey=None,
color=None,
facet_row=None,
facet_col=None,
facet_col_wrap=0,
facet_row_spacing=None,
facet_col_spacing=None,
hover_name=None,
hover_data=None,
custom_data=None,
Expand All @@ -955,6 +960,8 @@ def choropleth(
projection=None,
scope=None,
center=None,
fitbounds=None,
basemap_visible=None,
title=None,
template=None,
width=None,
Expand All @@ -967,13 +974,7 @@ def choropleth(
return make_figure(
args=locals(),
constructor=go.Choropleth,
trace_patch=dict(
locationmode=locationmode,
featureidkey=featureidkey,
geojson=geojson
if not hasattr(geojson, "__geo_interface__") # for geopandas
else geojson.__geo_interface__,
),
trace_patch=dict(locationmode=locationmode),
)


Expand All @@ -986,8 +987,16 @@ def scatter_geo(
lon=None,
locations=None,
locationmode=None,
geojson=None,
featureidkey=None,
color=None,
text=None,
symbol=None,
facet_row=None,
facet_col=None,
facet_col_wrap=0,
facet_row_spacing=None,
facet_col_spacing=None,
hover_name=None,
hover_data=None,
custom_data=None,
Expand All @@ -1001,11 +1010,15 @@ def scatter_geo(
color_continuous_scale=None,
range_color=None,
color_continuous_midpoint=None,
symbol_sequence=None,
symbol_map={},
opacity=None,
size_max=None,
projection=None,
scope=None,
center=None,
fitbounds=None,
basemap_visible=None,
title=None,
template=None,
width=None,
Expand All @@ -1031,9 +1044,16 @@ def line_geo(
lon=None,
locations=None,
locationmode=None,
geojson=None,
featureidkey=None,
color=None,
line_dash=None,
text=None,
facet_row=None,
facet_col=None,
facet_col_wrap=0,
facet_row_spacing=None,
facet_col_spacing=None,
hover_name=None,
hover_data=None,
custom_data=None,
Expand All @@ -1049,6 +1069,8 @@ def line_geo(
projection=None,
scope=None,
center=None,
fitbounds=None,
basemap_visible=None,
title=None,
template=None,
width=None,
Expand Down Expand Up @@ -1138,16 +1160,7 @@ def choropleth_mapbox(
In a Mapbox choropleth map, each row of `data_frame` is represented by a
colored region on a Mapbox map.
"""
return make_figure(
args=locals(),
constructor=go.Choroplethmapbox,
trace_patch=dict(
featureidkey=featureidkey,
geojson=geojson
if not hasattr(geojson, "__geo_interface__") # for geopandas
else geojson.__geo_interface__,
),
)
return make_figure(args=locals(), constructor=go.Choroplethmapbox)


choropleth_mapbox.__doc__ = make_docstring(choropleth_mapbox)
Expand Down
98 changes: 45 additions & 53 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,33 +616,27 @@ def configure_cartesian_axes(args, fig, orders):
if "is_timeline" in args:
fig.update_xaxes(type="date")

return fig.layout


def configure_ternary_axes(args, fig, orders):
fig.update_layout(
ternary=dict(
aaxis=dict(title_text=get_label(args, args["a"])),
baxis=dict(title_text=get_label(args, args["b"])),
caxis=dict(title_text=get_label(args, args["c"])),
)
fig.update_ternaries(
aaxis=dict(title_text=get_label(args, args["a"])),
baxis=dict(title_text=get_label(args, args["b"])),
caxis=dict(title_text=get_label(args, args["c"])),
)


def configure_polar_axes(args, fig, orders):
layout = dict(
polar=dict(
angularaxis=dict(direction=args["direction"], rotation=args["start_angle"]),
radialaxis=dict(),
)
patch = dict(
angularaxis=dict(direction=args["direction"], rotation=args["start_angle"]),
radialaxis=dict(),
)

for var, axis in [("r", "radialaxis"), ("theta", "angularaxis")]:
if args[var] in orders:
layout["polar"][axis]["categoryorder"] = "array"
layout["polar"][axis]["categoryarray"] = orders[args[var]]
patch[axis]["categoryorder"] = "array"
patch[axis]["categoryarray"] = orders[args[var]]

radialaxis = layout["polar"]["radialaxis"]
radialaxis = patch["radialaxis"]
if args["log_r"]:
radialaxis["type"] = "log"
if args["range_r"]:
Expand All @@ -652,21 +646,19 @@ def configure_polar_axes(args, fig, orders):
radialaxis["range"] = args["range_r"]

if args["range_theta"]:
layout["polar"]["sector"] = args["range_theta"]
fig.update(layout=layout)
patch["sector"] = args["range_theta"]
fig.update_polars(patch)


def configure_3d_axes(args, fig, orders):
layout = dict(
scene=dict(
xaxis=dict(title_text=get_label(args, args["x"])),
yaxis=dict(title_text=get_label(args, args["y"])),
zaxis=dict(title_text=get_label(args, args["z"])),
)
patch = dict(
xaxis=dict(title_text=get_label(args, args["x"])),
yaxis=dict(title_text=get_label(args, args["y"])),
zaxis=dict(title_text=get_label(args, args["z"])),
)

for letter in ["x", "y", "z"]:
axis = layout["scene"][letter + "axis"]
axis = patch[letter + "axis"]
if args["log_" + letter]:
axis["type"] = "log"
if args["range_" + letter]:
Expand All @@ -677,7 +669,7 @@ def configure_3d_axes(args, fig, orders):
if args[letter] in orders:
axis["categoryorder"] = "array"
axis["categoryarray"] = orders[args[letter]]
fig.update(layout=layout)
fig.update_scenes(patch)


def configure_mapbox(args, fig, orders):
Expand All @@ -687,23 +679,21 @@ def configure_mapbox(args, fig, orders):
lat=args["data_frame"][args["lat"]].mean(),
lon=args["data_frame"][args["lon"]].mean(),
)
fig.update_layout(
mapbox=dict(
accesstoken=MAPBOX_TOKEN,
center=center,
zoom=args["zoom"],
style=args["mapbox_style"],
)
fig.update_mapboxes(
accesstoken=MAPBOX_TOKEN,
center=center,
zoom=args["zoom"],
style=args["mapbox_style"],
)


def configure_geo(args, fig, orders):
fig.update_layout(
geo=dict(
center=args["center"],
scope=args["scope"],
projection=dict(type=args["projection"]),
)
fig.update_geos(
center=args["center"],
scope=args["scope"],
fitbounds=args["fitbounds"],
visible=args["basemap_visible"],
projection=dict(type=args["projection"]),
)


Expand Down Expand Up @@ -1750,6 +1740,14 @@ def infer_config(args, constructor, trace_patch, layout_patch):
if "line_shape" in args:
trace_patch["line"] = dict(shape=args["line_shape"])

if "geojson" in args:
trace_patch["featureidkey"] = args["featureidkey"]
trace_patch["geojson"] = (
args["geojson"]
if not hasattr(args["geojson"], "__geo_interface__") # for geopandas
else args["geojson"].__geo_interface__
)

# Compute marginal attribute
if "marginal" in args:
position = "marginal_x" if args["orientation"] == "v" else "marginal_y"
Expand Down Expand Up @@ -2062,20 +2060,12 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):

def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels):
# Build subplot specs
specs = [[{}] * ncols for _ in range(nrows)]
for frame in frame_list:
for trace in frame["data"]:
row0 = trace._subplot_row - 1
col0 = trace._subplot_col - 1
if isinstance(trace, go.Splom):
# Splom not compatible with make_subplots, treat as domain
specs[row0][col0] = {"type": "domain"}
else:
specs[row0][col0] = {"type": trace.type}
specs = [[dict(type=subplot_type or "domain")] * ncols for _ in range(nrows)]

# Default row/column widths uniform
column_widths = [1.0] * ncols
row_heights = [1.0] * nrows
facet_col_wrap = args.get("facet_col_wrap", 0)

# Build column_widths/row_heights
if subplot_type == "xy":
Expand All @@ -2087,7 +2077,7 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la

row_heights = [main_size] * (nrows - 1) + [1 - main_size]
vertical_spacing = 0.01
elif args.get("facet_col_wrap", 0):
elif facet_col_wrap:
vertical_spacing = args.get("facet_row_spacing", None) or 0.07
else:
vertical_spacing = args.get("facet_row_spacing", None) or 0.03
Expand All @@ -2108,10 +2098,12 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
#
# We can customize subplot spacing per type once we enable faceting
# for all plot types
vertical_spacing = 0.1
horizontal_spacing = 0.1
if facet_col_wrap:
vertical_spacing = args.get("facet_row_spacing", None) or 0.07
else:
vertical_spacing = args.get("facet_row_spacing", None) or 0.03
horizontal_spacing = args.get("facet_col_spacing", None) or 0.02

facet_col_wrap = args.get("facet_col_wrap", 0)
if facet_col_wrap:
subplot_labels = [None] * nrows * ncols
while len(col_labels) < nrows * ncols:
Expand Down
4 changes: 3 additions & 1 deletion packages/python/plotly/plotly/express/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,11 @@
"If `True`, an extra line segment is drawn between the first and last point.",
],
line_shape=["str (default `'linear'`)", "One of `'linear'` or `'spline'`."],
fitbounds=["str (default `False`).", "One of `False`, `locations` or `geojson`."],
basemap_visible=["bool", "Force the basemap visibility."],
scope=[
"str (default `'world'`).",
"One of `'world'`, `'usa'`, `'europe'`, `'asia'`, `'africa'`, `'north america'`, or `'south america'`)"
"One of `'world'`, `'usa'`, `'europe'`, `'asia'`, `'africa'`, `'north america'`, or `'south america'`"
"Default is `'world'` unless `projection` is set to `'albers usa'`, which forces `'usa'`.",
],
projection=[
Expand Down