In [None]:
# Imports necessary Packages
import numpy as np
import scipy.stats as sps
import pandas as pd

# Specific Plotting Packages
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib import patches
from mpl_toolkits.mplot3d import Axes3D

# sklearn
from sklearn.preprocessing import normalize
from sklearn.mixture import BayesianGaussianMixture

# other specific packages
from scipy.integrate import quad
from scipy.special import gamma as gamma_func
from scipy.signal import find_peaks
import warnings

# Import Custom Packages
from distr_tools import mixture_dist
import supplemental_funcs as sf
import example_master as EM # sets values for example

# to show in notebook
%matplotlib inline

In [None]:
from importlib import reload

In [None]:
plt.rcParams.update({'font.size': 14})

In [None]:
# fig_save_dictionary { 'filename' : figname }
fig_all_master = {}

# Example: DCI with Nonlinear 1D to 1D

Here we use the tripeak density as the observed distribution for the nonlinear model shown below. We then work to use DCI to find the corresponding distribution in the parameter sapce.

* $Q(\lambda) = \frac{1}{4}(\lambda+1)(\lambda-4)(\lambda-7)$

* Domain of parameter space: $\Lambda = [0,10]$

* Domain of observed distribution: $\mathcal{D}=Q(\Lambda) = [-3.704,49.5]$

In [None]:
# load Qmap from example masters file
Q_map = EM.Q_nonlinear_1D_to_1D

In [None]:
# initial domain of lambda
lamx = EM.tri_peak_lamx
q_dom = Q_map(lamx)

plt.plot(lamx,q_dom)
plt.xlabel('$\lambda$')
plt.ylabel('$q$')
plt.axhline(20,color='r')

Here we show the observed:

\begin{align}
\pi_{obs}(q) \sim Tripeakdens(f_1,f_2,f_3)
\end{align}

Where the tripeak density is defined in [this notebook](Review%20of%20Density%20Estimation.ipynb).

This density is shown below.

In [None]:
# load sample, dist, and support of obs
qx = EM.tri_peak_qx
obs_dist = EM.tri_peak_mixture
q_sample = EM.tri_peak_sample
plt.plot(qx,obs_dist.pdf(qx))
plt.hist(q_sample,edgecolor='k',density=True,color='xkcd:sky blue')

Initial Distribution:

$$\pi_{init}(\lambda)\sim 10\cdot(Beta(a,b))$$

where the Beta distribution has been scaled to the domain of $\Lambda=[0,10]$.

In [None]:
# define intial distribution
init_dist = sps.beta(a=1,b=1.5,scale=10)

In [None]:
plt.plot(lamx,init_dist.pdf(lamx))

Compute the approximate predicted for a beta distribution using a GKDE.

In [None]:
predict_sample = Q_map(init_dist.rvs(5000))
predict_kde = sps.gaussian_kde(predict_sample)

In [None]:
plt.plot(q_dom,predict_kde(q_dom))
plt.plot(q_dom,obs_dist.pdf(q_dom))

In [None]:
def check_nd_shape(lam,dim=1):
    if len(lam.shape)==0:
        vals = lam.reshape(1,1)
    elif len(lam.shape)==1:
        if dim==1:
            vals = lam.reshape(lam.shape[0],1)
        else:
            vals = lam.reshape(1,lam.shape[0])
    else:
        vals = lam
    return vals

In [None]:
class dci_update:
    def __init__(self,init,pred,obs,Qmap):
        self.init_dist = init
        self.pred_dist = pred
        self.obs_dist = obs
        self.Q = Qmap
        
    def pdf(self,lam):
        # specify vals to evaluate and get Q(lam)
        lam_vals = check_nd_shape(lam)
        lam_update = check_nd_shape(np.zeros(lam_vals.shape[0]))
        q = self.Q(lam_vals)
        
        # compute pdfs
        init_vals = check_nd_shape(sf.eval_pdf(lam_vals,self.init_dist))
        pred_vals = check_nd_shape(sf.eval_pdf(q,self.pred_dist))
        obs_vals = check_nd_shape(sf.eval_pdf(q,self.obs_dist))
        
        # predictability assumption
        nonzeros = init_vals != 0
        nonzero_denom = pred_vals > 1e-6*obs_vals 
        up_ind = np.logical_and(nonzeros,nonzero_denom)
        
        # return updated values
        lam_update[up_ind] = init_vals[up_ind]*obs_vals[up_ind]/pred_vals[up_ind]
        lam_update = np.squeeze(lam_update)
        
        return lam_update
        
    

In [None]:
# compute the exact update
exact_update = dci_update(init_dist,predict_kde,obs_dist,Q_map)

In [None]:
plt.plot(lamx,init_dist.pdf(lamx),color='gray',alpha=0.7,ls='--')
plt.plot(lamx,exact_update.pdf(lamx))


