# JAXOPT to fit atmospheric parameters on measured air transmission with 5 parameters

- P
- pwv
- oz
- tau (aerosols)
- beta (aerosols)


https://github.com/DifferentiableUniverseInitiative/jax-cosmo-paper/blob/master/notebooks/VectorizedNumPyro.ipynb

In [None]:
from diffatmemulator.diffatmemulator import DiffAtmEmulator
from diffatmemulator.diffatmemulator import Dict_Of_sitesAltitudes,Dict_Of_sitesPressures

In [None]:
from collections import OrderedDict

In [None]:
from instrument.instrument import Hologram

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
#from jax.scipy.special import logsumexp
import jax.scipy as jsc

from jax import grad, jit, vmap
from jax import jacfwd, jacrev, hessian

import jaxopt
import optax


jax.config.update("jax_enable_x64", True)



import matplotlib as mpl
from matplotlib import pyplot as plt

import corner
import arviz as az
mpl.rcParams['font.size'] = 15
mpl.rcParams["figure.figsize"] = [8, 8]

In [None]:
import os
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

In [None]:
def plot_params_kde(samples,hdi_probs=[0.393, 0.865, 0.989], 
                    patName=None, fname=None, pcut=None,
                   var_names=None, point_estimate="median"):
    
    if pcut is not None:
        low = pcut[0]
        up  = pcut[1] 
        #keep only data in the [low, up] percentiles ex. 0.5, 99.5
        samples={name:value[(value>np.percentile(value,low)) &  (value<np.percentile(value,up))] \
          for name, value in samples.items()}
        len_min = np.min([len(value) for name, value in samples.items()])
        len_max = np.max([len(value) for name, value in samples.items()])
        if (len_max-len_min)>0.01*len_max:
            print(f"Warning: pcut leads to min/max spls size = {len_min}/{len_max}")
        samples = {name:value[:len_min] for name, value in samples.items()}
    
    axs= az.plot_pair(
            samples,
            var_names=var_names,
            figsize=(10,10),
            kind="kde",
    #        marginal_kwargs={"plot_kwargs": {"lw": 3, "c": "b"}},
            kde_kwargs={
#                "hdi_probs": [0.68, 0.9],  # Plot 68% and 90% HDI contours
                "hdi_probs":hdi_probs,  # 1, 2 and 3 sigma contours
                "contour_kwargs":{"colors":('r', 'green', 'blue'), "linewidths":3},
                "contourf_kwargs":{"alpha":0},
            },
            point_estimate_kwargs={"lw": 3, "c": "b"},
            marginals=True, textsize=20, point_estimate=point_estimate,
        );
    
    plt.tight_layout()
    
    if patName is not None:
        patName_patch = mpatches.Patch(color='b', label=patName)
        axs[0,0].legend(handles=[patName_patch], fontsize=40, bbox_to_anchor=(1, 0.7));
    if fname is not None:
        plt.savefig(fname)
        plt.close()

# Instrument

defines the properties of the instrument measurement. The detector consists in the CCD plate of the Auxtel telescope. The Hologram class allows to retrieve the correct wavelength sampling corresponding to the measurement of a spectrum. This wavelength sampling correspond to pixel size sampling.

In [None]:
h = Hologram(rebin=1)

In [None]:
wls = h.get_wavelength_sample()

## Emulator

In [None]:
obs_str = "LSST"

emul1 =  DiffAtmEmulator(obs_str=obs_str)
emul2 =  DiffAtmEmulator(obs_str=obs_str,pressure=800.)

In [None]:
P1 = emul1.pressure
P2 = emul2.pressure

In [None]:
transm1 = emul1.vect1d_Alltransparencies(wls,am=1,pwv=4.0,oz=400.,tau=0.1,beta=-1.2)
transm2 = emul2.vect1d_Alltransparencies(wls,am=1,pwv=4.0,oz=400.,tau=0.1,beta=-1.2)

In [None]:
fig,ax = plt.subplots(1,1,figsize=(6,3))
ax.plot(wls,transm1,'b',label=f"P = {P1:.1f} hPa")
ax.plot(wls,transm2,'r',label=f"P = {P2:.1f} hPa")
ax.set_xlabel("$\lambda$ (nm)")
ax.set_ylabel("transmission")
ax.set_title("mean air transparency")
ax.legend()

In [None]:
def mean_transm(x, params,airmass):
    """ mean function of the model of transmission 
    :param x: set of wavelengths measured experimentaly (independent variable)  
    :type x: float, in unit nm
    
    :param params: dictionnary of parameters to be fitted
    :type params: float
    ...
    :raises [ErrorType]: [ErrorDescription]
    ...
    :return: return the transmission
    :rtype: float

    Assume the airmass is set to 1.0
    
    """
    pressure = params["P"]
    pwv  = params["pwv"]
    oz = params["oz"]
    tau =  params["tau"]
    beta =  params["beta"]

    airmass = airmass
    wls = x

    emul = DiffAtmEmulator(obs_str=obs_str,pressure=pressure)
    transm= emul.vect1d_Alltransparencies(wls,am=airmass,pwv=pwv,
                                          oz=oz,tau=tau,beta=beta)
    
    return transm

