## Initialize a few things

In [None]:
import emcee
import corner
import speclite as speclite; from speclite import filters
from tqdm import tqdm
from matplotlib import cm
from matplotlib.artist import Artist
from chromatic import *
from scipy.optimize import minimize
from scipy.optimize import curve_fit
from PyAstronomy import pyasl
from specutils.spectra import Spectrum1D, SpectralRegion
from specutils.fitting import fit_generic_continuum

In [None]:
transmission_data = np.loadtxt('../data/transmission_comp.txt')
mol_wave = transmission_data[:,0] * u.nm
mol_data = transmission_data[:,1]

In [None]:
bandpass=np.linspace(3800.,10000.,400)*u.angstrom

sdss_responses = speclite.filters.load_filters('sdss2010-*')
response_g = sdss_responses[1].interpolator(bandpass)
response_r = sdss_responses[2].interpolator(bandpass)
response_i = sdss_responses[3].interpolator(bandpass)

## Photometric variability model

In [None]:
def photometric_variability_model(parameters=None,plot=False,
                                  samples_exist=False,samples=None,
                                  label=None,title=None,
                                  **kwargs):
    
    f_spot,df_spot,T_spot,T_amb = parameters
    bandpass=np.linspace(3800.,10000.,400)*u.angstrom

    spotflux = get_phoenix_photons(temperature=int(T_spot), wavelength = bandpass,logg=4.52, metallicity=0.12)[1]
    ambflux = get_phoenix_photons(temperature=int(T_amb), wavelength = bandpass,logg=4.52, metallicity=0.12)[1]
    this_model_spectrum = f_spot*spotflux + (1.0-f_spot)*ambflux
                     
    d_lambda = (bandpass[1]-bandpass[0])
    contrast = 1.-(spotflux/ambflux)
    ds_over_s = -df_spot * ( contrast / ( 1.-f_spot * contrast ) )
    semi_amplitude = np.abs(ds_over_s)

    numerator = np.nansum(semi_amplitude*this_model_spectrum*response_g*d_lambda)
    denominator = np.nansum(this_model_spectrum*response_g*d_lambda)
    modelgp = numerator/denominator

    numerator = np.nansum(semi_amplitude*this_model_spectrum*response_r*d_lambda)
    denominator = np.nansum(this_model_spectrum*response_r*d_lambda)
    modelrp = numerator/denominator

    numerator = np.nansum(semi_amplitude*this_model_spectrum*response_i*d_lambda)
    denominator = np.nansum(this_model_spectrum*response_i*d_lambda)
    modelip = numerator/denominator

    model = np.array([modelgp,modelrp,modelip])
    model_coords = [4750,6200,7550]
    w_err = [500,400,500]
    phot_data = np.array([0.076,0.071,0.041])
    phot_errs = np.array([0.006,0.006,0.007])*1.25
        
    chisq = np.nansum((phot_data - model)**2./(phot_errs)**2.)
    err_weight = np.nansum(1./np.sqrt(2.*np.pi*(phot_errs)))
    ln_like = (err_weight - 0.5*chisq)
    
    if plot:
        fig, ax1 = plt.subplots(figsize=(6,4))
        ax1.set_title(f'{title}',fontsize=16,loc='left')
        ax1.set_xlabel(r'Wavelength $\AA$',fontsize=20)
        ax1.set_ylabel(r'$\frac{\Delta S(\lambda)}{S_{\rm avg}}$',fontsize=22)
        if samples_exist:
            fspot_sam, dfspot_sam, Tspot_sam, Tamb_sam = samples
            sig1_fspot = np.percentile(fspot_sam, [15.9, 50., 84.1]) # central 1-sigma values
            sig1_dfspot = np.percentile(dfspot_sam, [15.9, 50., 84.1])
            sig1_Tspot = np.percentile(Tspot_sam, [15.9, 50., 84.1])
            sig1_Tamb = np.percentile(Tamb_sam, [15.9, 50., 84.1])
            best_params = np.array([sig1_fspot[1], sig1_dfspot[1], sig1_Tspot[1], sig1_Tamb[1]])
            for k in range(0,1000):
                i = np.random.randint(low=0,high=(len(Tamb_sam)-1))
                ds_spot = get_phoenix_photons(temperature=Tspot_sam[i], wavelength = bandpass,
                                                               logg=4.52, metallicity=0.12)[1]
                ds_amb = get_phoenix_photons(temperature=Tamb_sam[i], wavelength = bandpass,
                                                               logg=4.52, metallicity=0.12)[1]
                _contrast = 1.-(ds_spot/ds_amb)
                _ds_over_s = -dfspot_sam[i] * ( _contrast / ( 1.-fspot_sam[i] * _contrast ) )
                _semi_amplitude = np.abs(_ds_over_s)
                ax1.plot(bandpass, _semi_amplitude, zorder=0, alpha=0.02, color='deepskyblue')

        ax1.errorbar(model_coords,phot_data,yerr=phot_errs,color='black',label='F21',fmt='none',zorder=1000,linewidth=2)
    
        ax1.set_xlim(3800,8500)
        ax1.set_ylim(0,0.11)
        ax1.legend(loc='upper right')

        ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
        ax2.fill_between(bandpass.value, 0, response_g,color='orange',
                         zorder=-100,label='SDSS g filter response',alpha=0.3)
        ax2.fill_between(bandpass.value, 0, response_r,color='teal',
                         zorder=-100,label='SDSS r filter response',alpha=0.3)
        ax2.fill_between(bandpass.value, 0, response_i,color='purple',
                         zorder=-100,label='SDSS i filter response',alpha=0.3)
        ax2.set_ylabel('Filter Response',fontsize=20)
        # ax2.set_ylim(0,1.1)

        ax2.legend(loc='lower left')

        plt.savefig(f'../figs/{label}_photmodel.png',dpi=200)
        plt.show()
    
    return model, ln_like, chisq