### Exact Update and Observed Plot

This shows the target observed density (in the data space) and the target update distribution used to generate the data (in the parameter space).

In [None]:
fig_exact_update, (axL,axD) = plt.subplots(1,2)
fig_exact_update.set_figwidth(10)
fig_exact_update.set_figheight(5)

# parameter space
axL.plot(lamx,init_dist.pdf(lamx),ls='--',color='gray',
         alpha=0.7,label='Initial')
axL.plot(lamx,exact_update.pdf(lamx),label='Update')


# data space
axD.plot(qx,predict_kde.pdf(qx),ls='--',color='gray',label='Predicted')
axD.plot(qx,obs_dist.pdf(qx),label='Observed')

# typical labels
axL.legend()
axL.set_title('Parameter Space $\Lambda$')
axL.set_xlabel('$\lambda$')
axD.legend()
axD.set_title('Data Space $\mathcal{D}$')
axD.set_xlabel('$q$')

# savefig
this_fig_title = 'fig_exact_update.png'
fig_all_master[this_fig_title] = fig_exact_update

# save just this fig
# fig_exact_update.savefig('../'+this_fig_title)

# Update with Estimation of the Observed Distribution

As seen in the figures below, the distribution of observations is generally skewed right. This means that the standard deviation across the whole data set will generally be pulled by the observations in the right tail.

So instead of using the standard deviation of the data in Scott's rule $\hat{\sigma}n^{-1/5}$, we use half the inter-quartile-range as a smaller window size benchmark.

In [None]:
N_list = EM.tri_peak_N_list
for i,N in enumerate(N_list):
    plt.boxplot(q_sample[0:N],positions=np.array([i]))

In [None]:
# example computations done with n=250
this_q_sample = q_sample[0:N_list[3]]
print(this_q_sample.shape)

In [None]:
fig_TPD_boxplots, (ax1,ax2) = plt.subplots(1,2)
fig_TPD_boxplots.set_figwidth(8)

# histogram of the data
ax1.set_title('Histogram of Data')
ax1.hist(this_q_sample,density=True,
         edgecolor='k',color='xkcd:sky blue',alpha=0.8)
ax1.set_xlabel('$q$')

# boxplot
ax2.boxplot(this_q_sample, vert=False)
ax2.set_yticks([])
ax2.set_xlabel('$q$')
ax2.set_title('Boxplot of Data')

# savefig
this_fig_title = 'fig_TPD_data_analysis.png'
fig_all_master[this_fig_title] = fig_TPD_boxplots

# #save just this fig
# fig_TPD_boxplots.savefig('../'+this_fig_title)

## Using a GKDE

Here we determine the Data-consistent update using a GKDE for the observed density.

In [None]:
# compute approximate update using tripeak sample
IQR_dev = sps.iqr(q_sample)/2
h_special = IQR_dev*len(q_sample)**(-1/5)
kde_factor = h_special/np.std(q_sample,ddof=1) # adjust for sps automatic multiplication
obs_kde_approx = sps.gaussian_kde(q_sample)
update_kde_approx = dci_update(init_dist,predict_kde,obs_kde_approx,Q_map)

In [None]:
fig_kde_approx, (axL,axD) = plt.subplots(1,2)
fig_kde_approx.set_figwidth(10)

# parameter space
axL.plot(lamx,exact_update.pdf(lamx),label='Exact Update',linewidth=2,alpha=0.75)
axL.plot(lamx,update_kde_approx.pdf(lamx),label='Approx Update',linewidth=2)

peak_ind,prop = find_peaks(exact_update.pdf(lamx))
axL.vlines(lamx[peak_ind][2:4],ymin=0,ymax=exact_update.pdf(lamx)[peak_ind][2:4],
           color='xkcd:sky blue',ls='--',alpha=0.6)

# data space
peak_ind,prop = find_peaks(obs_dist.pdf(qx))
axD.plot(qx,obs_dist.pdf(qx),label='Exact Observed')
axD.plot(qx,obs_kde_approx.pdf(qx), label='GKDE')
axD.vlines(qx[peak_ind[0]],ymin=0,ymax=obs_dist.pdf(qx[peak_ind[0]]),ls='--',
          color='xkcd:sky blue',alpha=0.75)

# typical labels
axL.legend()
axL.set_title('Parameter Space $\Lambda$')
axL.set_xlabel('$\lambda$')
axD.legend()
axD.set_title('Data Space $\mathcal{D}$')
axD.set_xlabel('$q$')


# savefig
this_fig_title = 'fig_TPD_KDE_update.png'
fig_all_master[this_fig_title] = fig_kde_approx

# # save just this fig
# fig_kde_approx.savefig('../'+this_fig_title)