In [None]:
rng_key = jax.random.PRNGKey(42)
rng_key, rng_key0, rng_key1, rng_key2 = jax.random.split(rng_key, 4)

In [None]:
sigma_obs=0.01 # 10 mmag accuracy

In [None]:
airmass = 1.0
par_true=OrderedDict({"P":730.0, "pwv":4.0, "oz":400, "tau": 0.05, "beta": -1.2})

In [None]:
par_true

In [None]:
TMes = mean_transm(wls,par_true,airmass) + sigma_obs * jax.random.normal(rng_key1,shape=wls.shape)

In [None]:
fig,ax = plt.subplots(1,1,figsize=(10,3))
ax.errorbar(wls,TMes,yerr=sigma_obs,fmt="o",ms=1,color="k",ecolor="r",lw=1)
ax.set_xlabel("$\lambda$ (nm)")
ax.set_ylabel("transmission")
ax.legend()
ax.grid();

In [None]:
def lik(p,wls,T, sigma_obs=1.0,airmass=1):
    params = OrderedDict({"P":p[0], "pwv":p[1], "oz":p[2], "tau":p[3], "beta":p[4]})
    mu = mean_transm(wls,params,airmass)
    resid = mean_transm(wls, params,airmass)-T
    return 0.5*jnp.sum((resid/sigma_obs) ** 2) 

In [None]:
def get_infos(res, model, wls,T,airmass=1):
    params    = res.params
    fun_min   = model(params,wls,T,airmass)
    jacob_min =jax.jacfwd(model)(params, wls,T,airmass)
    inv_hessian_min =jax.scipy.linalg.inv(jax.hessian(model)(params, wls,T,airmass))
    return params,fun_min,jacob_min,inv_hessian_min

## Jaxopt-GradientDescent

In [None]:
gd = jaxopt.GradientDescent(fun=lik, maxiter=1000)
init_params = jnp.array([730.,4.,300.,.05,-1.])
res = gd.run(init_params,wls=wls, T=TMes)

In [None]:
params,fun_min,jacob_min,inv_hessian_min = get_infos(res, lik, wls=wls,T=TMes)
print("params:",params,"\nfun@min:",fun_min,"\njacob@min:",jacob_min,
     "\n invH@min:",inv_hessian_min)

## OptaxSolver - Adam

In [None]:
opt = optax.adam(0.1)
solver = jaxopt.OptaxSolver(opt=opt, fun=lik, maxiter=10000)
init_params = jnp.array([730.,4.,300.,.05,-1.])
res = solver.run(init_params,wls=wls, T=TMes)

In [None]:
params,fun_min,jacob_min,inv_hessian_min = get_infos(res, lik, wls=wls,T=TMes)
print("params:",params,"\nfun@min:",fun_min,"\njacob@min:",jacob_min,
     "\n invH@min:",inv_hessian_min)

## JAXOPT-scipy-Minimize

In [None]:
minimizer = jaxopt.ScipyMinimize(fun=lik,method='BFGS',options={'gtol': 1e-6,'disp': False})
init_params = jnp.array([730.,4.,300.,.05,-1.])
res1 = minimizer.run(init_params, wls=wls, T=TMes)
params,fun_min,jacob_min,inv_hessian_min = get_infos(res1, lik, wls=wls,T=TMes)
print("params:",params,"\nfun@min:",fun_min,"\njacob@min:",jacob_min,
     "\n invH@min:",inv_hessian_min)

## JAXOPT-ScipyBoundedMinimize

In [None]:
lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik, method="L-BFGS-B")
init_params = jnp.array([730.,4.,300.,.05,-1.])
res2 = lbfgsb.run(init_params, 
                  bounds=([700.,0.,0.,0.,-3.0],[800.,10.,550.,0.5,0.]), 
                 wls=wls, T=TMes)
params,fun_min,jacob_min,inv_hessian_min = get_infos(res2, lik, wls=wls, T=TMes)
print("params:",params,"\nfun@min:",fun_min,"\njacob@min:",jacob_min,
     "\n invH@min:",inv_hessian_min)

# loss-landscape

