# Planet Property Calculator

In [None]:
import requests
%matplotlib widget
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import lightkurve as lk
import math
from astropy import units as u
from astropy.stats import BoxLeastSquares
import numpy as np
import warnings
import plot_utils as pu
warnings.filterwarnings('ignore')
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import ipywidgets as widgets
from IPython.display import clear_output
tic = widgets.Text(
       value='',
       description='TIC:', )
menu_author = widgets.Dropdown(
       options=['SPOC', '*SPOC*', 'K2', '*'],
       value='SPOC',
       description='Author:')
menu_exptime = widgets.Dropdown(
       options=['*', '120', '20', '1800'],
       value='*',
       description='Exptime:')

button_run = widgets.Button(description='Run Query', button_style='primary')
out_run = widgets.Output()
def on_run_button_clicked(_):
    global available_data_select
    global TIC_no
    global TIC
    with out_run:
        clear_output()
        author = menu_author.value
        if menu_exptime.value != '*':
            exptime = int(menu_exptime.value)
        else:
            exptime = menu_exptime.value
        TIC_no = tic.value.strip()
        TIC = 'TIC ' + TIC_no
        print("Searching for TIC = {} Author = {} Exptime = {}".format(TIC,author,exptime))        
        # Retrieve the list of available sectors for this TIC
        available_data_select = lk.search_lightcurve(TIC, author=author, exptime=exptime)
        print(available_data_select)
# linking button and function together using a button's method
button_run.on_click(on_run_button_clicked)

box = widgets.VBox([tic,menu_author,menu_exptime,button_run,out_run ])
box

In [None]:
import ipywidgets as widgets
from IPython.display import clear_output
range_sec = widgets.IntRangeSlider(
    value=[0, len(available_data_select)],
    min=0,
    max=len(available_data_select),
    step=1,
    description='Range:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
)

button_fetch = widgets.Button(description='Fetch Data', button_style='primary')
out_fetch = widgets.Output()
def on_fetch_button_clicked(_):
    global available_data_select
    global lc_collection
    global sectors
    with out_fetch:
        clear_output()
        # download the data
        min_sec = range_sec.value[0]
        max_sec = range_sec.value[1]        
        print("Downloading {} to {}".format(min_sec,max_sec))
        data_select = available_data_select[int(min_sec):int(max_sec)]    
        print(data_select)
        lc_coll_sel = data_select.download_all()
        print("Done")
        print(lc_coll_sel)
        # 'stitch' the data from the different sectors together (which also normalizes the data)
        lc_collection = lc_coll_sel.stitch()
        fluxmin = np.nanmin(lc_collection['flux'])
        fluxmax = np.nanmax(lc_collection['flux'])
        ymax = fluxmax+(fluxmax-1)*0.1
        ymin = fluxmin-(fluxmax-1)*0.1      
        sectors = []
        for x in data_select.mission:
            sectors.append(x)
        sectors=np.unique(np.char.replace(sectors,"TESS Sector ",""))
        
# linking button and function together using a button's method
button_fetch.on_click(on_fetch_button_clicked)        
print(available_data_select)
box1 = widgets.VBox([range_sec,button_fetch,out_fetch ])
box1


# Plot the lightcurve

In [None]:
%%time
# First with binned data overlay
#=================
#lc_collection = lc_collection.select_flux(flux_column="sap_flux",flux_err_column="sap_flux_err").normalize() # Switch to SAP Flux if needed (comment out if not)
#=================
bin_time = 15/24/60
lc_binned = lc_collection.bin(bin_time)


In [None]:
plt.style.use('default')
%matplotlib widget
fig, ax = plt.subplots(figsize=(16,6))
lc_collection.plot(ax=ax, linewidth=0,marker='o', color='gold',markersize=1, label="Flux")
lc_binned.plot(ax=ax, linewidth=0,marker='o', color='black',markersize=1, label="Binned")
#lc_collection.bin(20/24/60).plot(ax=ax, color='red')

fluxmin = np.nanmin(lc_collection.flux)
fluxmax = np.nanmax(lc_collection.flux)
ymax = fluxmax+(fluxmax-1)*0.1
ymin = fluxmin-(fluxmax-1)*0.1
# Reset the max and min
#plt.ylim(ymin, ymax)
ax.set(title=TIC)
pu.add_sector_labels(ax)
plt.show()
# Define the functions to allow the user to click on two points in the chart
coords = []
tran = []

def onclick(event):
    global ix, iy
    ix, iy = event.xdata, event.ydata
    if event.button == 3:  # Only act on a right click
        global coords

        if len(coords) == 0:  # remove any existing markers if this is the first click
            for i in range(len(tran)):
                try: 
                    tran[i].remove()
                except:
                    pass
            tran.clear()

        coords.append((ix, iy))
        #arrow_properties = dict(facecolor="black", width=0.5, headwidth=4, shrink=0.1)
        #plt.annotate('T%d: x=%8.4f y=%8.4f' %(len(coords),ix, iy), (ix, iy), color='White', fontsize=8, bbox=dict(boxstyle = 'round,pad=0.5'), textcoords="offset points",xytext=(7,-50), arrowprops=arrow_properties)
        tran.append(plt.axvline(ix, color = 'purple', zorder = -1))

        #if len(coords) == 6:
        #    fig.canvas.mpl_disconnect(cid)

cid = fig.canvas.mpl_connect('button_press_event', onclick)  # Connect the event

button_done = widgets.Button(description='Done', button_style='primary')
out_done = widgets.Output()
def on_done_button_clicked(_):
    global transit_time
    global coords
    global periods
    global starting_period
    with out_done:
        clear_output()
        # User must click at least 2 points - save the x co-ords as the required transit times
        # Run this to either accept the points clicked or reset them and try again
        if len(coords) < 2:
            print ("Click on at least 2 transits")
        else:
            transit_time = []
            for i in range(len(coords)):
                transit_time.append(coords[i][0])
            transit_time.sort()
            print (transit_time)
            # Tidy up so user can reclick if required
            coords.clear()
            # the period is the separation between two consecutive transits 
            periods=[]
            for i in range(0,len(transit_time)-1):
                periods.append(transit_time[i+1] - transit_time[i])
            starting_period = min(periods)
            print(periods)
            print ("Starting Period = {} days".format(starting_period))
            print ("Average = {} days".format(sum(periods)/len(periods)))            
# linking button and function together using a button's method
button_done.on_click(on_done_button_clicked)        
box2 = widgets.VBox([button_done,out_done ])
box2

# STOP Here and look at the light curve
- RIGHT Click on the transits - at least 2
- Then click Done

# Check the background to make sure none of the 'dips' are background spikes

In [None]:
%matplotlib widget
# set up the plotting region
fig, axbg = plt.subplots(figsize = (10,5))

# plot the time vs the background flux
plt.plot(lc_collection.time.value, lc_collection.sap_bkg.value, color = 'blue', lw = 0, marker = '.', ms = 1)

plt.ylabel("Background flux") # label the axes
plt.xlabel("Time (TJD)")
plt.title("TIC-" + TIC_no)
plt.tight_layout()
for transit in transit_time:
    plt.axvline(transit, color = 'lightblue', zorder = -1)
print ('\033[1mBackground flux for sector\033[0m')
plt.show()

# Interactive BLS

In [None]:
print('Num of observations:', len(lc_collection))
print('Observation elapsed time:', lc_collection.time.max()  - lc_collection.time.min())

# use the interactive transit period detection
#   caveat: un-sure if combining observations over time make sense for the algorithm
lc_collection.interact_bls()

# if False: 
if False: 
    x_min = 1
    x_max = 20
    # Box Least Square assumes U-shaped transit model (rapid dips)
    pdg_bls = lc_collection.remove_nans().to_periodogram(method='bls')
    print('BLS')
    pdg_bls.show_properties()
    ax = pdg_bls.plot()
    ax.set_title('BLS Periodogram, in period')
    ax.set_xlim(x_min, x_max)

    # Lomb Scargle better for general vairable curves with sin-like shape (gradual flucutation)
    pdg_ls = lc_collection.remove_nans().to_periodogram(method='lombscargle')
    print('Lomb Scargle')
    pdg_ls.show_properties()
    ax = pdg_ls.plot(view='period')
    ax.set_title('Lomb Scargle Periodogram, in period')    
    ax.set_xlim(x_min, x_max)

# BLS - manual method - slow