In [None]:
qint_bounds = (-5,25)
lamint_bounds = (0,10)
print('Dspace: Update Error')
print('L2: ', sf.L2_err_1D(obs_dist,obs_kde_approx,
                           qint_bounds[0],qint_bounds[1]))
print('L1: ', sf.L1_err_1D(obs_dist,obs_kde_approx,
                           qint_bounds[0],qint_bounds[1]))
print()
print('Pspace: Update Error')
print('L2: ', sf.L2_err_1D(exact_update,update_kde_approx,
                           lamint_bounds[0],lamint_bounds[1]))
print('L1: ', sf.L1_err_1D(exact_update,update_kde_approx,
                           lamint_bounds[0],lamint_bounds[1]))

In [None]:
# approximate C factor
0.016/0.009

### Plot the MISE

Here we make convergence plots of the mean-integrated squared error (MISE) for a GKDE using pre-computed boot-strapped GKDE.

In [None]:
# try loading expensive kde error computations from
# saved file
try:
    update_kde_error = np.load(EM.tri_peak_update_MISE_name+'.npz',allow_pickle=True)
    kde_error_data = np.load(EM.tri_peak_MISE_name+'.npz',allow_pickle=True)
except FileNotFoundError as err:
    print('File not found. Did you run the notebook '+
          'with expensive computations?')
    raise err
    
# check what keys are in the file
print(update_kde_error.files)
print(kde_error_data.files)
this_B = kde_error_data['hIQRD_MISE'].shape[0]

In [None]:
# compute the MISE using linear regression
MISE_lines = {}
these_kde_keys = ['updateMISE_L1', 'updateMISE_L2','hIQRD_MISE']
for label in these_kde_keys:
    if label=='hIQRD_MISE':
        # take logarithms to put on log-log-scale
        x_n = np.log(N_list*this_B) # to agree with the dimesnions of data
        y_err = np.log(kde_error_data[label].reshape(-1,))
        this_MISE_line = np.polynomial.polynomial.Polynomial.fit(x_n,y_err,1)
        MISE_lines[label] = this_MISE_line
    else:
        # take logarithms to put on log-log-scale
        x_n = np.log(N_list*this_B) # to agree with the dimesnions of data
        y_err = np.log(update_kde_error[label].reshape(-1,))
        this_MISE_line = np.polynomial.polynomial.Polynomial.fit(x_n,y_err,1)
        MISE_lines[label] = this_MISE_line


In [None]:
# Plot the KDEs and ISE
fig_update_kde_converge, (ax1,ax2) = plt.subplots(1,2,sharey=True)
fig_update_kde_converge.set_figwidth(12)

line_labels = ["L1", "L2"]

for i,label in enumerate(these_kde_keys):
    
    if label == 'hIQRD_MISE':
        this_err = kde_error_data[label]
        
        # take logarithms to put on log-log-scale
        x_n = np.log(N_list) # to agree with the dimesnions of data
        y_n = MISE_lines[label](x_n)
        y_var = np.std(this_err,ddof=1,axis=0)

        midpoint = (x_n[-1]+x_n[0])/2
        y_label = MISE_lines[label](midpoint)
        y_label += 0.5 if i ==0 else -0.5
        
        ax2.annotate("$m={1:0.2}$, $b={0:0.2}$".format(*MISE_lines[label].coef),
                     xy = (midpoint,y_label), ha='right')

        ax2.errorbar(x_n,y_n, yerr = y_var,marker='o',
                    barsabove=True,capsize=4,label=line_labels[1])
        
        
    elif label=='updateMISE_L1':
        this_err = update_kde_error[label]
        
        # take logarithms to put on log-log-scale
        x_n = np.log(N_list) # to agree with the dimesnions of data
        y_n = MISE_lines[label](x_n)
        y_var = np.std(this_err,ddof=1,axis=0)

        midpoint = (x_n[-1]+x_n[0])/2
        y_label = MISE_lines[label](midpoint)
        y_label += -0.25
        
        # plot L1 error
        ax1.annotate("$m={1:0.2}$, $b={0:0.2}$".format(*MISE_lines[label].coef),
                     xy = (midpoint,y_label), ha='right',va='top')
        ax1.errorbar(x_n,y_n, yerr = y_var,marker='o',
                    barsabove=True,capsize=4,label=line_labels[0])
        
        # plot L1 error
        ax2.annotate("$m={1:0.2}$, $b={0:0.2}$".format(*MISE_lines[label].coef),
                     xy = (midpoint,y_label), ha='right',va='top')
        ax2.errorbar(x_n,y_n, yerr = y_var,marker='o',
                    barsabove=True,capsize=4,label=line_labels[0])
        
    else:        
        this_err = update_kde_error[label]
        
        # take logarithms to put on log-log-scale
        x_n = np.log(N_list) # to agree with the dimesnions of data
        y_n = MISE_lines[label](x_n)
        y_var = np.std(this_err,ddof=1,axis=0)

        midpoint = (x_n[-1]+x_n[0])/2
        y_label = MISE_lines[label](midpoint)
        y_label += 0.5 if i ==0 else -0.5
        
        # plot L1 error
        ax1.annotate("$m={1:0.2}$, $b={0:0.2}$".format(*MISE_lines[label].coef),
                     xy = (midpoint,y_label), ha='right')
        ax1.errorbar(x_n,y_n, yerr = y_var,marker='o',
                    barsabove=True,capsize=4,label=line_labels[1])
        