In [None]:
def plot_landscape(ax,model,xdata,ydata, 
                   par_min, idx=(0,1), 
                   bounds=(0.,1.,0.,1.), 
                   model_args=(), model_kwargs={}):
    xmin,xmax,ymin,ymax = bounds
    grid = x0,y0 = jnp.mgrid[xmin:xmax:101j,ymin:ymax:101j]

    pts = jnp.swapaxes(grid,0,-1).reshape(-1,2)
    points = jnp.repeat(par_min[None, :], pts.shape[0], axis=0)
    for i in (0,1):
        points = points.at[:,idx[i]].set(pts[:,i])
    
    vfunc = jit(vmap(lambda p,x,y: model(p,x,y,*model_args, **model_kwargs), in_axes = (0,None,None)))
    v = vfunc(points,xdata,ydata)
    v = jnp.swapaxes(v.reshape(101,101),0,-1)
    g0=ax.contourf(x0,y0,v, levels = 100)
    ax.contour(x0,y0,v, levels = 50, colors = 'w')

    grid = jnp.mgrid[xmin:xmax:10j,ymin:ymax:10j]
    pts = jnp.swapaxes(grid,0,-1).reshape(-1,2)
    points = jnp.repeat(par_min[None, :], pts.shape[0], axis=0)
    for i in (0,1):
        points = points.at[:,idx[i]].set(pts[:,i])

    gradients = jit(vmap(
        grad(
            lambda p,x,y: model(p,x,y, *model_args, **model_kwargs)
        ), in_axes = (0,None,None)
    ))(points,xdata,ydata)

    scale = int(0.2*np.max(gradients))
    ax.quiver(
        points[:,idx[0]],
        points[:,idx[1]],
        gradients[:,idx[0]],
        gradients[:,idx[1]],
        color="white",
        angles = 'xy',
        scale_units='xy',
        scale = scale
    )
    ax.set_aspect("auto")
    fig.colorbar(g0,ax=ax,shrink=0.5)

In [None]:
fig,ax = plt.subplots(1,1,figsize=(10,10))
plot_landscape(ax,model=lik,
               xdata=wls,
               ydata=TMes,
               par_min=params, 
               idx=(0,3),
               bounds=(600.,900.,0.,.1))
plt.xlabel("P")
plt.ylabel(r"$\tau$")
plt.show()

In [None]:
fig,ax = plt.subplots(1,1,figsize=(10,10))
plot_landscape(ax,model=lik,
               xdata=wls,
               ydata=TMes,
               par_min=params, 
               idx=(1,2),
               bounds=(0.,10.,0.,550.))
plt.xlabel("pwv")
plt.ylabel(r"$oz$")
plt.show()

## Prédiction/error bands: comment faire un sampling à partir du Hessien 

In [None]:
wls_val = np.linspace(300.,1000,100)

In [None]:
Ttrue_val = mean_transm(wls_val,par_true,airmass)

In [None]:
param_spls = jax.random.multivariate_normal(rng_key2,
                                            mean=params,
                                            cov=inv_hessian_min,
                                            shape=(5000,))

In [None]:
func = jax.vmap(lambda x: mean_transm(wls_val,
                                      OrderedDict({"P":x[0],
                                                   "pwv":x[1],
                                                   "oz":x[2],
                                                   "tau":x[3],
                                                   "beta":x[4]}),airmass))
                                              

In [None]:
Tall_val= func(param_spls)

In [None]:
Tmean_val = jnp.mean(Tall_val,axis=0)
std_T_val = jnp.std(Tall_val,axis=0)

In [None]:
fig=plt.figure(figsize=(10,8))
plt.errorbar(wls,TMes,yerr=sigma_obs,fmt='o', linewidth=2, capsize=0, c='k', label="data")
plt.plot(wls_val,Ttrue_val,c='k',label="true")

plt.fill_between(wls_val, Tmean_val-2*std_T_val, Tmean_val+2*std_T_val, 
                    color="lightblue",label=r"$2-\sigma$")
plt.fill_between(wls_val, Tmean_val-std_T_val, Tmean_val+std_T_val, 
                    color="lightgray",label=r"$1-\sigma$")
# plot mean prediction
plt.plot(wls_val, Tmean_val, "blue", ls="--", lw=2.0, label="mean")


plt.xlabel("$\lambda$")
plt.ylabel("Transmission")
plt.legend()
plt.grid();

In [None]:
par_min = params
par_min

In [None]:
inv_hessian_min

In [None]:
rgn_key, new_key = jax.random.split(rng_key)

In [None]:
samples = jax.random.multivariate_normal(new_key, mean=par_min, cov=inv_hessian_min, shape=(5000,)) 

In [None]:
samples = samples.T

In [None]:
samples.shape

In [None]:
fig,ax=plt.subplots(1,1,figsize=(6,4))
az.plot_posterior({"$P$":samples[0,:]},point_estimate='mean',ax=ax);

# Contours plots : Fisher forecast

