Skip to content

Commit

Permalink
Require PRNG key, update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
tbmiller-astro committed Feb 6, 2024
1 parent 4f7a8c4 commit dca6629
Show file tree
Hide file tree
Showing 11 changed files with 1,756 additions and 1,829 deletions.
File renamed without changes.
4 changes: 2 additions & 2 deletions docs/source/rendering.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ An increasingly common method for rendering profiles is to instead render the pr

An immediate problem is that the Sersic profile has a pretty nasty Fourier transform. Instead, following `Hogg & Lang (2013) <https://arxiv.org/abs/1210.6563>`_ we implement a method to model a Sersic profile using a series of Gaussians. Gaussians are much better-behaved numerically, and the Fourier transform is simply a Gaussian. We implement a recent innovation presented in `Shajib (2019) <https://arxiv.org/abs/1906.08263>`_ which presents an analytic method for deriving a mixture of Gaussian approximation to any given profile. We base our implementation off of the implementation of this algorithm in ``lenstronomy``.

However, we noticed that this process leads to some to several numerical instabilities, specifically when tracing the gradient through the Shajib et al. formalism. This causes issues with convergence during inference. This may be due to ``jax``'s default 32 bit implementation. Adjusting the precision value or enabling the 64 bit implementation appears to help alleviate these issues, but these solutions have not been fully vetted. We noticed, however, that the amplitudes of each Gaussian component do vary smoothly with Sersic index, and can thus be easily re-scaled to any flux and radius. Now, ``pySersic`` pre-calculates a set of amplitudes at a grid of Sersic indices and then fits a polynomial as a function of log(n). Empirically this leads to a better fit than as a function of n. These polynomials are then used to decompose a profile for a given n during inference ensuring smooth gradients at only a minor cost in accuracy (less than roughly 1\%). This method is enabled by default but can easily be turned off, going back to the direct algorithm, using the argument ``use_poly_fit_amps``.
However, we noticed that this process leads to some to several numerical instabilities, specifically when tracing the gradient through the Shajib et al. formalism. This causes issues with convergence during inference. This may be due to ``jax``'s default 32 bit implementation. Adjusting the precision value or enabling the 64 bit implementation appears to help alleviate these issues, but these solutions have not been fully vetted. We noticed, however, that the amplitudes of each Gaussian component do vary smoothly with Sersic index, and can thus be easily re-scaled to any flux and radius. Now, ``pySersic`` pre-calculates a set of amplitudes at a grid of Sersic indices and then interpolates between these values when rendering sources.This interpolation is then used to decompose a profile for a given n during inference ensuring smooth gradients. This method is enabled by default but can easily be turned off, going back to the direct algorithm, using the argument ``use_interp_amps``.

Both ``FourierRenderer`` and ``HybridRenderer`` use this Gaussian decomposition and render (at least some) of these components in Fourier space. Both have additional optional arguments:

* ``frac_start`` - fraction of the effective radius for the width of the smallest Gaussian component
* ``frac_end`` - fraction of the effective radius for the width of the largest Gaussian component
* ``n_sigma`` - number of Gaussian components to use
* ``precision`` - Precision value to use in the decomposition described in `Shajib (2019) <https://arxiv.org/abs/1906.08263>`_
* ``use_poly_fit_amps`` - Whether to use a polynomial fit to the amplitudes as a function of n, see above for details.
* ``use_intepr_amps`` - Whether to use a computed grid to interpolate gaussian amplitudes as a function of n, see above for details.

``FourierRenderer``, as the name implies, renders sources solely in Fourier space. However this can lead to some artifacts, specifically, aliasing if the source is near the edge. This is because the inverse FFT assumes the image is periodic so part of the source that should lie outside the image appears opposite. To help combat this we also implement a version of the hybrid real-Fourier algorithm described in `Lang (2020) <https://arxiv.org/abs/2012.15797>`_ in ``HybridRenderer``. The innovation is to render some of the largest Gaussian components in real space to help avoid the aliasing while maintaining the benefits of rendering in Fourier space. This has one additional argument beyond those described above:

