In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn
from scipy.stats import multivariate_normal, norm
from statsmodels.nonparametric.kde import KDEUnivariate

In [None]:
from simtransient.models.supernova import Sn1aOpticalEnsemble, Sn2OpticalEnsemble
from simtransient.modelrun import ModelRun
import simtransient.utils as stutils
from simtransient import measures, hammer
import simtransient.plot as stplot


In [None]:
# seaborn.set_context('talk')
seaborn.set_context('poster')
current_palette = seaborn.color_palette()
#Poster
# bigfontsize=40
# labelfontsize=35
# tickfontsize=25
#Paper
bigfontsize=20
labelfontsize=18
tickfontsize=18
plt.rcParams.update({'font.size': bigfontsize,
                     'axes.labelsize':labelfontsize,
                     'xtick.labelsize':tickfontsize,
                     'ytick.labelsize':tickfontsize,
                     'legend.fontsize':tickfontsize,
                     })

In [None]:
obs_sigma=1
detection_thresh = 4.
analysis_thresh = 2.5

In [None]:
true_ensemble = Sn1aOpticalEnsemble()
gpar_hypers_rv = multivariate_normal(mean=true_ensemble.gauss_pars.loc['mu'], 
                                     cov=true_ensemble.gauss_cov)

In [None]:
true_ensemble.gauss_pars

In [None]:
true_gpars = pd.Series(true_ensemble.gauss_pars.T.mu)
true_gpars[:]= gpar_hypers_rv.rvs()
true_t0 = 5 #+np.random.random()*5
true_pars = true_gpars.copy()
true_pars['t0']=true_t0

true_pars.a=15.053480
true_pars.rise_tau  =   2.803343
true_pars.decay_tau =    12.711032
true_pars

In [None]:
true_curve = true_ensemble.get_curve(**true_pars)

In [None]:
tstep = 3.5
sim_epochs = np.arange(start = -30., stop=50, step=tstep)

In [None]:
rstate = np.random.RandomState(1)
noise_offsets = rstate.normal(scale=obs_sigma,size=len(sim_epochs))
# plt.plot(noise_offsets)

In [None]:
sim_data = true_curve(sim_epochs)+noise_offsets

sim_data = pd.Series(data=sim_data, index=sim_epochs)
# sim_data

In [None]:
threshold = obs_sigma*detection_thresh

print "Thresh:", threshold
detectable = sim_data[sim_data>obs_sigma*detection_thresh]
monitorable = sim_data[sim_data>obs_sigma*analysis_thresh]
first_detection_epoch = detectable.index[0]
usable_data = monitorable[first_detection_epoch:]
# usable_data

In [None]:
#Detect:
n_data_epochs = 1
obs_data = usable_data.iloc[:n_data_epochs]
# obs_data = usable_data.iloc[::2].iloc[:3]
obs_data

In [None]:
seaborn.set_style('darkgrid')

current_palette = seaborn.color_palette('dark')
seaborn.palplot(current_palette)


In [None]:
tsteps= np.linspace(-30,50, 1000)
plt.plot(tsteps, true_curve(tsteps),c='y', ls='--',label='True')

plt.scatter(sim_data.index, sim_data,c=current_palette[2])

plt.axhline(obs_sigma*detection_thresh, ls='--', label='Detection')
plt.axhline(obs_sigma*analysis_thresh, ls='-.', label='Analysis')

# plt.scatter(usable_data.index, usable_data,color=current_palette[1],s=55)
stplot.curve.graded_errorbar(usable_data,obs_sigma,alpha=0.5)

#Data we'll be using for fitting:
plt.scatter(obs_data.index,obs_data,
            color=current_palette[-1],s=160, lw=3, marker='x',
           label='Fitted data')

# plt.yscale('log')
# plt.axhline(true_amp, ls=':')

# plt.axvline(true_t0, ls='--')
plt.ylim(-1.05*np.abs(np.min(usable_data)),1.1*np.max(usable_data+2*obs_sigma))
plt.legend(loc='best')
plt.gcf().suptitle('Model, observables, detections')
#plt.savefig('data.png')

In [None]:
model1=Sn1aOpticalEnsemble()
model1=Sn2OpticalEnsemble()

In [None]:
model_set = {'type1':Sn1aOpticalEnsemble, 'type2':Sn2OpticalEnsemble}
model_runs={'type1':[], 'type2':[]}

for n_data_epochs in range(0,len(usable_data)):
# for n_data_epochs in range(0,3):
    print "Running MCMC for", n_data_epochs, "datapoints"
    if n_data_epochs==0:
        obs_data=None
    else:
        obs_data = usable_data.iloc[:n_data_epochs]
        
    for model_name, model_ensemble in model_set.items():
        mr= ModelRun(ensemble=model_ensemble(), 
              obs_data=obs_data, 
              obs_sigma=obs_sigma,
              use_pt=False,
             )
        if obs_data is not None:
            mr.fit_data()
        sampler = mr.get_sampler(threads=4)
        mr.run(sampler,500)
        model_runs[model_name].append(mr)
        

