Skip to content

Commit

Permalink
Merge pull request #318 from sblunt/rv_plot_fix2
Browse files Browse the repository at this point in the history
Fix multiple plotting issues
  • Loading branch information
sblunt committed Apr 22, 2022
2 parents 8f56664 + 2b70168 commit efb638a
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 253 deletions.
274 changes: 54 additions & 220 deletions docs/tutorials/Plotting_tutorial.ipynb

Large diffs are not rendered by default.

124 changes: 93 additions & 31 deletions orbitize/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def plot_orbits(results, object_to_plot=1, start_mjd=51544.,
square_plot=True, show_colorbar=True, cmap=cmap,
sep_pa_color='lightgrey', sep_pa_end_year=2025.0,
cbar_param='Epoch [year]', mod180=False, rv_time_series=False, plot_astrometry=True,
plot_astrometry_insts=False, fig=None):
plot_astrometry_insts=False, plot_errorbars=True, fig=None):
"""
Plots one orbital period for a select number of fitted orbits
for a given object, with line segments colored according to time
Expand Down Expand Up @@ -179,6 +179,7 @@ def plot_orbits(results, object_to_plot=1, start_mjd=51544.,
display time series, set to True.
plot_astrometry (Boolean): set to True by default. Plots the astrometric data.
plot_astrometry_insts (Boolean): set to False by default. Plots the astrometric data by instruments.
plot_errorbars (Boolean): set to True by default. Plots error bars of measurements
fig (matplotlib.pyplot.Figure): optionally include a predefined Figure object to plot the orbit on.
Most users will not need this keyword.
Expand Down Expand Up @@ -336,8 +337,11 @@ def plot_orbits(results, object_to_plot=1, start_mjd=51544.,
radec_inds = np.where(data['quant_type'] == 'radec')
seppa_inds = np.where(data['quant_type'] == 'seppa')

sep_data, sep_err=data['quant1'][seppa_inds],data['quant1_err'][seppa_inds]
pa_data, pa_err=data['quant2'][seppa_inds],data['quant2_err'][seppa_inds]
# transform RA/Dec points to Sep/PA
sep_data = np.copy(data['quant1'])
sep_err = np.copy(data['quant1_err'])
pa_data = np.copy(data['quant2'])
pa_err = np.copy(data['quant2_err'])

if len(radec_inds[0] > 0):

Expand All @@ -356,16 +360,45 @@ def plot_orbits(results, object_to_plot=1, start_mjd=51544.,
np.array(data['quant12_corr'][radec_inds][j]), orbitize.system.radec2seppa
)

sep_data = np.append(sep_data, sep_from_ra_data)
sep_err = np.append(sep_err, sep_err_from_ra_data)
sep_data[radec_inds] = sep_from_ra_data
sep_err[radec_inds] = sep_err_from_ra_data

pa_data = np.append(pa_data, pa_from_dec_data)
pa_err = np.append(pa_err, pa_err_from_dec_data)
pa_data[radec_inds] = pa_from_dec_data
pa_err[radec_inds] = pa_err_from_dec_data

# Transform Sep/PA points to RA/Dec
ra_data = np.copy(data['quant1'])
ra_err = np.copy(data['quant1_err'])
dec_data = np.copy(data['quant2'])
dec_err = np.copy(data['quant2_err'])

if len(seppa_inds[0] > 0):

ra_from_seppa_data, dec_from_seppa_data = orbitize.system.seppa2radec(
data['quant1'][seppa_inds], data['quant2'][seppa_inds]
)

num_seppa_pts = len(seppa_inds[0])
ra_err_from_seppa_data = np.empty(num_seppa_pts)
dec_err_from_seppa_data = np.empty(num_seppa_pts)
for j in np.arange(num_seppa_pts):

ra_err_from_seppa_data[j], dec_err_from_seppa_data[j], _ = orbitize.system.transform_errors(
np.array(data['quant1'][seppa_inds][j]), np.array(data['quant2'][seppa_inds][j]),
np.array(data['quant1_err'][seppa_inds][j]), np.array(data['quant2_err'][seppa_inds][j]),
np.array(data['quant12_corr'][seppa_inds][j]), orbitize.system.seppa2radec
)

ra_data[seppa_inds] = ra_from_seppa_data
ra_err[seppa_inds] = ra_err_from_seppa_data

dec_data[seppa_inds] = dec_from_seppa_data
dec_err[seppa_inds] = dec_err_from_seppa_data

# For plotting different astrometry instruments
if plot_astrometry_insts:
astr_colors = ('#FF7F11', '#11FFE3', '#14FF11', '#7A11FF', '#FF1919')
astr_symbols = ('*', 'o', 'p', 's')
astr_colors = ('purple','#FF7F11', '#11FFE3', '#14FF11', '#7A11FF', '#FF1919')
astr_symbols = ( 'o', '*', 'p', 's')

ax_colors = itertools.cycle(astr_colors)
ax_symbols = itertools.cycle(astr_symbols)
Expand All @@ -391,17 +424,30 @@ def plot_orbits(results, object_to_plot=1, start_mjd=51544.,
lc.set_array(epochs[i, :])
ax.add_collection(lc)

if plot_astrometry:
ra_data,dec_data=orbitize.system.seppa2radec(sep_data,pa_data)

# Plot astrometry along with instruments
if plot_astrometry_insts:
for i in range(len(astr_insts)):
ra = ra_data[astr_inst_inds[astr_insts[i]]]
dec = dec_data[astr_inst_inds[astr_insts[i]]]
ax.scatter(ra, dec, marker=next(ax_symbols), c=next(ax_colors), zorder=10, s=60, label=astr_insts[i])
else:
ax.scatter(ra_data, dec_data, marker='*', c='#FF7F11', zorder=10, s=60)
# if plot_astrometry:

# # Plot astrometry along with instruments
# if plot_astrometry_insts:
# for i in range(len(astr_insts)):
# ra = ra_data[astr_inst_inds[astr_insts[i]]]
# dec = dec_data[astr_inst_inds[astr_insts[i]]]
# if plot_errorbars:
# xerr = ra_err[astr_inst_inds[astr_insts[i]]]
# yerr = dec_err[astr_inst_inds[astr_insts[i]]]
# else:
# xerr = None
# yerr = None

# ax.errorbar(ra, dec, xerr=xerr, yerr=yerr, marker=next(ax_symbols), c=next(ax_colors), zorder=10, label=astr_insts[i], linestyle='', ms=5, capsize=2)
# else:
# if plot_errorbars:
# xerr = ra_err
# yerr = dec_err
# else:
# xerr = None
# yerr = None

# ax.errorbar(ra_data, dec_data, xerr=xerr, yerr=yerr, marker='o', c='#FF7F11', zorder=10, linestyle='', capsize=2, ms=5)

# modify the axes
if square_plot:
Expand Down Expand Up @@ -474,28 +520,42 @@ def plot_orbits(results, object_to_plot=1, start_mjd=51544.,
seps, pas = orbitize.system.radec2seppa(raoff[i, :], deoff[i, :], mod180=mod180)

plt.sca(ax1)
plt.plot(yr_epochs, seps, color=sep_pa_color)
plt.plot(yr_epochs, seps, color=sep_pa_color, zorder=1)

plt.sca(ax2)
plt.plot(yr_epochs, pas, color=sep_pa_color)
plt.plot(yr_epochs, pas, color=sep_pa_color, zorder=1)

# Plot sep/pa instruments
if plot_astrometry_insts:
for i in range(len(astr_insts)):
sep = sep_data[astr_inst_inds[astr_insts[i]]]
pa = pa_data[astr_inst_inds[astr_insts[i]]]
epochs = astr_epochs[astr_inst_inds[astr_insts[i]]]
if plot_errorbars:
serr = sep_err[astr_inst_inds[astr_insts[i]]]
perr = pa_err[astr_inst_inds[astr_insts[i]]]
else:
yerr = None
perr = None

plt.sca(ax1)
plt.scatter(Time(epochs,format='mjd').decimalyear,sep,s=10,marker=next(ax1_symbols),c=next(ax1_colors),zorder=10,label=astr_insts[i])
plt.errorbar(Time(epochs,format='mjd').decimalyear,sep,yerr=serr,ms=5, linestyle='',marker=next(ax1_symbols),c=next(ax1_colors),zorder=10,label=astr_insts[i], capsize=2)
plt.sca(ax2)
plt.scatter(Time(epochs,format='mjd').decimalyear,pa,s=10,marker=next(ax2_symbols),c=next(ax2_colors),zorder=10)
plt.errorbar(Time(epochs,format='mjd').decimalyear,pa,yerr=perr,ms=5, linestyle='',marker=next(ax2_symbols),c=next(ax2_colors),zorder=10, capsize=2)
plt.sca(ax1)
plt.legend(title='Instruments', bbox_to_anchor=(1.3, 1), loc='upper right')
else:
if plot_errorbars:
serr = sep_err
perr = pa_err
else:
yerr = None
perr = None

plt.sca(ax1)
plt.scatter(Time(astr_epochs,format='mjd').decimalyear,sep_data,s=10,marker='*',c='purple',zorder=10)
plt.errorbar(Time(astr_epochs,format='mjd').decimalyear,sep_data,yerr=serr,ms=5, linestyle='',marker='o',c='purple',zorder=2, capsize=2)
plt.sca(ax2)
plt.scatter(Time(astr_epochs,format='mjd').decimalyear,pa_data,s=10,marker='*',c='purple',zorder=10)
plt.errorbar(Time(astr_epochs,format='mjd').decimalyear,pa_data,yerr=perr,ms=5, linestyle='',marker='o',c='purple',zorder=2, capsize=2)

if rv_time_series:

Expand All @@ -512,6 +572,7 @@ def plot_orbits(results, object_to_plot=1, start_mjd=51544.,

# get gamma/sigma labels and corresponding positions in the posterior
gams=['gamma_'+inst for inst in insts]
sigs = ['sigma_'+inst for inst in insts]

if isinstance(results.labels,list):
labels=np.array(results.labels)
Expand All @@ -528,7 +589,6 @@ def plot_orbits(results, object_to_plot=1, start_mjd=51544.,

# choose the orbit with the best log probability
best_like=np.where(results.lnlike==np.amax(results.lnlike))[0][0]
med_ga=[results.post[best_like,i] for i in gam_idx]

# Get the posteriors for this index and convert to standard basis
best_post = results.basis.to_standard_basis(results.post[best_like].copy())
Expand All @@ -539,7 +599,7 @@ def plot_orbits(results, object_to_plot=1, start_mjd=51544.,
best_mtot = best_m0 + best_m1

# colour/shape scheme scheme for rv data points
clrs=('#0496FF','#372554','#FF1053','#3A7CA5','#143109')
clrs=('purple', '#0496FF','#372554','#FF1053','#3A7CA5','#143109')
symbols=('o','^','v','s')

ax3_colors = itertools.cycle(clrs)
Expand All @@ -551,9 +611,11 @@ def plot_orbits(results, object_to_plot=1, start_mjd=51544.,
rvs=inst_data['quant1']
epochs=inst_data['epoch']
epochs=Time(epochs, format='mjd').decimalyear
rvs-=med_ga[i]
rvs -= best_post[results.param_idx[gams[i]]]
plt.scatter(epochs,rvs,s=5,marker=next(ax3_symbols),c=next(ax3_colors),label=name,zorder=5)
if plot_errorbars:
yerr = inst_data['quant1_err']
yerr = np.sqrt(yerr**2 + best_post[results.param_idx[sigs[i]]]**2)
plt.errorbar(epochs,rvs,yerr=yerr,ms=5, linestyle='',marker=next(ax3_symbols),c=next(ax3_colors),label=name,zorder=5,capsize=2)
if len(inds.keys()) == 1 and 'defrv' in inds.keys():
pass
else:
Expand All @@ -575,7 +637,7 @@ def plot_orbits(results, object_to_plot=1, start_mjd=51544.,
vz=vz*-(best_m1)/np.median(best_m0)

# plot rv trend
plt.plot(Time(epochs_seppa[0, :],format='mjd').decimalyear, vz, color=sep_pa_color)
plt.plot(Time(epochs_seppa[0, :],format='mjd').decimalyear, vz, color=sep_pa_color, zorder=1)


# add colorbar
Expand Down
6 changes: 4 additions & 2 deletions orbitize/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ def plot_orbits(self, object_to_plot=1, start_mjd=51544.,
sep_pa_color='lightgrey', sep_pa_end_year=2025.0,
cbar_param='Epoch [year]', mod180=False, rv_time_series=False,
plot_astrometry=True,
plot_astrometry_insts=False, fig=None
plot_astrometry_insts=False,
plot_errorbars=True, fig=None
):
"""
Wrapper for orbitize.plot.plot_orbits
Expand All @@ -327,7 +328,8 @@ def plot_orbits(self, object_to_plot=1, start_mjd=51544.,
sep_pa_color=sep_pa_color, sep_pa_end_year=sep_pa_end_year,
cbar_param=cbar_param, mod180=mod180, rv_time_series=rv_time_series,
plot_astrometry=plot_astrometry,
plot_astrometry_insts=plot_astrometry_insts, fig=fig
plot_astrometry_insts=plot_astrometry_insts,
plot_errorbars=plot_errorbars, fig=fig
)


Expand Down

0 comments on commit efb638a

Please sign in to comment.