Skip to content

Commit

Permalink
inspect multiple boxes and display them in a single figure (#317)
Browse files Browse the repository at this point in the history
* make it possible to plot mulitple boxes in one figure

default behavior of inspect with list of indices is to show each box in a seperate figure (if style is graph). This code makes it possible to plot them in a single figure.

closes #124

* Update test_prim.py
  • Loading branch information
quaquel committed Dec 4, 2023
1 parent 2002532 commit 2ae0278
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 13 deletions.
35 changes: 31 additions & 4 deletions ema_workbench/analysis/prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from operator import itemgetter

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
Expand Down Expand Up @@ -396,9 +397,9 @@ def __getattr__(self, name):
else:
raise AttributeError

def inspect(self, i=None, style="table", **kwargs):
def inspect(self, i=None, style="table", ax=None, **kwargs):
"""Write the stats and box limits of the user specified box to
standard out. if i is not provided, the last box will be
standard out. If i is not provided, the last box will be
printed
Parameters
Expand All @@ -409,20 +410,38 @@ def inspect(self, i=None, style="table", **kwargs):
the style of the visualization. 'table' prints the stats and
boxlim. 'graph' creates a figure. 'data' returns a list of
tuples, where each tuple contains the stats and the box_lims.
ax : axes or list of axes instances, optional
used in conjunction with `graph` style, allows you to control the axes on which graph is plotted
if i is list, axes should be list of equal length. If axes is None, each i_j in i will be plotted
in a separate figure.
additional kwargs are passed to the helper function that
generates the table or graph
"""
if style not in {"table", "graph", "data"}:
raise ValueError(f"style must be one of 'table', 'graph', or 'data', not {style}")

if i is None:
i = [self._cur_box]
elif isinstance(i, int):
i = [i]

if isinstance(ax, mpl.axes.Axes):
ax = [ax]

if not all(isinstance(x, int) for x in i):
raise TypeError(f"i must be an integer or list of integers, not {type(i)}")

return [self._inspect(entry, style=style, **kwargs) for entry in i]
if (ax is not None) and style == "graph":
if len(ax) != len(i):
raise ValueError(
f"the number of axes ({len(ax)}) does not match the number of boxes to inspect ({len(i)})"
)
else:
return [self._inspect(i_j, style=style, ax=ax, **kwargs) for i_j, ax in zip(i, ax)]
else:
return [self._inspect(entry, style=style, **kwargs) for entry in i]

def _inspect(self, i=None, style="table", **kwargs):
"""Helper method for inspecting one or more boxes on the
Expand Down Expand Up @@ -450,7 +469,13 @@ def _inspect(self, i=None, style="table", **kwargs):
if style == "table":
return self._inspect_table(i, uncs, qp_values)
elif style == "graph":
return self._inspect_graph(i, uncs, qp_values, **kwargs)
# makes it possible to use _inspect to plot multiple
# boxes into a single figure
try:
ax = kwargs.pop("ax")
except KeyError:
fig, ax = plt.subplots()
return self._inspect_graph(i, uncs, qp_values, ax=ax, **kwargs)
elif style == "data":
return self._inspect_data(i, uncs, qp_values)
else:
Expand Down Expand Up @@ -496,6 +521,7 @@ def _inspect_graph(
ticklabel_formatter="{} ({})",
boxlim_formatter="{: .2g}",
table_formatter="{:.3g}",
ax=None,
):
"""Helper method for visualizing box statistics in
graph form"""
Expand All @@ -507,6 +533,7 @@ def _inspect_graph(
uncs,
self.peeling_trajectory.at[i, "coverage"],
self.peeling_trajectory.at[i, "density"],
ax,
ticklabel_formatter=ticklabel_formatter,
boxlim_formatter=boxlim_formatter,
table_formatter=table_formatter,
Expand Down
22 changes: 15 additions & 7 deletions ema_workbench/analysis/scenario_discovery_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def plot_pair_wise_scatter(
fill_subplots=True,
):
"""helper function for pair wise scatter plotting
Parameters
----------
x : DataFrame
Expand Down Expand Up @@ -530,16 +531,19 @@ def plot_pair_wise_scatter(
return grid


def _setup_figure(uncs):
def _setup_figure(uncs, ax):
"""
helper function for creating the basic layout for the figures that
show the box lims.
Parameters
----------
uncs : list of str
ax : axes instance
"""
nr_unc = len(uncs)
fig = plt.figure()
ax = fig.add_subplot(111)

# create the shaded grey background
rect = mpl.patches.Rectangle(
Expand All @@ -551,7 +555,6 @@ def _setup_figure(uncs):
ax.yaxis.set_ticks(list(range(nr_unc)))
ax.xaxis.set_ticks([0, 0.25, 0.5, 0.75, 1])
ax.set_yticklabels(uncs[::-1])
return fig, ax


def plot_box(
Expand All @@ -561,6 +564,7 @@ def plot_box(
uncs,
coverage,
density,
ax,
ticklabel_formatter="{} ({})",
boxlim_formatter="{: .2g}",
table_formatter="{:.3g}",
Expand All @@ -579,6 +583,7 @@ def plot_box(
ticklabel_formatter : str
boxlim_formatter : str
table_formatter : str
ax : Axes instance
Returns
-------
Expand All @@ -587,8 +592,9 @@ def plot_box(
"""
norm_box_lim = _normalize(boxlim, box_init, uncs)
fig = plt.gcf()

fig, ax = _setup_figure(uncs)
_setup_figure(uncs, ax)
for j, u in enumerate(uncs):
# we want to have the most restricted dimension
# at the top of the figure
Expand Down Expand Up @@ -842,7 +848,8 @@ def plot_boxes(x, boxes, together):
norm_box_lims = [_normalize(box_lim, box_init, uncs) for box_lim in boxes]

if together:
fig, ax = _setup_figure(uncs)
fig, ax = plt.subplots()
_setup_figure(uncs, ax)

for i, u in enumerate(uncs):
colors = itertools.cycle(COLOR_LIST)
Expand All @@ -862,7 +869,8 @@ def plot_boxes(x, boxes, together):
colors = itertools.cycle(COLOR_LIST)

for j, norm_box_lim in enumerate(norm_box_lims):
fig, ax = _setup_figure(uncs)
fig, ax = plt.subplots()
_setup_figure(uncs, ax)
ax.set_title(f"box {j}")
color = next(colors)

Expand Down
8 changes: 7 additions & 1 deletion ema_workbench/examples/sd_prim_flu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ def classify(data):
box_1 = prim_obj.find_box()
box_1.show_ppt()
box_1.show_tradeoff()
box_1.inspect(5, style="graph", boxlim_formatter="{: .2f}")
# box_1.inspect([5, 6], style="graph", boxlim_formatter="{: .2f}")

fig, axes = plt.subplots(nrows=2, ncols=1)

box_1.inspect([5, 6], style="graph", boxlim_formatter="{: .2f}", ax=axes)
plt.show()

box_1.inspect(5)
box_1.select(5)
box_1.write_ppt_to_stdout()
Expand Down
Binary file removed test/data/test.tar.gz
Binary file not shown.
11 changes: 11 additions & 0 deletions test/test_analysis/test_prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from ema_workbench.analysis import prim
from ema_workbench.analysis.prim import PrimBox
Expand Down Expand Up @@ -69,9 +70,19 @@ def test_inspect(self):
box.inspect(1)
box.inspect()
box.inspect(style="graph")
box.inspect(style="data")

box.inspect([0, 1])

fig, axes = plt.subplots(2)
box.inspect([0, 1], ax=axes, style="graph")

fig, ax = plt.subplots()
box.inspect(0, ax=ax, style="graph")

with pytest.raises(ValueError):
fig, axes = plt.subplots(3)
box.inspect([0, 1], ax=axes, style="graph")
with pytest.raises(ValueError):
box.inspect(style="some unknown style")
with pytest.raises(TypeError):
Expand Down
3 changes: 2 additions & 1 deletion test/test_analysis/test_scenario_discovery_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def test_plot_box(self):

qp_values = {"a": [0.05, 0.9], "c": [0.05, -1]}

sdutil.plot_box(boxlim, qp_values, box_init, restricted_dims, 1, 1)
fig, ax = plt.subplots()
sdutil.plot_box(boxlim, qp_values, box_init, restricted_dims, 1, 1, ax)
plt.draw()
plt.close("all")

Expand Down

0 comments on commit 2ae0278

Please sign in to comment.