Skip to content

Commit

Permalink
Optional separation of data and model legend in plot_spectrum, type c…
Browse files Browse the repository at this point in the history
…hecks added to plot_spectrum
  • Loading branch information
Tomas Stolker committed May 13, 2020
1 parent 2c956de commit 0657b90
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 61 deletions.
16 changes: 10 additions & 6 deletions species/plot/plot_color.py
Expand Up @@ -21,12 +21,6 @@
from species.util import plot_util


mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)


@typechecked
def plot_color_magnitude(boxes: list,
objects: Optional[Union[List[Tuple[str, str, str, str]],
Expand Down Expand Up @@ -107,6 +101,11 @@ def plot_color_magnitude(boxes: list,
"""

mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)

model_color = ('#234398', '#f6a432', 'black')
model_linestyle = ('-', '--', ':', '-.')

Expand Down Expand Up @@ -559,6 +558,11 @@ def plot_color_color(boxes: list,
None
"""

mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)

model_color = ('#234398', '#f6a432', 'black')
model_linestyle = ('-', '--', ':', '-.')

Expand Down
21 changes: 15 additions & 6 deletions species/plot/plot_mcmc.py
Expand Up @@ -19,12 +19,6 @@
from species.util import plot_util


mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)


@typechecked
def plot_walkers(tag: str,
nsteps: Optional[int] = None,
Expand Down Expand Up @@ -52,6 +46,11 @@ def plot_walkers(tag: str,

print(f'Plotting walkers: {output}...', end='', flush=True)

mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)

species_db = database.Database()
box = species_db.get_samples(tag)

Expand Down Expand Up @@ -154,6 +153,11 @@ def plot_posterior(tag: str,
None
"""

mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)

if burnin is None:
burnin = 0

Expand Down Expand Up @@ -303,6 +307,11 @@ def plot_photometry(tag,
None
"""

mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)

species_db = database.Database()

samples = species_db.get_mcmc_photometry(tag, burnin, filter_id)
Expand Down
173 changes: 124 additions & 49 deletions species/plot/plot_spectrum.py
Expand Up @@ -4,39 +4,39 @@

import os
import math
import warnings
import itertools

from typing import Optional, Union, Tuple, List

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from typeguard import typechecked

from species.core import box, constants
from species.read import read_filter
from species.util import plot_util


mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)
plt.rcParams['axes.axisbelow'] = False


def plot_spectrum(boxes,
filters=None,
residuals=None,
plot_kwargs=None,
xlim=None,
ylim=None,
ylim_res=None,
scale=('linear', 'linear'),
title=None,
offset=None,
legend=None,
figsize=(7., 5.),
object_type='planet',
quantity='flux',
output='spectrum.pdf'):
@typechecked
def plot_spectrum(boxes: list,
filters: Optional[List[str]] = None,
residuals: Optional[box.ResidualsBox] = None,
plot_kwargs: Optional[List[Optional[dict]]] = None,
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
ylim_res: Optional[Tuple[float, float]] = None,
scale: Optional[Tuple[str, str]] = None,
title: Optional[str] = None,
offset: Optional[Tuple[float, float]] = None,
legend: Union[str, dict, Tuple[float, float],
List[Optional[Union[dict, str, Tuple[float, float]]]]] = None,
figsize: Optional[Tuple[float, float]] = (7., 5.),
object_type: str = 'planet',
quantity: str = 'flux',
output: str = 'spectrum.pdf'):
"""
Parameters
----------
Expand Down Expand Up @@ -75,15 +75,16 @@ def plot_spectrum(boxes,
ylim_res : tuple(float, float), None
Limits of the residuals axis. Automatically chosen (based on the minimum and maximum
residual value) if set to None.
scale : tuple(str, str)
Scale of the axes ('linear' or 'log').
scale : tuple(str, str), None
Scale of the x and y axes ('linear' or 'log'). The scale is set to ``('linear', 'linear')``
if set to ``None``.
title : str
Title.
offset : tuple(float, float)
Offset for the label of the x- and y-axis.
legend : str, tuple, dict, None
legend : str, tuple, dict, list(dict, dict), None
Location of the legend (str, tuple) or a dictionary with the ``**kwargs`` of
``matplotlib.pyplot.legend``, e.g. ``{'loc': 'upper left', 'fontsize: 12.}``.
``matplotlib.pyplot.legend``, for example ``{'loc': 'upper left', 'fontsize: 12.}``.
figsize : tuple(float, float)
Figure size.
object_type : str
Expand All @@ -100,6 +101,12 @@ def plot_spectrum(boxes,
None
"""

mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)
plt.rcParams['axes.axisbelow'] = False

if plot_kwargs is None:
plot_kwargs = []

Expand Down Expand Up @@ -190,7 +197,7 @@ def plot_spectrum(boxes,
ax2.set_ylabel('Transmission', fontsize=13)

if residuals is not None:
ax3.set_ylabel(r'Residual ($\sigma$)', fontsize=13)
ax3.set_ylabel(r'$\Delta$$F_\lambda$ ($\sigma$)', fontsize=13)

if xlim is not None:
ax1.set_xlim(xlim[0], xlim[1])
Expand All @@ -213,7 +220,7 @@ def plot_spectrum(boxes,
exponent = math.floor(math.log10(ylim[1]))
scaling = 10.**exponent

ylabel = r'Flux (10$^{'+str(exponent)+r'}$ W m$^{-2}$ $\mu$m$^{-1}$)'
ylabel = r'$F_\lambda$ (10$^{'+str(exponent)+r'}$ W m$^{-2}$ $\mu$m$^{-1}$)'

ax1.set_ylabel(ylabel, fontsize=13)
ax1.set_ylim(ylim[0]/scaling, ylim[1]/scaling)
Expand All @@ -222,7 +229,7 @@ def plot_spectrum(boxes,
ax1.axhline(0.0, linestyle='--', color='gray', dashes=(2, 4), zorder=0.5)

else:
ax1.set_ylabel(r'Flux (W m$^{-2}$ $\mu$m$^{-1}$)', fontsize=13)
ax1.set_ylabel(r'$F_\lambda$ (W m$^{-2}$ $\mu$m$^{-1}$)', fontsize=13)
scaling = 1.

if filters is not None:
Expand Down Expand Up @@ -263,6 +270,9 @@ def plot_spectrum(boxes,
ax1.get_xaxis().set_label_coords(0.5, -0.12)
ax1.get_yaxis().set_label_coords(-0.1, 0.5)

if scale is None:
scale = ('linear', 'linear')

ax1.set_xscale(scale[0])
ax1.set_yscale(scale[1])

Expand Down Expand Up @@ -346,7 +356,7 @@ def plot_spectrum(boxes,
label = kwargs_copy['label']

del kwargs_copy['label']

ax1.plot(wavelength, masked/scaling, zorder=2, label=label, **kwargs_copy)

else:
Expand Down Expand Up @@ -391,8 +401,52 @@ def plot_spectrum(boxes,
zorder=3)

elif isinstance(boxitem, box.ObjectBox):
if boxitem.spectrum is not None:
spec_list = []
wavel_list = []

for item in boxitem.spectrum:
spec_list.append(item)
wavel_list.append(boxitem.spectrum[item][0][0, 0])

sort_index = np.argsort(wavel_list)
spec_sort = []

for i in range(sort_index.size):
spec_sort.append(spec_list[sort_index[i]])

for key in spec_sort:
masked = np.ma.array(boxitem.spectrum[key][0],
mask=np.isnan(boxitem.spectrum[key][0]))

if not plot_kwargs[j] or key not in plot_kwargs[j]:
plot_obj = ax1.errorbar(masked[:, 0], masked[:, 1]/scaling,
yerr=masked[:, 2]/scaling, ms=2, marker='s',
zorder=2.5, ls='none')

plot_kwargs[j][key] = {'marker': 's', 'ms': 2., 'ls': 'none',
'color': plot_obj[0].get_color()}

else:
ax1.errorbar(masked[:, 0], masked[:, 1]/scaling, yerr=masked[:, 2]/scaling,
zorder=2.5, **plot_kwargs[j][key])

if boxitem.flux is not None:
filter_list = []
wavel_list = []

for item in boxitem.flux:
read_filt = read_filter.ReadFilter(item)
filter_list.append(item)
wavel_list.append(read_filt.mean_wavelength())

sort_index = np.argsort(wavel_list)
filter_sort = []

for i in range(sort_index.size):
filter_sort.append(filter_list[sort_index[i]])

for item in filter_sort:
transmission = read_filter.ReadFilter(item)
wavelength = transmission.mean_wavelength()
fwhm = transmission.filter_fwhm()
Expand Down Expand Up @@ -430,23 +484,6 @@ def plot_spectrum(boxes,
ax1.errorbar(wavelength, boxitem.flux[item][0]/scaling, xerr=fwhm/2.,
yerr=boxitem.flux[item][1]/scaling, zorder=3, **plot_kwargs[j][item])

if boxitem.spectrum is not None:
for key, value in boxitem.spectrum.items():
masked = np.ma.array(boxitem.spectrum[key][0],
mask=np.isnan(boxitem.spectrum[key][0]))

if not plot_kwargs[j] or key not in plot_kwargs[j]:
plot_obj = ax1.errorbar(masked[:, 0], masked[:, 1]/scaling,
yerr=masked[:, 2]/scaling, ms=2, marker='s',
zorder=2.5, ls='none')

plot_kwargs[j][key] = {'marker': 's', 'ms': 2., 'ls': 'none',
'color': plot_obj[0].get_color()}

else:
ax1.errorbar(masked[:, 0], masked[:, 1]/scaling, yerr=masked[:, 2]/scaling,
zorder=2.5, **plot_kwargs[j][key])

elif isinstance(boxitem, box.SynphotBox):
for i, find_item in enumerate(boxes):
if isinstance(find_item, box.ObjectBox):
Expand Down Expand Up @@ -558,11 +595,49 @@ def plot_spectrum(boxes,
else:
ax1.set_title(title, y=1.02, fontsize=15)

handles, _ = ax1.get_legend_handles_labels()
handles, labels = ax1.get_legend_handles_labels()

if handles and legend is not None:
if isinstance(legend, (str, tuple)):
if isinstance(legend, list):
model_handles = []
data_handles = []

model_labels = []
data_labels = []

for i, item in enumerate(handles):
if isinstance(item, mpl.lines.Line2D):
model_handles.append(item)
model_labels.append(labels[i])

elif isinstance(item, mpl.container.ErrorbarContainer):
data_handles.append(item)
data_labels.append(labels[i])

else:
warnings.warn(f'The object type {item} is not implemented for the legend.')

if legend[0] is not None:
if isinstance(legend[0], (str, tuple)):
leg_1 = ax1.legend(model_handles, model_labels, loc=legend[0], fontsize=10., frameon=False)
else:
leg_1 = ax1.legend(model_handles, model_labels, **legend[0])

else:
leg_1 = None

if legend[1] is not None:
if isinstance(legend[1], (str, tuple)):
leg_2 = ax1.legend(data_handles, data_labels, loc=legend[1], fontsize=8, frameon=False)
else:
leg_2 = ax1.legend(data_handles, data_labels, **legend[1])

if leg_1 is not None:
ax1.add_artist(leg_1)

elif isinstance(legend, (str, tuple)):
ax1.legend(loc=legend, fontsize=8, frameon=False)

else:
ax1.legend(**legend)

Expand Down

0 comments on commit 0657b90

Please sign in to comment.