In [None]:
%matplotlib inline
from datetime import date, datetime, timedelta, time
import pandas as pd
import seaborn
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np

In [None]:
from simlightcurve.solvers import find_rise_t, find_peak
import scipy.optimize as op
import emcee
from simtransient.models.supernova import Sn1aOpticalEnsemble
from scipy.stats import multivariate_normal, norm
import simtransient.utils as simutils
from simtransient import measures, hammer

In [None]:
seaborn.set_context('poster')
current_palette = seaborn.color_palette()
#Poster
# bigfontsize=40
# labelfontsize=35
# tickfontsize=25
#Paper
bigfontsize=20
labelfontsize=18
tickfontsize=18
plt.rcParams.update({'font.size': bigfontsize,
                     'axes.labelsize':labelfontsize,
                     'xtick.labelsize':tickfontsize,
                     'ytick.labelsize':tickfontsize,
                     'legend.fontsize':tickfontsize,
                     })

In [None]:
obs_sigma=0.1
detection_thresh = 5.
analysis_thresh = 3.

In [None]:
ensemble = Sn1aOpticalEnsemble()
gpar_hypers_rv = multivariate_normal(mean=ensemble.gauss_pars.loc['mu'], 
                                     cov=ensemble.gauss_cov)

In [None]:
true_gpars = pd.Series(ensemble.gauss_pars.T.mu)
true_gpars[:]= gpar_hypers_rv.rvs()

In [None]:
true_t0 = 5 #+np.random.random()*5
true_pars = true_gpars.copy()
true_pars['t0']=true_t0
true_pars

In [None]:
true_curve = ensemble.get_curve(*true_pars)

In [None]:
tstep = 1.5
sim_epochs = np.arange(start = -30., stop=50, step=tstep)
sim_data = true_curve(sim_epochs)+norm(scale=obs_sigma).rvs(size=len(sim_epochs))

sim_data = pd.Series(data=sim_data, index=sim_epochs)
# sim_data

In [None]:
threshold = obs_sigma*detection_thresh

print "Thresh:", threshold
detectable = sim_data[sim_data>obs_sigma*detection_thresh]
monitorable = sim_data[sim_data>obs_sigma*analysis_thresh]
first_detection_epoch = detectable.index[0]
usable_data = monitorable[first_detection_epoch:]
# usable_data

In [None]:
#Detect:
n_data_epochs = 6
# obs_data = usable_data.iloc[:6]
obs_data = usable_data.iloc[::2].iloc[:3]
obs_data

In [None]:
seaborn.set_style('darkgrid')
seaborn.palplot(current_palette)

In [None]:
tsteps= np.linspace(-30,50, 1000)
plt.plot(tsteps, true_curve(tsteps),c='g', ls='--',label='True')

plt.scatter(sim_data.index, sim_data,c=current_palette[2])

plt.axhline(obs_sigma*detection_thresh, ls='--', label='Detection')
plt.axhline(obs_sigma*analysis_thresh, ls='-.', label='Analysis')

plt.scatter(usable_data.index, usable_data,color=current_palette[1],s=55)

#Data we'll be using for fitting:
plt.scatter(obs_data.index,obs_data,
            color=current_palette[-1],s=160, lw=3, marker='x',
           label='Fitted data')

# plt.yscale('log')
# plt.axhline(true_amp, ls=':')

# plt.axvline(true_t0, ls='--')
plt.ylim(-1.05*np.abs(np.min(usable_data)),1.05*np.max(usable_data))
plt.legend(loc='best')
plt.gcf().suptitle('Model, observables, detections')
plt


In [None]:
# detection_t0_offset = find_rise_t(zero_t0_curve, threshold=obs_data.iloc[0], 
#                                      t_min=-3*zero_t0_curve.rise_tau, t_max=1)
# t0_guess = obs_data.index[0]-detection_t0_offset
# print t0_guess, true_t0

In [None]:
#Initialize a ML fit at the maximum Gaussian priors, and wildly guess at t0
init_model_pars = pd.Series(ensemble.gauss_pars.T.mu, copy=True)
init_model_pars.name="Initial"
init_model_pars['t0']=obs_data.index[0]
init_model_pars

In [None]:
neg_likelihood = lambda *args: -measures.gauss_lnlikelihood(*args)
results = op.minimize(neg_likelihood, init_model_pars.copy(), args=(ensemble,obs_data, obs_sigma))
max_likelihood_pars = pd.Series(init_model_pars, name='MaxLikelihood', copy=True)
max_likelihood_pars[:]=results.x
print max_likelihood_pars
ml_curve=ensemble.get_curve(**max_likelihood_pars)

In [None]:
obs_data_min = np.min(obs_data)
obs_data_max = np.max(obs_data)

max_timespan = max_likelihood_pars.decay_tau*10

# bounded_t0_lnprior = measures.get_uniform_lnprior(obs_data_min-max_timespan, obs_data_max+max_timespan)

def lnprob(model_pars, model_ensemble, obs_data, obs_sigma):
    lp = model_ensemble.lnprior(model_pars)
    if not np.isfinite(lp):
        prob = -np.inf
    else:
        prob = lp + measures.gauss_lnlikelihood(model_pars, model_ensemble,obs_data,obs_sigma)
    return prob