Expand Down
759 changes: 351 additions & 408 deletions examples/example-fit.ipynb

Large diffs are not rendered by default.

Binary file modified examples/example_fit.asdf
Binary file not shown.
35 changes: 18 additions & 17 deletions examples/manual-priors.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand All @@ -349,7 +349,7 @@
"True"
]
},
"execution_count": 11,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -376,7 +376,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand All @@ -400,9 +400,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 2%|▏ | 390/20000 [00:00<00:29, 661.06it/s, Round = 0,step_size = 5.0e-02 loss: -1.537e+04]\n",
" 1%|▏ | 253/20000 [00:00<00:29, 675.30it/s, Round = 1,step_size = 5.0e-03 loss: -1.537e+04]\n",
" 1%|▏ | 254/20000 [00:00<00:30, 657.78it/s, Round = 2,step_size = 5.0e-04 loss: -1.537e+04]\n"
" 2%|▏ | 446/20000 [00:00<00:41, 469.47it/s, Round = 0,step_size = 5.0e-02 loss: -1.537e+04]\n",
" 3%|▎ | 529/20000 [00:01<00:40, 475.77it/s, Round = 1,step_size = 5.0e-03 loss: -1.537e+04]\n",
" 1%|▏ | 280/20000 [00:00<00:44, 440.61it/s, Round = 2,step_size = 5.0e-04 loss: -1.537e+04]\n"
]
},
{
Expand All @@ -411,7 +411,7 @@
"text": [
"\n",
"---- \n",
"MAP parameters - flux = 3.757e+03, r_eff = 6.257, n = 3.641\n",
"MAP parameters - flux = 3.757e+03, r_eff = 6.257, n = 3.642\n",
"---- \n",
"\n",
"Prior for a sersic source:\n",
Expand All @@ -431,9 +431,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 2%|▏ | 409/20000 [00:00<00:29, 662.81it/s, Round = 0,step_size = 5.0e-02 loss: -1.537e+04]\n",
" 1%|▏ | 286/20000 [00:00<00:29, 668.22it/s, Round = 1,step_size = 5.0e-03 loss: -1.537e+04]\n",
" 2%|▏ | 304/20000 [00:00<00:30, 653.99it/s, Round = 2,step_size = 5.0e-04 loss: -1.537e+04]\n"
" 2%|▏ | 422/20000 [00:00<00:42, 464.19it/s, Round = 0,step_size = 5.0e-02 loss: -1.537e+04]\n",
" 2%|▏ | 432/20000 [00:00<00:41, 468.94it/s, Round = 1,step_size = 5.0e-03 loss: -1.537e+04]\n",
" 1%|▏ | 252/20000 [00:00<00:41, 478.65it/s, Round = 2,step_size = 5.0e-04 loss: -1.537e+04]\n"
]
},
{
Expand Down Expand Up @@ -463,9 +463,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 3%|▎ | 528/20000 [00:00<00:29, 658.28it/s, Round = 0,step_size = 5.0e-02 loss: -1.544e+04]\n",
" 1%|▏ | 257/20000 [00:00<00:30, 642.45it/s, Round = 1,step_size = 5.0e-03 loss: -1.544e+04]\n",
" 1%|▏ | 251/20000 [00:00<00:29, 661.67it/s, Round = 2,step_size = 5.0e-04 loss: -1.544e+04]"
" 2%|▏ | 406/20000 [00:00<00:41, 473.05it/s, Round = 0,step_size = 5.0e-02 loss: -1.544e+04]\n",
" 2%|▏ | 487/20000 [00:01<00:41, 471.25it/s, Round = 1,step_size = 5.0e-03 loss: -1.544e+04]\n",
" 1%|▏ | 291/20000 [00:00<00:43, 456.33it/s, Round = 2,step_size = 5.0e-04 loss: -1.544e+04]"
]
},
{
Expand All @@ -490,14 +490,15 @@
"source": [
"from pysersic import FitSingle\n",
"from pysersic.loss import student_t_loss\n",
"from jax.random import PRNGKey, split\n",
"rkeys = split(PRNGKey(42),3) # Generate 3 keys\n",
"\n",
"\n",
"for prior in [prior_default, prior_w_flux, custom_prior]:\n",
"for rkey,prior in zip(rkeys,[prior_default, prior_w_flux, custom_prior]):\n",
" fitter = FitSingle(data=im,rms=sig,mask=mask,psf=psf,prior=prior,loss_func=student_t_loss)\n",
" \n",
" print(prior)\n",
"\n",
" res = fitter.find_MAP()\n",
" res = fitter.find_MAP(rkey = rkey)\n",
" print (\"\\n---- \")\n",
" print (f\"MAP parameters - flux = {res['flux']:.3e}, r_eff = {res['r_eff']:.3f}, n = {res['n']:.3f}\" )\n",
" print (\"---- \\n\")\n"
Expand All @@ -521,7 +522,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand Down
2,727 changes: 1,359 additions & 1,368 deletions examples/multi-source-fitting.ipynb

Large diffs are not rendered by default.

12 changes: 1 addition & 11 deletions examples/multiband-example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,6 @@
"In this example we will showcase the ability of pysersic to jointly model multiple bands while 'linking' morphological parameters across wavelength using a smooth function. For this we will use an example galaxy from the DeCALS survey observed in g,r and z bands."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
Expand Down Expand Up @@ -319,7 +309,7 @@
" wv_to_save= wv_to_save,\n",
" poly_order = 2)\n",
"rkey,_ = jax.random.split(rkey,2)\n",
"multires = MultiFitter.estimate_posterior('svi-flow', rkey = rkey)"
"multires = MultiFitter.estimate_posterior(method = 'svi-flow', rkey = rkey)"
]
},
{
Expand Down
20 changes: 11 additions & 9 deletions pysersic/pysersic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
from .priors import PySersicMultiPrior, PySersicSourcePrior, base_profile_params
from .rendering import BaseRenderer, HybridRenderer, base_profile_types
from .results import PySersicResults
from jax.typing import ArrayLike

from .loss import gaussian_loss

ArrayLike = Union[np.array, jax.numpy.array]

def identity(x : Callable)-> Callable:
return x

class BaseFitter(ABC):
"""
Expand Down Expand Up @@ -96,15 +98,15 @@ def set_prior(self,parameter: str,


def sample(self,
rkey : jax.random.PRNGKey,
num_samples: int = 1000,
num_warmup: int = 1000,
num_chains: int = 2,
init_strategy: Optional[Callable] = infer.init_to_sample,
sampler_kwargs: Optional[dict] ={},
mcmc_kwargs: Optional[dict] = {},
return_model: Optional[bool] = True,
reparam_func: Optional[Callable] = lambda x: x,
rkey: Optional[jax.random.PRNGKey] = jax.random.PRNGKey(3)
reparam_func: Optional[Callable] = identity,
) -> pandas.DataFrame:
""" Perform inference using a NUTS sampler
Expand Down Expand Up @@ -147,14 +149,14 @@ def sample(self,
def _train_SVI(self,
autoguide: numpyro.infer.autoguide.AutoContinuous,
method:str,
ELBO_loss: Optional[Callable] = infer.Trace_ELBO(1),
rkey: jax.random.PRNGKey,
ELBO_loss: Optional[Callable] = infer.Trace_ELBO(5),
lr_init: Optional[int] = 1e-2,
num_round: Optional[int] = 3,
SVI_kwargs: Optional[dict]= {},
train_kwargs: Optional[dict] = {},
return_model: Optional[bool] = True,
num_sample: Optional[int] = 1_000,
rkey: Optional[jax.random.PRNGKey] = jax.random.PRNGKey(6),
)-> pandas.DataFrame:
"""
Internal function to perform inference using stochastic variational inference.
Expand Down Expand Up @@ -204,8 +206,8 @@ def _train_SVI(self,


def find_MAP(self,
rkey: jax.random.PRNGKey,
return_model: Optional[bool] = True,
rkey: Optional[jax.random.PRNGKey] = jax.random.PRNGKey(3),
purge_extra: Optional[bool] = True):
"""Find the "best-fit" parameters as the maximum a-posteriori and return a dictionary with values for the parameters.
Expand Down Expand Up @@ -250,10 +252,10 @@ def find_MAP(self,
return real_out

def estimate_posterior(self,
rkey : jax.random.PRNGKey,
method : str='laplace',
return_model: bool = True,
num_sample: Optional[int] = 1_000,
rkey: Optional[jax.random.PRNGKey] = jax.random.PRNGKey(6),
num_sample: Optional[int] = 1_000
) -> pandas.DataFrame:
"""Estimate the posterior using a method other than MCMC sampling. Generally faster than MCMC, but could be less accurate.
Current Options are:
Expand Down Expand Up @@ -284,7 +286,7 @@ def estimate_posterior(self,
results = self._train_SVI(guide_func,method=method, lr_init = 0.05, train_kwargs=train_kwargs, return_model = return_model, rkey=rkey, num_sample=num_sample)
elif method=='svi-flow':
train_kwargs = dict(patience = 500, max_train = 20000)
guide_func = partial(infer.autoguide.AutoBNAFNormal, num_flows =3,hidden_factors = [5,5], init_loc_fn = infer.init_to_median)
guide_func = partial(infer.autoguide.AutoBNAFNormal, num_flows =1,hidden_factors = [16,8], init_loc_fn = infer.init_to_median)
results = self._train_SVI(guide_func,method='svi-flow',ELBO_loss= infer.Trace_ELBO(8),train_kwargs=train_kwargs,num_round=3,lr_init = 5e-2, rkey=rkey,return_model = return_model,num_sample=num_sample)
elif method=='svi-mvn':
train_kwargs = dict(patience = 200, max_train = 5000)
Expand Down
22 changes: 11 additions & 11 deletions pysersic/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ class FourierRenderer(BaseRenderer):
frac_end : float = eqx.field(static=True)
n_sigma : int = eqx.field(static=True)
precision : int = eqx.field(static=True)
use_poly_fit_amps : bool = eqx.field(static=True)
use_interp_amps : bool = eqx.field(static=True)

etas : jax.numpy.array
betas : jax.numpy.array
Expand All @@ -362,7 +362,7 @@ def __init__(self,
frac_end: Optional[float] = 15.,
n_sigma: Optional[int] = 15,
precision: Optional[int] = 10,
use_poly_fit_amps: Optional[bool] = True)-> None:
use_interp_amps: Optional[bool] = True)-> None:
"""Initialize a Fourier renderer class
Parameters
Expand All @@ -379,25 +379,25 @@ def __init__(self,
Number of Gaussian Components, by default 15
precision : Optional[int], optional
precision value used in calculating Gaussian components, see Shajib (2019) for more details, by default 10
use_poly_fit_amps: Optional[bool]
If True, instead of performing the direct calculation in Shajib (2019) at each iteration, a polynomial approximation is fit and used. The amplitudes of each gaussian component amplitudes as a function of Sersic index are fit with a polynomial. This smooth approximation is then used at each interval. While this adds a a little extra to the renderering error budget (roughly 1\%) but is much more numerically stable owing to the smooth gradients. If this matters for you then set this to False and make sure to enable jax's 64 bit capabilities which we find helps the stability.
use_interp_amps: Optional[bool]
If True, instead of performing the direct calculation in Shajib (2019) at each iteration, a polynomial approximation is fit and used. The amplitudes of each gaussian component amplitudes as a function of Sersic index are interpolated based on a computed grid. This is much more numerically stable owing to the smooth gradients. If this matters for you then set this to False and make sure to enable jax's 64 bit capabilities which we find helps the stability.
"""
super().__init__(im_shape, pixel_PSF)
self.frac_start = frac_start
self.frac_end = frac_end
self.n_sigma = n_sigma
self.precision = precision
self.use_poly_fit_amps = use_poly_fit_amps
self.use_interp_amps = use_interp_amps
self.etas, self.betas = calculate_etas_betas(self.precision)
if not use_poly_fit_amps and not jax.config.x64_enabled:
if not use_interp_amps and not jax.config.x64_enabled:
warnings.warn("!! WARNING !! - Gaussian decomposition can be numerically unstable when using jax's default 32 bit. Please either enable jax 64 bit or set 'use_poly_amps' = True in the renderer kwargs")

#Fit polynomial for smooth interpolation
self.n_ax = jnp.linspace(.65,8., num = 50)
self.amps_n_ax = jax.vmap( lambda n: sersic_gauss_decomp(1.,1.,n,self.etas,self.betas,self.frac_start,self.frac_end,self.n_sigma)[0] ) (self.n_ax)

def get_amps_sigmas(self, flux, r_eff,n):
if self.use_poly_fit_amps:
if self.use_interp_amps:
amps_norm = interp1d(n,self.n_ax, self.amps_n_ax, method='cubic2')
amps = amps_norm*flux
sigmas = jnp.logspace(jnp.log10(r_eff*self.frac_start),jnp.log10(r_eff*self.frac_end),num = self.n_sigma)
Expand Down Expand Up @@ -481,7 +481,7 @@ def __init__(self,
n_sigma: Optional[int] = 15,
num_pixel_render: Optional[int] = 3,
precision: Optional[int] = 10,
use_poly_fit_amps: Optional[bool] = True)-> None:
use_interp_amps: Optional[bool] = True)-> None:
"""Initialize a HybridRenderer class
Parameters
Expand All @@ -500,10 +500,10 @@ def __init__(self,
Number of components to render in pixel space, counts back from largest component
precision : Optional[int], optional
precision value used in calculating Gaussian components, see Shajib (2019) for more details, by default 10
use_poly_fit_amps: Optional[bool]
If True, instead of performing the direct calculation in Shajib (2019) at each iteration, a polynomial approximation is fit and used. The amplitudes of each gaussian component amplitudes as a function of Sersic index are fit with a polynomial. This smooth approximation is then used at each interval. While this adds a a little extra to the renderering error budget (roughly 1\%) but is much more numerically stable owing to the smooth gradients. If this matters for you then set this to False and make sure to enable jax's 64 bit capabilities which we find helps the stability.
use_interp_amps: Optional[bool]
If True, instead of performing the direct calculation in Shajib (2019) at each iteration, a polynomial approximation is fit and used. The amplitudes of each gaussian component amplitudes as a function of Sersic index are interpolated based on a computed grid. This is much more numerically stable owing to the smooth gradients. If this matters for you then set this to False and make sure to enable jax's 64 bit capabilities which we find helps the stability.
"""
super().__init__(im_shape, pixel_PSF, frac_start,frac_end,n_sigma,precision, use_poly_fit_amps)
super().__init__(im_shape, pixel_PSF, frac_start,frac_end,n_sigma,precision, use_interp_amps)

self.num_pixel_render = num_pixel_render
self.w_real = jnp.arange(self.n_sigma - self.num_pixel_render, self.n_sigma, dtype=jnp.int32)
Expand Down
2 changes: 1 addition & 1 deletion pysersic/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def save_result(self,fname:str):
tree['input_data']['psf'] = np.array(self.psf)
tree['input_data']['mask'] = np.array(self.mask)
tree['loss_func'] = str(self.loss_func)
tree['renderer'] = str(self.renderer)
tree['rendere_type'] = str( type(self.renderer) )
tree['method_used'] = self.runtype
if self.runtype == 'svi':
tree['svi_method_used'] = self.svi_method_used
Expand Down
Loading

0 comments on commit dca6629

Please sign in to comment.