$$
\Large
F_{i,j} = \sum_{t:t_{mes}} \frac{1}{\sigma^2} \frac{\partial f(p_{true},t)}{\partial p_i}\frac{\partial f(p_{true},t)}{\partial p_j}
$$
généralisation dans le cas de données avec une matrice de covariance non-diagonale.

In [None]:
def f(p):
    par = OrderedDict({"P":p[0], "pwv":p[1], "oz":p[2], "tau":p[3],"beta":p[4]})
    return mean_transm(wls,par,airmass) 

In [None]:
p_true = np.fromiter(par_true.values(), dtype=float)
p_true

In [None]:
jac = jax.jacfwd(f)(p_true)

In [None]:
jac.shape

In [None]:
cov_inv = np.zeros((jac.shape[0],jac.shape[0]))
di = np.diag_indices(cov_inv.shape[0])
cov_inv[di]=1./sigma_obs**2

In [None]:
F = jnp.einsum('ia,ij,jb',jac,cov_inv,jac)
F = 0.5*(F+F.T)

In [None]:
F.shape

In [None]:
from matplotlib.patches import Ellipse

def plot_contours(fisher, pos, inds, nstd=1., ax=None, **kwargs):
  """
  Plot 2D parameter contours given a Hessian matrix of the likelihood
  """
  
  def eigsorted(cov):
    vals, vecs = np.linalg.eigh(cov)
    order = vals.argsort()[::-1]
    return vals[order], vecs[:, order]

  mat = fisher
  cov = np.linalg.inv(mat)
  sigma_marg = lambda i: np.sqrt(cov[i, i])

  if ax is None:
      ax = plt.gca()

  # Extracts the block we are interested in
  cov = cov[inds][::,inds]
  vals, vecs = eigsorted(cov)
  theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))

  # Width and height are "full" widths, not radius
  width, height = 2 * nstd * np.sqrt(vals)
  ellip = Ellipse(xy=pos[inds], width=width,
                  height=height, angle=theta, **kwargs)

  ax.add_artist(ellip)
  sz = max(width, height)
  s1 = 1.5*nstd*sigma_marg(0)
  s2 = 1.5*nstd*sigma_marg(1)
  ax.set_xlim(pos[inds[0]] - s1, pos[inds[0]] + s1)
  ax.set_ylim(pos[inds[1]] - s2, pos[inds[1]] + s2)
  plt.draw()
  return ellip

In [None]:
npar = p_true.shape[0]
pname = ["P","pwv","oz",r"$\tau$",r"$\beta$"]
plt.figure(figsize=(20, 20))
for i in range(0,npar):
    for j in range(npar):
        if j<i:
            plt.subplot(npar,npar,i*npar+j+1)
            plt.scatter(p_true[j],p_true[i], label="true")
            plt.scatter(par_min[j],par_min[i], label="mini")
            plt.xlabel(pname[j])
            plt.ylabel(pname[i])            
            plot_contours(F, p_true, [j,i],fill=False,color='C0')
            if j==0 and i==1: plt.legend()

In [None]:
param_spls.shape

In [None]:
data = OrderedDict({"P":param_spls[:,0], "pwv":param_spls[:,1],"oz":param_spls[:,2],"tau":param_spls[:,3],"beta":param_spls[:,4]})

In [None]:
nparams =len(data)
nparams
par_names = ["P","pwv","oz","tau","beta"]

In [None]:
import arviz.labels as azl

labeller = azl.MapLabeller(var_name_map=OrderedDict({"P": r"$P$", 
                                         "pwv":r"$H_2O$",
                                         "oz":r"$O_3$",
                                         "tau":r"$\tau$",
                                         "beta":r"$\beta$" 
                                        }))


In [None]:
ax=az.plot_pair(
        data,
        kind="kde",
        labeller=labeller,
        marginal_kwargs={"plot_kwargs": {"lw":3, "c":"blue", "ls":"-"}},
        kde_kwargs={
            "hdi_probs": [0.3, 0.68, 0.9],  # Plot 30%, 68% and 90% HDI contours
            "contour_kwargs":{"colors":None, "cmap":"Blues", "linewidths":3,
                              "linestyles":"-"},
            "contourf_kwargs":{"alpha":0.5},
        },
        point_estimate_kwargs={"lw": 3, "c": "b"},
        marginals=True, textsize=50, point_estimate='median',
    );

# plot true parameter point
for idy in range(nparams):
    for idx in range(idy):
        label_x = par_names[idx]
        label_y = par_names[idy]
        ax[idy,idx].scatter(par_true[label_x],par_true[label_y],c="r",s=150,zorder=10)
        print(idx,idy,label_x,label_y,par_true[label_x] ,par_true[label_y] )
        

for idx,name in enumerate(par_names):
    ax[idx,idx].axvline(par_true[name],c='r',lw=3)
    