Skip to content

Commit

Permalink
Allow registering of custom exception handlers for potential_fn compu…
Browse files Browse the repository at this point in the history
…tations (#3168)

* [RFC] Allow registering excpetion handlers for potential function computations

In some cases evaluation the potential funciton may result in numerical issues. Currently the code hard-codes the handling of a RuntimeError raised when matrices are (numerically) singular. This PR adds the ability to register custom exception handlers. This allows other code depending on pyro to register custom exception handlers without having to modify core pyro code.

There are some other places in which `potential_fn` is called that could benefit from being guarded by these handlers (one is `HMC._find_reasonable_step_size`). I'm not sure what the right thing to do there is when encountering numerical isssues, but happy to add this in as needed.

* Fix typing lint, typos.

* Warn instead of raise, fix typing import.

* Make isort happy (hopefully)

* Check for instance rather than type equality in _handle_torch_singular

* Handle numerical issues also in HMC._find_reasonable_step_size

* isort once more

* Fix black format
  • Loading branch information
Balandat committed Jan 3, 2023
1 parent e082c09 commit 1ec2c39
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 6 deletions.
13 changes: 11 additions & 2 deletions pyro/infer/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pyro.infer.mcmc.adaptation import WarmupAdapter
from pyro.infer.mcmc.mcmc_kernel import MCMCKernel
from pyro.infer.mcmc.util import initialize_model
from pyro.ops.integrator import potential_grad, velocity_verlet
from pyro.ops.integrator import _EXCEPTION_HANDLERS, potential_grad, velocity_verlet
from pyro.util import optional, torch_isnan


Expand Down Expand Up @@ -173,7 +173,16 @@ def _find_reasonable_step_size(self, z):
# We are going to find a step_size which make accept_prob (Metropolis correction)
# near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small,
# then we have to decrease step_size; otherwise, increase step_size.
potential_energy = self.potential_fn(z)
try:
potential_energy = self.potential_fn(z)
# handle exceptions as defined in the exception registry
except Exception as e:
if any(h(e) for h in _EXCEPTION_HANDLERS.values()):
# skip finding reasonable step size
return step_size
else:
raise e

r, r_unscaled = self._sample_r(name="r_presample_0")
energy_current = self._kinetic_energy(r_unscaled) + potential_energy
# This is required so as to avoid issues with autograd when model
Expand Down
49 changes: 45 additions & 4 deletions pyro/ops/integrator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Callable, Dict

from torch.autograd import grad

# Registry for exception handlers that can be used to catch certain failures
# during computation of `potential_fn` within `potential_grad`.
_EXCEPTION_HANDLERS: Dict[str, Callable[[Exception], bool]] = {}


def velocity_verlet(
z, r, potential_fn, kinetic_grad, step_size, num_steps=1, z_grads=None
Expand Down Expand Up @@ -74,15 +81,49 @@ def potential_grad(potential_fn, z):
node.requires_grad_(True)
try:
potential_energy = potential_fn(z)
# deal with singular matrices
except RuntimeError as e:
if "singular" in str(e) or "input is not positive-definite" in str(e):
# handle exceptions as defined in the exception registry
except Exception as e:
if any(h(e) for h in _EXCEPTION_HANDLERS.values()):
grads = {k: v.new_zeros(v.shape) for k, v in z.items()}
return grads, z_nodes[0].new_tensor(float("nan"))
else:
raise e

grads = grad(potential_energy, z_nodes)
for node in z_nodes:
node.requires_grad_(False)
return dict(zip(z_keys, grads)), potential_energy.detach()


def register_exception_handler(
name: str, handler: Callable[[Exception], bool], warn_on_overwrite: bool = True
) -> None:
"""
Register an exception handler for handling (primarily numerical) errors
when evaluating the potential function.
:param name: name of the handler (must be unique).
:param handler: A callable mapping an Exception to a boolean. Exceptions
that evaluate to true in any of the handlers are handled in the computation
of the potential energy.
:param warn_on_overwrite: If True, warns when overwriting a handler already
registered under the provided name.
"""
if name in _EXCEPTION_HANDLERS and warn_on_overwrite:
warnings.warn(
f"Overwriting Exception handler already registered under key {name}.",
RuntimeWarning,
)
_EXCEPTION_HANDLERS[name] = handler


def _handle_torch_singular(exception: Exception) -> bool:
"""Exception handler for errors thrown on (numerically) singular matrices."""
# the actual type of the exception thrown is torch._C._LinAlgError
if isinstance(exception, RuntimeError):
msg = str(exception)
return "singular" in msg or "input is not positive-definite" in msg
return False


# Register default exception handler
register_exception_handler("torch_singular", _handle_torch_singular)

0 comments on commit 1ec2c39

Please sign in to comment.