# labels and stuff
ax1.legend(loc='center right')
ax2.legend()
ax1.set_title('Convergence Rate in $\Lambda$')
ax1.set_ylabel('Log Error')
ax2.set_title('Convergence Rate in $\mathcal{D}$')
ax2.set_xlabel('Log Sample Size $n$')
ax2.set_ylabel('Log Error')

# savefig
this_fig_title = 'fig_TP_kde_converge.png'
fig_all_master[this_fig_title] = fig_update_kde_converge

# # save just this fig
# fig_update_kde_converge.savefig('../'+this_fig_title)

## CI with GKDE

Here we construct a confidence interval for the GKDE in the data space. We use this to determine the CI for the update in the parameter space.

In [None]:
# try loading expensive kde error computations from
# saved file
try:
    kde_CI_data = np.load(EM.tri_peak_CI_name+'.npz',allow_pickle=True)
except FileNotFoundError as err:
    print('File not found. Did you run the notebook '+
          'with expensive computations?')
    raise err
    
# check what keys are in the file
print(kde_CI_data.files)

In [None]:
ex_dist = 3
this_h = h_special # h_IQRD
RK = 1/(2*np.sqrt(np.pi))
this_n = N_list[ex_dist]
print('N, R(K), h: ', this_n,RK,this_h) # take values from earlier

# CI value
alphaCI = 0.05
z_lower,z_upper = sps.norm.ppf(alphaCI/2),sps.norm.ppf(1-alphaCI/2)

# the  kde values for this estimate
this_kde_vals = obs_kde_approx(qx)
this_update_vals = update_kde_approx.pdf(lamx)

# compute the confidence intervals plug-in and bootstrap

pointwiseCI = {'Plugin': {},'Bootstrap': {}}
for CI in pointwiseCI:
    if CI=='Plugin':
        error_term = z_lower*np.sqrt(RK*this_kde_vals/(this_n*this_h**2))
        qy_lower = this_kde_vals - error_term
        qy_upper = this_kde_vals + error_term
        
        error_term = z_lower*np.sqrt(RK*this_update_vals/(this_n*this_h**2))
        lamy_lower = this_update_vals - error_term
        lamy_upper = this_update_vals + error_term
        
    else:
        this_q = kde_CI_data['hIQRDCI']
        deviations = np.abs(this_q - this_kde_vals)
        dev_bound = np.quantile(deviations,q=1-alphaCI,axis=0) 
        qy_lower = this_kde_vals-dev_bound
        qy_upper = this_kde_vals+dev_bound
        
        this_lam = kde_CI_data['updateCI']
        deviations = np.abs(this_lam - this_update_vals)
        dev_bound = np.quantile(deviations,q=1-alphaCI,axis=0) 
        lamy_lower = this_update_vals-dev_bound
        lamy_upper = this_update_vals+dev_bound
    
    # save the quantiles
    pointwiseCI[CI]['qy'] = (qy_lower,qy_upper)
    pointwiseCI[CI]['lamy'] = (lamy_lower,lamy_upper)


In [None]:
# Plot the KDEs and ISE
fig_update_KDE_confidence, axes = plt.subplots(2,2)
fig_update_KDE_confidence.set_figwidth(12)
fig_update_KDE_confidence.set_figheight(8)

this_x = {'qy': qx, 'lamy': lamx}
spacename = {'qy': '$\mathcal{D}$', 'lamy': '$\Lambda$'}
mid_pdf = {'qy': this_kde_vals, 'lamy': this_update_vals}
target_pdf = {'qy': obs_dist.pdf(qx), 'lamy': exact_update.pdf(lamx)}
for i,ax_row in enumerate(axes):
    label = 'Plugin' if i == 0 else 'Bootstrap'
    for j, ax in enumerate(ax_row):
        key = 'qy' if j==1 else 'lamy'
        ax.plot(this_x[key],target_pdf[key],color='xkcd:black',ls='--',alpha=0.7,
                 label='Target')
        ax.plot(this_x[key],mid_pdf[key],label='GKDE',color='C1')

        

        y_lower,y_upper = pointwiseCI[label][key]
        ax.fill_between(this_x[key],y_lower,y_upper,edgecolor='xkcd:red',
                        facecolor='xkcd:yellow orange',alpha=0.5,zorder=2,
                        label='{:0.0f}% CI'.format(100*(1-alphaCI)))

        ax.set_title('{}: {}% CI with {}'.format(spacename[key],
                                                 int(100*(1-alphaCI)),
                                                label))
        ax.legend()

