In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
from datetime import date, datetime, timedelta, time
import pandas as pd
import seaborn
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import emcee
import triangle
from scipy.stats import multivariate_normal, norm
import pathos.multiprocessing as pathos_mp

In [None]:
from simlightcurve.solvers import find_rise_t, find_peak
import scipy.optimize as op
from simtransient.models.supernova import Sn1aOpticalEnsemble
from simtransient.modelrun import ModelRun
import simtransient.utils as stutils
from simtransient import measures, hammer
import simtransient.plot as stplot


In [None]:
# seaborn.set_context('talk')
seaborn.set_context('poster')
current_palette = seaborn.color_palette()
#Poster
# bigfontsize=40
# labelfontsize=35
# tickfontsize=25
#Paper
bigfontsize=20
labelfontsize=18
tickfontsize=18
plt.rcParams.update({'font.size': bigfontsize,
                     'axes.labelsize':labelfontsize,
                     'xtick.labelsize':tickfontsize,
                     'ytick.labelsize':tickfontsize,
                     'legend.fontsize':tickfontsize,
                     })

In [None]:
obs_sigma=1
detection_thresh = 5.
analysis_thresh = 3.

In [None]:
true_ensemble = Sn1aOpticalEnsemble()
gpar_hypers_rv = multivariate_normal(mean=true_ensemble.gauss_pars.loc['mu'], 
                                     cov=true_ensemble.gauss_cov)

In [None]:
true_ensemble.gauss_pars

In [None]:
true_gpars = pd.Series(true_ensemble.gauss_pars.T.mu)
true_gpars[:]= gpar_hypers_rv.rvs()

In [None]:
true_t0 = 5 #+np.random.random()*5
true_pars = true_gpars.copy()
true_pars['t0']=true_t0
# true_pars.a=
true_pars

In [None]:
true_curve = true_ensemble.get_curve(**true_pars)

In [None]:
tstep = 3.5
sim_epochs = np.arange(start = -30., stop=50, step=tstep)
sim_data = true_curve(sim_epochs)+norm(scale=obs_sigma).rvs(size=len(sim_epochs))

sim_data = pd.Series(data=sim_data, index=sim_epochs)
# sim_data

In [None]:
threshold = obs_sigma*detection_thresh

print "Thresh:", threshold
detectable = sim_data[sim_data>obs_sigma*detection_thresh]
monitorable = sim_data[sim_data>obs_sigma*analysis_thresh]
first_detection_epoch = detectable.index[0]
usable_data = monitorable[first_detection_epoch:]
# usable_data

In [None]:
#Detect:
n_data_epochs = 3
obs_data = usable_data.iloc[:n_data_epochs]
# obs_data = usable_data.iloc[::2].iloc[:3]
obs_data

In [None]:
seaborn.set_style('darkgrid')
seaborn.palplot(current_palette)

In [None]:
tsteps= np.linspace(-30,50, 1000)
plt.plot(tsteps, true_curve(tsteps),c='y', ls='--',label='True')

plt.scatter(sim_data.index, sim_data,c=current_palette[2])

plt.axhline(obs_sigma*detection_thresh, ls='--', label='Detection')
plt.axhline(obs_sigma*analysis_thresh, ls='-.', label='Analysis')

# plt.scatter(usable_data.index, usable_data,color=current_palette[1],s=55)
stplot.curve.graded_errorbar(usable_data,obs_sigma,alpha=0.5)

#Data we'll be using for fitting:
plt.scatter(obs_data.index,obs_data,
            color=current_palette[-1],s=160, lw=3, marker='x',
           label='Fitted data')

# plt.yscale('log')
# plt.axhline(true_amp, ls=':')

# plt.axvline(true_t0, ls='--')
plt.ylim(-1.05*np.abs(np.min(usable_data)),1.1*np.max(usable_data+2*obs_sigma))
plt.legend(loc='best')
plt.gcf().suptitle('Model, observables, detections')


In [None]:
# detection_t0_offset = find_rise_t(zero_t0_curve, threshold=obs_data.iloc[0], 
#                                      t_min=-3*zero_t0_curve.rise_tau, t_max=1)
# t0_guess = obs_data.index[0]-detection_t0_offset
# print t0_guess, true_t0

In [None]:
model1=Sn1aOpticalEnsemble()

In [None]:
# pool = pathos_mp.Pool(4)
mr1= ModelRun(ensemble=model1, 
              obs_data=obs_data, obs_sigma=obs_sigma,
              use_pt=False,
             emcee_kwargs=dict(
#                               pool=pool,
                              threads=1
                              ),
             )
mr1.fit_data()


In [None]:
seaborn.set_style('dark')

data_color = current_palette[1]
stplot.curve.graded_errorbar(obs_data, obs_sigma, color=data_color, alpha=0.8, label='Data')
plt.plot(tsteps,mr1.ml_curve(tsteps), ls='-.',lw=5, label='ML',c=current_palette[0])
plt.plot(tsteps,mr1.map_curve(tsteps), ls='-', label='MAP', c=current_palette[0], zorder=1)
plt.plot(tsteps,true_curve(tsteps), ls='--', lw=5,label='True', c=current_palette[-2])

# plt.errorbar(obs_data.index, obs_data, color=data_color)
plt.ylim(0,max(obs_data)+3*obs_sigma)
plt.legend()

In [None]:
_ = mr1.run(100)

In [None]:
mr1.chainstats

In [None]:
mr1.plot_walkers()

In [None]:
mr1.plot_hists()

In [None]:
mr1.plot_triangle()

In [None]:
mr1.trimmed.shape

In [None]:
mr1.plot_forecast(tsteps,
                  true_curve=true_curve,
                 subsample_size = 150,
                 )

# plt.plot(tsteps,mr1.ensemble.get_curve(**mr1.init_pars)(tsteps),c='r',lw=5)
# plt.plot(tsteps,mr1.map_curve(tsteps),c='y',lw=5)