In [None]:
ndata=2
mr1 = model_runs['type1'][ndata]
mr2 = model_runs['type2'][ndata]

In [None]:
usable_data

In [None]:
usable_data.iloc[:ndata]

In [None]:
seaborn.set_style('darkgrid')
obs_data = usable_data.iloc[:ndata]


data_color = current_palette[1]
stplot.curve.graded_errorbar(obs_data, obs_sigma, color=data_color, alpha=0.8, label='Data')
if ndata:
    plt.plot(tsteps,mr1.ml_curve(tsteps), ls='-.',lw=5, label='ML fit',c=current_palette[0])
    pass
    plt.plot(tsteps,mr1.map_curve(tsteps), ls='-', label='MAP fit', c=current_palette[0], zorder=1)
    plt.plot(tsteps,true_curve(tsteps), ls='--', lw=5,label='True', c=current_palette[-2])


# plt.errorbar(obs_data.index, obs_data, color=data_color)
plt.xlim(-10,30)
plt.ylim(0,max(obs_data)+3*obs_sigma)
plt.axhline(obs_sigma*detection_thresh, ls='--', label='Detection threshold', c=current_palette[-1])
plt.legend()
# plt.gcf().suptitle('Best fits, 2 datapoints',size=25)

plt.xlabel('Time')
plt.ylabel('Flux')
plt.savefig('2datapts_plot3.pdf')

In [None]:
mr1.chainstats

In [None]:
mr2.chainstats

In [None]:
mr1.plot_walkers()

In [None]:
mr1.plot_hists()

In [None]:
seaborn.set_style('white')
mr1.plot_triangle(plot_contours=True)
plt.savefig('2datapts_constrained_sample.pdf')

In [None]:
mr2.plot_triangle()

In [None]:
seaborn.palplot(seaborn.color_palette("Paired",8))

In [None]:
import itertools
# chls=seaborn.color_palette("husl", 6)
palette=seaborn.color_palette("Paired",8)
pal_cycle = itertools.cycle(palette)
pal1 = dict(trace=pal_cycle.next(), forecast=pal_cycle.next(), data=pal_cycle.next())
pal1['data']=pal_cycle.next()
pal2 = dict(trace=pal_cycle.next(), forecast=pal_cycle.next())
pal1['true']=palette[-1]
# , data=pal_cycle.next())

In [None]:
t_forecast=40
n_subsamples = 150

In [None]:
seaborn.set_style('dark')

In [None]:
def overplot_ensemble_forecasts(ndata,t_forecast,n_subsamples):
    mr1 = model_runs['type1'][ndata]
    mr2 = model_runs['type2'][ndata]
    axes=None
    plot_data = not mr1.obs_data is None
    
    axes=mr2.plot_forecast(tsteps, 
                  t_forecast=t_forecast,
                  kde_noise_sigma=obs_sigma,
                  axes=axes,
                  palette=pal2,
                  plot_data=False,
                  subsample_size=n_subsamples
                 )
    axes=mr1.plot_forecast(tsteps,
                       t_forecast=t_forecast, 
                       kde_noise_sigma=obs_sigma,
                       axes=axes,
                       plot_data=plot_data,
                       palette=pal1,
                       subsample_size=n_subsamples,
                    true_curve=true_curve
                     
                 )
    plt.gcf().suptitle('Comparison, {} datapoints'.format(ndata), size=25)
#     plt.savefig('single_model_{}_dpts.pdf'.format(ndata))
#     plt.savefig('two_models_{}_dpts.pdf'.format(ndata))

In [None]:
mr1_prior = model_runs['type1'][0]
mr2_prior = model_runs['type2'][0]

In [None]:
mr1_prior.plot_forecast(tsteps,
                   palette=pal1,
                   subsample_size=n_subsamples,
                 )
plt.gcf().suptitle('Prior ensemble, Type 1', size=25)

In [None]:
mr2_prior.plot_forecast(tsteps,
                   palette=pal2,
                   subsample_size=n_subsamples,
                 )
plt.gcf().suptitle('Prior ensemble, Type 2', size=25)

In [None]:
overplot_ensemble_forecasts(0,
#                             t_forecast,
                            None,
                            n_subsamples)

In [None]:
seaborn.set_style('darkgrid')
overplot_ensemble_forecasts(1,t_forecast,n_subsamples)

In [None]:
overplot_ensemble_forecasts(2,t_forecast,n_subsamples)

In [None]:
overplot_ensemble_forecasts(3,t_forecast,n_subsamples)

In [None]:
overplot_ensemble_forecasts(4,t_forecast,n_subsamples)

In [None]:
overplot_ensemble_forecasts(6,t_forecast,n_subsamples)

In [None]:
# overplot_ensemble_forecasts(8,t_forecast,n_subsamples)