In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
from __future__ import print_function
from datetime import date, datetime, timedelta, time
import pandas as pd
import seaborn
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import gridspec
import numpy as np

In [None]:
import scipy
from scipy.stats import multivariate_normal, norm
import scipy.optimize as op
import emcee
import triangle

In [None]:
import simtransient as st
import simtransient.utils as stutils
from simtransient import hammer
from simtransient.models.supernova import Sn1aOpticalEnsemble, Sn2OpticalEnsemble
from simtransient.modelrun import ModelRun

In [None]:
import math

In [None]:
# Commence plot tweaking!
bigfontsize=40
labelfontsize=35
tickfontsize=30
linewidth=9
#Paper
# bigfontsize=20
# labelfontsize=18
# tickfontsize=18
plt.rcParams.update({'font.size': bigfontsize,
                     'axes.labelsize':labelfontsize,
                     'xtick.labelsize':tickfontsize,
                     'ytick.labelsize':tickfontsize,
                     'legend.fontsize':tickfontsize,
                     })

In [None]:


rpal = seaborn.color_palette('Reds',7)
bpal = seaborn.color_palette('Blues',7)
gpal = seaborn.color_palette('Greens',7)
orpal = seaborn.color_palette('YlOrBr',7)
kpal = seaborn.color_palette('Greys',6)

darkpal = seaborn.color_palette('dark')
mutepal = seaborn.color_palette('muted')
deeppal = seaborn.color_palette('deep')

for pal in (rpal, bpal, gpal, orpal, kpal, darkpal, mutepal, deeppal):
    seaborn.palplot(pal)

In [None]:
pal1 = {}
pal1['trace'] = bpal[3]
pal1['data']= gpal[-1]
pal1['true']=kpal[-2]
pal1['forecast']='k'
pal2 = pal1.copy()
pal2['trace']=rpal[3]

In [None]:
nthreads=4

First: Display the model classes we'll be using:

In [None]:
emcee_kwargs=dict(threads=nthreads)
priorrun1 = ModelRun(ensemble=Sn1aOpticalEnsemble())
sampler1 = priorrun1.get_sampler(threads=nthreads)
priorrun2 = ModelRun(ensemble=Sn2OpticalEnsemble())
sampler2 = priorrun2.get_sampler(threads=nthreads)

In [None]:
%%capture
priorrun1.run(sampler1,300)
priorrun2.run(sampler2,300)
print()

In [None]:
seaborn.set_context('poster')
tsteps=np.linspace(-30,80,500)
t_forecast=50
mr=priorrun1
maxprior_pars = mr.ensemble.gauss_pars.T.mu.copy()
maxprior_pars['t0']=0
maxprior = mr.ensemble.get_curve(**maxprior_pars)


gs = gridspec.GridSpec(2, 8)# width_ratios=[1, 1])
axes = plt.subplot(gs[:,:4]),plt.subplot(gs[:,4:])

priorrun1.plot_forecast(tsteps,
#                        t_forecast=t_forecast, 
                axes=(axes[0],None), 
                       palette=pal1
                 )

priorrun2.plot_forecast(tsteps, 
#                   t_forecast=t_forecast,
                  palette=pal2,
                  axes=(axes[1],None), 
                 )

axes[0].set_ylabel('Flux')
axes[1].set_yticklabels([])
axes[0].set_ylim(axes[1].get_ylim())
for ax in axes:
    ax.set_xlabel('Time')
    ax.set_ylim(0,27)
plt.tight_layout(pad=3)
axes[0].set_title('Type 1')
axes[1].set_title('Type 2')
plt.tight_layout()
fig = plt.gcf()
fig.set_size_inches(12, 5)
plt.savefig('prior_models.pdf')


Simulate some datapoints:

In [None]:
obs_sigma=1
detection_thresh = 4.
analysis_thresh = 2.5



In [None]:
true_ensemble = Sn1aOpticalEnsemble()
true_ensemble.gauss_pars

In [None]:
gpar_hypers_rv = multivariate_normal(mean=true_ensemble.gauss_pars.loc['mu'], 
                                     cov=true_ensemble.gauss_cov)

In [None]:
#Randomly generate some plausible parameters for the intrinsic lightcurve function
true_gpars = pd.Series(index=true_ensemble.gauss_pars.keys())
true_gpars[:]= gpar_hypers_rv.rvs()
true_t0 = 5 #+np.random.random()*5
true_pars = true_gpars.copy()
true_pars['t0']=true_t0

#Then ditch them and use some we made earlier that happen to produce pretty results!
true_pars.a=15.053480
true_pars.rise_tau  =   2.803343
true_pars.decay_tau =    12.711032
true_pars