In [None]:
# you don't have to change the code in this cell, but you do have to run it once for the rest of the notebook to work
def plot_bls(alltime, allflux, alltimebinned, allfluxbinned, model, results, period, duration, t0, mid_transit_t0, in_transit = [0], in_transit_notbinned = [0]):
    '''
    Plot the BLS. This functinon is called in data_bls().

    Parameters
    ----------
    alltime  :  list
        times (not binned)
    allflux  :  list
        normalized flux (not binned)
    alltimebinned  :  list
        binned time
    allfluxbinned  :  list
        normalized binned flux
    model :  float
        the transit model at the given period, duration, and phase
    results :  class
        results from the BLS fitting routine
    period :  float
        the period of the 'most-likely' transits
    duration :  float
        the duration of the transit
    t0  :  float
        the mid-transit time of the reference transit
    in_transit = [0] :  float
        if this is [0] (by deafult), the code knows that this is the initial run i.e. no transits have been removes (+ results are plotted in different colors)
    in_transit_notbinned = [0]. :  float
        if this is [0] (by deafult), the code knows that this is the initial run i.e. no transits have been removes (+ results are plotted in different colors)

    Returns
    -------
        Plot the results from the BLS with three pannels: periodgram, best fit model to the transits, phase folded fit.
    '''

    if len(in_transit) == 1:  # conditions for the first 'round' of plotting
        # define the colours of the plot
        color1 = '#DC143C'
        color2 = 'darkorange'
        title = 'Initial BLS'

    else:  # conditions for the second 'round' of plotting once the first event has been removed
        # define the colours of the plot
        color1 = 'deepskyblue'
        color2 = '#4682B4'
        title = 'Initial event removed'
        

    fig, axes = plt.subplots(3, 1, figsize=(10, 14))

    # highlight the harmonics of the peak period
    ax = axes[0]
    ax.axvline(period, alpha=0.4, lw=5, color = color1)
    for n in range(2, 15):
        ax.axvline(n*period, alpha=0.4, lw=2, linestyle="dashed", color = color2) # plot the harmonics
        ax.axvline(period / n, alpha=0.4, lw=2, linestyle="dashed", color = color2)

    # ------------
    # plot the periodogram
    ax.plot(results.period, results.power, "k", lw=0.5, label = 'P = %.3f T0 = %.3f' % (period,mid_transit_t0))
    ax.set_title(title)
    ax.set_xlim(results.period.min(), results.period.max())
    ax.set_xlabel("period (days)")
    ax.set_ylabel("log likelihood")
    ax.legend(fontsize = 10, loc = 1)

    # ------------
    # plot the light curve and best-fit model
    ax = axes[1]

    if len(in_transit) == 1:  # for the initial run
        ax.plot(alltime, allflux, marker =".", alpha = 0.4, color = color2, ms=2, lw = 0, markerfacecolor = 'none')
        ax.plot(alltimebinned, allfluxbinned, marker ="o", alpha = 0.6, color = 'black', ms=3, lw = 0, markerfacecolor = 'none')
    else:  # for the second run (once the first 'event' has been removed)
        ax.plot(alltime[~in_transit_notbinned], allflux[~in_transit_notbinned], marker =".", alpha = 0.4, color = color2, ms=2, lw = 0, markerfacecolor = 'none')
        ax.plot(alltimebinned[~in_transit], allfluxbinned[~in_transit], marker ="o", alpha = 0.6, color = 'black',  markerfacecolor = 'none', ms=3, lw = 0)

    x = np.linspace(alltimebinned.min(), alltimebinned.max(), 3*len(alltimebinned))
    f = model.model(x, period, duration, t0)
    ax.plot(x, f, lw=2, color = color1)
    ax.set_xlim(alltimebinned.min(), alltimebinned.max())
    ax.set_xlabel("time (days)")
    ax.set_ylabel("de-trended flux (ppt)");

    # ------------
    ax = axes[2]
    if len(in_transit) == 1:  # for the initial run
        x_binned = (alltimebinned - t0 + 0.5*period) % period - 0.5*period
        x = (alltime - t0 + 0.5*period) % period - 0.5*period
    else: # for the second run (once the first 'event' has been removed)
        x_binned = (alltimebinned[~in_transit] - t0 + 0.5*period) % period - 0.5*period
        x = (alltime[~in_transit_notbinned] - t0 + 0.5*period) % period - 0.5*period

    m_binned = np.abs(x_binned) < 0.5
    m = np.abs(x) < 0.5

    # plot the data
    if len(in_transit) == 1:  # for the initial run
        ax.plot(x[m], allflux[m],marker =".", alpha = 0.4, color = color2, ms=2, lw = 0, markerfacecolor = 'none')
        ax.plot(x_binned[m_binned], allfluxbinned[m_binned], marker ="o", alpha = 0.6, color = 'black', ms=3, lw = 0, markerfacecolor = 'none')

    else: # for the second run (once the first 'event' has been removed)
        ax.plot(x[m], allflux[~in_transit_notbinned][m],marker =".", alpha = 0.4, color = color2, ms=2, lw = 0, markerfacecolor = 'none')
        ax.plot(x_binned[m_binned], allfluxbinned[~in_transit][m_binned], marker ="o", alpha = 0.6, color = 'black', ms=3, lw = 0, markerfacecolor = 'none')

    x = np.linspace(-0.5, 0.5, 1000)
    f = model.model(x + t0, period, duration, t0)
    ax.plot(x, f, lw=2, color = color1)
    ax.set_xlim(-0.5, 0.5)
    ax.set_xlabel("time since transit (days)")
    ax.set_ylabel("de-trended flux (ppt)");
    plt.tight_layout()

    plt.show()



In [None]:
# you don't have to change the code in this cell, but you do have to run it once for the rest of the notebook to work
def data_bls(lc):
    '''
    function that runs the BLS routine and plots the results. The BLS is run twice and in the second
    the most significant result found in the first run is removed.
    Prior to running the BLS the data is detrended.

    Parameters
    ----------
    lc: from lightkurve for one or more sectors 

    Returns
    -------
        two lists of the statistics of the to BLS runs. Each list contains:
    stats_period
    stats_t0
    stats_depth
    stats_depth_phased
    stats_depth_half
    stats_depth_odd
    stats_depth_even
    '''
    
    # normalize the data
    try:
        lc = lc.normalize()
    except:
        lc = lc.stitch()
        
    alltime = lc.time.value
    allflux = lc.flux.value
    
    lc_bin = lc.bin(15/60/24) # you can change the binning factor here if you like
    alltimebinned = lc_bin.time.value
    allfluxbinned = lc_bin.flux.value
        
    # make sure that there are no nan (empty) values in the data - they cause everything to crash so let's get rid of them
    mask_binned = np.isfinite(alltimebinned) * np.isfinite(allfluxbinned)
    mask = np.isfinite(alltime) * np.isfinite(allflux)

    alltimebinned = np.array(alltimebinned)[mask_binned]
    allfluxbinned = np.array(allfluxbinned)[mask_binned]
    alltime = np.array(alltime)[mask]
    allflux = np.array(allflux)[mask]

    # -------------------

    # detrend the data before running the BLS

    mask_binned = np.isfinite(alltimebinned) * np.isfinite(allfluxbinned)
    alltimebinned = np.array(alltimebinned)[mask_binned]
    allfluxbinned = np.array(allfluxbinned)[mask_binned]
    # -----------------------

    durations = np.linspace(0.05, 0.5, 15) # ????? CHECK THESE
    periods = np.arange(0.7, (np.nanmax(alltimebinned) - np.nanmin(alltimebinned)), 0.01)
    
    model = BoxLeastSquares(alltimebinned, allfluxbinned)
    results = model.power(periods, durations)

    index = np.argmax(results.power)
    period = results.period[index]
    t0 = results.transit_time[index]
    duration = results.duration[index]
    mid_transit_t0 = model.compute_stats(period, duration, t0)['transit_times'][0]

    # call the first round of plotting

    plot_bls(alltime, allflux, alltimebinned, allfluxbinned, model, results, period, duration, t0, mid_transit_t0)

    stats_period = period
    stats_t0 = mid_transit_t0
    stats_depth = model.compute_stats(period, duration, t0)['depth']
    stats_depth_phased = model.compute_stats(period, duration, t0)['depth_phased']
    stats_depth_half = model.compute_stats(period, duration, t0)['depth_half']
    stats_depth_odd = model.compute_stats(period, duration, t0)['depth_odd']
    stats_depth_even = model.compute_stats(period, duration, t0)['depth_even']

    if (1*duration) >= period: # if the 'found' events are very short period, don't rn the BLS twice as the code would crash.
        return [stats_period, stats_t0, stats_depth, stats_depth_phased, stats_depth_half, stats_depth_odd, stats_depth_even], [-999]

    # Find the in-transit points using a longer duration as a buffer to avoid ingress and egress
    in_transit = model.transit_mask(alltimebinned, period, 2*duration, t0)
    in_transit_notbinned = model.transit_mask(alltime, period, 2*duration, t0)
    
    # Re-run the algorithm, and plot the results
    model2 = BoxLeastSquares(alltimebinned[~in_transit], allfluxbinned[~in_transit])
    results2 = model2.power(periods, durations)

    # Extract the parameters of the best-fit model
    index = np.argmax(results2.power)
    period2 = results2.period[index]
    t02 = results2.transit_time[index]
    duration2 = results2.duration[index]
    mid_transit_t02 = model.compute_stats(period2, duration2, t02)['transit_times'][0]
    
    # call the second round of plotting - once the intitial transit has been removed
    plot_bls(alltime, allflux, alltimebinned, allfluxbinned, model2, results2,period2,duration2,t02, mid_transit_t02, in_transit = in_transit, in_transit_notbinned = in_transit_notbinned)
    
    stats2_period = period2
    stats2_t0 = mid_transit_t02
    stats2_depth = model2.compute_stats(period2, duration2, t0)['depth']
    stats2_depth_phased = model2.compute_stats(period2, duration2, t0)['depth_phased']
    stats2_depth_half = model2.compute_stats(period2, duration2, t0)['depth_half']
    stats2_depth_odd = model2.compute_stats(period2, duration2, t0)['depth_odd']
    stats2_depth_even = model2.compute_stats(period2, duration2, t0)['depth_even']
        
    df = pd.DataFrame({"Period (days)": [stats_period,stats2_period], "T0 (TBJD)":[stats_t0,stats2_t0], "Transit depth (ppm)":[stats_depth[0],stats2_depth[0]], "Odd depth (ppm)": [stats_depth_odd[0],stats2_depth_odd[0]], "Even depth (ppm)":[stats_depth_even[0],stats2_depth_even[0]]})
    return df