# savefig
this_fig_title = 'fig_TP_update_kde_CI.png'
fig_all_master[this_fig_title] =fig_update_KDE_confidence

# # save just this fig
# fig_update_KDE_confidence.savefig('../'+this_fig_title)

## DPMM

In this section we use a Dirichlet Process Mixture Model (DPMM) to fit the tripeak density.

Note that we choose $\psi$ using the IQR for an appropriate window length (rather than using the default covariance of the observed density) because the distribution is skewed right (similar to our consideration for choosing the bandwidth $h$).


In [None]:
arg_dict = EM.tri_peak_BGMM_arg_dict
this_K = 30

arg_prior_dict_DPMM = {'n_components': this_K,
                'weight_concentration_prior_type': 'dirichlet_process',
                'weight_concentration_prior': 1,
                'mean_prior': np.atleast_1d(1),
                'mean_precision_prior': 1, # kappa
                'degrees_of_freedom_prior': 1, # nu
                'covariance_prior': np.atleast_2d(np.round(IQR_dev**2)) # psi
                   }
# print(h_special)
print(arg_prior_dict_DPMM)

In [None]:
# define the model
obs_DPMM_approx = BayesianGaussianMixture(**arg_prior_dict_DPMM,**arg_dict)

In [None]:
# fit the model
obs_DPMM_approx.fit(q_sample.reshape(-1,1))

### Show the Update

Compute the data-consistent update using the DPMM for the observed density.

In [None]:
# define the update
update_DPMM_approx = dci_update(init_dist,predict_kde,obs_DPMM_approx,Q_map)

In [None]:
fig_DPMM_approx, (axL,axD) = plt.subplots(1,2)
fig_DPMM_approx.set_figwidth(12)
fig_DPMM_approx.set_figheight(5)

# parameter space
axL.plot(lamx,exact_update.pdf(lamx),label='Exact Update',linewidth=2,alpha=0.75)
axL.plot(lamx,sf.eval_pdf(lamx,update_DPMM_approx),label='Approx Update',linewidth=2)

peak_ind,prop = find_peaks(exact_update.pdf(lamx))
axL.vlines(lamx[peak_ind][2:4],ymin=0,ymax=exact_update.pdf(lamx)[peak_ind][2:4],
           color='xkcd:sky blue',ls='--',alpha=0.6)

# data space
peak_ind,prop = find_peaks(obs_dist.pdf(qx))
axD.plot(qx,obs_dist.pdf(qx),label='Exact Observed')
axD.plot(qx,sf.eval_pdf(qx,obs_DPMM_approx), label='Exp. Posterior DPMM')
axD.vlines(qx[peak_ind[0]],ymin=0,ymax=obs_dist.pdf(qx[peak_ind[0]]),ls='--',
          color='xkcd:sky blue',alpha=0.75)

# typical labels
axL.legend()
axL.set_title('Parameter Space $\Lambda$')
axL.set_xlabel('$\lambda$')
axD.legend()
axD.set_title('Data Space $\mathcal{D}$')
axD.set_xlabel('$q$')


# savefig
this_fig_title = 'fig_TPD_DPMM_update.png'
fig_all_master[this_fig_title] = fig_DPMM_approx

# # save just this fig
# fig_DPMM_approx.savefig('../'+this_fig_title)

In [None]:
qint_bounds = (-5,25)
lamint_bounds = (0,10)
print('Dspace: Update Error')
print('L2: ', sf.L2_err_1D(obs_dist,obs_DPMM_approx,
                           qint_bounds[0],qint_bounds[1]))
print('L1: ', sf.L1_err_1D(obs_dist,obs_DPMM_approx,
                           qint_bounds[0],qint_bounds[1]))
print()
print('Pspace: Update Error')
print('L2: ', sf.L2_err_1D(exact_update,update_DPMM_approx,
                           lamint_bounds[0],lamint_bounds[1]))
print('L1: ', sf.L1_err_1D(exact_update,update_DPMM_approx,
                           lamint_bounds[0],lamint_bounds[1]))

### Credible Interval for DPMM

Here we use the credible interval for the observed distribution in the data space to construct a credible interval for the update in the data space.

In [None]:
# number of samples
this_M = 500

# define the forward DPMM
obs_DPMM_forward = sf.Forward_BGM_Model(obs_DPMM_approx)

# get some sampled pdfs
obs_DPMM_sample_params = obs_DPMM_forward.rvs(this_M)

