Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,034 changes: 1,609 additions & 425 deletions notebooks/deterministic_advi_example.ipynb

Large diffs are not rendered by default.

234 changes: 162 additions & 72 deletions pymc_extras/inference/dadvi/dadvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,88 +5,97 @@
import pytensor.tensor as pt
import xarray

from better_optimize import minimize
from better_optimize import basinhopping, minimize
from better_optimize.constants import minimize_method
from pymc import DictToArrayBijection, Model, join_nonshared_inputs
from pymc.backends.arviz import (
PointFunc,
apply_function_over_dataset,
coords_and_dims_for_inferencedata,
)
from pymc.blocking import RaveledVars
from pymc.util import RandomSeed, get_default_varnames
from pytensor.tensor.variable import TensorVariable

from pymc_extras.inference.laplace_approx.idata import (
add_data_to_inference_data,
add_optimizer_result_to_inference_data,
)
from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
from pymc_extras.inference.laplace_approx.scipy_interface import (
_compile_functions_for_scipy_optimize,
scipy_optimize_funcs_from_loss,
set_optimizer_function_defaults,
)


def fit_dadvi(
model: Model | None = None,
n_fixed_draws: int = 30,
random_seed: RandomSeed = None,
n_draws: int = 1000,
keep_untransformed: bool = False,
include_transformed: bool = False,
optimizer_method: minimize_method = "trust-ncg",
use_grad: bool = True,
use_hessp: bool = True,
use_hess: bool = False,
**minimize_kwargs,
use_grad: bool | None = None,
use_hessp: bool | None = None,
use_hess: bool | None = None,
gradient_backend: str = "pytensor",
compile_kwargs: dict | None = None,
random_seed: RandomSeed = None,
progressbar: bool = True,
**optimizer_kwargs,
) -> az.InferenceData:
"""
Does inference using deterministic ADVI (automatic differentiation
variational inference), DADVI for short.
Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.

For full details see the paper cited in the references:
https://www.jmlr.org/papers/v25/23-1015.html
For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html

Parameters
----------
model : pm.Model
The PyMC model to be fit. If None, the current model context is used.

n_fixed_draws : int
The number of fixed draws to use for the optimisation. More
draws will result in more accurate estimates, but also
increase inference time. Usually, the default of 30 is a good
tradeoff.between speed and accuracy.
The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but
also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy.

random_seed: int
The random seed to use for the fixed draws. Running the optimisation
twice with the same seed should arrive at the same result.
The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at
the same result.

n_draws: int
The number of draws to return from the variational approximation.

keep_untransformed: bool
Whether or not to keep the unconstrained variables (such as
logs of positive-constrained parameters) in the output.
include_transformed: bool
Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
output.

optimizer_method: str
Which optimization method to use. The function calls
``scipy.optimize.minimize``, so any of the methods there can
be used. The default is trust-ncg, which uses second-order
information and is generally very reliable. Other methods such
as L-BFGS-B might be faster but potentially more brittle and
may not converge exactly to the optimum.

minimize_kwargs:
Additional keyword arguments to pass to the
``scipy.optimize.minimize`` function. See the documentation of
that function for details.
Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there
can be used. The default is trust-ncg, which uses second-order information and is generally very reliable.
Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to
the optimum.

use_grad:
If True, pass the gradient function to
`scipy.optimize.minimize` (where it is referred to as `jac`).
gradient_backend: str
Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor".

use_hessp:
compile_kwargs: dict, optional
Additional keyword arguments to pass to `pytensor.function`

use_grad: bool, optional
If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).

use_hessp: bool, optional
If True, pass the hessian vector product to `scipy.optimize.minimize`.

use_hess:
If True, pass the hessian to `scipy.optimize.minimize`. Note that
this is generally not recommended since its computation can be slow
and memory-intensive if there are many parameters.
use_hess: bool, optional
If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
computation can be slow and memory-intensive if there are many parameters.

progressbar: bool
Whether or not to show a progress bar during optimization. Default is True.

optimizer_kwargs:
Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
that function for details.

Returns
-------
Expand All @@ -95,16 +104,25 @@ def fit_dadvi(

References
----------
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
Variational Inference with a Deterministic Objective: Faster, More
Accurate, and Even More Black Box. Journal of Machine Learning
Research, 25(18), 1–39.
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective:
Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
"""

model = pymc.modelcontext(model) if model is None else model
do_basinhopping = optimizer_method == "basinhopping"
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})

if do_basinhopping:
# For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
# another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
# if one isn't provided.

optimizer_method = minimizer_kwargs.pop("method", "L-BFGS-B")
minimizer_kwargs["method"] = optimizer_method

initial_point_dict = model.initial_point()
n_params = DictToArrayBijection.map(initial_point_dict).data.shape[0]
initial_point = DictToArrayBijection.map(initial_point_dict)
n_params = initial_point.data.shape[0]