In [None]:
# Run the above functions with this simple command here (once you have downloaded the data - defined as lc here)
# This cell can take a minute or so to run (it's doing a lot of work)

df = data_bls(lc_coll)

print ("Some statistics about the two fits (these numbers are just estimates and should be taken with a large pinch of salt!)")
df


# Colour graded scatter plot for use later when folding

In [None]:

# plot the data to look at when the transit events are
# this plotting is different to how we did it previously as I wanted to plot the data points coloured by time (earlier times are darker)
plt.style.use('default')
fig, ax = plt.subplots(figsize=(20,6))

plt.scatter(lc_collection.time.value, lc_collection.flux.value, c = lc_collection.time.value, s = 1, cmap = 'tab20b')

plt.xlabel("Time (BJD - 2457000)")
plt.ylabel("Normalized Flux")
# Reset the max and min
plt.ylim(ymin, ymax)
plt.title(TIC_no)
plt.show()
print("Suggested period = {}".format(starting_period))

# Interactive plot to refine the period

In [None]:
'''
In the code below you can change the period using a slider and watch how different orbital periods affect the phase folded lightcurve. if you click on the slider, you can use the arrows on the keypad to move it by smaller amounts. NOTE: there is often a bit of a lag between changing the value and the figure updating!

Run the interactive widget using this line of code (see example below):

*interact(plot_phase_folded_color, period = widgets.FloatSlider(**value**=xx,**min**=xx,**max**=xx,**step**=xx,description='period:', readout_format='.4f'))*

The values in bold need to be changed for different targets!
- **value**: starting guess at the period
- **min**: the minimum period that you think it could be
- **max**: the maximum period that you think it could be
- **step**: the step size that the widget jumps when you press the up or down arrow on your keyboard (good starting point=0.0001 - if you see no change in the phase folded lighcurve when you press the up and down errors then make the step sizelarger!)

In order to use this widget, you need to give it a range of periods that you want to test (so you already need an idea of the period!). 
'''

import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual
from IPython.display import clear_output

def plot_phase_folded_color(period, xrange=1):

    # phase fold the light curve
    lc_phased = lc_collection.fold(period = period, epoch_time = transit_time[0])
    
    # plot the binned and unbinned phase folded lightcurve on the same figure
    %matplotlib inline 
    fig, ax = plt.subplots(figsize = (30,12))
    
    plt.scatter(lc_phased.time.value, lc_phased.flux.value, c = lc_phased.time_original.value, s = 4, marker = 'o', cmap = 'tab20b')

    # if you prefer all the point to have the same color, delete the line above and get rid of the hastag in fron of the line below
    #lc_phased.plot(ax = ax, marker = '.', linewidth = 0, color = 'black', alpha = 0.2, markersize = 3)
    
    plt.xlabel("Phase")
    plt.ylabel("Normalized flux")
    #plt.ylim(ymin, ymax)
    xlims = xrange
    plt.xlim(-xlims,xlims)
    #plt.ylim(0.97, 1.03)
    plt.show()
    
start_p = widgets.Text(
       value=str(starting_period),
       description='Period', )
start_r = widgets.Text(
       value=str(0.2),
       description='Slider Range:', )
start_step = widgets.Text(
       value=str(0.0001),
       description='Slider Step:', )
start_x = widgets.Text(
       value=str(int(starting_period/2)+1),
       description='Xrange', )

button_plot = widgets.Button(description='Plot', button_style='primary')
out_plot = widgets.Output()
def on_plot_button_clicked(_):

    with out_plot:
        clear_output()
        #========= Set Starting Period and slider range ===============
        starting_period=float(start_p.value)
        starting_range=float(start_r.value)
        step = float(start_step.value)
        #=========Set X range starting value =============================================
        xrange_start = float(start_x.value)
        #=================================================================================
        interact(plot_phase_folded_color, period = widgets.FloatSlider(value=starting_period,min =starting_period-starting_range,max=starting_period+starting_range,step=step,description='period:', readout_format='.5f'), xrange = widgets.FloatSlider(value=xrange_start,min =0,max=(starting_period/2)+1,step=0.05,description='xrange:'))     
# linking button and function together using a button's method
button_plot.on_click(on_plot_button_clicked)        

box3 = widgets.VBox([start_p,start_r,start_step,start_x,button_plot,out_plot ])
box3

# Plot again to enable measurements to be taken

In [None]:
# ========================= look at the period selected above and copy and paste it here =====================
period = 1.703583731306844
offset = 0
# =============================================================================================================

from IPython.display import Markdown, display
def printmd(string):
    display(Markdown(string))
printmd("**Right-click the six parameters of the dip** -  1: Start   2: End    3: Start of floor    4: End of floor    5: Top    6: Bottom")

# phase fold the light curve using this period and the time of one of the above transit as the value of t0 
lc_phased = lc_collection.fold(period = period, epoch_time = transit_time[0], epoch_phase=offset)
lc_phased_binned = lc_phased.bin(10/24/60)

# This is a manual alternative to the slider above if needed
#fig, ax = plt.subplots(figsize = (20,8))
#lc_phased.plot(ax = ax, marker = '.', linewidth = 0, color = 'blue', alpha = 0.2, markersize = 3)
#plt.xlim(-.5,.5)
#plt.ylim(.80,1.1)
# change the period above and re run this cell until the transit events below line up

# Odd/Even plot to look for possible different planets or EB, along with binned overlay to help with measurements
%matplotlib widget
plt.style.use('default')
fig, ax = plt.subplots(figsize=(16,8))
lc_phased[lc_phased.odd_mask].plot(ax=ax, linewidth=0,marker='.', color='red',alpha=0.2,markersize=5,label='unbinned_odd')
lc_phased[lc_phased.even_mask].plot(ax=ax, linewidth=0,marker='.', color='blue',alpha=0.2,markersize=5,label='unbinned_even')
lc_phased_binned.plot(ax=ax, linewidth=0,marker='o', color='black',markersize=3,alpha=0.8,label='binned')

xlims =70
plt.xlim(-xlims,xlims)
focus_factor = 1 #  Set higher to focus the y dimension
plt.ylim(fluxmin+(fluxmax-1)*focus_factor, fluxmax-(fluxmax-1)*focus_factor)
#plt.ylim(0.995,1.005)
plt.ylim(ymin, ymax)
ax.set(title=TIC)
plt.show()

# Click event to select all 6 parameter one after the other with right click.  
def onclick(event):
    global ix, iy, coords, tran
    ix, iy = event.xdata, event.ydata
    if event.button == 3:  # right click

        if len(coords) < 2:
            coords.append((ix, iy))
            tran.append(plt.axvline(ix, color = 'blue', zorder = -1))   
        elif len(coords) < 4:
            coords.append((ix, iy))
            tran.append(plt.axvline(ix, color = 'grey', linestyle = '-', zorder = -1))            
        elif len(coords) < 6:
            coords.append((ix, iy))
            tran.append(plt.axhline(iy, color = 'green', zorder = -1))            
        elif len(coords) == 6:
            #fig.canvas.mpl_disconnect(cid)
            for line in tran:
                try: 
                    line.remove()
                except:
                    pass
            tran.clear()   
            coords.clear()