In [None]:
obs_DPMM_sample_params['cov'].shape

In [None]:
# add the small window to the pdf list
pointwiseCI_DPMM = {'qy': {'x': qx}, 'lamy': {'x': lamx}}

for key in pointwiseCI_DPMM:
    this_x = pointwiseCI_DPMM[key]['x']
    
    if key=='qy':
        post_ys = sf.batch_GMM_pdf(this_x,obs_DPMM_sample_params)
    else:
        # compute the update for each M
        post_ys = np.zeros([this_M,this_x.shape[0]])
        for m in np.arange(this_M):
            this_mean = np.squeeze(obs_DPMM_sample_params['mean'][m])
            this_std = np.squeeze(np.sqrt(obs_DPMM_sample_params['cov'][m]))
            norms = [sps.norm(mu,sig) for mu,sig in zip(this_mean,this_std)]
            this_weight = np.squeeze(obs_DPMM_sample_params['weight'][m])
            
            # observed mixture distribution
            this_obs_mixture = sf.mixture_dist(norms,this_weight)
            
            # this dc update
            this_update = dci_update(init_dist,predict_kde,this_obs_mixture,Q_map)
            
            post_ys[m] = this_update.pdf(this_x)
        
    pointwiseCI_DPMM[key]['sampled_post_ys'] = post_ys



In [None]:
print(pointwiseCI_DPMM['qy']['sampled_post_ys'].shape)
print(pointwiseCI_DPMM['lamy']['sampled_post_ys'].shape)

In [None]:
# Plot the KDEs and ISE
fig_update_DPMM_confidence, axes = plt.subplots(1,2)
fig_update_DPMM_confidence.set_figwidth(12)

# labels and stuff
this_x = {'qy': qx, 'lamy': lamx}
spacename = {'qy': '$\mathcal{D}$', 'lamy': '$\Lambda$'}
mid_pdf = {'qy': sf.eval_pdf(qx,obs_DPMM_approx), 
           'lamy': sf.eval_pdf(lamx,update_DPMM_approx)}
target_pdf = {'qy': obs_dist.pdf(qx), 'lamy': exact_update.pdf(lamx)}
xlabels = {'qy': '$q$', 'lamy': '$\lambda$'}

for j, ax in enumerate(axes):
    key = 'qy' if j==1 else 'lamy'
    ax.plot(this_x[key],target_pdf[key],color='xkcd:black',ls='--',alpha=0.7,
             label='Target')
    ax.plot(this_x[key],mid_pdf[key],label='Exp. DPMM',color='C1')

    # credible interval
    this_ys = pointwiseCI_DPMM[key]['sampled_post_ys']
    y_lower = np.quantile(this_ys,q=alphaCI/2,axis=0)
    y_upper = np.quantile(this_ys,q=1-alphaCI/2,axis=0)
    ax.fill_between(this_x[key],y_lower,y_upper,edgecolor='xkcd:red',
                    facecolor='xkcd:yellow orange',alpha=0.5,zorder=2,
                    label='{:0.0f}% CI'.format(100*(1-alphaCI)))

    ax.set_title('{}: {}% CI with {}'.format(spacename[key],
                                             int(100*(1-alphaCI)),
                                            'DPMM'))
    ax.legend()
    ax.set_xlabel(xlabels[key])

# savefig
this_fig_title = 'fig_TP_update_DPMM_CI.png'
fig_all_master[this_fig_title] = fig_update_DPMM_confidence

# # save just this fig
# fig_update_DPMM_confidence.savefig('../'+this_fig_title)

### Components Plot

Here we plot the results from the update using a DPMM componentwise.

In [None]:
# compute each component observed
this_model = obs_DPMM_approx
these_pdf_outs = np.zeros([this_K,len(qx)])
for k in np.arange(this_K): 
    this_mean = this_model.means_[k] 
    this_cov = this_model.covariances_[k]
    this_weight = this_model.weights_[k]
    these_pdf_outs[k,:] = this_weight*sps.multivariate_normal.pdf(qx,
                                     mean = this_mean, cov = this_cov)

In [None]:
# plot the post sampled
fig_DPMM_components_dspace, axes = plt.subplots(1,2)
fig_DPMM_components_dspace.set_figwidth(10)

key = 'post_DPMM'
# plot target 
axes[0].plot(qx,sf.eval_pdf(qx,obs_dist), ls='--', label='Target',
                color='k',alpha=0.4,zorder=-1)

# plot expected posterior
qy = sf.eval_pdf(qx,obs_DPMM_approx)
axes[0].plot(qx,qy,label='Expected',zorder=0,linewidth=2)


# bar graph of the weights
mean_weights = obs_DPMM_forward.weight_dist.mean()

