From 0657b9065c425d975c110fdd2a70a6bee22b3234 Mon Sep 17 00:00:00 2001 From: Tomas Stolker Date: Wed, 13 May 2020 10:16:31 +0200 Subject: [PATCH] Optional separation of data and model legend in plot_spectrum, type checks added to plot_spectrum --- species/plot/plot_color.py | 16 ++-- species/plot/plot_mcmc.py | 21 +++-- species/plot/plot_spectrum.py | 173 ++++++++++++++++++++++++---------- 3 files changed, 149 insertions(+), 61 deletions(-) diff --git a/species/plot/plot_color.py b/species/plot/plot_color.py index 900f93a0..e5e54f9e 100644 --- a/species/plot/plot_color.py +++ b/species/plot/plot_color.py @@ -21,12 +21,6 @@ from species.util import plot_util -mpl.rcParams['font.serif'] = ['Bitstream Vera Serif'] -mpl.rcParams['font.family'] = 'serif' - -plt.rc('axes', edgecolor='black', linewidth=2.2) - - @typechecked def plot_color_magnitude(boxes: list, objects: Optional[Union[List[Tuple[str, str, str, str]], @@ -107,6 +101,11 @@ def plot_color_magnitude(boxes: list, """ + mpl.rcParams['font.serif'] = ['Bitstream Vera Serif'] + mpl.rcParams['font.family'] = 'serif' + + plt.rc('axes', edgecolor='black', linewidth=2.2) + model_color = ('#234398', '#f6a432', 'black') model_linestyle = ('-', '--', ':', '-.') @@ -559,6 +558,11 @@ def plot_color_color(boxes: list, None """ + mpl.rcParams['font.serif'] = ['Bitstream Vera Serif'] + mpl.rcParams['font.family'] = 'serif' + + plt.rc('axes', edgecolor='black', linewidth=2.2) + model_color = ('#234398', '#f6a432', 'black') model_linestyle = ('-', '--', ':', '-.') diff --git a/species/plot/plot_mcmc.py b/species/plot/plot_mcmc.py index 73180291..a67688e9 100644 --- a/species/plot/plot_mcmc.py +++ b/species/plot/plot_mcmc.py @@ -19,12 +19,6 @@ from species.util import plot_util -mpl.rcParams['font.serif'] = ['Bitstream Vera Serif'] -mpl.rcParams['font.family'] = 'serif' - -plt.rc('axes', edgecolor='black', linewidth=2.2) - - @typechecked def plot_walkers(tag: str, nsteps: Optional[int] = None, @@ -52,6 +46,11 @@ def plot_walkers(tag: str, print(f'Plotting walkers: {output}...', end='', flush=True) + mpl.rcParams['font.serif'] = ['Bitstream Vera Serif'] + mpl.rcParams['font.family'] = 'serif' + + plt.rc('axes', edgecolor='black', linewidth=2.2) + species_db = database.Database() box = species_db.get_samples(tag) @@ -154,6 +153,11 @@ def plot_posterior(tag: str, None """ + mpl.rcParams['font.serif'] = ['Bitstream Vera Serif'] + mpl.rcParams['font.family'] = 'serif' + + plt.rc('axes', edgecolor='black', linewidth=2.2) + if burnin is None: burnin = 0 @@ -303,6 +307,11 @@ def plot_photometry(tag, None """ + mpl.rcParams['font.serif'] = ['Bitstream Vera Serif'] + mpl.rcParams['font.family'] = 'serif' + + plt.rc('axes', edgecolor='black', linewidth=2.2) + species_db = database.Database() samples = species_db.get_mcmc_photometry(tag, burnin, filter_id) diff --git a/species/plot/plot_spectrum.py b/species/plot/plot_spectrum.py index d34d1084..ec4456e5 100644 --- a/species/plot/plot_spectrum.py +++ b/species/plot/plot_spectrum.py @@ -4,39 +4,39 @@ import os import math +import warnings import itertools +from typing import Optional, Union, Tuple, List + import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt +from typeguard import typechecked + from species.core import box, constants from species.read import read_filter from species.util import plot_util -mpl.rcParams['font.serif'] = ['Bitstream Vera Serif'] -mpl.rcParams['font.family'] = 'serif' - -plt.rc('axes', edgecolor='black', linewidth=2.2) -plt.rcParams['axes.axisbelow'] = False - - -def plot_spectrum(boxes, - filters=None, - residuals=None, - plot_kwargs=None, - xlim=None, - ylim=None, - ylim_res=None, - scale=('linear', 'linear'), - title=None, - offset=None, - legend=None, - figsize=(7., 5.), - object_type='planet', - quantity='flux', - output='spectrum.pdf'): +@typechecked +def plot_spectrum(boxes: list, + filters: Optional[List[str]] = None, + residuals: Optional[box.ResidualsBox] = None, + plot_kwargs: Optional[List[Optional[dict]]] = None, + xlim: Optional[Tuple[float, float]] = None, + ylim: Optional[Tuple[float, float]] = None, + ylim_res: Optional[Tuple[float, float]] = None, + scale: Optional[Tuple[str, str]] = None, + title: Optional[str] = None, + offset: Optional[Tuple[float, float]] = None, + legend: Union[str, dict, Tuple[float, float], + List[Optional[Union[dict, str, Tuple[float, float]]]]] = None, + figsize: Optional[Tuple[float, float]] = (7., 5.), + object_type: str = 'planet', + quantity: str = 'flux', + output: str = 'spectrum.pdf'): """ Parameters ---------- @@ -75,15 +75,16 @@ def plot_spectrum(boxes, ylim_res : tuple(float, float), None Limits of the residuals axis. Automatically chosen (based on the minimum and maximum residual value) if set to None. - scale : tuple(str, str) - Scale of the axes ('linear' or 'log'). + scale : tuple(str, str), None + Scale of the x and y axes ('linear' or 'log'). The scale is set to ``('linear', 'linear')`` + if set to ``None``. title : str Title. offset : tuple(float, float) Offset for the label of the x- and y-axis. - legend : str, tuple, dict, None + legend : str, tuple, dict, list(dict, dict), None Location of the legend (str, tuple) or a dictionary with the ``**kwargs`` of - ``matplotlib.pyplot.legend``, e.g. ``{'loc': 'upper left', 'fontsize: 12.}``. + ``matplotlib.pyplot.legend``, for example ``{'loc': 'upper left', 'fontsize: 12.}``. figsize : tuple(float, float) Figure size. object_type : str @@ -100,6 +101,12 @@ def plot_spectrum(boxes, None """ + mpl.rcParams['font.serif'] = ['Bitstream Vera Serif'] + mpl.rcParams['font.family'] = 'serif' + + plt.rc('axes', edgecolor='black', linewidth=2.2) + plt.rcParams['axes.axisbelow'] = False + if plot_kwargs is None: plot_kwargs = [] @@ -190,7 +197,7 @@ def plot_spectrum(boxes, ax2.set_ylabel('Transmission', fontsize=13) if residuals is not None: - ax3.set_ylabel(r'Residual ($\sigma$)', fontsize=13) + ax3.set_ylabel(r'$\Delta$$F_\lambda$ ($\sigma$)', fontsize=13) if xlim is not None: ax1.set_xlim(xlim[0], xlim[1]) @@ -213,7 +220,7 @@ def plot_spectrum(boxes, exponent = math.floor(math.log10(ylim[1])) scaling = 10.**exponent - ylabel = r'Flux (10$^{'+str(exponent)+r'}$ W m$^{-2}$ $\mu$m$^{-1}$)' + ylabel = r'$F_\lambda$ (10$^{'+str(exponent)+r'}$ W m$^{-2}$ $\mu$m$^{-1}$)' ax1.set_ylabel(ylabel, fontsize=13) ax1.set_ylim(ylim[0]/scaling, ylim[1]/scaling) @@ -222,7 +229,7 @@ def plot_spectrum(boxes, ax1.axhline(0.0, linestyle='--', color='gray', dashes=(2, 4), zorder=0.5) else: - ax1.set_ylabel(r'Flux (W m$^{-2}$ $\mu$m$^{-1}$)', fontsize=13) + ax1.set_ylabel(r'$F_\lambda$ (W m$^{-2}$ $\mu$m$^{-1}$)', fontsize=13) scaling = 1. if filters is not None: @@ -263,6 +270,9 @@ def plot_spectrum(boxes, ax1.get_xaxis().set_label_coords(0.5, -0.12) ax1.get_yaxis().set_label_coords(-0.1, 0.5) + if scale is None: + scale = ('linear', 'linear') + ax1.set_xscale(scale[0]) ax1.set_yscale(scale[1]) @@ -346,7 +356,7 @@ def plot_spectrum(boxes, label = kwargs_copy['label'] del kwargs_copy['label'] - + ax1.plot(wavelength, masked/scaling, zorder=2, label=label, **kwargs_copy) else: @@ -391,8 +401,52 @@ def plot_spectrum(boxes, zorder=3) elif isinstance(boxitem, box.ObjectBox): + if boxitem.spectrum is not None: + spec_list = [] + wavel_list = [] + + for item in boxitem.spectrum: + spec_list.append(item) + wavel_list.append(boxitem.spectrum[item][0][0, 0]) + + sort_index = np.argsort(wavel_list) + spec_sort = [] + + for i in range(sort_index.size): + spec_sort.append(spec_list[sort_index[i]]) + + for key in spec_sort: + masked = np.ma.array(boxitem.spectrum[key][0], + mask=np.isnan(boxitem.spectrum[key][0])) + + if not plot_kwargs[j] or key not in plot_kwargs[j]: + plot_obj = ax1.errorbar(masked[:, 0], masked[:, 1]/scaling, + yerr=masked[:, 2]/scaling, ms=2, marker='s', + zorder=2.5, ls='none') + + plot_kwargs[j][key] = {'marker': 's', 'ms': 2., 'ls': 'none', + 'color': plot_obj[0].get_color()} + + else: + ax1.errorbar(masked[:, 0], masked[:, 1]/scaling, yerr=masked[:, 2]/scaling, + zorder=2.5, **plot_kwargs[j][key]) + if boxitem.flux is not None: + filter_list = [] + wavel_list = [] + for item in boxitem.flux: + read_filt = read_filter.ReadFilter(item) + filter_list.append(item) + wavel_list.append(read_filt.mean_wavelength()) + + sort_index = np.argsort(wavel_list) + filter_sort = [] + + for i in range(sort_index.size): + filter_sort.append(filter_list[sort_index[i]]) + + for item in filter_sort: transmission = read_filter.ReadFilter(item) wavelength = transmission.mean_wavelength() fwhm = transmission.filter_fwhm() @@ -430,23 +484,6 @@ def plot_spectrum(boxes, ax1.errorbar(wavelength, boxitem.flux[item][0]/scaling, xerr=fwhm/2., yerr=boxitem.flux[item][1]/scaling, zorder=3, **plot_kwargs[j][item]) - if boxitem.spectrum is not None: - for key, value in boxitem.spectrum.items(): - masked = np.ma.array(boxitem.spectrum[key][0], - mask=np.isnan(boxitem.spectrum[key][0])) - - if not plot_kwargs[j] or key not in plot_kwargs[j]: - plot_obj = ax1.errorbar(masked[:, 0], masked[:, 1]/scaling, - yerr=masked[:, 2]/scaling, ms=2, marker='s', - zorder=2.5, ls='none') - - plot_kwargs[j][key] = {'marker': 's', 'ms': 2., 'ls': 'none', - 'color': plot_obj[0].get_color()} - - else: - ax1.errorbar(masked[:, 0], masked[:, 1]/scaling, yerr=masked[:, 2]/scaling, - zorder=2.5, **plot_kwargs[j][key]) - elif isinstance(boxitem, box.SynphotBox): for i, find_item in enumerate(boxes): if isinstance(find_item, box.ObjectBox): @@ -558,11 +595,49 @@ def plot_spectrum(boxes, else: ax1.set_title(title, y=1.02, fontsize=15) - handles, _ = ax1.get_legend_handles_labels() + handles, labels = ax1.get_legend_handles_labels() if handles and legend is not None: - if isinstance(legend, (str, tuple)): + if isinstance(legend, list): + model_handles = [] + data_handles = [] + + model_labels = [] + data_labels = [] + + for i, item in enumerate(handles): + if isinstance(item, mpl.lines.Line2D): + model_handles.append(item) + model_labels.append(labels[i]) + + elif isinstance(item, mpl.container.ErrorbarContainer): + data_handles.append(item) + data_labels.append(labels[i]) + + else: + warnings.warn(f'The object type {item} is not implemented for the legend.') + + if legend[0] is not None: + if isinstance(legend[0], (str, tuple)): + leg_1 = ax1.legend(model_handles, model_labels, loc=legend[0], fontsize=10., frameon=False) + else: + leg_1 = ax1.legend(model_handles, model_labels, **legend[0]) + + else: + leg_1 = None + + if legend[1] is not None: + if isinstance(legend[1], (str, tuple)): + leg_2 = ax1.legend(data_handles, data_labels, loc=legend[1], fontsize=8, frameon=False) + else: + leg_2 = ax1.legend(data_handles, data_labels, **legend[1]) + + if leg_1 is not None: + ax1.add_artist(leg_1) + + elif isinstance(legend, (str, tuple)): ax1.legend(loc=legend, fontsize=8, frameon=False) + else: ax1.legend(**legend)