Skip to content

Commit

Permalink
Support for different formats of the histogram titles in plot_posterior
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomas Stolker committed Aug 5, 2020
1 parent 5625e66 commit a340303
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 12 deletions.
71 changes: 60 additions & 11 deletions species/plot/plot_mcmc.py
Expand Up @@ -3,8 +3,9 @@
"""

import os
import warnings

from typing import Optional, Tuple, List
from typing import List, Optional, Tuple, Union

import h5py
import corner
Expand Down Expand Up @@ -119,7 +120,7 @@ def plot_posterior(tag: str,
burnin: Optional[int] = None,
title: Optional[str] = None,
offset: Optional[Tuple[float, float]] = None,
title_fmt: str = '.2f',
title_fmt: Union[str, List[str]] = '.2f',
limits: Optional[List[Tuple[float, float]]] = None,
max_posterior: bool = False,
inc_luminosity: bool = False,
Expand All @@ -138,8 +139,10 @@ def plot_posterior(tag: str,
Plot title. No title is shown if set to ``None``.
offset : tuple(float, float), None
Offset of the x- and y-axis label. Default values are used if set to ``None``.
title_fmt : str
Format of the median and error values.
title_fmt : str, list(str)
Format of the titles above the 1D distributions. Either a single string, which will be used
for all parameters, or a list with the title format for each parameter separately (in the
order as shown in the corner plot).
limits : list(tuple(float, float), ), None
Axis limits of all parameters. Automatically set if set to ``None``.
max_posterior : bool
Expand Down Expand Up @@ -265,7 +268,6 @@ def plot_posterior(tag: str,
logg_index = np.argwhere(np.array(box.parameters) == 'logg')[0]
radius_index = np.argwhere(np.array(box.parameters) == 'radius')[0]


mass_samples = read_util.get_mass(samples[..., logg_index], samples[..., radius_index])

samples = np.append(samples, mass_samples, axis=-1)
Expand All @@ -275,13 +277,56 @@ def plot_posterior(tag: str,
else:
warnings.warn('Samples with the log(g) and radius are required for \'inc_mass=True\'.')

if isinstance(title_fmt, list) and len(title_fmt) != ndim:
raise ValueError(f'The number of items in the list of \'title_fmt\' ({len(title_fmt)}) is '
f'not equal to the number of dimensions of the samples ({ndim}).')

labels = plot_util.update_labels(box.parameters)

samples = samples.reshape((-1, ndim))

fig = corner.corner(samples, labels=labels, quantiles=[0.16, 0.5, 0.84],
label_kwargs={'fontsize': 13}, show_titles=True,
title_kwargs={'fontsize': 12}, title_fmt=title_fmt)
hist_titles = []

for i, item in enumerate(labels):
unit_start = item.find('(')

if unit_start == -1:
param_label = item
unit_label = None

else:
param_label = item[:unit_start]
# Remove parenthesis from the units
unit_label = item[unit_start+1:-1]

q_16, q_50, q_84 = corner.quantile(samples[:, i], [0.16, 0.5, 0.84])
q_minus, q_plus = q_50-q_16, q_84-q_50

if isinstance(title_fmt, str):
fmt = '{{0:{0}}}'.format(title_fmt).format

elif isinstance(title_fmt, list):
fmt = '{{0:{0}}}'.format(title_fmt[i]).format

best_fit = r'${{{0}}}_{{-{1}}}^{{+{2}}}$'
best_fit = best_fit.format(fmt(q_50), fmt(q_minus), fmt(q_plus))

if unit_label is None:
hist_title = f'{param_label} = {best_fit}'

else:
hist_title = f'{param_label} = {best_fit} {unit_label}'

hist_titles.append(hist_title)

fig = corner.corner(samples,
quantiles=[0.16, 0.5, 0.84],
labels=labels,
label_kwargs={'fontsize': 13},
titles=hist_titles,
show_titles=True,
title_fmt=None,
title_kwargs={'fontsize': 12})

axes = np.array(fig.axes).reshape((ndim, ndim))

Expand Down Expand Up @@ -383,9 +428,13 @@ def plot_photometry(tag,

print(f'Plotting photometry samples: {output}...', end='', flush=True)

fig = corner.corner(samples, labels=['Magnitude'], quantiles=[0.16, 0.5, 0.84],
label_kwargs={'fontsize': 13}, show_titles=True,
title_kwargs={'fontsize': 12}, title_fmt='.2f')
fig = corner.corner(samples,
labels=['Magnitude'],
quantiles=[0.16, 0.5, 0.84],
label_kwargs={'fontsize': 13},
show_titles=True,
title_kwargs={'fontsize': 12},
title_fmt='.2f')

axes = np.array(fig.axes).reshape((1, 1))

Expand Down
2 changes: 1 addition & 1 deletion species/util/plot_util.py
Expand Up @@ -226,7 +226,7 @@ def update_labels(param: List[str]) -> List[str]:
param[i] = rf'$\mathregular{{c}}_\mathregular{{{item[11:]}}}$ (nm)'

elif item[0:9] == 'corr_len_':
param[i] = rf'$\mathregular{{log}}\,\ell_\mathregular{{{item[9:]}}}/\mathregular{{µm}}$'
param[i] = rf'$\mathregular{{log}}\,\ell_\mathregular{{{item[9:]}}}$'

elif item[0:9] == 'corr_amp_':
param[i] = rf'$\mathregular{{f}}_\mathregular{{{item[9:]}}}$'
Expand Down

0 comments on commit a340303

Please sign in to comment.