# Initialise the arrays before we start
coords = []
tran = []

cid = fig.canvas.mpl_connect('button_press_event', onclick)  # Connect the event

button_reset = widgets.Button(description='Reset', button_style='warning')

out_op = widgets.Output()
def on_reset_button_clicked(_):
    global coords, tran
    with out_op:
        clear_output()
        for line in tran:
            try: 
                line.remove()
            except:
                pass
        tran.clear()   
        coords.clear()
# linking button and function together using a button's method
button_reset.on_click(on_reset_button_clicked)        

button_op = widgets.Button(description='Done', button_style='primary')
def on_op_button_clicked(_):
    global coords
    global transit_depth,transit_start,transit_end,transit_length,transit_mid,floor_length
    with out_op:
        clear_output()
        if len(coords) < 6:
            print("Select the six parameters before continuing")
        else:
            #fig.canvas.mpl_disconnect(cid)
            transit_top = coords[4][1] 
            transit_bottom = coords[5][1] 
            transit_depth = transit_top-transit_bottom
            if transit_depth < 0 : transit_depth *= -1  # In case we choose the wrong way round
            print("Depth = %6.4f" %transit_depth)
            transit_start = coords[0][0] 
            transit_end = coords[1][0] 
            transit_length = (transit_end-transit_start) * 24
            if transit_length < 0 : transit_length *= -1  # In case we choose the wrong way round
            transit_mid = transit_start+(transit_end-transit_start)/2
            print('Mid-point = %6.4f' %transit_mid)
            print('Duration = %6.4f hours' %transit_length)
            floor_start = coords[2][0] 
            floor_end = coords[3][0] 
            floor_length = (floor_end-floor_start) * 24
            if floor_length < 0 : floor_length *= -1  # In case we choose the wrong way round
            print("Floor = %6.4f hours" %floor_length)         
# linking button and function together using a button's method
button_op.on_click(on_op_button_clicked)        
boxop = widgets.VBox([button_reset,button_op,out_op ])
boxop

# Planet Calculations

In [None]:
import star_data as star
GAIA_ID = TIC_rad = TIC_erad = TIC_mass = TIC_emass = TIC_Teff = TIC_eTeff = TIC_lum = TIC_Vmag = TIC_dist = TIC_mind = TIC_maxd = TIC_ra = TIC_dec = float('nan')
GAIA_ID, TIC_rad, TIC_erad, TIC_mass, TIC_emass, TIC_Teff, TIC_eTeff, TIC_lum, TIC_Vmag, TIC_dist, TIC_mind, TIC_maxd, TIC_ra, TIC_dec = star.get_exofop_data(TIC_no)
R_star = float(TIC_rad)
M_star = float(TIC_mass)
star.get_simbad_and_EB_data(TIC_no)

try:
    gaia_rad = gaia_rad_lower = gaia_rad_upper = gaia_rad_sigma = gaia_teff = gaia_teff_lower = gaia_teff_upper = gaia_teff_sigma = gaia_lum = gaia_lum_lower = gaia_lum_upper = gaia_mass = gaia_mass_lower = gaia_mass_upper = gaia_mass_sigma = float('nan')
    gaia_rad, gaia_rad_lower, gaia_rad_upper, gaia_rad_sigma, gaia_teff, gaia_teff_lower, gaia_teff_upper, gaia_teff_sigma, gaia_lum, gaia_lum_lower, gaia_lum_upper, gaia_mass, gaia_mass_lower, gaia_mass_upper, gaia_mass_sigma = star.get_gaia_data(GAIA_ID)
    if not(np.isnan(gaia_rad)):
        R_star= float(gaia_rad)
    if not(np.isnan(gaia_mass)):    
        M_star = float(gaia_mass)
except:
    print("Unable to get all GAIA data")        

In [None]:
# ============= Radius relative to sol ========================
#TIC_rad = 0.761  # Override radius
#R_star = float(TIC_rad)
#R_star= float(gaia_rad)
# =============================================================
Rs = R_star * u.Rsun
r_pl_solar_radius = np.sqrt(transit_depth) * Rs
Rs_km = Rs.to(u.kilometer)
r_pl_solar_radius

In [None]:
# Radius relative to earth
r_pl_earth_radius = r_pl_solar_radius.to(u.Rearth)
r_pl_earth_radius

In [None]:
# Radius relative to jupiter
r_pl_jup_radius = r_pl_solar_radius.to(u.Rjupiter)
r_pl_jup_radius

In [None]:
# Radius in kilometers
r_pl_km_radius = r_pl_solar_radius.to(u.kilometer)
r_pl_km_radius

In [None]:
# ============ Mass of star relative to sol ==================
#TIC_mass = 0.820    # Override masses
#M_star = float(TIC_mass)
#M_star = float(gaia_mass)
# ============================================================
transit_length_days = transit_length/24
floor_length_days = floor_length/24

# Impact Parameter
import math
sqdf = math.sqrt(transit_depth)
ts = (floor_length**2)/(transit_length**2)

#bs = math.sqrt((((1-sqdf)*(1-sqdf))-(ts*(1+sqdf)*(1+sqdf)))/(1.0-ts))
bs = math.sqrt(((1-sqdf)**2-(ts*(1+sqdf)**2))/(1-ts))

# A/r*
ar = (2 * period * math.pow(transit_depth,0.25))/(math.pi * math.sqrt((transit_length_days**2)-(floor_length_days**2)));

# Tc
tc = transit_length/math.sqrt(1-(bs**2));

# Separation
pp = period/365;
au = math.pow(pp*pp*M_star,0.333);
au_km = (au * u.AU).to(u.kilometer)
i = math.degrees(math.acos((bs*Rs_km)/au_km))

In [None]:
import math
from astropy import units as u
Ms = M_star * u.Msun
Ms_g = Ms.to(u.gram)
print(Ms_g)
Rs = R_star * u.Rsun
Rs_cm = Rs.to(u.centimeter)
print(Rs_cm)
rho = Ms_g.value/((4/3)*math.pi*math.pow(Rs_cm.value,3))
print(rho)

In [None]:
 
print ('\033[1mSummary\033[0m: ' + TIC)    
print ('T0 = %8.4f BJD' %transit_time[0])
print ('Transit_Depth = %6.4f' %(transit_depth))
print ('Transit_Length = %6.4f hours' %(transit_length))
print ('Floor_Length = %6.4f hours' %(floor_length))
print ('Star_Radius = ' + f"{R_star:0.04f} Rsol")
print ('Star_Mass = ' + f"{M_star:0.04f} Msol")
print ('Star_Density = %6.4f g/cm^3' %(rho))
print ('Vmag = %7.4f  GAIA = %d' %(TIC_Vmag, GAIA_ID))
print ('Distance = %8.4f (-%0.4f +%0.4f)pc' %(TIC_dist, TIC_mind, TIC_maxd))
print ('Planet_Radius = ' + f"{r_pl_jup_radius:0.02f}" + "   " + f"{r_pl_earth_radius:0.02f}")   
print ('Period = %7.4f days' %(period))
print ('Impact_Parameter = %5.4f' %bs)
print ('i = %6.4f deg' %i)
print ('A/rstar = %6.4f' %ar)
print ('Transit_Central = %6.4f hours (est.)' %tc)
print ('Separation = %6.4f AU' %au)

In [None]:
# Calculate period from transit parameters
# Separation
tcm = (tc/(13 * R_star))*math.sqrt(M_star)  #Using solar units and relative to transit time of earth - 13 hours
au = tcm ** 2
calc_period = math.sqrt((au*au*au)/M_star)*365  #Keplers 3rd law - x365 to get period in days rather than earth years

# A/r*
ar = (2 * period * math.pow(transit_depth,0.25))/(math.pi * math.sqrt((transit_length_days**2)-(floor_length_days**2)))

print ('A/r* = %6.4f' %ar)
print ('Period = %7.4f days' %(calc_period))
print ('Separation = %6.4f AU' %au)

# Export files for Pyaneti

In [None]:
try:
    print("Exofop data")
    print("Radius = {}".format(*TIC_rad))
    print("Mass = {}".format(*TIC_mass))
    print("Teff = {}".format(*TIC_Teff))
    print("GAIA data")
    print("Radius = {}".format(*gaia_rad))
    print("Mass = {}".format(*gaia_mass))
    print("Teff = {}".format(*gaia_Teff))
except:
    pass

In [None]:
import ipywidgets as widgets
from IPython.display import clear_output
chk_exofop = widgets.Checkbox(
           value=False,
           description='Use exofop:',)
menu_method = widgets.Dropdown(
       options=['DETREND', 'CUTOUT', '30MIN'],
       value='DETREND',
       description='Method:')
