Skip to content

Commit

Permalink
make it possible to plot mulitple boxes in one figure
Browse files Browse the repository at this point in the history
default behavior of inspect with list of indices is to show each box in a seperate figure (if style is graph). This code makes it possible to plot them in a single figure.

outstanding issue from #124
  • Loading branch information
quaquel committed Nov 18, 2023
1 parent b6ac5cc commit e28ee4a
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 10 deletions.
13 changes: 11 additions & 2 deletions ema_workbench/analysis/prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from operator import itemgetter

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
Expand Down Expand Up @@ -398,7 +399,7 @@ def __getattr__(self, name):

def inspect(self, i=None, style="table", **kwargs):
"""Write the stats and box limits of the user specified box to
standard out. if i is not provided, the last box will be
standard out. If i is not provided, the last box will be
printed
Parameters
Expand Down Expand Up @@ -450,7 +451,13 @@ def _inspect(self, i=None, style="table", **kwargs):
if style == "table":
return self._inspect_table(i, uncs, qp_values)
elif style == "graph":
return self._inspect_graph(i, uncs, qp_values, **kwargs)
# makes it possible to use _inspect to plot multiple
# boxes into a single figure
try:
ax = kwargs.pop("ax")
except KeyError:
fig, ax = plt.subplots()
return self._inspect_graph(i, uncs, qp_values, ax=ax, **kwargs)
elif style == "data":
return self._inspect_data(i, uncs, qp_values)
else:
Expand Down Expand Up @@ -496,6 +503,7 @@ def _inspect_graph(
ticklabel_formatter="{} ({})",
boxlim_formatter="{: .2g}",
table_formatter="{:.3g}",
ax=None,
):
"""Helper method for visualizing box statistics in
graph form"""
Expand All @@ -510,6 +518,7 @@ def _inspect_graph(
ticklabel_formatter=ticklabel_formatter,
boxlim_formatter=boxlim_formatter,
table_formatter=table_formatter,
ax=ax,
)

def inspect_tradeoff(self):
Expand Down
22 changes: 15 additions & 7 deletions ema_workbench/analysis/scenario_discovery_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def plot_pair_wise_scatter(
fill_subplots=True,
):
"""helper function for pair wise scatter plotting
Parameters
----------
x : DataFrame
Expand Down Expand Up @@ -530,16 +531,19 @@ def plot_pair_wise_scatter(
return grid


def _setup_figure(uncs):
def _setup_figure(uncs, ax):
"""
helper function for creating the basic layout for the figures that
show the box lims.
Parameters
----------
uncs : list of str
ax : axes instance
"""
nr_unc = len(uncs)
fig = plt.figure()
ax = fig.add_subplot(111)

# create the shaded grey background
rect = mpl.patches.Rectangle(
Expand All @@ -551,7 +555,6 @@ def _setup_figure(uncs):
ax.yaxis.set_ticks(list(range(nr_unc)))
ax.xaxis.set_ticks([0, 0.25, 0.5, 0.75, 1])
ax.set_yticklabels(uncs[::-1])
return fig, ax


def plot_box(
Expand All @@ -564,6 +567,7 @@ def plot_box(
ticklabel_formatter="{} ({})",
boxlim_formatter="{: .2g}",
table_formatter="{:.3g}",
ax=None,
):
"""Helper function for parallel coordinate style visualization
of a box
Expand All @@ -579,6 +583,7 @@ def plot_box(
ticklabel_formatter : str
boxlim_formatter : str
table_formatter : str
ax : Axes instance
Returns
-------
Expand All @@ -587,8 +592,9 @@ def plot_box(
"""
norm_box_lim = _normalize(boxlim, box_init, uncs)
fig = plt.gcf()

fig, ax = _setup_figure(uncs)
_setup_figure(uncs, ax)
for j, u in enumerate(uncs):
# we want to have the most restricted dimension
# at the top of the figure
Expand Down Expand Up @@ -842,7 +848,8 @@ def plot_boxes(x, boxes, together):
norm_box_lims = [_normalize(box_lim, box_init, uncs) for box_lim in boxes]

if together:
fig, ax = _setup_figure(uncs)
fig, ax = plt.subplots()
_setup_figure(uncs, ax)

for i, u in enumerate(uncs):
colors = itertools.cycle(COLOR_LIST)
Expand All @@ -862,7 +869,8 @@ def plot_boxes(x, boxes, together):
colors = itertools.cycle(COLOR_LIST)

for j, norm_box_lim in enumerate(norm_box_lims):
fig, ax = _setup_figure(uncs)
fig, ax = plt.subplots()
_setup_figure(uncs, ax)
ax.set_title(f"box {j}")
color = next(colors)

Expand Down
9 changes: 8 additions & 1 deletion ema_workbench/examples/sd_prim_flu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@ def classify(data):
box_1 = prim_obj.find_box()
box_1.show_ppt()
box_1.show_tradeoff()
box_1.inspect(5, style="graph", boxlim_formatter="{: .2f}")
# box_1.inspect([5, 6], style="graph", boxlim_formatter="{: .2f}")

fig, axes = plt.subplots(nrows=2, ncols=1)

for i, ax in zip([5, 6], axes):
box_1._inspect(i, style="graph", boxlim_formatter="{: .2f}", ax=ax)
plt.show()

box_1.inspect(5)
box_1.select(5)
box_1.write_ppt_to_stdout()
Expand Down
Binary file removed test/data/test.tar.gz
Binary file not shown.

0 comments on commit e28ee4a

Please sign in to comment.