In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
from datetime import date, datetime, timedelta, time
import pandas as pd
import seaborn
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import IPython.html.widgets as widgets

In [None]:
import scipy
from scipy.stats import norm
import scipy.optimize as op
import emcee
import triangle

In [None]:
seaborn.set_context('talk')

In [None]:
import simtransient as st
import simtransient.utils as stutils
from simtransient import hammer
from simtransient.models.supernova import Sn1aOpticalEnsemble, Sn2OpticalEnsemble

In [None]:
models = pd.DataFrame(index=('ensemble','sampler','chainstats','trimmed'),
                      columns=('sn1','sn2'),
                        dtype=np.object
                       )

models.sn1.ensemble= Sn1aOpticalEnsemble()
models.sn2.ensemble= Sn2OpticalEnsemble()

In [None]:
models.sn1.ensemble.gauss_pars

In [None]:
def sample_gaussian_prior(ensemble, nwalkers, nsteps, nthreads):
    sampler, init_ball = hammer.prep_ensemble_sampler(ensemble.gauss_pars.T.mu.values, 
                                       ensemble.gauss_lnprior, 
                                       args=[],
                                       nwalkers=100,
                                       ballsize=1e-3,
                                       threads=4)
    _ = sampler.run_mcmc(init_ball, N=300)
    cs, trimmed = st.hammer.trim_chain(sampler, pt=False)
    return dict(sampler=sampler, 
                trimmed=trimmed, 
                chainstats=cs)

In [None]:
mc_runs = []
nwalkers = 100
nsteps=300
nthreads=4

for modelname in models.keys():
    model = models[modelname]
    results = sample_gaussian_prior(model.ensemble, 
                              nwalkers,nsteps,nthreads
                              )
    
    model.sampler = results['sampler']
    model.trimmed = results['trimmed']
    model.chainstats = results['chainstats']

In [None]:
model = models.sn2
st.plot.chain.all_walkers(model.sampler.chain, model.chainstats, model.ensemble.gauss_pars.keys())

In [None]:
# _=triangle.corner(models.sn2.trimmed,
#                  labels=models.sn2.ensemble.gauss_pars.keys(),
#                  quantiles=[0.05, 0.5, 0.95],
#                  truths=models.sn2.ensemble.gauss_pars.T.mu)

In [None]:
# seaborn.palplot(seaborn.color_palette('colorblind'))

In [None]:
palette = seaborn.color_palette('Paired')
models.loc['color']=palette[1],palette[-1]
models.loc['plotthis']=True
# seaborn.palplot(models.T.color)
# models

In [None]:
def plt_traces(ntrace):
#     ntrace=int(ntrace)
    for modelname, model in models.iteritems():
        if model.plotthis:
            choice_idx=np.random.choice(len(model.trimmed), size=ntrace, replace=False)
            subsamples = model.trimmed[choice_idx]

            t=np.linspace(-30, 80,1000)
            lcs=[]
            for pltpars in subsamples:
                lcs.append(model.ensemble.evaluate(t,*pltpars,t0=0))
            lcs=np.array(lcs)
        #     seaborn.tsplot( lcs,  t, err_style="unit_traces", ls='', color='y')
            seaborn.tsplot( lcs,  t, err_style="unit_traces", ls='', 
                           value='flux',
                           color=model.color,
                           )

In [None]:
models.loc['plotthis']=True
# models.sn1.plotthis=False
# models.sn2.plotthis=False

# widgets.interact(
widgets.interact_manual(
    plt_traces, ntrace=(0,150),
        )