# order K components: keep the top 5
topX = 5
large_ind_to_small = np.flip(np.argsort(mean_weights))
print(large_ind_to_small[0:topX])
colorlist = ['xkcd:orangered','xkcd:crimson','xkcd:orchid']+['C0']*(topX-3)
weight_bars = np.arange(topX)+1

# get forward sample to compute weight variances
weight_std = np.sqrt(np.var(obs_DPMM_forward.weight_dist.rvs(500),
                    ddof=1,axis=0))
low_err = np.min(np.array([weight_std,mean_weights]),axis=0)
err_bars = np.array([low_err,weight_std])

# bar plot
axes[1].bar(weight_bars,mean_weights[large_ind_to_small[0:topX]],edgecolor=colorlist,
               linewidth=2,tick_label=large_ind_to_small[0:topX]+1,
                yerr=err_bars[:,large_ind_to_small[0:topX]],
                  error_kw={'capsize': 4}
               )

# plot the components
for k,y_pdf in enumerate(these_pdf_outs[large_ind_to_small[0:topX]]):
#     print(k,mean_weights[k])
    axes[0].plot(qx,y_pdf, color=colorlist[k],ls='-.',
                    linewidth=2)

axes[0].legend()

# set titles:
this_alpha = [1]
axes[0].set_title('$\mathcal{D}$: Posterior DPMM Comonents')
axes[1].set_title('Posterior Average Weights')

# labels
axes[0].set_xlabel('$q$')
axes[1].set_xlabel('Component # $k$')
# use ylim of first plot

# savefig
this_fig_title = 'fig_TPD_DPMM_post_component.png'
fig_all_master[this_fig_title] = fig_DPMM_components_dspace

# # save just this fig
# fig_DPMM_components_dspace.savefig('../'+this_fig_title)

In [None]:
large_ind_to_small[0:topX]

In [None]:
this_model = obs_DPMM_approx
these_updates = {}
for k in large_ind_to_small[0:topX]:
    if this_model.weights_[k]>0.01:
        this_mean = this_model.means_[k] 
        this_cov = this_model.covariances_[k]
        this_weight = this_model.weights_[k]
        this_obs = sps.multivariate_normal(mean = this_mean, cov = this_cov)
        these_updates[k] = (this_weight,
                            dci_update(init_dist,predict_kde,this_obs,Q_map))

In [None]:
for key in these_updates:
    plt.plot(lamx,these_updates[key][0]*these_updates[key][1].pdf(lamx))

In [None]:
these_updates.keys()

In [None]:
list(spacename.keys())[0]

In [None]:
fig_update_components, axes = plt.subplots(3,2)
fig_update_components.set_figwidth(8)
fig_update_components.set_figheight(8*1.25)

spacetitles = ['$\\Lambda$: Update Components',
               '$\\mathcal{D}$: Components of DPMM']
for i,ax_row in enumerate(axes):
    this_ind = large_ind_to_small[i]
    for j,ax in enumerate(ax_row):
        if j==0:
            this_x = lamx 
            # exact update
            ax.plot(this_x,sf.eval_pdf(this_x,update_DPMM_approx),alpha=0.75)
            
            # component update
            this_weight = these_updates[this_ind][0]
            this_component_pdf = this_weight*these_updates[this_ind][1].pdf(this_x) 
            ax.plot(this_x,this_component_pdf, color=colorlist[i],ls='-.',
                    linewidth=2,label='Comp. {}'.format(this_ind+1))
        
        else:
            this_x = qx
            # exact obs
            ax.plot(this_x,sf.eval_pdf(this_x,obs_dist),alpha=0.75)
            # component obs
            ax.plot(this_x,these_pdf_outs[this_ind], color=colorlist[i],ls='-.',
                    linewidth=2,label='Comp. {}'.format(this_ind+1))
        
        # labels and things
        ax.legend()
        
        if i==0:
            ax.set_title(spacetitles[j])
        elif i==2:
            xlab = '$\\lambda$' if j==0 else '$q$'
            ax.set_xlabel(xlab)

# savefig
this_fig_title = 'fig_TPD_update_component.png'
fig_all_master[this_fig_title] = fig_update_components

# # save just this fig
# fig_update_components.savefig('../'+this_fig_title)            

## Compare Convergence Plots

Here we compare the convergence of the data-consistent update using a GKDE vs. a DPMM (defined above). We load the data for the MISE for the convergence of the observed (and thus the updated) distributions.

In [None]:
# try loading expensive kde error computations from
# saved file
try:
    update_kde_error = np.load(EM.tri_peak_update_MISE_name+'.npz',allow_pickle=True)
    update_DPMM_error = np.load(EM.tri_peak_DPMM_name+'.npz',allow_pickle=True)
    kde_error_data = np.load(EM.tri_peak_MISE_name+'.npz',allow_pickle=True)
