# Prepare to try HMC methods

https://github.com/DifferentiableUniverseInitiative/jax-cosmo-paper/blob/master/notebooks/VectorizedNumPyro.ipynb

In [None]:
from diffatmemulator.diffatmemulator import DiffAtmEmulator
from diffatmemulator.diffatmemulator import Dict_Of_sitesAltitudes,Dict_Of_sitesPressures

In [None]:
from instrument.instrument import Hologram

In [None]:
import numpy as np


import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
jax.config.update("jax_enable_x64", True)

import numpyro
from numpyro import optim
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, HMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoBNAFNormal, AutoMultivariateNormal
from numpyro.infer.reparam import NeuTraReparam
from numpyro.handlers import seed, trace, condition

import matplotlib as mpl
from matplotlib import pyplot as plt

import corner
import arviz as az
mpl.rcParams['font.size'] = 20

In [None]:
import os
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

# Instrument

In [None]:
h = Hologram(rebin=2)

In [None]:
wls = h.get_wavelength_sample()

## Emulator

In [None]:
obs_str = "LSST"
emul =  DiffAtmEmulator(obs_str=obs_str)

## Definition of Forward Model

In [None]:
class ForwardModel(object):
    def __init__(emul):
        self.e = emul
    def getData():
        return 0

In [None]:
from numpyro.distributions.transforms import AffineTransform

def Uniform(name, min_value, max_value):
    """ Creates a Uniform distribution in target range from a base
    distribution between [-3, 3]
    """
    s = (max_value - min_value) / 6.
    return numpyro.sample(
            name,
            dist.TransformedDistribution(
                dist.Uniform(-3., 3.),
                AffineTransform(min_value + 3.*s, s),
            ),
        )


In [None]:
# Let's define our model using numpyro
# Papier https://arxiv.org/pdf/1708.01530.pdf  and desy1.py
def model():
    #  atmospheric  params
    pressure = numpyro.sample('pressure', dist.Uniform(700., 780.))
    pwv = numpyro.sample('pwv', dist.Uniform(0., 10.0))
    oz = numpyro.sample('oz', dist.Uniform(0., 600.))
    tau = numpyro.sample('tau', dist.Uniform(0., 0.3))
    beta = numpyro.sample('beta', dist.Uniform(-3., 0.1)) 
    A = numpyro.sample('A', dist.Uniform(0, 1.0))

     
    # Now that params are defined, here is the forward model
    cosmo = jc.Cosmology(Omega_c=Omega_c, sigma8=sigma8, Omega_b=Omega_b,
                          h=h, n_s=n_s, w0=w0, Omega_k=0., wa=0.)
    
    # Build source nz with redshift systematic bias
    nzs_s_sys = [jc.redshift.systematic_shift(nzi, dzi, zmax=2.0) 
                for nzi, dzi in zip(nzs_s, dz)]
    
    # Define IA model, z0 is fixed
    b_ia = jc.bias.des_y1_ia_bias(A, eta, 0.62)

    # Bias for the lenses
    b = [jc.bias.constant_linear_bias(bi) for bi in bias] 
    
    # Define the lensing and number counts probe
    probes = [jc.probes.WeakLensing(nzs_s_sys, 
                                    ia_bias=b_ia,
                                    multiplicative_bias=m),
             jc.probes.NumberCounts(nzs_l, b)]

    cl, C = jc.angular_cl.gaussian_cl_covariance_and_mean(cosmo, ell, probes, 
                                                          f_sky=0.25, sparse=True)
    
    P = jc.sparse.to_dense(jc.sparse.inv(C))
    C = jc.sparse.to_dense(C)
    return cl, P, C


In [None]:
from numpyro.handlers import seed, trace, condition
# So, let's generate the data at the fiducial parameters
fiducial_model = condition(model,
    {'pressure':730., 'pwv':5.0, 'oz':300., 'tau':0., 'beta':-1.,'A':0.})
     

with seed(rng_seed=42):
    data, P, C = fiducial_model()
