# Building Nifty forward model for cluster inference with stellar ages

In [1]:
######## IMPORTS ##############
import os
import jax
import matplotlib.pyplot as pl
import nifty8.re as jft

from models import SigmoidField, logStellarAges #, StarsToLuminosity

In [2]:
######### Parameters ###############
### random seed
seed = 54
key = jax.random.PRNGKey(seed)
key_initial, key_optimization, key_noise, key_noise_std, key_data_coordinates = jax.random.split(key, 5)


### paths and plotting 

name = "test"
data_path = "runs/" + name + "/"
if not os.path.exists(data_path):
    os.makedirs(data_path)

pl.rcParams.update({
    "text.usetex": True,
    "font.family": "sans-serif",
    "font.sans-serif": "Helvetica",
})


### position space parameters

n_pix = 1000
res = 1.  ## resolution

### mock data parameters

n_data = 100 ## number of data points 

noise_std_lower = 1.e-1 
noise_std_upper = 3.e-1
vol = noise_std_upper - noise_std_lower
std = jax.random.uniform(key=key_noise_std, shape=(n_data,), minval=noise_std_lower, maxval=noise_std_upper)

noise = jax.random.normal(key=key_noise, shape=(n_data,))*std 

### draw position space coordinates of stars
stellar_coordinates = jax.random.uniform(key=key_data_coordinates, shape=(n_data, ), minval=0, maxval=n_pix*res)

### hyperparameters for the correlation model

field_fluctuations = (3., .3) # per pixel fluctuations / mean and std
field_offset_mean = -8.  # offset from 0
field_offset_std = (.8, .8)  # standard deviation and standrad deviation of standard deviaition

field_slope = (-5., 1.)  # power law logslope


### other model kwargs

c_mean_amp_mean = 1.
c_mean_amp_std = 1.
c_std_amp_mean = 1.
c_std_amp_std = 1.

### inference parameters 

## General

inference_kwargs = {"n_samples": 20, "n_total_iterations": 10,}

#### Gradient descent parameters
kl_minimizer_dict = dict(
        minimize_kwargs=dict(
            name="Minimizer", xtol=1e-15, cg_kwargs=dict(name=None), maxiter=50
        )
    )
    

### Sampling parameters, used for curvature estimation     

sample_controller_dict = dict(
        cg_name="Samples",
        cg_kwargs=dict(absdelta=1e-10, maxiter=1000),
    )
  
##  non linear sampling parameters, used for GeoVI 
    
nl_sampling_minimizer_dict  =dict(
        minimize_kwargs=dict(
            name="non linear samples",
            xtol=1e-12,
            cg_kwargs=dict(name=None),
            maxiter=50,
        )
    )



In [3]:
####  Correlated field models

bg_mean = jft.CorrelatedFieldMaker("bg_mean")

bg_mean.set_amplitude_total_offset(field_offset_mean, field_offset_std)
bg_mean.add_fluctuations(n_pix, distances=res, fluctuations=field_fluctuations, loglogavgslope=field_slope, non_parametric_kind="power"
                    )
bg_mean = bg_mean.finalize()

bg_std = jft.CorrelatedFieldMaker("bg_std")

bg_std.set_amplitude_total_offset(field_offset_mean, field_offset_std)
bg_std.add_fluctuations(n_pix, distances=res, fluctuations=field_fluctuations, loglogavgslope=field_slope, non_parametric_kind="power"
                    )
bg_std = bg_std.finalize()


c_ex = jft.CorrelatedFieldMaker("c_ex")
c_ex.set_amplitude_total_offset(field_offset_mean, field_offset_std)
c_ex.add_fluctuations(n_pix, distances=res, fluctuations=field_fluctuations, loglogavgslope=field_slope, non_parametric_kind="power"
                    )
c_ex = c_ex.finalize()

###
 
c_mean = SigmoidField(c_mean_amp_mean, c_mean_amp_std, c_ex, 'c_mean')
c_std = SigmoidField(c_std_amp_mean, c_std_amp_std, c_ex, 'c_std')

####

lSA= logStellarAges(bg_mean_field=bg_mean, c_mean_field=c_mean, bg_std_field=bg_std, c_std_field=c_std, n_data=n_data, coordinates=stellar_coordinates)
model = lSA
#luminosity = StarsToLuminosity(lsa)


In [4]:
########## draw random data

true_latent_position = jft.Vector(model.init(key_initial))

truth = model(true_latent_position)

data = truth + noise

#plot1d(data_path, 'prior', grid, data, std, data_kernel, truth, true_pdf, prior_samples, qdf.cf, ppf, qdf)        

NameError: name 'luminosity' is not defined

In [None]:
########## likelihood, inital latent space position and inference


log_likelihood = jft.Gaussian(data, std**(-2)).amend(model)

position = jft.Vector(log_likelihood.init(key_initial))

### prior samples, just for plotting
prior_samples_keys = jax.random.split(key_initial, 20)
prior_samples = [jft.Vector(log_likelihood.init(k)) for k in prior_samples_keys]  

    
sl, state = jft.optimize_kl(log_likelihood, position, **inference_kwargs, key=key_optimization,  kl_kwargs=kl_minimizer_dict, draw_linear_kwargs= sample_controller_dict, nonlinearly_update_kwargs=nl_sampling_minimizer_dict,  odir=data_path)


#plot1d(data_path, 'results', grid, data, std, data_kernel, truth, true_pdf, sl, qdf.cf, ppf, qdf)