In [None]:
#Intrinsic lightcurve of our transient:
true_curve = true_ensemble.get_curve(**true_pars)

In [None]:
#OK, now generate some data:
tstep = 3.5 # Represents the cadence of our transient survey
sim_epochs = np.arange(start = -30., stop=50, step=tstep)
rstate = np.random.RandomState(1)
noise_offsets = rstate.normal(scale=obs_sigma,size=len(sim_epochs))
sim_data = true_curve(sim_epochs)+noise_offsets
sim_data = pd.Series(data=sim_data, index=sim_epochs)
# sim_data

In [None]:
threshold = obs_sigma*detection_thresh

print("Thresh:", threshold)
detectable = sim_data[sim_data>obs_sigma*detection_thresh]
monitorable = sim_data[sim_data>obs_sigma*analysis_thresh]
first_detection_epoch = detectable.index[0]
usable_data = monitorable[first_detection_epoch:]

In [None]:
#Determine early detections:
n_data_epochs = 2
obs_data = usable_data.iloc[:n_data_epochs]
# obs_data = usable_data.iloc[::2].iloc[:3]
obs_data

OK! Plot the initial datapoints:

In [None]:
%%capture 
#Don't actually use this plot in-poster - skip to fitted datapoints
currentpal = deeppal

tsteps= np.linspace(-30,50, 1000)
plt.plot(tsteps, true_curve(tsteps),c='y', ls='--',label='True')

plt.scatter(sim_data.index, sim_data,c=currentpal[2])

plt.axhline(obs_sigma*detection_thresh, ls='--', label='Detection')
plt.axhline(obs_sigma*analysis_thresh, ls='-.', label='Analysis')

# plt.scatter(usable_data.index, usable_data,color=current_palette[1],s=55)
st.plot.curve.graded_errorbar(usable_data,obs_sigma,alpha=0.5)

#Data we'll be using for fitting:
plt.scatter(obs_data.index,obs_data,
            color=currentpal[-1],s=160, lw=3, marker='x',
           label='Fitted data')

# plt.yscale('log')
# plt.axhline(true_amp, ls=':')

# plt.axvline(true_t0, ls='--')
plt.ylim(-1.05*np.abs(np.min(usable_data)),1.1*np.max(usable_data+2*obs_sigma))
plt.legend(loc='best')
plt.gcf().suptitle('Model, observables, detections')
#plt.savefig('data.png')

Try fitting the first two datapoints with model Type 2:

In [None]:
n_data_epochs = 2
obs_data = usable_data.iloc[:n_data_epochs]

model_set = {'type1':Sn1aOpticalEnsemble, 'type2':Sn2OpticalEnsemble}
model_runs={}


for model_name, model_ensemble in model_set.items():
    print("Running MCMC for ", model_name)
    mr= ModelRun(ensemble=model_ensemble(), 
          obs_data=obs_data, 
          obs_sigma=obs_sigma,
          use_pt=False,
         )
    if obs_data is not None:
        mr.fit_data()
    sampler = mr.get_sampler(threads=nthreads)
    mr.run(sampler,500)
    model_runs[model_name]=(mr)

In [None]:
mr1 = model_runs['type1']
mr2 = model_runs['type2']

In [None]:
currentpal = deeppal
#Plot the best-fit results
data_color = currentpal[1]
#Observed:
st.plot.curve.graded_errorbar(obs_data, obs_sigma, color=gpal[6], alpha=0.8, label='Observed', zorder=10,ms=25)

plt.plot(tsteps,mr1.ml_curve(tsteps), ls='-.',lw=linewidth, label='T1 ML fit',color=bpal[4])
plt.plot(tsteps,mr1.map_curve(tsteps), ls='-', lw=linewidth, label='T1 MAP fit', color=bpal[3])
# plt.plot(tsteps,mr2.ml_curve(tsteps), ls='-.',lw=5, label='T2 ML fit',c=rpal[3])
# plt.plot(tsteps,mr2.map_curve(tsteps), ls='-', label='T2 MAP fit', c=rpal[3], zorder=1)
# plt.plot(tsteps,true_curve(tsteps), ls='--', lw=linewidth, label='True', c=kpal[-2])


# # plt.errorbar(obs_data.index, obs_data, color=data_color)
plt.xlim(-10,30)
plt.ylim(0,max(obs_data)+3*obs_sigma)
# plt.axhline(obs_sigma*detection_thresh, ls='--', label='Detection threshold', c=current_palette[-1])
plt.legend()
# # plt.gcf().suptitle('Best fits, 2 datapoints',size=25)