chk_density = widgets.Checkbox(
           value=False,
           description='Fit density:',)
suffix = widgets.Text(
       value='_a_fit',
       description='Suffix:', )

button_set = widgets.Button(description='Set Options', button_style='primary')
out_set = widgets.Output()
def on_set_button_clicked(_):
    global use_exofop
    global data_method
    global fit_density
    global model_suffix
    with out_set:
        clear_output()
        use_exofop = chk_exofop.value
        data_method = menu_method.value
        fit_density = chk_density.value
        model_suffix = suffix.value
        if use_exofop == True:
            print("Using exofop data")
        else:
            print("Using GAIA data")
        print("Data method: {}".format(data_method))
        if fit_density == True:
            print("Fitting density")
        else:
            print("Fitting semi-major axis")        
        print("File suffix: {}".format(model_suffix))
        
# linking button and function together using a button's method
button_set.on_click(on_set_button_clicked)    

box4 = widgets.VBox([chk_exofop,menu_method,chk_density,suffix,button_set,out_set])
box4

In [None]:
import numpy as np
import pandas as pd
import astropy.io.fits as pf

def detrend_pyaneti(tic, transit_time_list, alltime, allflux, allflux_err, lim_window, lim, poly_n, save = True):
    
    global LC_max
    combined_corr_time = []
    combined_corr_flux = []
    combined_corr_err  = []
    
    for transit_time in transit_time_list:
        try:

            mask = (np.array(alltime) > transit_time-lim_window) & (np.array(alltime) < transit_time+lim_window)

            masked_time= np.array(alltime)[mask]
            masked_flux = np.array(allflux)[mask]
            masked_flux_err = np.array(allflux_err)[mask]

            np.nan_to_num(masked_time,False,nan=0.0)  # Make dure there are no Nans as it breaks polyfit
            np.nan_to_num(masked_flux,False,nan=1.0)

            z = np.polyfit(masked_time, masked_flux, 1)

            p = np.poly1d(z)

            xp = np.linspace(np.nanmin(masked_time), np.nanmax(masked_time), 100)

            finite = np.isfinite(masked_time) & np.isfinite(masked_flux)
            
            x = masked_time[finite]
            y = masked_flux[finite]
            y_err = masked_flux_err[finite]
                      
            #plt.show()

            oot = (x > (transit_time-lim)) & (x < (transit_time+lim))
            
            x_oot = x[~oot]
            y_oot = y[~oot]
            
            # Detrend with a 2d order polynomial
            
            model = np.polyfit(x_oot, y_oot, poly_n)
            predicted = np.polyval(model, x)
            
            fig, axes = plt.subplots(nrows=2, sharex=True, figsize=(10,10))
            
            axes[0].plot(x, y, '#c44e52', marker='o', lw =0)
            axes[0].plot(x, predicted, 'k-')
            axes[0].set(title='Original Data and 2nd Order Polynomial Trend')
            axes[0].axvline(transit_time-lim)
            axes[0].axvline(transit_time+lim)
                            
            axes[1].plot(x, y - predicted,  '#c44e52', marker='o', lw =0)
            axes[1].set(title='Detrended Residual')
            
            combined_corr_time.append(x)
            combined_corr_flux.append((y - predicted) + 1)
            combined_corr_err.append(y_err)
            
            plt.show()
        except Exception as e: print(e)
            #print ("{} not in the data set".format(transit_time))

    # Include the first and last data points so the model extends to the full time range
    combined_corr_time.insert(0,alltime[0])
    combined_corr_flux.insert(0,allflux[0])
    combined_corr_err.insert(0,allflux_err[0])
    combined_corr_time.append(alltime[len(alltime)-1])  
    combined_corr_flux.append(allflux[len(allflux)-1])
    combined_corr_err.append(allflux_err[len(allflux_err)-1])
    
    combined_corr_time = np.hstack(combined_corr_time)
    combined_corr_flux = np.hstack(combined_corr_flux)
    combined_corr_err  = np.hstack(combined_corr_err)
    #print(combined_corr_time,combined_corr_flux,combined_corr_err)
    finite_mask = np.isfinite(combined_corr_time) & np.isfinite(combined_corr_flux) & np.isfinite(combined_corr_err)

    combined_corr_time = combined_corr_time[finite_mask] 
    combined_corr_flux =  combined_corr_flux[finite_mask]
    combined_corr_err  = combined_corr_err[finite_mask]
    
    fig, axes = plt.subplots(figsize=(14,7))
    axes.plot(combined_corr_time,combined_corr_flux, marker='o', lw =0)
    plt.show()
    
    cadence_col = []
    for ftime in enumerate(combined_corr_time):
        if ftime[1] < LC_max:
            cadence_col.append("LC")
        else:
            cadence_col.append("SC")     
    
    with open(pyaneti_data_file, "w", newline='') as f:
        writer = csv.writer(f, delimiter='\t')
        writer.writerow(['#time', 'flux', 'flux_err', 'band'])    
    
    if save == True:
        with open(pyaneti_data_file, "a", newline='') as f:
            writer = csv.writer(f, delimiter='\t')
            writer.writerows(zip(combined_corr_time, combined_corr_flux, combined_corr_err,cadence_col))  # Comment out for full data        
        
        #np.savetxt(pyaneti_data_file, np.array([combined_corr_time, combined_corr_flux, combined_corr_err, cadence_col]).T)
        print("Created detrended data file: " + pyaneti_data_file)

#  Set up data and options

import os
import csv
import shutil

# Add Long cadence (LC) and short cadence (SC) column
# LC_max is the highest time value of long cadence data used to separate the - set to 0 if all short cadence

nbands = 1
LC_max = 0
if nbands > 1:
    bands = "bands = ['SC','LC']\nt_cad = [2/24/60,30/24/60]\nn_cad = [1,10]\nis_multi_radius = False"
    is_jitter_tr = "True"
else:
    bands = "n_cad = [1]\nt_cad = [2/24/60]"
    #bands = "n_cad = [1]\nt_cad = [0.3/24/60]"
    is_jitter_tr = "False"
    
# *****************    

if use_exofop:
    pya_rad = TIC_rad
    pya_erad = TIC_erad
    pya_mass = TIC_mass
    pya_emass = TIC_emass
    pya_Teff = TIC_Teff
    pya_eTeff = TIC_eTeff
    radius_source = "Exofop"
    mass_source = "Exofop"
    Teff_source = "Exofop"
else:
    #GAIA data
    pya_rad = gaia_rad
    pya_erad = gaia_rad_sigma
    radius_source = "GAIA"
    if np.isnan(gaia_mass):
        pya_mass = TIC_mass
        mass_source = "Exofop"
    else:
        pya_mass = gaia_mass
        mass_source = "GAIA"
    if np.isnan(gaia_mass_sigma):
        pya_emass = TIC_emass
    else:
        pya_emass = gaia_mass_sigma
    if np.isnan(gaia_teff):
        pya_Teff = TIC_Teff
        Teff_source = "Exofop"
    else:
        pya_Teff = gaia_teff
        Teff_source = "GAIA"
    if np.isnan(gaia_teff_sigma):
        pya_eTeff = TIC_eTeff
    else:
        pya_eTeff = gaia_teff_sigma
    data_source = "GAIA"
    
if fit_density:
    min_a = rho
    max_a = rho*0.15
    sample_stellar_density = "True"
    fit_a = "g"
    fit_a_comment = "We fit a with gaussian priors (given by the stellar parameters)"
else:
    min_a = 0.001
    max_a = 200
    sample_stellar_density = "False"
    fit_a = "u"
    fit_a_comment = "We fit the scaled semi-major axis"    
# Pyaneti files

#pyaneti_home_dir = r".."  
pyaneti_home_dir = "/home/ian/pyaneti"
inpy_dir = pyaneti_home_dir + "/inpy/"
outpy_dir = pyaneti_home_dir + "/outpy/"
pyaneti_data_filename = '{}_data.txt'.format(TIC_no)
pyaneti_data_file = inpy_dir + '{}/'.format(TIC_no) + pyaneti_data_filename  
input_fit_filename = "input_fit.py"
pyaneti_input_file = inpy_dir + '{}/{}'.format(TIC_no,input_fit_filename)
pyaneti_command = 'python pyaneti.py {}'.format(TIC_no)
target_out_dir = outpy_dir + '{}_out'.format(TIC_no)
target_in_dir = inpy_dir + '{}'.format(TIC_no)
pyaneti_full_data_file = target_out_dir + '/full_' + pyaneti_data_filename 

#remove any existing files first
try:
    shutil.rmtree(target_out_dir)
except:
    pass
try:
    shutil.rmtree(target_in_dir)
except:
    pass

# Create directories to store the files
path = inpy_dir + TIC_no
try:
    os.mkdir(path)