var_params, objective = create_dadvi_graph(
model,
Expand All @@ -113,31 +131,65 @@ def fit_dadvi(
n_params=n_params,
)

f_fused, f_hessp = _compile_functions_for_scipy_optimize(
objective,
[var_params],
compute_grad=use_grad,
compute_hessp=use_hessp,
compute_hess=use_hess,
use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
optimizer_method, use_grad, use_hess, use_hessp
)

derivative_kwargs = {}

if use_grad:
derivative_kwargs["jac"] = True
if use_hessp:
derivative_kwargs["hessp"] = f_hessp
if use_hess:
derivative_kwargs["hess"] = True
f_fused, f_hessp = scipy_optimize_funcs_from_loss(
loss=objective,
inputs=[var_params],
initial_point_dict=None,
use_grad=use_grad,
use_hessp=use_hessp,
use_hess=use_hess,
gradient_backend=gradient_backend,
compile_kwargs=compile_kwargs,
inputs_are_flat=True,
)

result = minimize(
f_fused,
np.zeros(2 * n_params),
method=optimizer_method,
**derivative_kwargs,
**minimize_kwargs,
dadvi_initial_point = {
f"{var_name}_mu": np.zeros_like(value).ravel()
for var_name, value in initial_point_dict.items()
}
dadvi_initial_point.update(
{
f"{var_name}_sigma__log": np.zeros_like(value).ravel()
for var_name, value in initial_point_dict.items()
}
)

dadvi_initial_point = DictToArrayBijection.map(dadvi_initial_point)
args = optimizer_kwargs.pop("args", ())

if do_basinhopping:
if "args" not in minimizer_kwargs:
minimizer_kwargs["args"] = args
if "hessp" not in minimizer_kwargs:
minimizer_kwargs["hessp"] = f_hessp
if "method" not in minimizer_kwargs:
minimizer_kwargs["method"] = optimizer_method

result = basinhopping(
func=f_fused,
x0=dadvi_initial_point.data,
progressbar=progressbar,
minimizer_kwargs=minimizer_kwargs,
**optimizer_kwargs,
)

else:
result = minimize(
f=f_fused,
x0=dadvi_initial_point.data,
args=args,
method=optimizer_method,
hessp=f_hessp,
progressbar=progressbar,
**optimizer_kwargs,
)

raveled_optimized = RaveledVars(result.x, dadvi_initial_point.point_map_info)

opt_var_params = result.x
opt_means, opt_log_sds = np.split(opt_var_params, 2)

Expand All @@ -148,9 +200,29 @@ def fit_dadvi(
draws = opt_means + draws_raw * np.exp(opt_log_sds)
draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)

transformed_draws = transform_draws(draws_arviz, model, keep_untransformed=keep_untransformed)
idata = dadvi_result_to_idata(
draws_arviz, model, include_transformed=include_transformed, progressbar=progressbar
)

return transformed_draws
var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()}
var_name_to_model_var.update(
{f"{var_name}_sigma__log": var_name for var_name in initial_point_dict.keys()}
)

idata = add_optimizer_result_to_inference_data(
idata=idata,
result=result,
method=optimizer_method,
mu=raveled_optimized,
model=model,
var_name_to_model_var=var_name_to_model_var,
)

idata = add_data_to_inference_data(
idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
)

return idata


def create_dadvi_graph(
Expand Down Expand Up @@ -213,10 +285,11 @@ def create_dadvi_graph(
return var_params, objective


def transform_draws(
def dadvi_result_to_idata(
unstacked_draws: xarray.Dataset,
model: Model,
keep_untransformed: bool = False,
include_transformed: bool = False,
progressbar: bool = True,
):
"""
Transforms the unconstrained draws back into the constrained space.
Expand All @@ -232,9 +305,12 @@ def transform_draws(
n_draws: int
The number of draws to return from the variational approximation.

keep_untransformed: bool
include_transformed: bool
Whether or not to keep the unconstrained variables in the output.

progressbar: bool
Whether or not to show a progress bar during the transformation. Default is True.

Returns
-------
:class:`~arviz.InferenceData`
Expand All @@ -243,7 +319,7 @@ def transform_draws(

filtered_var_names = model.unobserved_value_vars
vars_to_sample = list(
get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
get_default_varnames(filtered_var_names, include_transformed=include_transformed)
)
fn = pytensor.function(model.value_vars, vars_to_sample)
point_func = PointFunc(fn)
Expand All @@ -256,6 +332,20 @@ def transform_draws(
output_var_names=[x.name for x in vars_to_sample],
coords=coords,
dims=dims,
progressbar=progressbar,
)

return transformed_result
constrained_names = [
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
]
all_varnames = [
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=True)
]
unconstrained_names = sorted(set(all_varnames) - set(constrained_names))

idata = az.InferenceData(posterior=transformed_result[constrained_names])

if unconstrained_names and include_transformed:
idata["unconstrained_posterior"] = transformed_result[unconstrained_names]

return idata
32 changes: 2 additions & 30 deletions pymc_extras/inference/laplace_approx/find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pymc as pm

from better_optimize import basinhopping, minimize
from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
from better_optimize.constants import minimize_method
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.initial_point import make_initial_point_fn
from pymc.model.transform.optimization import freeze_dims_and_data
Expand All @@ -24,40 +24,12 @@
from pymc_extras.inference.laplace_approx.scipy_interface import (
GradientBackend,
scipy_optimize_funcs_from_loss,
set_optimizer_function_defaults,
)

_log = logging.getLogger(__name__)


def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
method_info = MINIMIZE_MODE_KWARGS[method].copy()

if use_hess and use_hessp:
_log.warning(
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
'Setting "use_hess" to False.'
)
use_hess = False

use_grad = use_grad if use_grad is not None else method_info["uses_grad"]

if use_hessp is not None and use_hess is None:
use_hess = not use_hessp

elif use_hess is not None and use_hessp is None:
use_hessp = not use_hess

elif use_hessp is None and use_hess is None:
use_hessp = method_info["uses_hessp"]
use_hess = method_info["uses_hess"]
if use_hessp and use_hess:
# If a method could use either hess or hessp, we default to using hessp
use_hess = False

return use_grad, use_hess, use_hessp


def get_nearest_psd(A: np.ndarray) -> np.ndarray:
"""
Compute the nearest positive semi-definite matrix to a given matrix.
Expand Down
Loading
Loading