Skip to content

Commit

Permalink
Add show results function (a wrapper around show_priors but wit ha be…
Browse files Browse the repository at this point in the history
…tter name) (#45)
  • Loading branch information
xgarrido committed Mar 26, 2024
1 parent 533dbe5 commit 1e4cc61
Showing 1 changed file with 50 additions and 4 deletions.
54 changes: 50 additions & 4 deletions cobaya_utilities/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,41 @@ def get_mc_samples(
return samples, labels, colors


def show_results(g, results, with_legend=True, legend_kwargs=None):
"""Show results values on a given set of axes
Parameters
----------
g: getdist.plots
the getdist plotter instance
results: dict
dictionary holding result values. The dict can be either a parameter: loc/scale combo or
parameter: dict(loc/scale + and other settings such as color, linestyle)
with_legend: bool
add legend
legend_kwargs: dict
set legend kwargs
"""
for name, result in results.items():
show_priors(
g,
result.get("values", result),
color=result.get("color"),
ls=result.get("ls", "-"),
with_legend=False,
)
kwargs = dict(
ax=g.fig.axes[-1],
labels=results.keys(),
colors=[v.get("color") for v in results.values()],
fontsize=10,
bbox_to_anchor=(1, 1),
loc="upper left",
)
kwargs |= legend_kwargs or {}
if with_legend:
add_legend(**kwargs)


def show_priors(g, priors, color="gray", ls="--", with_legend=True, legend_kwargs=None):
"""Show prior values on a given set of axes
Parameters
Expand Down Expand Up @@ -528,18 +563,29 @@ def plot_mean_distributions(samples, params, colors="0.7", return_results=False,
return pd.DataFrame.from_dict(results, orient="index")


def add_legend(fig=None, ax=None, labels=None, colors=None, ls=None, **kwargs):
def add_legend(obj, labels=None, colors=None, ls=None, **kwargs):
"""Add legend
if not fig and not ax:
raise ValueError("Missing either fig or axis instance!")
Parameters
----------
obj: figure or axis
a matplotlib figure or axis
params: dict or list
a dict holding the parameter names and its default value or
a unique list of parameter names
colors: list or str
the colors of the different markers. If colors == "chi2" then the markers will be colored
relatively to their chi2 values (if default values are given for parameters)
return_results: bool
return a pandas Dataframe holding the results
"""

labels = labels or kwargs.get("legend_labels")

colors = colors or [None for label in labels]
ls = ls or ["-" for label in labels]
handles = [Line2D([0], [0], color=colors[i], ls=ls[i]) for i, label in enumerate(labels)]

obj = fig or ax
leg = obj.legend(
handles,
labels,
Expand Down

0 comments on commit 1e4cc61

Please sign in to comment.