except OSError:
    pass   # Assume its already there

path = target_out_dir
try:
    os.mkdir(path)
except OSError:
    pass   # Assume its already there

In [None]:
# Use this if flux_errs are NaNs in some of the data
#lc_collection['flux_err'] = [0.0003 if math.isnan(x) else x for x in lc_collection['flux_err']]
#lc_collection['flux_err'] = [0.003 if x<0.003 else x for x in lc_collection['flux_err']]
#print(lc_collection['flux_err'])
# All the data in Mito
#import mitosheet
#import pandas as pd
#dat=pd.DataFrame(lc_collection['flux'])
#mitosheet.sheet(dat)  #pass the data frame into mito
#print(lc_collection['flux'].unit)
#print(lc_collection['flux_err'].unit)

In [None]:
if data_method == "CUTOUT":  # Option to export just cutouts round the transit (no detrending)
    # Data file - just the 2 days surrounding the dip (See below if full data file wanted - but it slows pyaneti down!)
    lc_norm = lc_collection.normalize()
    # Export file for pyaneti
    with open(pyaneti_data_file, "w", newline='') as f:
        writer = csv.writer(f, delimiter='\t')
        writer.writerow(['#time', 'flux', 'flux_err', 'band'])

    alltime = lc_norm.remove_nans().time.btjd
    allflux = lc_norm.remove_nans().flux
    allfluxerr = lc_norm.remove_nans().flux_err

    # If the transits are within two days of each other we will get overlaps, hopefully this mask will handle that!

    pyaneti_mask = (alltime > (transit_time[0] - 1)) * (alltime < (transit_time[0] + 1)) 

    for i in range(1,len(transit_time)):
        pyaneti_mask += (alltime > (transit_time[i] - 1)) * (alltime < (transit_time[i] + 1))
        
    # Include the first and last data points so the model extends to the full time range
    pyaneti_mask[0] = True
    pyaneti_mask[len(alltime)-1] = True
    
    cadence_col = []
    for ftime in enumerate(alltime[pyaneti_mask]):
        if ftime[1] < LC_max:
            cadence_col.append("LC")
        else:
            cadence_col.append("SC") 
    
    with open(pyaneti_data_file, "a", newline='') as f:
        writer = csv.writer(f, delimiter='\t')
        writer.writerows(zip(alltime[pyaneti_mask],allflux[pyaneti_mask],allfluxerr[pyaneti_mask],cadence_col))  # Comment out for full data
        #writer.writerows(zip(alltime,allflux,allfluxerr))  # Uncomment for full data

    f.close()
    fig, axes = plt.subplots(figsize=(20,7))
    axes.plot(alltime[pyaneti_mask],allflux[pyaneti_mask], marker='o', lw =0)
    plt.show()
    print("Created cutout data file: " + pyaneti_data_file)
elif data_method == "30MIN":  # Option to export a full light curve of 30 minute cadence data if want to compare FFI lcs
    # Data file - just the 2 days surrounding the dip (See below if full data file wanted - but it slows pyaneti down!)

    #lc_collection.flux = lc_collection['sap_flux']
    lc_30min = lc_collection.normalize().bin(30/24/60)
    lc_30min.plot(lw = 0, marker = '.', ms = 1)

    # Export file for pyaneti
    with open(pyaneti_data_file, "w", newline='') as f:
        writer = csv.writer(f, delimiter='\t')
        writer.writerow(['#time', 'flux', 'flux_err', 'cadence'])

    alltime = lc_30min.remove_nans().time.btjd
    allflux = lc_30min.remove_nans().flux
    allfluxerr = lc_30min.remove_nans().flux_err

    with open(pyaneti_data_file, "a", newline='') as f:
        writer = csv.writer(f, delimiter='\t')
        writer.writerows(zip(alltime,allflux,allfluxerr)) 

    f.close()
    fig, axes = plt.subplots(figsize=(20,7))
    axes.plot(alltime,allflux, marker='o', lw =0)
    plt.show()
    print("Created 30 minute data file: " + pyaneti_data_file)
elif data_method == "DETREND":  # Option to export detrended LC
    lc_detrend = lc_collection.normalize()

    # Extract the time and fluxes
    alltime = np.array(lc_detrend.remove_nans().time.btjd)
    allflux = np.array(lc_detrend.remove_nans().flux.value)
    allflux_err = np.array(lc_detrend.remove_nans().flux_err.value)

    # to get a good fit you need to adjust these parameters

    # the size of the cut-off window - ideally you want this to be a couple of times the transit duration
    cutoutwindow_width = (transit_length/24)*1.5

    # this is to mask out the transit event (the stellar trend will be fit to the data excluding this data)
    # this widnow is marked on by the blue vertical lines
    transit_mask = transit_length/24

    # this is the order of the polynomial - don't go to high! I recommend no higher than 3. 
    polynomial_order = 2

    detrend_pyaneti(TIC_no, transit_time, alltime, allflux, allflux_err,cutoutwindow_width, transit_mask, polynomial_order, save = True)

    # the second panel shows the detrended LC - check that it looks okay. The data is automatically saved and ready for pyaneti.
else:
    print ("Select an data method")
    
# Export full data for plotting later

lc_norm = lc_collection.normalize()
# Export file for pyaneti
with open(pyaneti_full_data_file, "w", newline='') as f:
    writer = csv.writer(f, delimiter='\t')
    writer.writerow(['#time', 'flux', 'flux_err'])

alltime = lc_norm.remove_nans().time.btjd
allflux = lc_norm.remove_nans().flux
allfluxerr = lc_norm.remove_nans().flux_err

with open(pyaneti_full_data_file, "a", newline='') as f:
    writer = csv.writer(f, delimiter='\t')
    writer.writerows(zip(alltime,allflux,allfluxerr))  # Uncomment for full data

f.close()
#fig, axes = plt.subplots(figsize=(20,7))
#axes.plot(alltime,allflux, marker='o', lw =0)
#plt.show()
print("Created full data file: " + pyaneti_full_data_file)

# Create the Pyaneti command file

