In [365]:
#%matplotlib notebook
%matplotlib widget
from IPython.display import display
import exoplanet as xo
import aesara_theano_fallback.tensor as tt
import theano
from celerite2.theano import terms, GaussianProcess
import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt

In [366]:
def ready_data(y, yerr, data_type='lc'):
    truth = y - yerr
    yerr = np.abs(yerr)

    if data_type == 'lc':
        mu = np.mean(y)
        y = (y / mu - 1) * 1e3
        yerr = yerr * 1e3 / mu
        truth = (truth / mu - 1) * 1e3

    return y, yerr, truth

In [428]:
# test data

#var_path = "Spots/Orientation 1/Simulated Data/Decay Rate/tau_0.050"
var_path = "example data/tau_0.050"

t = np.load(var_path + "_t.npy")
flux = np.load(var_path + "_f.npy")
flux_err = np.load(var_path + "_ferr.npy")
rv = np.load(var_path + "_rv.npy")
rv_err = np.load(var_path + "_rverr.npy")

flux, flux_err, f_truth = ready_data(flux, flux_err)
rv, rv_err, rv_truth = ready_data(rv, rv_err)
tau = t - t[0]

In [368]:
mean = tt.dscalar('mean')
jitter = tt.dscalar('jitter')
sigma = tt.dscalar('sigma')
rho = tt.dscalar('rho')
sigma_rot = tt.dscalar('sigma_rot')
period = tt.dscalar('period')
Q0 = tt.dscalar('Q0')
dQ = tt.dscalar('dQ')
f = tt.dscalar('f')

In [369]:
#def run_gp(p, x=t, y=flux, yerr=flux_err):
def run_gp(mean, jitter, sigma, rho, sigma_rot, period, Q0, dQ, f, x=t, y=flux, yerr=flux_err):    
    #mean, log_jitter, sigma, rho, sigma_rot, log_period, log_Q0, log_dQ, f = p

    kernel = terms.SHOTerm(sigma=sigma, rho=rho, Q=1.0/3.0)
    kernel += terms.RotationTerm(
        sigma=sigma_rot,
        period=period,
        Q0=Q0,
        dQ=dQ,
        f=f,)

    gp = GaussianProcess(
        kernel,
        t=x,
        diag=yerr**2 + jitter**2,
        mean=mean,
        quiet=True,)
    
    print()
    
    shot_psd = gp.kernel.terms[0].get_psd(omega)
    rotation_psd = gp.kernel.terms[1].get_psd(omega)
    full_psd = gp.kernel.get_psd(omega)

    mu = gp.predict(y, return_var=False)

    k_val = kernel.get_value(x - x[0])
    k_val /= k_val[0]

    return mu, k_val, shot_psd, rotation_psd, full_psd

In [372]:
gp_params = [mean, jitter, sigma, rho, sigma_rot, period, Q0, dQ, f]
#gp_func = theano.function(gp_params, run_gp(gp_params))

gp_func = theano.function(gp_params, run_gp(mean, jitter, sigma, rho, sigma_rot, period, Q0, dQ, f))

#k_params = [np.average(flux), np.std(flux), 1.5, 1.0, 0.1, 11.0, 5.0, 3.0, 0.5]
mu, k_val, shot_psd, rotation_psd, full_psd = gp_func(np.average(flux), np.average(flux_err), 1.5, 1.0, 0.1, 11.0, 5.0, 3.0, 0.5)

freq = np.linspace(1.0 / 20.0, 1.0 / 0.25, 1000)
omega = 2.0 * np.pi * freq
tau = t - t[0]
acor = xo.estimators.autocorr_function(flux)




In [427]:
lc_c = '#5198E6'
rv_c = '#F58D2F'

fig = plt.figure(figsize=(14, 6))
gs = fig.add_gridspec(nrows=2, ncols=4, left=0.075, right=0.98, wspace=0.375, hspace=0.35)
ax = fig.add_subplot(gs[:-1, :])
ax1 = fig.add_subplot(gs[-1, :2])
ax2 = fig.add_subplot(gs[-1, 2:])

ax1.plot(tau, acor, color='black', lw=2.0, ls='-', label='ACF')
line_kernel, = ax1.plot(tau, k_val, color=lc_c, lw=2.5, ls='--', label='GP Kernel')

psd1, = ax2.loglog(freq, shot_psd, color='purple', label=r'SHOT Term')
psd2, = ax2.loglog(freq, rotation_psd, color='gold', label=r'Rotation Term')
psd3, = ax2.loglog(freq, full_psd, color='black', lw=2.0, ls='dotted', label=r'Full Model')

