In [None]:
%matplotlib inline
from datetime import date, datetime, timedelta, time
import pandas as pd
import seaborn
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np

In [None]:
from simlightcurve.curves import GaussExp as Gred
from simlightcurve.solvers import find_rise_t, find_peak
import scipy.optimize as op
import emcee
from simtransient.models.supernova import Sn1aOpticalEnsemble
from scipy.stats import multivariate_normal, norm

In [None]:
seaborn.set_context('poster')
current_palette = seaborn.color_palette()

In [None]:
obs_sigma=0.05
detection_thresh = 5.
analysis_thresh = 3.

$\qquad =  - 0.5 \sum_{i=1}^N \left[ ln[2\pi\sigma^2] + [x_i - \alpha]^2 / \sigma^2  \right]$

In [None]:
def gaussian_logpdf(x, mean, sigma):
    return -0.5*(np.log(2 * np.pi * sigma ** 2) + ((x - mean)/sigma) ** 2 )

# amplitude_tail_cut = 1e-5
# amp_minmax = (amplitude.ppf(amplitude_tail_cut), amplitude.ppf(1-amplitude_tail_cut))

# def amp_prior(x):    
#     if amp_minmax[0]<=x<=amp_minmax[1]:
#         return gaussian_logpdf(x, amp_mean, amp_sigma)
#     else:
#         return -np.inf

# np_rvs = norm(loc=amp_mean, scale=amp_sigma)
# print gaussian_logpdf(0.6, amp_mean,amp_sigma)
# print np_rvs.logpdf(0.6)


In [None]:
ensemble = Sn1aOpticalEnsemble()
gpar_hypers_rv = multivariate_normal(mean=ensemble.gpars.loc['mu'], 
                                     cov=ensemble.gpar_cov)

In [None]:
true_gpars = pd.Series(ensemble.gpars.loc['mu'])
true_gpars[:]= gpar_hypers_rv.rvs()
true_gpars
true_pars = true_gpars.copy()
true_pars

In [None]:
true_t0 = 5 #+np.random.random()*5
true_pars['t0']=true_t0
true_pars

In [None]:
true_curve = ensemble.get_curve(*true_gpars, t0=true_t0)

maxprior_curve = ensemble.get_curve(*ensemble.gpars.T.mu, t0=true_t0)

In [None]:
true_curve

In [None]:
maxprior_curve

In [None]:
tstep = 1.5
sim_epochs = np.arange(start = -30., stop=50, step=tstep)
sim_data = true_curve(sim_epochs)+norm(loc=0, scale=obs_sigma).rvs(size=len(sim_epochs))

sim_data = pd.Series(data=sim_data, index=sim_epochs)
# sim_data

In [None]:
threshold = obs_sigma*detection_thresh

# detection_time = find_rise_t(true_curve, threshold=threshold, 
#                             t_min=true_curve.t0-2*true_curve.rise_tau, t_max=find_peak(true_curve,true_curve.t0)[0]) 
print "Thresh:", threshold
detectable = sim_data[sim_data>obs_sigma*detection_thresh]
monitorable = sim_data[sim_data>obs_sigma*analysis_thresh]
first_detection_epoch = detectable.index[0]
usable_data = monitorable[first_detection_epoch:]
# usable_data

In [None]:
detectable.index[0]

In [None]:
seaborn.palplot(current_palette)

In [None]:
tsteps= np.linspace(-30,50, 1000)
plt.plot(tsteps, true_curve(tsteps),c='g', ls='--',label='True')
# plt.plot(tsteps, maxprior_curve(tsteps),c='b', ls=':', label='Priori')
# plt.plot(tsteps, true_curve(tsteps)+sigma, ls=':',c='g')
# plt.plot(tsteps, true_curve(tsteps)-sigma, ls=':',c='g')

plt.scatter(sim_data.index, sim_data,c=current_palette[2])

plt.axhline(obs_sigma*detection_thresh, ls='--', label='Detection')
plt.axhline(obs_sigma*analysis_thresh, ls='-.', label='Analysis')

plt.scatter(usable_data.index, usable_data,c=current_palette[1],s=55)

# plt.yscale('log')
# plt.axhline(true_amp, ls=':')

# plt.axvline(true_t0, ls='--')
plt.ylim(-0.2,.8)
plt.legend()


In [None]:
# support = np.linspace(amplitude.ppf(0.01), amplitude.ppf(.99), 100)
# plt.plot(support, amplitude.pdf(support))
# # plt.plot(support, np.log(amplitude.pdf(support)))
# plt.plot(support, amp_prior(support))

In [None]:
#Detect:
n_data_epochs = 3
obs_data = usable_data.iloc[:n_data_epochs]
obs_data

In [None]:
# comparison = pd.DataFrame(index=usable_data.index, data={'noisy':usable_data, 'true':true_curve(usable_data.index)})
# comparison['err'] = comparison.true - comparison.noisy
# comparison.plot()