plt.xlabel('Time')
plt.ylabel('Flux')
plt.tight_layout()
plt.gcf().set_size_inches(7, 7)
plt.savefig('2dpt_t1_bestfits.pdf')

In [None]:
rcparam_temp = plt.rcParams
scaling=0.75
plt.rcParams.update({'font.size': math.floor(bigfontsize*scaling),
                     'axes.labelsize':math.floor(labelfontsize*scaling),
                     'xtick.labelsize':math.floor(tickfontsize*scaling),
                     'ytick.labelsize':math.floor(tickfontsize*scaling),
                     'legend.fontsize':math.floor(tickfontsize*scaling),
                     })
mr1.plot_triangle(max_n_ticks=2)
plt.savefig('2dpts_t1_triangle.pdf')
plt.rcParams=rcparam_temp

In [None]:
def overplot_ensemble_forecasts(modelruns,t_forecast,n_subsamples, axes=None):
    mr1 = modelruns['type1']
    mr2 = model_runs['type2']

    
#     width_ratios = [6]
#     width_ratios.extend([1]*len(t_forecast))
#     print("Widths",width_ratios)
#     gs = gridspec.GridSpec(1, 1+len(t_forecast),width_ratios=width_ratios)
#     ts_ax = plt.subplot(gs[0])
#     hist_axes = [ plt.subplot(gs[1+idx]) for idx in range(len(t_forecast))]
#     axes = ts_ax, hist_axes
    axes=None
    
    axes=mr2.plot_forecast(tsteps, 
                  t_forecast=t_forecast,
                  forecast_marker=False,
                  kde_noise_sigma=obs_sigma,
                  axes=axes,
                  palette=pal2,
                  plot_data=False,
                  subsample_size=n_subsamples
                 )
    axes=mr1.plot_forecast(tsteps,
                       t_forecast=t_forecast, 
                       kde_noise_sigma=obs_sigma,
                       axes=axes,
                       plot_data=True,
                       palette=pal1,
                       subsample_size=n_subsamples,
                    true_curve=true_curve
                 )
    ts_ax, hist_axes = axes
    ts_ax.set_xlabel('Time')
    ts_ax.set_ylabel('Flux')
    ts_ax.set_title('Overview')
    
    if hist_axes:
        for idx, ax in enumerate(hist_axes):
            ax.set_xlabel('Prob.')
            if idx != (len(hist_axes)-1):
                ax.set_yticklabels([])
            else:
                ax.yaxis.tick_right()
                ax.yaxis.set_label_position("right")
            ax.set_title('$t={}$'.format(t_forecast[idx]))
            ax.set_xticklabels([])
#         max_hist_xticks = 2
#         kde_xloc = plt.MaxNLocator(max_hist_xticks)
#         ax.xaxis.set_major_locator(kde_xloc)
#     plt.gcf().suptitle('Comparison, {} datapoints'.format(ndata), size=25)
#     plt.savefig('single_model_{}_dpts_nonoise.pdf'.format(ndata))
#     plt.savefig('two_models_{}_dpts.pdf'.format(ndata))m

In [None]:
# ??plt.tight_layout

In [None]:
n_subsamples = 150
t_forecast=[18,47]
# t_forecast=[18]

In [None]:
overplot_ensemble_forecasts(model_runs, t_forecast=t_forecast,n_subsamples=n_subsamples)
fig = plt.gcf()
fig.set_size_inches(12, 7)
plt.tight_layout(w_pad=0.05)
plt.savefig('2dpt_ensemble_w_forecasts.pdf')

In [None]:
new_t=9
newdata =  true_curve(new_t)+rstate.normal(scale=obs_sigma)
obs_data.loc[new_t]=newdata
obs_data

In [None]:
n_data_epochs = 3
obs_data = usable_data.iloc[:n_data_epochs]

model_set = {'type1':Sn1aOpticalEnsemble, 'type2':Sn2OpticalEnsemble}
model_runs={}


for model_name, model_ensemble in model_set.items():
    print("Running MCMC for ", model_name)
    mr= ModelRun(ensemble=model_ensemble(), 
          obs_data=obs_data, 
          obs_sigma=obs_sigma,
          use_pt=False,
         )
    if obs_data is not None:
        mr.fit_data()
    sampler = mr.get_sampler(threads=nthreads)
    mr.run(sampler,500)
    model_runs[model_name]=(mr)
    
mr1 = model_runs['type1']
mr2 = model_runs['type2']

In [None]:
overplot_ensemble_forecasts(model_runs, t_forecast=None,n_subsamples=n_subsamples)
fig = plt.gcf()
fig.set_size_inches(10, 7)
plt.ylim(0,15)
plt.tight_layout()
plt.savefig('3dpt_ensemble.pdf')