ax.scatter(t, flux, c='black', marker='o', s=10.0, label='Data', zorder=-1)
ax.plot(t, f_truth, color='orange', alpha=0.6, lw=1.5, label='Truth', zorder=10)
line, = ax.plot(t, mu, color=lc_c,
                alpha=0.85, zorder=6, lw=3.5,
                label=r'GP')

def update(avg=np.average(flux), jitter=np.average(flux_err), sigma=3.0, rho=0.5, sigma_rot=0.5, period=10.0, Q0=1.0, dQ=1.0, f=0.5):
    new_mu, new_k_val, new_psd1, new_psd2, new_psd3 = gp_func(avg, jitter, sigma, rho, sigma_rot, period, Q0, dQ, f)
    line.set_ydata(new_mu)
    line_kernel.set_ydata(new_k_val)
    psd1.set_ydata(new_psd1)
    psd2.set_ydata(new_psd2)
    psd3.set_ydata(new_psd3)
    fig.canvas.draw_idle()
    
avg_widget = widgets.FloatSlider(description=r"$\mu$", min=np.min(flux), max=np.max(flux), step=0.05*np.ptp(flux), value=np.average(flux))
jitter_widget = widgets.FloatSlider(description=r"jitter", min=0.01, max=5.0, step=0.1, value=np.average(flux_err))
sigma_widget = widgets.FloatSlider(description=r"$\sigma$", min=0.01, max=2.5, step=0.01, value=1.0)
rho_widget = widgets.FloatSlider(description=r"$\rho$", min=0.01, max=50.0, step=0.01, value=0.5)
sigma_rot_widget = widgets.FloatSlider(description=r"$\sigma_\mathrm{rot}$", min=0.0, max=2.5, step=0.01, value=0.5)
period_widget = widgets.FloatSlider(description=r"$P$", min=0.01, max=20.0, step=0.1, value=10.0)
Q0_widget = widgets.FloatSlider(description=r"$Q_0$", min=0.01, max=25.0, step=0.01, value=1.0)
dQ_widget = widgets.FloatSlider(description=r"$dQ$", min=0.01, max=50.0, step=0.01, value=1.0)
f_widget = widgets.FloatSlider(description=r"$f$", min=0.01, max=2.5, step=0.01, value=0.5)

# old way, list all sliders but vertically stacked
#widgets.interact(update, avg=avg_widget, jitter=jitter_widget, sigma=sigma_widget, rho=rho_widget, sigma_rot=sigma_rot_widget,
#                 period=period_widget, Q0=Q0_widget, dQ=dQ_widget, f=f_widget);

all_params = {'avg':avg_widget, 'jitter':jitter_widget, 'sigma':sigma_widget, 'rho':rho_widget,
              'sigma_rot':sigma_rot_widget, 'period':period_widget, 'Q0':Q0_widget, 'dQ':dQ_widget, 'f':f_widget}
out = widgets.interactive_output(update, all_params)

display(widgets.HBox([avg_widget, jitter_widget]), out)
display(widgets.HBox([sigma_widget, rho_widget]), out)
display(widgets.HBox([sigma_rot_widget, period_widget, Q0_widget, dQ_widget, f_widget]), out)

ax.set_xlim([np.min(t), np.max(t)])
ax.tick_params(axis='both', which='major', labelsize=15)
ax.set_xlabel(r"Time (days)", fontsize=19)
ax.set_ylabel(r"Relative Flux (ppt)", fontsize=19)
ax.legend(fontsize=14, markerscale=2.)

ax1.set_xlim([t.min(), t.max()])
ax1.set_ylim(top=1.0)
ax1.tick_params(axis='both', which='major', labelsize=13)
ax1.set_ylabel(r"ACF or GP Kernel (ppt$^2$)", fontsize=14)
ax1.set_xlabel(r"$\tau$", fontsize=18)
ax1.legend(fontsize=14)

ax2.set_xlim([freq.min(), freq.max()])
ax2.set_ylim(top=50.0)
ax2.tick_params(axis='both', which='major', labelsize=13)
ax2.set_ylabel(r"PSD (day ppt$^2$)", fontsize=14)
ax2.set_xlabel(r"Frequency (day$^{-1}$)", fontsize=18)
ax2.legend(fontsize=14)

plt.tight_layout()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

HBox(children=(FloatSlider(value=-4.889586677934319e-14, description='$\\mu$', max=4.790708451977821, min=-9.9…

Output()

HBox(children=(FloatSlider(value=1.0, description='$\\sigma$', max=2.5, min=0.01, step=0.01), FloatSlider(valu…

Output()

HBox(children=(FloatSlider(value=0.5, description='$\\sigma_\\mathrm{rot}$', max=2.5, step=0.01), FloatSlider(…

Output()

  plt.tight_layout()


In [423]:
plt.close()