In [None]:
def plot_results(samples=None,nsteps=None,label=None,title=None,
                 plot_variability = False,
                 plot_samples = False,
                 plot_corner = False,
                 plot_contamination= False,
                 **kwargs):
    
    fspot_sam, dfspot_sam, Tspot_sam, Tamb_sam = samples
    sig1_fspot = np.percentile(fspot_sam, [15.9, 50., 84.1]) # central 1-sigma values
    sig1_dfspot = np.percentile(dfspot_sam, [15.9, 50., 84.1])
    sig1_Tspot = np.percentile(Tspot_sam, [15.9, 50., 84.1])
    sig1_Tamb = np.percentile(Tamb_sam, [15.9, 50., 84.1])
    famb_sam = 1.0 - fspot_sam
    Teff_sam = (fspot_sam*Tspot_sam**4 + famb_sam*Tamb_sam**4)**(1/4)
    sig1_Teff = np.percentile(Teff_sam, [15.9, 50., 84.1])
    
    best_params = np.array([sig1_fspot[1], sig1_dfspot[1],sig1_Tspot[1], sig1_Tamb[1]])
    # print(f"chisq = {chisq}")
    variable_names = [r'$f_{\rm{spot}}$',r'$\Delta f_{\rm{spot}}$',r'T$_{\rm{spot}}$',r'T$_{\rm{amb}}$']

    print( 'Teff=',int(sig1_Teff[1]),int(sig1_Teff[2]-sig1_Teff[1]),int(sig1_Teff[1]-sig1_Teff[0]) )
    
    phot_chisq = [None]*1000
    for j in range(0,1000):
        i = np.random.randint(low=0,high=(len(Tamb_sam)-1))
        params = [fspot_sam[i],dfspot_sam[i],Tspot_sam[i],Tamb_sam[i]]
        phot_chisq[j] = photometric_variability_model(parameters=params)[2]
    
    print(f'Phot Chisq={np.median(phot_chisq):.1f} (dof=1)')
    
    if plot_variability:
        photometric_variability_model(parameters=best_params,
                                      plot=True,samples_exist=True,
                                      samples=samples,label=label,title=title)
         
    if plot_corner:
        params = {'axes.labelsize': 'xx-large'}
        plt.rcParams.update(params)
        rng = 0.9995
        fig = corner.corner( 
            samples.T,show_titles=True, labels=variable_names,
            range=[rng,rng,rng,rng],
            smooth=1,quantiles=(0.16, 0.84),
            fill_contours=True, plot_datapoints=False,title_kwargs={"fontsize": 16},title_fmt='.3f',
            hist_kwargs={"linewidth": 2.5},levels=[(1-np.exp(-0.5)),(1-np.exp(-2)),(1-np.exp(-4.5))]
        )
        plt.savefig(f'../figs/{label}_corner.png',dpi=200)
        plt.show()
        plt.close()
        
    if plot_samples:
        fig, axs = plt.subplots(2,1,figsize=(4,4),sharex=True)
        fig.suptitle(f'{title}',fontsize=16)
        ax1 = axs[0]
        ax2 = axs[1]
        # ax3 = axs[2]
        # ax4 = axs[3]
        colmap = 'hot'
        
        ax1.scatter(dfspot_sam,Tspot_sam,c=np.log(Tspot_sam/Tamb_sam),
                    cmap=colmap,alpha=0.05,edgecolor=None,vmin=-0.4,vmax=-0.01,s=0.5)
        ax1.set_ylabel(r'T$_{\rm{spot}}$',fontsize=20)
        ax1.set_ylim(2300,3800)


        bottom_plot = ax2.scatter(dfspot_sam,fspot_sam,
                                        c=np.log(Tspot_sam/Tamb_sam),cmap=colmap,
                                        alpha=0.05,edgecolor=None,vmin=-0.4,vmax=-0.01,s=0.5)
        ax2.set_xlim(0.0,0.25)
        ax2.set_xlabel(r'$\Delta f_{\rm{spot}}$',fontsize=20)
        
        ax2.set_ylabel(r'$f_{\rm{spot}}$',fontsize=20)
        ax2.set_ylim(0,1.0)
        fig.colorbar(bottom_plot, ax=axs[:],label=r'Log($\rm T_{spot}/T_{amb}$)')

        plt.savefig(f'../figs/{label}_samples.png',dpi=200)
        plt.show()
        plt.close()
        
    if plot_contamination:
        
        fig, axs = plt.subplots(1,2,figsize=(8,3),sharey=True)
        ax1 = axs[0]
        ax2 = axs[1]
        
        transit_depth_b = (0.0433)**2. #0.0313 for AU Mic c
        transit_depth_c = (0.0313)**2. #0.0313 for AU Mic c
        
        depths = [transit_depth_b,transit_depth_c]
        titles = ['AU Mic b','AU Mic c']

        """
        Now we will plot samples from the posteriors
        """
        wavelength=np.linspace(0.5,5,200)*u.micron
        ax1.set_ylabel(r'$\Delta \rm D_{spot}$ (ppm)',fontsize=15)
        ax1.set_xlabel(r'Wavelength ($\mu$m)',fontsize=15)
        ax2.set_xlabel(r'Wavelength ($\mu$m)',fontsize=15)
        ax1.set_ylim(0,1400)
        ax1.set_xlim(0.5,5)
        ax2.set_xlim(0.5,5)
        ax2.axhline(10,color='red',label='Est. Atm Signal at 1 Scale Height',linestyle='--',zorder=1000)
        ax1.axhline(52,color='red',label='Est. Atm Signal at 1 Scale Height',linestyle='--',zorder=1000)
        
        l=0
        for ax in [ax1,ax2]:
            depth_factors = [None]*500
            for k in range(0,500):
                j = np.random.randint(low=0,high=(len(Tspot_sam)-1))
                s_spot = get_phoenix_photons(temperature=Tspot_sam[j], wavelength = wavelength,
                                                               logg=4.52, metallicity=0.0)
                s_amb = get_phoenix_photons(temperature=Tamb_sam[j], wavelength = wavelength,
                                                               logg=4.52, metallicity=0.0)
                flux_ratio = s_spot[1]/s_amb[1]
                top = 1.
                bottom = (1. - fspot_sam[j]) + fspot_sam[j] * flux_ratio
                delta_D_spot = ((top / bottom) - 1.) * depths[l]
                depth_factor = (delta_D_spot/depths[l]) + 1.

                depth_factors[k] = depth_factor

                ax.plot(s_spot[0],(depth_factor-1)*depths[l]*1e6,color = 'k',alpha=0.05,zorder=100) # this will be the input wavelength from the order in question

            median_depth_factor = np.median(depth_factors,axis=0)
            ax.plot(s_spot[0],(median_depth_factor-1)*depths[l]*1e6,color = 'turquoise',label='Median contamination model',alpha=1,zorder=10000) # this will be the input wavelength from the order in question
            ax.set_title(f'{titles[l]}',fontsize=14)
            l+=1

        ax2.legend(loc = 'upper right',fontsize=9)
        plt.savefig(f'../figs/{label}_contamination.png',dpi=200)
        plt.show()
        plt.close()