except FileNotFoundError as err:
    print('File not found. Did you run the notebook '+
          'with expensive computations?')
    raise err
    
# check what keys are in the file
print(update_kde_error.files)
print(update_DPMM_error.files)
print(kde_error_data.files)

In [None]:
# organize the data for plot
MISE_lines = {'lamx': {'L1_kde': {'err_data': update_kde_error['updateMISE_L1']},
                       'L2_kde': {'err_data': update_kde_error['updateMISE_L2']},
                       'L1_DPMM': {'err_data': update_DPMM_error['DPMM_err_L1']},
                       'L2_DPMM': {'err_data': update_DPMM_error['DPMM_up_err_L2']}
                      },
              'qx': {'L1_kde': {'err_data': update_kde_error['updateMISE_L1']},
                       'L2_kde': {'err_data': kde_error_data['hIQRD_MISE']},
                       'L1_DPMM': {'err_data': update_DPMM_error['DPMM_err_L1']},
                       'L2_DPMM': {'err_data': update_DPMM_error['DPMM_err_L2']}
                    }
             }



In [None]:
update_kde_error['updateMISE_L1'].shape

In [None]:
# compute the MISE using linear regression
for space in MISE_lines:
    # get the dictionary for this set
    this_err_set = MISE_lines[space]
    
    # compute linear regression for each
    for errtype in this_err_set:
        
        # take logarithms to put on log-log-scale
        y_n = this_err_set[errtype]['err_data']
        y_var = np.std(y_n,ddof=1,axis=0)
        
        # reshape to agree with the dimesnions of data
        this_Bn = y_n.shape[0]
        x_n = np.log(N_list*this_Bn) 
        y_err = np.log(y_n.reshape(-1,))

        # get the line and save with variance
        this_MISE_line = np.polynomial.polynomial.Polynomial.fit(x_n,y_err,1)
        MISE_lines[space][errtype]['line'] = this_MISE_line
        MISE_lines[space][errtype]['std_err'] = y_var


In [None]:
# Compare the convergence of DPMM to KDE
fig_converge_compare, axes = plt.subplots(1,2,sharey=True)
fig_converge_compare.set_figwidth(8)

colorlist = {'L1_kde': 'xkcd:sky blue', 'L2_kde': 'C0',
             'L1_DPMM': 'xkcd:tangerine','L2_DPMM': 'C1'}
markerlist = {'L1_kde': '^', 'L2_kde': 'o',
             'L1_DPMM': '^','L2_DPMM': 'o'}
labellist = {'L1_kde': '$L^1$ GKDE', 'L2_kde': '$L^2$ GKDE',
             'L1_DPMM': '$L^1$ DPMM','L2_DPMM': '$L^2$ DPMM'}
for i,space in enumerate(MISE_lines):
    # get the dictionary for this set
    this_err_set = MISE_lines[space]
    
    # plot each error type
    for errtype in this_err_set:        
        # take logarithms to put on log-log-scale
        x_n = np.log(N_list) 
        y_n = this_err_set[errtype]['line'](x_n)
        y_var = this_err_set[errtype]['std_err']

        midpoint = (x_n[-1]+x_n[0])/2
        y_label = this_err_set[errtype]['line'](midpoint)
        y_label += 0.5 if i ==0 else -0.5
        
#         axes[i].annotate("$m={1:0.2}$, $b={0:0.2}$".format(*this_err_set[errtype]['line'].coef),
#                      xy = (midpoint,y_label), ha='right')
        
        axes[i].errorbar(x_n,y_n, yerr = y_var,
                         marker=markerlist[errtype],color=colorlist[errtype],
                    barsabove=True,capsize=4,label=labellist[errtype])
        
                
    # labels and stuff
    if i ==0:
        title = 'Convergence Rate in $\Lambda$' 
    else:
        title = 'Convergence Rate in $\mathcal{D}$'
    axes[i].set_title(title)
    axes[i].set_ylabel('Log Error')
    axes[i].legend()
    axes[i].set_xlabel('Log Sample Size $n$')


# savefig
this_fig_title = 'fig_TPD_compare_converge.png'
fig_all_master[this_fig_title] = fig_converge_compare

# # save just this fig
# fig_converge_compare.savefig('../'+this_fig_title)

# Save All Figs

In [None]:
# make sure all the figures are there
for key in fig_all_master.keys():
    print(key)

In [None]:
# check individual figures
check_name = None
if check_name == None:
    print()
else:
    fig_all_master[check_name]

**[Advice for fig parameters](http://aeturrell.com/2018/01/31/publication-quality-plots-in-python/)**

The above link provides some good guidance for figure parameters.

In [None]:
# # # save all figs
# for figfilename in fig_all_master:
#     fig_all_master[figfilename].savefig('../'+figfilename,
#                                         dpi=250,bbox_inches='tight')