In [None]:
import os
import warnings
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import ascii
import sys
import warnings
import pickle
import argparse
from multiprocess import Pool

# enter your exotic-ld path if it isn't already defined in your 
# environment variables. 
os.environ['EXO_LD_PATH'] = '/Users/tylergordon/research/exotic_ld_data'

from fit import fit_joint_wlc, fit_joint_spec
from plot_utils import get_wl_models, get_canonical_params
from distributions import *
import sys_params
import plot_utils

def load_all(base_dir, prefix):

    time = np.load(base_dir + prefix + '_times_bjd.npy')
    time_offset = time[0]
    time -= time[0]
    spec = np.load(base_dir + prefix + '.npy')
    cube = np.load(base_dir + prefix + '_cleaned_cube.npy')
    wavs = np.load(base_dir + prefix + '_wav.npy')

    return time, wavs, spec, cube, time_offset

def fit_wrapper(base_dirs, prefixes, detector):

    times = []
    cubes = []
    specs = []
    wavs = []
    for b, p in zip(basedirs, prefixes):
        t, w, s, cube, to = load_all(b, p)
        cubes.append(cube)
        times.append(t - to)
        specs.append(s)
        wavs.append(w)

    to = times[0][0]
    for i in range(len(times)):
        times[i] -= to

    sort = np.argsort([t[0] for t in times])
    times = [times[i] for i in sort]
    cubes = [cubes[i] for i in sort]
    specs = [specs[i] for i in sort]
    wavs = [wavs[i] for i in sort]
    
    joint_results = fit_joint_wlc(
        times, 
        specs, 
        wavs, 
        priors_dict, 
        st_params_dict,
        [detector] * len(prefixes), 
        cubes=cubes,
        gp=False, 
        n_components=8,
        samples=10000, 
        burnin=0,
        thin=1,
        nproc=2,
        save_chains=False,
        return_chains=True,
        polyorder=1,
        progress=True
    )

    return joint_results

In [None]:
target = '175.01'
visits = ['T1', 'T2']
basedir = '/Users/tylergordon/research/compass/targets/'

basedirs = [basedir + target + '/' + v + '/reduction/stage3/' for v in visits]
prefixes_nrs1 = [target + '_' + v + '_nrs1' for v in visits]
prefixes_nrs2 = [target + '_' + v + '_nrs2' for v in visits]

# define priors for the planet's parameters. Use the 
# priors defined in distributions.py
priors_dict = sys_params.priors_dict[target]

# stellar parameters for computing priors 
# on the limb-darkening coefficients
st_params_dict = sys_params.stellar_params_dict[target]

result_nrs1 = fit_wrapper(basedirs, prefixes_nrs1, 'nrs1')
result_nrs2 = fit_wrapper(basedirs, prefixes_nrs2, 'nrs2')

In [None]:
fig = plt.figure(figsize=(12, 12))
plot_utils.plot_corner(result_nrs1[0], fig=fig, burnin=5000, color=plt.cm.terrain(0.7));
plot_utils.plot_corner(result_nrs2[0], fig=fig, burnin=5000, color=plt.cm.terrain(0.1));

In [None]:
fit_spec_wrapper = lambda result: fit_joint_spec(
    result,
    wav_per_bin=0.02, 
    samples=1000, 
    burnin=500,  
    nproc=12, 
    n_components_spec=0,
    save_chains=False,
    return_chains=True,
    gp=False,
    polyorder=1,
    progress=True
)

wavs_nrs1, chains_nrs1 = fit_spec_wrapper(result_nrs1)
wavs_nrs2, chains_nrs2 = fit_spec_wrapper(result_nrs2)

In [None]:
for wavs, chains in zip([wavs_nrs1, wavs_nrs2], [chains_nrs1, chains_nrs2]):
    dep = np.array([np.mean(r.get_chain()[:, :, -3], axis=(0, 1)) for r in results_spec[0]])
    dep_err = np.array([np.std(r.get_chain()[:, :, -3], axis=(0, 1)) for r in results_spec[0]])
    wav = results_spec[1]
    plt.plot(wav, dep**2 * 1e6, 'o', color='k')
    plt.errorbar(wav, dep ** 2 * 1e6, yerr=2 * dep * dep_err * 1e6, ls='none', color='k')