In [None]:
# Pyaneti input data file
template = """#Input file for pyaneti
#fname_tr contains the transit data
fname_tr = ['{pyaneti_data_filename}']
'''
TIC {TIC_no}
Sectors {sectors}
{no_transits} Transits {transit_time}
Light Curve {data_method}
'''
#MCMC controls
#the thin factor for the chains
thin_factor = 10
#The number of iterations to be taken into account
#The TOTAL number of iterations for the burn-in phase is thin_factor*niter
niter       = 500
#Number of independent Markov chains for the ensemble sampler
nchains     = 100

#Choose the method that we want to use
# mcmc -> runs the mcmc fit program
# plot -> this option create the plots only if a previus run was done
method = 'mcmc'
#method = 'plot'

#bands
{bands}

#If you want a plot with the seaborn library, is_seaborn_plot has to be True
is_seaborn_plot = False
plot_binned_data = True

#Set this true for multi-band
is_jitter_tr = {is_jitter_tr}

#Define the star parameters to calculate the planet parameters
mstar_mean  = {pya_mass:0.04f}  # {mass_source}
mstar_sigma = {pya_emass:0.04f}
rstar_mean  = {pya_rad:0.04f}  # {radius_source}
rstar_sigma = {pya_erad:0.04f}
tstar_mean  = {pya_Teff:0.04f}  # {Teff_source}
tstar_sigma = {pya_eTeff:0.04f}

#What units do you prefer for your planet parameters?
# earth, jupiter or solar
unit_mass = 'earth'

#If we want posterior, correlation and/or chain plots these options have to be set True
is_plot_posterior    = True
is_plot_correlations = True
is_plot_chains       = False

nplanets = 1

#Are we setting gaussian priors on the semi-major axis based on the stellar parameters?
a_from_kepler = [True]*nplanets

#We want to fit transit and RV 
#For a pure RV fit, fit_tr has to be False
#For a pure TR fit, fit_rv has to be False
#For multi-planet fits fit_rv and fit_tr have the form [True,True,False,...]
#one element for each planet.
fit_rv = [False]*nplanets
fit_tr = [True]*nplanets

#is_ew controls the parametrization sqrt(e)sin w and sqrt(e) cos w
#if True we fit for the parametrization parameters
#if False we fit for e and w
#Default is True
is_ew = False

# Change the next line to False and alter a priors if you want to use a/R* rather than density
sample_stellar_density = {sample_stellar_density}

#Prior section
# f -> fixed value
# u -> Uniform priors
# g -> Gaussian priors
fit_t0 = ['u']*nplanets   #We fit for t0 with uniform priors
fit_P  = ['u']*nplanets   #We fit for P with uniform priors
fit_ew1= ['f']*nplanets   #We fit sqrt(e) sin w, it works only if is_ew = True - Change to u to model eccentricity
fit_ew2= ['f']*nplanets   #We fit sqrt(e) cos w, it works only if is_ew = True - Change to u to model eccentricity
fit_e  = ['f']*nplanets   #We fix e, it works only if is_ew = False
fit_w  = ['f']*nplanets   #We fix w, it works only if is_ew = False
fit_b  = ['u']*nplanets   #We fit the impact factor
fit_a  = ['{fit_a}']*nplanets   #{fit_a_comment}
fit_rp = ['u']*nplanets   #We fit rp with uniform priors
fit_k  = ['u']*nplanets   #We fit k with uniform priors
fit_v0 = 'u'     #We fit systemc velicities with uniform priors
fit_q1 = ['u']*{nbands}     #We fit q1 with uniform priors
fit_q2 = ['u']*{nbands}     #We fit q2 with uniform priors

#Prior ranges for a parameter A
#if 'f' is selected for the parameter A, A is fixed to the one given by min_A
#if 'u' is selected for the parameter A, sets uniform priors between min_A and max_A
#if 'g' is selected for the parameter A, sets gaussian priors with mean min_A and standard deviation max_A

min_t0  = [{transit_0_lower:0.04f}] 
max_t0  = [{transit_0_upper:0.04f}]  
min_P   = [{period_lower}]
max_P   = [{period_upper}]
min_ew1 = [0.0]*nplanets  # Change to -1 if modelling ew
min_ew2 = [0.0]*nplanets  # Change to -1 if modelling ew
max_ew1 = [1.0]*nplanets
max_ew2 = [1.0]*nplanets
min_e   = [0]*nplanets
max_e   = [1]*nplanets
min_w   = [0]*nplanets
max_w   = [2*np.pi]*nplanets
min_a   = [{min_a}]*nplanets
max_a   = [{max_a}]*nplanets
min_b   = [0.0]*nplanets
max_b   = [1.15]*nplanets
min_k   = [0.0]*nplanets
max_k   = [0.001]*nplanets
min_rp  = [{prad_min:0.04f}]
max_rp  = [{prad_max:0.04f}]
min_q1  = [0.0]*{nbands} 
max_q1  = [1.0]*{nbands} 
min_q2  = [0.0]*{nbands} 
max_q2  = [1.0]*{nbands} 

""" 
context = {
"pyaneti_data_filename":pyaneti_data_filename,  
"TIC_no":TIC_no,
"sectors":sectors,
"no_transits":len(transit_time),
"transit_time":transit_time,
"data_method":data_method,
"bands":bands,
"radius_source":radius_source,
"mass_source":mass_source,
"Teff_source":Teff_source,
"is_jitter_tr":is_jitter_tr,
"pya_mass":float(pya_mass),
"pya_emass":float(pya_emass), 
"pya_rad":float(pya_rad),     
"pya_erad":float(pya_erad),    
"pya_Teff":float(pya_Teff),     
"pya_eTeff":float(pya_eTeff),
"sample_stellar_density":sample_stellar_density,
"fit_a":fit_a,
"fit_a_comment":fit_a_comment,
"min_a":min_a,
"max_a":max_a,
"transit_0_lower":transit_time[0]-0.03,
"transit_0_upper":transit_time[0]+0.03,
"period_lower":period*.99,
"period_upper":period*1.01,
"prad_min":r_pl_solar_radius.value*0.2,
"prad_max":r_pl_solar_radius.value*2,
"nbands":nbands
} 
#print(context)

with open(pyaneti_input_file, "w", newline='') as f:
    #writer = csv.writer(f, delimiter='\t')
    f.write(template.format(**context))
    
f.close()
print("Created " + pyaneti_input_file)

In [None]:
!gedit $pyaneti_input_file

In [None]:
%%time
!echo Running Pyaneti
!cd $pyaneti_home_dir ; $pyaneti_command

In [None]:
import pyaneti_utils as py
%matplotlib widget
py.display_results(TIC_no,target_out_dir)

In [None]:
#pyaneti_home_dir = "/home/ian/pyaneti"
#!cd $pyaneti_home_dir ; python pyaneti.py 312543349

# Publish in PHT Shared

In [None]:
import os
import shutil
from pathlib import Path
from distutils.dir_util import copy_tree

publish_dir = '/mnt/g/Google Drive/Astrophysics/PHT Shared/{}{}/'.format(TIC_no,model_suffix)
print(publish_dir)
# Save the params `.dat` as `.txt` so that it can be easily viewed on Google Drive.
file_params = Path(target_out_dir, f"{TIC_no}_params.dat")
file_params_txt = Path(target_out_dir, f"{TIC_no}_params.txt")
shutil.copyfile(file_params, file_params_txt)

# Copy `input_fit.py` and input data file to output directory so that it can be easily shared (on Google Drive).
destination = Path(target_out_dir, input_fit_filename)
shutil.copyfile(pyaneti_input_file, destination)
destination = Path(target_out_dir, pyaneti_data_filename)
shutil.copyfile(pyaneti_data_file, destination)

try:
    shutil.rmtree(publish_dir)
except:
    pass
try:
    os.mkdir(publish_dir)
except:
    pass

# Publish to PHT Shared and rename the model directories to preserve them
copy_tree(target_out_dir, publish_dir)
try:
    shutil.rmtree(target_out_dir + model_suffix)
except:
    pass
shutil.move(target_out_dir, target_out_dir + model_suffix)
try:
    shutil.rmtree(target_in_dir + model_suffix)
except:
    pass
shutil.move(target_in_dir, target_in_dir + model_suffix)

# Google publish
To publish Google Drive links in discussion:
https://drive.google.com/uc?export=view&id=

- Add the image's id to the end

# Odd-Even Check

In [None]:
def odd_even_phase(lc, period, t0, plot_size = False, same_axes = False, binning = False):
    
    lc = lc.normalize()
    
    if binning != False:
        lc = lc.bin(binning/60/24) # you can change the binning factor here if you like
    time = lc.time.value
    flux = lc.flux.value
    
    t0_odd = t0
    t0_even = t0 + period
    period = period*2
    
    phase_odd = np.array([-0.5+( ( t - t0_odd-0.5*period) % period) / period for t in time])
    phase_even = np.array([-0.5+( ( t - t0_even-0.5*period) % period) / period for t in time])

    if same_axes == False:
        fig, ax = plt.subplots(1,2, figsize = (20,8), sharey = True)
        
        ax[0].plot(phase_odd, flux, lw = 0, color = 'navy', marker = '.', alpha =0.4)
        ax[1].plot(phase_even, flux, lw = 0, color = 'maroon', marker = '.', alpha =0.4)
        
        ax[0].set_xlabel("Phase")
        ax[0].set_ylabel("Normalized flux")
        ax[1].set_xlabel("Phase")
        
        ax[0].annotate("ODD", (0.3, np.nanmin(flux)), fontsize = 14)
        ax[1].annotate("EVEN", (0.3, np.nanmin(flux)), fontsize = 14)
    
        plt.subplots_adjust(wspace=0.02)
        
        if plot_size != False:
            ax[0].set_xlim(-plot_size, plot_size)
            ax[1].set_xlim(-plot_size, plot_size)
    else:
        fig, ax = plt.subplots(figsize = (10,8))
        
        ax.plot(phase_odd, flux, lw = 0, color = 'navy', marker = '.', alpha =0.4, label = 'odd')
        ax.plot(phase_even, flux, lw = 0, color = 'maroon', marker = '.', alpha =0.4, label = 'even')
        
        ax.set_xlabel("Phase")
        ax.set_ylabel("Normalized flux")
        
        plt.legend()
            
        if plot_size != False:
            ax.set_xlim(-plot_size, plot_size)


In [None]:
#period = transits[1] - transits[0]
t0 = transit_time[0]
%matplotlib widget
plt.style.use('default')
odd_even_phase(lc_collection, period, t0=t0, plot_size = 0.05, same_axes = False, binning=True)
period

# Periodogram

In [None]:
pg = lc_collection.to_periodogram(maximum_period=100)
pg.plot(view='period');
print(pg.period_at_max_power)
print(pg.period)

In [None]:
# Create a model light curve for the highest peak in the periodogram
lc_model = pg.model(time=lc_collection.time, frequency=pg.frequency_at_max_power)
# Plot the light curve
axper = lc_collection.plot()
# Plot the model light curve on top
lc_model.plot(ax=axper, lw=3, ls='--', c='red');

In [None]:
# Remove the signals associated with the 50 highest peaks
newlc = lc_collection.copy()
for i in range(50):
  pg = newlc.to_periodogram()
  model = pg.model(time=newlc.time, frequency=pg.frequency_at_max_power)
  newlc.flux = newlc.flux / model.flux