In [None]:
zero_t0_curve = ensemble.get_curve(*ensemble.gpars.T.mu, t0=0)

In [None]:
detection_t0_offset = find_rise_t(zero_t0_curve, threshold=obs_data.iloc[0], 
                                     t_min=-2*zero_t0_curve.rise_tau, t_max=0)
t0_guess = obs_data.index[0]-detection_t0_offset
print t0_guess, true_t0

In [None]:
maxprior_curve = ensemble.get_curve(*ensemble.gpars.T.mu,t0=t0_guess)

$$
 \textrm{gauss_lnlikelihood} =  - 0.5 \sum_{i=1}^N \left[ ln[2\pi\sigma^2] + [x_i - \alpha]^2 / \sigma^2  \right]
$$

In [None]:
def gauss_lnlikelihood(model_pars, model_ensemble, obs_data, obs_sigma):
    """
    Basic observation likelihood assuming unbiased Gaussian noise of width ``obs_sigma``.
    """
    intrinsic_fluxes = model_ensemble.evaluate(obs_data.index, *model_pars)
    return -0.5 * np.sum(np.log(2 * np.pi * obs_sigma ** 2) + ((obs_data-intrinsic_fluxes) /obs_sigma) ** 2)
#     return -0.5 * np.sum( ((obs_data.values-intrinsic_fluxes) /obs_sigma) ** 2)

In [None]:
init_model_pars = pd.Series(ensemble.gpars.T.mu)
init_model_pars.name="SN1a Params"
init_model_pars['t0']=t0_guess
init_model_pars

In [None]:
gauss_lnlikelihood(init_model_pars, ensemble,obs_data,obs_sigma)

In [None]:
neg_likelihood = lambda *args: -gauss_lnlikelihood(*args)
results = op.minimize(neg_likelihood, init_model_pars, args=(ensemble,obs_data, obs_sigma))
results.x
# print t0_guess, results.x, t0_guess - results.x

In [None]:
max_likelihood_pars = init_model_pars.copy()
max_likelihood_pars[:]=results.x
max_likelihood_pars

In [None]:
ml_curve=ensemble.get_curve(**max_likelihood_pars)

In [None]:
def t0_prior(t0):
#     return 0.0
    tmin,tmax = -60.,40
    if tmin<= t0 < tmax:
        return np.log(1/(tmax-tmin))
    return -np.inf

def sn1a_lnprior(model_pars):
    return ensemble.gpar_lnprior(model_pars[:-1]) + t0_prior(model_pars[-1])

def sn1a_lnprob(model_pars, model_ensemble, obs_data, obs_sigma):
    lp = sn1a_lnprior(model_pars)
    if not np.isfinite(lp):
        prob = -np.inf
    else:
        prob = lp + gauss_lnlikelihood(model_pars, model_ensemble,obs_data,obs_sigma)
    return prob

In [None]:
sn1a_lnprior(max_likelihood_pars), sn1a_lnprior(init_model_pars), sn1a_lnprior(true_pars)

In [None]:
sn1a_lnprob(true_pars, ensemble, obs_data, obs_sigma)

In [None]:
neg_lnprob = lambda *args: -sn1a_lnprob(*args)
results = op.minimize(neg_lnprob, init_model_pars, args=(ensemble,obs_data, obs_sigma))
# results.x
map_pars = init_model_pars.copy()
map_pars[:]=results.x
map_pars


In [None]:
map_curve = ensemble.get_curve(**map_pars)

In [None]:
plt.plot(tsteps,true_curve(tsteps), ls='--', label='true')

plt.plot(tsteps,maxprior_curve(tsteps), ls='--', label='Prior + rise fit')

# plt.plot(tsteps,ml_curve(tsteps), ls='--', label='ML')
plt.plot(tsteps,map_curve(tsteps), ls='--', label='MAP')
plt.scatter(obs_data.index, obs_data, s=50)
plt.errorbar(obs_data.index, obs_data, yerr=1*obs_sigma, c=current_palette[2], linewidth=0,elinewidth=3, ms=16, marker='.')
plt.errorbar(obs_data.index, obs_data, yerr=2*obs_sigma, c=current_palette[2], linewidth=0,elinewidth=1.5)
plt.ylim(0,1.6)
plt.legend()

In [None]:
ndim = len(map_pars)  # number of parameters in the model
nwalkers = 50  # number of MCMC walkers
nsteps = 500  # number of MCMC steps to take
# nsteps = 2  # number of MCMC steps to take
# nburn=1
nthreads=4
#for PT
ntemps=20


theta_init = map_pars.values

# pos = [starting_point+ 1e-4*np.random.randn(ndim) for i in range(nwalkers)]
# pos = [starting_point+ 1e-4*np.random.randn(ndim) for i in range(nwalkers)]