In [None]:
neg_lnprob = lambda *args: -lnprob(*args)
results = op.minimize(neg_lnprob, init_model_pars, args=(ensemble,obs_data, obs_sigma))
# results.x
map_pars = init_model_pars.copy()
map_pars[:]=results.x
map_curve = ensemble.get_curve(**map_pars)
map_pars

In [None]:
plt.plot(tsteps,true_curve(tsteps), ls='--', label='true')

plt.plot(tsteps,ml_curve(tsteps), ls=':', label='ML')
plt.plot(tsteps,map_curve(tsteps), ls='-', label='MAP')
plt.scatter(obs_data.index, obs_data, s=50)
plt.errorbar(obs_data.index, obs_data, yerr=1*obs_sigma, c=current_palette[2], linewidth=0,elinewidth=3, ms=16, marker='.')
plt.errorbar(obs_data.index, obs_data, yerr=2*obs_sigma, c=current_palette[2], linewidth=0,elinewidth=1.5)
plt.ylim(0,1.6)
plt.legend()

In [None]:
ndim = len(map_pars)  # number of parameters in the model
nwalkers = 50  # number of MCMC walkers
nsteps = 100  # number of MCMC steps to take
# nsteps = 2  # number of MCMC steps to take
# nburn=1
nthreads=4
#for PT
ntemps=10

theta_init = map_pars.values

# pos = [starting_point+ 1e-4*np.random.randn(ndim) for i in range(nwalkers)]
# pos = [starting_point+ 1e-4*np.random.randn(ndim) for i in range(nwalkers)]

prop_scale=3
en_nwalkers = nwalkers*3

en_sampler, en_init_ball = hammer.prep_ensemble_sampler(
    lnprob = lnprob,
    init_params = map_pars.values,
    args=(ensemble,obs_data, obs_sigma),
    nwalkers=en_nwalkers,
    nthreads=nthreads,
    a=prop_scale
    )


pt_sampler,pt_init_ball = hammer.prep_pt_sampler(
    lnprior = ensemble.lnprior, 
    init_params = map_pars.values,
    lnlikelihood=measures.gauss_lnlikelihood,
    lnlikeargs=(ensemble,obs_data, obs_sigma),
    nwalkers=en_nwalkers,
    nthreads=nthreads,
    a=prop_scale
    )

In [None]:
en_sampler.reset()
with np.errstate(all='ignore'):
    _=en_sampler.run_mcmc(en_init_ball, nsteps)
print "Done"

In [None]:
# %%capture
# pt_sampler.reset()
# # with np.errstate(over='ignore', under='ignore'):
# with np.errstate(all='ignore'):
#     _=pt_sampler.run_mcmc(pt_init_ball, nsteps)
# print "Done"

In [None]:
sampler = en_sampler
plotchain=sampler.chain
acceptance = np.median(sampler.acceptance_fraction)

In [None]:
map_pars

In [None]:
cs, trimmed = hammer.trim_chain(sampler, pt=False)
cs, trimmed.shape

In [None]:
hammer.plot_all_param_chain(sampler.chain, cs, map_pars.index)
plt.tight_layout()

In [None]:
varindex=3
varname=map_pars.keys()[varindex]
plt.xlabel(varname)
plt.ylabel("Frequency")
plt.hist(plotchain[:,nburn::1,varindex].ravel(),normed=True,alpha=0.8)
plt.hist(plotchain[:,nburn::acorr,varindex].ravel(),normed=True,alpha=0.5)
plt.gcf().suptitle('Effects of sample thinning')

In [None]:
plotchain.shape

In [None]:
samples=plotchain[:,nburn::acorr,:]
samples = samples.ravel().reshape(-1,ndim)
samples.shape

In [None]:
seaborn.set_style('white')
import triangle

In [None]:
figure = triangle.corner(samples, truths=true_pars.values, labels=true_pars.keys())

In [None]:
print len(samples)
subsamples = samples[np.random.randint(len(samples),size=200)]
# subsamples = samples

    
sample_curves = []

for theta in subsamples:
    sc=ensemble.get_curve(*theta)
    values = sc(tsteps)
    sample_curves.append(values)

sample_curves = np.asarray(sample_curves)
sample_curves.shape

In [None]:
#Forecast:
t_forecast= 40
forecast_data=[ensemble.get_curve(*theta)(t_forecast) for theta in samples]

In [None]:
from matplotlib import gridspec
seaborn.set_style('darkgrid')
fig = plt.figure() 
# seaborn.set_context('poster')
gs = gridspec.GridSpec(1,2, width_ratios=[3,1])

ts_axes = plt.subplot(gs[0])

seaborn.tsplot( sample_curves,  tsteps, err_style="unit_traces", ax=ts_axes, 
               ls=''
              )

ts_axes.plot(tsteps,true_curve(tsteps), ls='--', c='y',label='true', lw='5')

ts_axes.scatter(obs_data.index,obs_data, c='r', s=200,zorder=10)
# ts_axes.errorbar(t_data,noisy_flux, c='y',yerr=sigma,zorder=10)

ts_axes.axvline(t_forecast, ls=':')
ts_axes.axhline(np.mean(forecast_data), ls=':')

hist_axes = plt.subplot(gs[1])
hist_axes.axhline(np.mean(forecast_data), ls=':')
hist_axes.axhline(true_curve(t_forecast), ls='--', c='y')
hist_axes.hist(forecast_data, orientation='horizontal')
_=hist_axes.set_ylim(ts_axes.get_ylim())
plt.legend()
# plt.savefig('test.png')