# Plot the new light curve on top of the original one
axper1 = lc_collection.plot(alpha=.5, label='Original',lw = 0, marker = '.', ms = 1);
newlc.plot(ax=axper1, label='New');

In [None]:
# Folded
period = 8.03
axperf = newlc.fold(period=period, epoch_time = transit_time[0]).plot(label='Unbinned',lw = 0, marker = '.', ms = 1)
newlc.fold(period=period, epoch_time = transit_time[0]).bin(0.1).plot(ax=axperf, lw=2, label='Binned');

# River Plot

In [None]:
%matplotlib inline
period = 12.0319985
lcl = lc_collection
#lcl = lcl.flatten(21)
lcl_fold = lcl.fold(period = period, epoch_time = transit_time[0])
lcl_fold.plot_river()

# Plot against known exoplanets

In [None]:
# Check NASA Exoplanet Archive and retrieve planets with similar period and radius (within tolerance of each)
import astropy.units as u
from astropy.time import Time
from astroquery.ipac.nexsci.nasa_exoplanet_archive import NasaExoplanetArchive
import numpy as np

period_c =  float(period)
radius_c = r_pl_jup_radius.value

tolerance = 0.25

query_string = "(pl_orbper > %0.3f and pl_orbper < %0.3f) and (pl_radj > %0.3f and pl_radj < %0.3f)" %(period_c-(period_c*tolerance), period_c+(period_c*tolerance), radius_c-(radius_c*tolerance), radius_c+(radius_c*tolerance))
print(query_string)
planet_data = NasaExoplanetArchive.query_criteria(table="pscomppars", where=query_string, order="hostname")
if len(planet_data) > 0:
    planets = np.array(planet_data)  #extract table data into an array
    planets_df = pd.DataFrame(planets, columns=["pl_name","disc_year","pl_orbper","pl_trandur","pl_radj","pl_rade","st_rad","st_teff"])  #create a pandas data frame from the array and the headers
    print(planets_df)
else:
    print ("No Planets found")

In [None]:
per=np.array(planet_data["pl_orbper"])
rad=np.array(planet_data["pl_radj"])
st_rad=np.array(planet_data["st_rad"])
teff=np.array(planet_data['st_teff'])

fig_plan, ax_plan = plt.subplots(figsize=(10, 10))
sc = plt.scatter(per, rad, s=st_rad*50, c = np.log(teff), cmap = 'RdYlBu', alpha = 0.9, zorder = -1, edgecolors = 'white', linewidths = 0.1)
sc.set_clim(np.log(2500), np.log(20000))
sc_target = plt.scatter(period_c,radius_c,s=R_star*50, c = np.log(TIC_Teff), cmap = 'RdYlBu', alpha = 0.8, zorder = -1, edgecolors = 'black', linewidths = 0.5)
sc_target.set_clim(np.log(2500), np.log(20000))
for i, txt in enumerate(planet_data["pl_name"]):
    ax_plan.annotate(txt, (per[i], rad[i]),color='grey',fontsize=10,textcoords="offset points",xytext=(4,3))  # Test is offest from the point by the xytext value
ax_plan.annotate(TIC, (period_c,radius_c),color='grey',fontsize=10,textcoords="offset points",xytext=(4,3))

plt.xlabel('Period (d)')
plt.ylabel('Radius (Rjup)')

# Create star chart round a target

In [None]:
import star_data as star
TIC_Teff, TIC_lum, TIC_rad = star.plot_star_chart(TIC_no)

# HR Diagram

In [None]:
import star_data as star
star.plot_HR(TIC_no, TIC_Teff, TIC_lum, TIC_rad, TOI = False)

In [None]:
import csv
# Export file for pyaneti
with open('{}/{}_data_pyaneti.dat'.format(TIC_no, TIC_no), "w") as f:
    writer = csv.writer(f, delimiter='\t')
    writer.writerow(['#time', 'flux', 'flux_err'])

# If the dip separations are too small, then don't create cut outs and save the whole dataset
#if (len(transit_list) > 1) and ((transit_list[1] - transit_list[0]) < 2): # if there are LOTS of transit events on short period (if so it's probably a TOI but let's keep it here as a condition)
with open('{}/{}_data_pyaneti.dat'.format(TIC_no, TIC_no), "a") as f:
    writer = csv.writer(f, delimiter='\t')
    writer.writerows(zip(lc_collection.remove_nans().time,lc_collection.remove_nans().flux,lc_collection.remove_nans().flux_err)) # save all the data

# else create a cut out of the data around the time of the transit events
#else:
#    for transit in transit_list:
#        # save the data 
#        # get rid of nan values first - this is used for the pyaneti code
#        pyaneti_mask = (alltime_ar > (transit - 1)) * (alltime_ar < (transit + 1))

#        with open('{}/{}/{}_data_pyaneti.dat'.format(indir, tic, tic), "a") as f:
#            writer = csv.writer(f, delimiter='\t')
#            writer.writerows(zip(alltime_ar[pyaneti_mask],allflux_ar[pyaneti_mask],allflux_err_ar[pyaneti_mask]))

# Old Calc actions

In [None]:
# Click event to select the TOP and BOTTOM of the Transit 
def onclick(event):
    global ix, iy
    ix, iy = event.xdata, event.ydata
    if event.button == 3:  # Only act on a right click
        global coords

        if len(coords) == 0:  # remove any existing markers if this is the first click
            try: 
                tran[0].remove()
            except:
                pass
            try: 
                tran[1].remove()
            except:
                pass
            tran.clear()

        coords.append((ix, iy))
        tran.append(plt.axhline(iy, color = 'orange', zorder = -1))

        if len(coords) == 2:
            fig.canvas.mpl_disconnect(cid)
            
# Click event to select the START and END of the Transit 
def onclick1(event):
    global ix, iy
    ix, iy = event.xdata, event.ydata
    if event.button == 3:  # Only act on a right click
        global coords

        if len(coords) == 0:  # remove any existing markers if this is the first click
            try: 
                tran[0].remove()
            except:
                pass
            try: 
                tran[1].remove()
            except:
                pass
            tran.clear()

        coords.append((ix, iy))
        tran.append(plt.axvline(ix, color = 'blue', zorder = -1))

        if len(coords) == 2:
            fig.canvas.mpl_disconnect(cid)
            
# Click event to select the Floor START and END of the Transit 
def onclick2(event):
    global ix, iy
    ix, iy = event.xdata, event.ydata
    if event.button == 3:  # Only act on a right click
        global coords

        if len(coords) == 0:  # remove any existing markers if this is the first click
            try: 
                tran[0].remove()
            except:
                pass
            try: 
                tran[1].remove()
            except:
                pass
            tran.clear()

        coords.append((ix, iy))
        tran.append(plt.axvline(ix, color = 'grey', linestyle = '-', zorder = -1))

        if len(coords) == 2:
            fig.canvas.mpl_disconnect(cid)

In [None]:
# Click on the TOP and BOTTOM, then run this = transit depth

transit_top = coords[0][1] 
transit_bottom = coords[1][1] 
coords.clear()
cid = fig.canvas.mpl_connect('button_press_event', onclick)  # Reconnect the event in case we change our mind

transit_depth = transit_top-transit_bottom
if transit_depth < 0 : transit_depth *= -1  # In case we choose the wrong way round
print ('Transit Depth = %6.4f' %(transit_depth))

In [None]:
# When happy with top/bottom, run this to initialise event to select the START and END of the Transit - run this first

fig.canvas.mpl_disconnect(cid)
tran.clear()

cid = fig.canvas.mpl_connect('button_press_event', onclick1)  # Connect the event

In [None]:
# Click on the START and END then run this = transit length

transit_start = coords[0][0] 
transit_end = coords[1][0] 
coords = []
cid = fig.canvas.mpl_connect('button_press_event', onclick1)  # Reconnect the event in case we change our mind

transit_length = (transit_end-transit_start) * 24
if transit_length < 0 : transit_length *= -1  # In case we choose the wrong way round
print ('Transit Length = %6.4f hours' %(transit_length))

In [None]:
# When happy with start/end, run this to initialise event to select the START and END of the Floor - run this first

fig.canvas.mpl_disconnect(cid)
tran.clear()

cid = fig.canvas.mpl_connect('button_press_event', onclick2)  # Connect the event

In [None]:
# Click on the START and END of the floor then run this = Floor length

floor_start = coords[0][0] 
floor_end = coords[1][0] 
coords = []
cid = fig.canvas.mpl_connect('button_press_event', onclick2)  # Reconnect the event in case we change our mind

floor_length = (floor_end-floor_start) * 24
if floor_length < 0 : floor_length *= -1  # In case we choose the wrong way round
print ('Floor Length = %6.4f hours' %(floor_length))