prop_scale=4

nwalkers= nwalkers*5
en_sampler = emcee.EnsembleSampler(nwalkers, ndim, sn1a_lnprob, 
                                a=prop_scale, 
                                args=(ensemble,obs_data, obs_sigma),
                                threads=nthreads)


en_theta_init_ball = theta_init+ 1e-4*np.random.randn(nwalkers*ndim).reshape(nwalkers,ndim)

#PT needs extra Ntemps dimension:

# pt_theta_init_ball = np.random.uniform(low=-1.0, high=1.0, size=(ntemps, nwalkers, ndim))
pt_t0_init = theta_init+ 1e-4*np.random.randn(ndim*nwalkers*ntemps).reshape(ntemps,nwalkers,ndim)
pt_sampler = emcee.PTSampler(ntemps,
                          nwalkers, ndim, 
                          logl=gauss_lnlikelihood,
                          logp=sn1a_lnprior, 
                          a=prop_scale, 
                          loglargs=(ensemble,obs_data, obs_sigma),
                          logpargs=(),
                          threads=nthreads
                          )
sampler=en_sampler
theta_init_ball = en_theta_init_ball

# sampler=pt_sampler
# theta_init_ball = pt_theta_init_ball

In [None]:
# theta_init_ball

In [None]:
sampler.reset()
_=sampler.run_mcmc(theta_init_ball, nsteps)
print "Done"

In [None]:

acorr = np.ceil(np.max(sampler.get_autocorr_time()))
print "Acorr array:",sampler.acor
print "Acorr:", acorr
nburn=np.ceil(acorr*2.5)
print "nburn:", nburn
#Ensemble
samples=sampler.chain[:,nburn::acorr,:]
samples.reshape(-1,ndim)
plotchain=sampler.chain
acceptance = np.median(sampler.acceptance_fraction)

#PT
# plotchain=sampler.chain[0]
# samples=sampler.chain[0][:,nburn::acorr,:].ravel()
# acceptance = np.median(sampler.acceptance_fraction[0])

print "Acceptance:", acceptance


In [None]:
map_pars

In [None]:
varindex = 3
plt.subplot(2,1,1)
for walker in plotchain[:,:,varindex]:
    plt.plot(walker)
plt.axvline(nburn, ls=':', color='k')
print "Acorr",sampler.get_autocorr_time()
print "Acceptance",acceptance
plt.subplot(2,1,2)
for walker in plotchain[:,::acorr,varindex]:
    plt.plot(walker)

In [None]:
varindex=3
plt.hist(plotchain[:,nburn::1,varindex].ravel(),normed=True,alpha=0.8)
plt.hist(plotchain[:,nburn::acorr,varindex].ravel(),normed=True,alpha=0.5)

In [None]:
samples = samples.ravel().reshape(-1,ndim)
samples.shape

In [None]:
plt.hexbin(samples[:,0], samples[:,1])

In [None]:
import triangle

In [None]:
figure = triangle.corner(samples, truths=true_pars.values, labels=true_pars.keys())

In [None]:
print len(samples)
subsamples = samples[np.random.randint(len(samples),size=100)]

    
forecasts = []

for theta in subsamples:
    forecast_curve=ensemble.get_curve(*theta)
    values = forecast_curve(tsteps)
    forecasts.append(values)

forecasts = np.asarray(forecasts)
forecasts.shape

In [None]:
# np.max(subsamples_t0[:,1])
# subsamples

In [None]:
#Forecast:
# t_forecast= 15
# forecast_data=[zero_t0_class2(t_forecast - t0) for t0 in samples]

In [None]:
from matplotlib import gridspec

fig = plt.figure() 
# fig, axes = plt.subplots(nrows=1, ncols=2,
# #                          sharey=True, 
# #                          figsize=(18,6)
#                          )

seaborn.set_context('poster')
gs = gridspec.GridSpec(1,2, width_ratios=[3,1])

ts_axes = plt.subplot(gs[0])

seaborn.tsplot( forecasts,  tsteps, err_style="unit_traces", ax=ts_axes, 
#                ls=''
              )
# seaborn.tsplot(curves,  tsteps, 
#                err_style="ci_band",
#                color='k', ax=ts_axes)
ts_axes.plot(tsteps,true_curve(tsteps), ls='--', c='y',label='true', lw='5')

ts_axes.scatter(obs_data.index,obs_data, c='r', s=200,zorder=10)
# ts_axes.errorbar(t_data,noisy_flux, c='y',yerr=sigma,zorder=10)
# ts_axes.axvline(t_forecast, ls=':')
# ts_axes.axhline(np.mean(forecast_data), ls=':')

# hist_axes = plt.subplot(gs[1])
# hist_axes.hist(forecast_data, orientation='horizontal')
# _=hist_axes.set_ylim(ts_axes.get_ylim())
plt.legend()
# plt.savefig('test.png')