Skip to content

Commit

Permalink
Merge pull request #276 from sblunt/fix-241
Browse files Browse the repository at this point in the history
fix issue #241 and add a unit test
  • Loading branch information
sblunt committed Sep 4, 2021
2 parents c953cf7 + 9140745 commit 01cf83b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
18 changes: 12 additions & 6 deletions orbitize/results.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import numpy as np
import warnings
import h5py
import copy
import itertools

import astropy.units as u
import astropy.constants as consts
from astropy.io import fits
from astropy.time import Time
import astropy.table as table
from erfa import ErfaWarning
Expand All @@ -15,7 +13,6 @@
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import matplotlib.colors as colors
import pandas as pd

import corner

Expand Down Expand Up @@ -561,6 +558,7 @@ def plot_orbits(self, object_to_plot=1, start_mjd=51544.,
# Compute period (from Kepler's third law)
period = np.sqrt(4*np.pi**2.0*(sma*u.AU)**3/(consts.G*(mtot*u.Msun)))
period = period.to(u.day).value

# Create an epochs array to plot num_epochs_to_plot points over one orbital period
epochs[i, :] = np.linspace(start_mjd, float(
start_mjd+period[i]), num_epochs_to_plot)
Expand All @@ -583,11 +581,19 @@ def plot_orbits(self, object_to_plot=1, start_mjd=51544.,
cbar_param_arr), vmax=np.max(cbar_param_arr))

elif cbar_param == 'Epoch [year]':
norm = mpl.colors.Normalize(vmin=np.min(epochs), vmax=np.max(epochs[-1, :]))

min_cbar_date = np.min(epochs)
max_cbar_date = np.max(epochs[-1, :])

# if we're plotting orbital periods greater than 1,000 yrs, limit the colorbar dynamic range
if max_cbar_date - min_cbar_date > 1000 * 365.25:
max_cbar_date = min_cbar_date + 1000 * 365.25

norm = mpl.colors.Normalize(vmin=min_cbar_date, vmax=max_cbar_date)

norm_yr = mpl.colors.Normalize(
vmin=np.min(Time(epochs, format='mjd').decimalyear),
vmax=np.max(Time(epochs, format='mjd').decimalyear)
vmin=Time(min_cbar_date, format='mjd').decimalyear,
vmax=Time(max_cbar_date, format='mjd').decimalyear
)

# Before starting to plot rv data, make sure rv data exists:
Expand Down
40 changes: 24 additions & 16 deletions tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ def results_to_test():
# Return object for testing
return results_obj

def test_plot_long_periods(results_to_test):

# make all orbits in the results posterior have absurdly long orbits
mtot_idx = results_to_test.param_idx['mtot']
results_to_test.post[:,mtot_idx] = 1e-10

results_to_test.plot_orbits()

def test_save_and_load_results(results_to_test, has_lnlike=True):
"""
Expand Down Expand Up @@ -204,23 +211,24 @@ def test_plot_orbits(results_to_test):

if __name__ == "__main__":
test_results = test_init_and_add_samples()
test_results_radec = test_init_and_add_samples(radec_input=True)
test_plot_long_periods(test_results)
# test_results_radec = test_init_and_add_samples(radec_input=True)

test_save_and_load_results(test_results, has_lnlike=True)
test_save_and_load_results(test_results, has_lnlike=True)
test_save_and_load_results(test_results, has_lnlike=False)
test_save_and_load_results(test_results, has_lnlike=False)
test_corner_fig1, test_corner_fig2, test_corner_fig3 = test_plot_corner(test_results)
test_orbit_figs = test_plot_orbits(test_results)
test_orbit_figs = test_plot_orbits(test_results_radec)
test_corner_fig1.savefig('test_corner1.png');
test_corner_fig2.savefig('test_corner2.png')
test_corner_fig3.savefig('test_corner3.png')
test_orbit_figs[0].savefig('test_orbit1.png')
test_orbit_figs[1].savefig('test_orbit2.png')
test_orbit_figs[2].savefig('test_orbit3.png')
test_orbit_figs[3].savefig('test_orbit4.png')
test_orbit_figs[4].savefig('test_orbit5.png')
# test_save_and_load_results(test_results, has_lnlike=True)
# test_save_and_load_results(test_results, has_lnlike=True)
# test_save_and_load_results(test_results, has_lnlike=False)
# test_save_and_load_results(test_results, has_lnlike=False)
# test_corner_fig1, test_corner_fig2, test_corner_fig3 = test_plot_corner(test_results)
# test_orbit_figs = test_plot_orbits(test_results)
# test_orbit_figs = test_plot_orbits(test_results_radec)
# test_corner_fig1.savefig('test_corner1.png');
# test_corner_fig2.savefig('test_corner2.png')
# test_corner_fig3.savefig('test_corner3.png')
# test_orbit_figs[0].savefig('test_orbit1.png')
# test_orbit_figs[1].savefig('test_orbit2.png')
# test_orbit_figs[2].savefig('test_orbit3.png')
# test_orbit_figs[3].savefig('test_orbit4.png')
# test_orbit_figs[4].savefig('test_orbit5.png')

# clean up
os.system('rm test_*.png')

0 comments on commit 01cf83b

Please sign in to comment.