## Define the log probability

In [None]:
def lnprob(parameters=None,**kwargs):

    f_spot,df_spot,T_spot,T_amb = parameters
    
    if (0.0<=f_spot<=1.0) and (0.0<=df_spot<=f_spot) and (12000>=T_amb>=T_spot>=2300.0):

        f_amb = 1.0 - f_spot

        ln_like=0.0
        
        if do_Photometry:
            ln_like_phot = photometric_variability_model(parameters=parameters)[1]
            ln_like += ln_like_phot
        
        if do_Teff:
            Teff_model = (f_spot*(T_spot**4.) + f_amb*(T_amb**4.) )**(1./4.)    
            chisq_Teff = (3650. - Teff_model)**2./(100.)**2.
            err_weight_Teff = 1./np.sqrt(2.*np.pi*(100.))
            ln_like += (err_weight_Teff - 0.5*chisq_Teff) 

    else:
        ln_like = -np.inf

    return ln_like

## Make a wrapper for the sampler

In [None]:
def do_mcmc(label='oops you didnt label your samples :sadface:',
            nsteps=100,burnin=25,ndim=4,nwalkers=100,**kwargs):
    
    # these are initial parameters
    fspot_init = np.random.uniform(0.1, 0.45, nwalkers)
    dfspot_init = np.random.uniform(0.01, 0.1, nwalkers)
    Tspot_init = np.random.uniform(2500, 3400, nwalkers)
    Tamb_init = np.random.uniform(3600, 4100, nwalkers)
    p0 = np.transpose([fspot_init,dfspot_init, Tspot_init, Tamb_init])

    # set up file saving for the samples when finished
    filename = f"../data/samples/{label}.h5"
    backend = emcee.backends.HDFBackend(filename)
    backend.reset(nwalkers, ndim)
    
    # Initialize and run the sampler
    sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, backend=backend)
    result = sampler.run_mcmc(p0, nsteps,store=True, progress=progress)
    samples = sampler.chain[:, burnin:, :].reshape((-1, ndim)).T

    for i in range(len(samples)):
        tau_f = emcee.autocorr.integrated_time(samples[i])
        print('(Nsteps-burnin)*nwalkers/tau=',int((nsteps-burnin)*nwalkers/tau_f))
    
    return samples

