Skip to content

Commit

Permalink
Merge pull request #50 from mmore500/main
Browse files Browse the repository at this point in the history
Create context managers for temporary plot9/seaborn patching
  • Loading branch information
ponnhide committed Dec 7, 2023
2 parents 4428f24 + 010008a commit 76c7f88
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 25 deletions.
12 changes: 12 additions & 0 deletions API.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@

`None`

- ### **`patched_axisgrid()`**

[Context manager](https://docs.python.org/3/reference/compound_stmts.html#with)/[decorator](https://docs.python.org/3/glossary.html#term-decorator)
interface for `overwrite_axisgrid` patching that reverts changes when leaving
`with`/function scope.

#### Returns

Context manager (i.e., `with patched_axisgrid():`) or decorator (i.e.,
`@patched_axisgrid()`) that temporarily patches seaborn for patchworklib
compatibility.

- ### **`load_seabornobj(g, label=None, labels=None, figsize=(3, 3))`**

Load a seaborn plot generated based on `seaborn._core.plot.Plotter` class.
Expand Down
140 changes: 115 additions & 25 deletions patchworklib/patchworklib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import matplotlib.axes as axes
import plotnine

from contextlib import suppress
from contextlib import contextmanager, suppress
from types import SimpleNamespace as NS
from unittest.mock import patch
from matplotlib.offsetbox import AnchoredOffsetbox
from matplotlib.transforms import Bbox, TransformedBbox, Affine2D

Expand Down Expand Up @@ -210,11 +211,58 @@ def _reset_ggplot_legend(bricks):
else:
pass

def overwrite_plotnine():
def _needs_plotnine_ggplot_draw_patch():
"""Implementation detail for patched_plotnine, for internal use."""
import plotnine
plotnine.ggplot.draw = mp9.draw
plotnine_version = plotnine.__version__

return StrictVersion(plotnine_version) >= StrictVersion("0.12")

@contextmanager
def patched_plotnine():
"""
Temporarily patch plot9 for patchworklib compatibility. Can be used as a
context manager or a decorator.
Examples
-------
>>> with patched_plotnine():
... pw.load_ggplot(
... p9.ggplot(data, p9.aes(x="x", y="y", fill="fill")),
... )
Example use as a context manager.
>>> @patched_plotnine()
>>> def custom_plot():
... pw.load_ggplot(
... p9.ggplot(data, p9.aes(x="x", y="y", fill="fill")),
... )
>>> custom_plot()
Example use as a decorator.
"""
if _needs_plotnine_ggplot_draw_patch():
with patch("plotnine.ggplot.ggplot.draw", mp9.draw):
yield
else:
yield

def overwrite_plotnine():
"""
Modify plot9 for patchworklib compatibility.
See Also
--------
patched_plotnine : Context manager that applies then reverses plotnine
patches.
"""
patched_plotnine().__enter__()

def load_ggplot(ggplot=None, figsize=None):
@patched_plotnine()
def load_ggplot(ggplot=None, figsize=None):
"""
Convert a plotnine plot object to a patchworklib.Bricks object.
Expand All @@ -231,6 +279,8 @@ def load_ggplot(ggplot=None, figsize=None):
patchworklib.Bricks object.
"""
import plotnine
plotnine_version = plotnine.__version__

def draw_labels(bricks, gori, gcp, figsize):
get_property = gcp.theme.themeables.property
Expand Down Expand Up @@ -453,10 +503,6 @@ def draw_title(bricks, gori, gcp, figsize):
for ax in gori.axs:
gori.theme.themeables['plot_title'].apply(ax)

import plotnine
plotnine_version = plotnine.__version__
if StrictVersion(plotnine_version) >= StrictVersion("0.12"):
overwrite_plotnine()

#save_original_position
global _basefigure
Expand Down Expand Up @@ -659,33 +705,77 @@ def draw_title(bricks, gori, gcp, figsize):

return return_obj

@contextmanager
def patched_axisgrid():
"""
Temporarily patch seaborn.axisgrid methods with patchworklib counterparts.
This allows for custom behaviors in seaborn's grid objects like FacetGrid,
PairGrid, JointGrid, and ClusterGrid, particularly useful when integrating
seaborn's plotting with patchworklib's functionalities. Can be used as a
context manager or a decorator.
Examples
-------
>>> with patched_axisgrid():
... pw.load_seabornobj(
... sns.jointplot(x="x", y="y", data=data),
... )
Example use as a context manager.
>>> @patched_axisgrid()
>>> def custom_plot():
... pw.load_seabornobj(
... sns.jointplot(x="x", y="y", data=data),
... )
>>> custom_plot()
Example use as a decorator.
"""
# patch("sns.pairplot", mg.pairplot)
with patch.object(
sns.axisgrid.Grid, "_figure", _basefigure, create=True
), patch(
"seaborn.axisgrid.Grid.add_legend", mg.add_legend
), patch(
"seaborn.axisgrid.FacetGrid.__init__", mg.__init_for_facetgrid__
), patch(
"seaborn.axisgrid.FacetGrid.despine", mg.despine
), patch(
"seaborn.axisgrid.PairGrid.__init__", mg.__init_for_pairgrid__
), patch(
"seaborn.axisgrid.JointGrid.__init__", mg.__init_for_jointgrid__
), patch(
"seaborn.matrix.ClusterGrid.__init__", mg.__init_for_clustergrid__
), patch(
"seaborn.matrix.ClusterGrid.__setattr__", mg.__setattr_for_clustergrid__
), patch(
"seaborn.matrix.ClusterGrid.plot", mg.__plot_for_clustergrid__
):
yield

def overwrite_axisgrid():
"""
Overwrite `__init__` functions in seaborn.axisgrid.FacetGrid,
Overwrite `__init__` functions in seaborn.axisgrid.FacetGrid,
seaborn.axisgrid.PairGrid and seaborn.axisgrid.JointGrid.
The function changes the figure object given in the `__init__` functions of the
axisgrid class objects, which is used for drawing plots, to `_basefigure
in the patchworklib. If you want to import plots generated baseon
seabron.axisgrid.xxGrid objects as patchworklib.Brick(s) object by using
The function changes the figure object given in the `__init__` functions of the
axisgrid class objects, which is used for drawing plots, to `_basefigure
in the patchworklib. If you want to import plots generated baseon
seabron.axisgrid.xxGrid objects as patchworklib.Brick(s) object by using
`load_seaborngrid` function, you should execute the function in advance.
Returns
-------
None.
"""

#sns.pairplot = mg.pairplot
sns.axisgrid.Grid._figure = _basefigure
sns.axisgrid.Grid.add_legend = mg.add_legend
sns.axisgrid.FacetGrid.__init__ = mg.__init_for_facetgrid__
sns.axisgrid.FacetGrid.despine = mg.despine
sns.axisgrid.PairGrid.__init__ = mg.__init_for_pairgrid__
sns.axisgrid.JointGrid.__init__ = mg.__init_for_jointgrid__
sns.matrix.ClusterGrid.__init__ = mg.__init_for_clustergrid__
sns.matrix.ClusterGrid.__setattr__ = mg.__setattr_for_clustergrid__
sns.matrix.ClusterGrid.plot = mg.__plot_for_clustergrid__
See Also
--------
patched_axisgrid : Context manager that applies then reverses axisgrid
patches.
"""
patched_axisgrid().__enter__()

def load_seabornobj(g, label=None, labels=None, figsize=(3,3)):
"""
Expand Down
74 changes: 74 additions & 0 deletions tests/test_patchworklib.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,77 @@ def test_sns_and_p9(tmp_path: Path):
result_file = tmp_path / "g.png"
g.savefig(result_file)
assert result_file.exists()


@pw.patched_axisgrid()
def _make_seabornobj():
iris = sns.load_dataset("iris")
tips = sns.load_dataset("tips")

# An lmplot
g0 = sns.lmplot(
x="total_bill", y="tip", hue="smoker", data=tips, palette=dict(Yes="g", No="m")
)
g0 = pw.load_seaborngrid(g0, label="g0")

# A Pairplot
g1 = sns.pairplot(iris, hue="species")
g1 = pw.load_seaborngrid(g1, label="g1", figsize=(6, 6))

# A relplot
g2 = sns.relplot(
data=tips,
x="total_bill",
y="tip",
col="time",
hue="time",
size="size",
style="sex",
palette=["b", "r"],
sizes=(10, 100),
)
g2.set_titles("")
g2 = pw.load_seaborngrid(g2, label="g2")

# A JointGrid
g3 = sns.jointplot(
data=iris, x="sepal_width", y="petal_length", kind="kde", space=0
)
g3 = pw.load_seaborngrid(g3, label="g3")

composite = (((g0/g3)["g0"]|g1)["g1"]/g2)
return composite


def test_load_seabornobj(tmp_path: Path):
composite = _make_seabornobj()

result_file = tmp_path / "composite.png"
composite.savefig(result_file)
assert result_file.exists()


@pw.patched_axisgrid() # duplicate patch wrapper
def test_patch_nesting(tmp_path: Path):
composite = _make_seabornobj()

result_file = tmp_path / "composite.png"
composite.savefig(result_file)
assert result_file.exists()


def test_patched_axisgrid():
with pw.patched_axisgrid():
assert hasattr(sns.axisgrid.Grid, "_figure")
assert sns.axisgrid.FacetGrid.add_legend is pw.modified_grid.add_legend

assert not hasattr(sns.axisgrid.Grid, "_figure")
assert sns.axisgrid.FacetGrid.add_legend is not pw.modified_grid.add_legend


def test_patched_plotnine():
with pw.patched_plotnine():
if pw.patchworklib._needs_plotnine_ggplot_draw_patch:
assert p9.ggplot.draw is pw.modified_plotnine.draw

assert p9.ggplot.draw is not pw.modified_plotnine.draw

0 comments on commit 76c7f88

Please sign in to comment.