Skip to content

Commit

Permalink
[RFC] Allow registering excpetion handlers for potential function com…
Browse files Browse the repository at this point in the history
…putations

In some cases evaluation the potential funciton may result in numerical issues. Currently the code hard-codes the handling of a RuntimeError raised when matrices are (numerically) singular. This PR adds the ability to register custom exception handlers. This allows other code depending on pyro to register custom exception handlers without having to modify core pyro code.

There are some other places in which `potential_fn` is called that could benefit from being guarded by these handlers (one is `HMC._find_reasonable_step_size`). I'm not sure what the right thing to do there is when encountering numerical isssues, but happy to add this in as needed.
  • Loading branch information
Balandat committed Dec 29, 2022
1 parent 3422c3a commit fb580d1
Showing 1 changed file with 44 additions and 4 deletions.
48 changes: 44 additions & 4 deletions pyro/ops/integrator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Callable

from torch.autograd import grad

# Registry for exception handlers that can be used to catch certain failures
# during computation of `potential_fn` within `potential_grad`.
_EXCEPTION_HANDLERS = {}


def velocity_verlet(
z, r, potential_fn, kinetic_grad, step_size, num_steps=1, z_grads=None
Expand Down Expand Up @@ -74,15 +80,49 @@ def potential_grad(potential_fn, z):
node.requires_grad_(True)
try:
potential_energy = potential_fn(z)
# deal with singular matrices
except RuntimeError as e:
if "singular" in str(e) or "input is not positive-definite" in str(e):
# handle exceptions as defined in the exception registry
except Exception as e:
if any(h(e) for h in _EXCEPTION_HANDLERS.values()):
grads = {k: v.new_zeros(v.shape) for k, v in z.items()}
return grads, z_nodes[0].new_tensor(float("nan"))
else:
raise e

grads = grad(potential_energy, z_nodes)
for node in z_nodes:
node.requires_grad_(False)
return dict(zip(z_keys, grads)), potential_energy.detach()


def register_exception_handler(
name: str, handler: Callable[[Exception], bool], overwrite: bool = False
) -> None:
"""
Register an exception handler for handling (primarily numerical) errors
when evaluating the potential function.
:param name: name of the handler (must be unique).
:param handler: A callable mapping an exception to a boolean. Exceptions
that evaluate to true in any of the handlers are handled in teh computation
of the potential energy.
:param overwrite: If True, overwrite handlers already registerd under the
provided name.
"""
if name in _EXCEPTION_HANDLERS and not overwrite:
raise RuntimeError(
f"Exception handler already registered under key {name}. "
"Use `overwrite=True` to force overwriting the handler."
)
_EXCEPTION_HANDLERS[name] = handler


def _handle_torch_singular(exception: Exception) -> bool:
"""Exception handler for errors thrown on (numerically) singular matrices."""
if type(exception) == RuntimeError:
return "singular" in str(exception) or "input is not positive-definite" in str(
exception
)
return False


# Register default exception handler
register_exception_handler("torch_singular", _handle_torch_singular)

0 comments on commit fb580d1

Please sign in to comment.