## Run an Emcee on the variability model

In [None]:
do_Photometry=True
do_Teff = True

nsteps=2000
progress = True
modeltype = 'Variability Model'    
    
_label = f'{modeltype}_{nsteps}steps'

# Run the MCMC
print('beginning MCMC..')
samples = do_mcmc(label = _label,
                  nsteps = nsteps, burnin = int(0.25*nsteps))

reader = emcee.backends.HDFBackend(f'../data/samples/{_label}.h5')
sampler = reader.get_chain(discard=int(0.25*nsteps), flat=True)
photsamples = sampler.reshape((-1, 4)).T
plot_results(samples=photsamples,nsteps=nsteps,label=_label,
             title=r'T$_{\rm eff}$ & Variability',
             plot_variability = False,
             plot_samples = False,
             plot_corner = False,
            plot_contamination = False)

In [None]:
# nsteps = 500
# modeltype = 'spectraldecomp'
# visit = 'jointvisit'
# progress = False
# do_Photometry=False
# do_Teff = True
# model_spectra_by_order=True
# model_all_good_orders=False

# # orders_to_model = all_orders
# # orders_to_model = [73,76,77]

# for order in tqdm(orders_to_model):
    
#     _label = f'{visit}_3T_{modeltype}_{order}_{nsteps}steps'

#     # Run the MCMC
#     print('beginning MCMC..')
#     samples = do_mcmc(label = _label,
#                       nsteps = nsteps, burnin = int(0.25*nsteps))