Skip to content

Commit

Permalink
[SymForce] Warn on symengine not found
Browse files Browse the repository at this point in the history
We currently fall back silently, I think this is better. Made it easy
to silence this either by calling symforce.set_symbolic_api or setting
the environment variable, since I know there are users who build
symforce from source without sympy. This should absolutely be something
you do intentionally though, not something that happens silently,
because the performance is so much worse.

GitOrigin-RevId: 124ddd94d51e48997ff82e226286a55de53e3d4d
  • Loading branch information
aaron-skydio authored and nathan-skydio committed Dec 19, 2023
1 parent f893581 commit 2650c4b
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 79 deletions.
35 changes: 18 additions & 17 deletions symforce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import sys
import typing as T
import warnings
from dataclasses import dataclass
from types import ModuleType

# -------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -139,11 +139,11 @@ def _find_symengine() -> ModuleType:
return symengine


_symbolic_api: T.Optional[str] = None
_symbolic_api: T.Optional[T.Literal["sympy", "symengine"]] = None
_have_imported_symbolic = False


def _set_symbolic_api(sympy_module: str) -> None:
def _set_symbolic_api(sympy_module: T.Literal["sympy", "symengine"]) -> None:
# Set this as the default symbolic API
global _symbolic_api # pylint: disable=global-statement
_symbolic_api = sympy_module
Expand Down Expand Up @@ -213,16 +213,15 @@ def set_symbolic_api(name: str) -> None:
logger.debug("No SYMFORCE_SYMBOLIC_API set, found and using symengine.")
set_symbolic_api("symengine")
except ImportError:
logger.debug("No SYMFORCE_SYMBOLIC_API set, no symengine found, using sympy.")
set_symbolic_api("sympy")
logger.debug("No SYMFORCE_SYMBOLIC_API set, no symengine found. Will use sympy.")
pass


def get_symbolic_api() -> str:
def get_symbolic_api() -> T.Literal["sympy", "symengine"]:
"""
Return the current symbolic API as a string.
"""
assert _symbolic_api is not None
return _symbolic_api
return _symbolic_api or "sympy"


# --------------------------------------------------------------------------------
Expand All @@ -241,7 +240,7 @@ class AlreadyUsedEpsilon(Exception):
pass


_epsilon = 0.0
_epsilon: T.Any = 0.0
_have_used_epsilon = False


Expand All @@ -267,6 +266,15 @@ def _set_epsilon(new_epsilon: T.Any) -> None:
_epsilon = new_epsilon


@dataclass
class SymbolicEpsilon:
"""
An indicator that SymForce should use a symbolic epsilon
"""

name: str


def set_epsilon_to_symbol(name: str = "epsilon") -> None:
"""
Set the default epsilon for Symforce to a Symbol.
Expand All @@ -277,14 +285,7 @@ def set_epsilon_to_symbol(name: str = "epsilon") -> None:
Args:
name: The name of the symbol for the new default epsilon to use
"""
if get_symbolic_api() == "sympy":
import sympy
elif get_symbolic_api() == "symengine":
sympy = _find_symengine()
else:
raise InvalidSymbolicApiError(get_symbolic_api())

_set_epsilon(sympy.Symbol(name))
_set_epsilon(SymbolicEpsilon(name))


def set_epsilon_to_number(value: T.Any = numeric_epsilon) -> None:
Expand Down
148 changes: 86 additions & 62 deletions symforce/internal/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@
from symforce import logger
from symforce import typing as T

if symforce._symbolic_api is None: # pylint: disable=protected-access
import textwrap

logger.warning(
textwrap.dedent(
"""
No SYMFORCE_SYMBOLIC_API set, no symengine found, using sympy.
For best performance during symbolic manipulation and code generation, use the symengine
symbolic API, which should come installed with SymForce. If you've installed SymForce
without intentionally disabling SymEngine, and you are seeing this warning, this is a
bug - please submit an issue: https://github.com/symforce-org/symforce/issues.
To silence this warning, call `symforce.set_symbolic_api("sympy")` or set
`SYMFORCE_SYMBOLIC_API=sympy` in your environment before importing `symforce.symbolic`.
"""
)
)

# See `symforce/__init__.py` for more information, this is used to check whether things that this
# module depends on are modified after importing
symforce._have_imported_symbolic = True # pylint: disable=protected-access
Expand Down Expand Up @@ -246,42 +265,45 @@


# --------------------------------------------------------------------------------
# Default epsilon
# Create scopes
# --------------------------------------------------------------------------------


from symforce import numeric_epsilon
def create_named_scope(scopes_list: T.List[str]) -> T.Callable:
"""
Return a context manager that adds to the given list of name scopes. This is used to
add scopes to symbol names for namespacing.
"""

@contextlib.contextmanager
def named_scope(scope: str) -> T.Iterator[None]:
scopes_list.append(scope)

def epsilon() -> T.Any:
"""
The default epsilon for SymForce
# The body of the with block is executed inside the yield, this ensures we release the
# scope if something in the block throws
try:
yield None
finally:
scopes_list.pop()

Library functions that require an epsilon argument should use a function signature like::
return named_scope

def foo(x: Scalar, epsilon: Scalar = sf.epsilon()) -> Scalar:
...

This makes it easy to configure entire expressions that make extensive use of epsilon to either
use no epsilon (i.e. 0), or a symbol, or a numerical value. It also means that by setting the
default to a symbol, you can confidently generate code without worrying about having forgotten
to pass an epsilon argument to one of these functions.
# Nested scopes created with `sf.scope`, initialized to empty (symbols created with no added scope)
__scopes__ = []

For more information on how we use epsilon to prevent singularities, see the Epsilon Tutorial
in the SymForce docs here: https://symforce.org/tutorials/epsilon_tutorial.html

For purely numerical code that just needs a good default numerical epsilon, see
:data:`symforce.symbolic.numeric_epsilon`.
def set_scope(scope: str) -> None:
global __scopes__ # pylint: disable=global-statement
__scopes__ = scope.split(".") if scope else []

Returns:
The current default epsilon. This is typically some kind of "Scalar", like a float or a
:class:`Symbol <symforce.symbolic.Symbol>`.
"""
symforce._have_used_epsilon = True # pylint: disable=protected-access

return symforce._epsilon # pylint: disable=protected-access
def get_scope() -> str:
return ".".join(__scopes__)


scope = create_named_scope(__scopes__)

# --------------------------------------------------------------------------------
# Override Symbol and symbols
# --------------------------------------------------------------------------------
Expand Down Expand Up @@ -343,6 +365,48 @@ def new_symbol(
raise symforce.InvalidSymbolicApiError(sympy.__package__)


# --------------------------------------------------------------------------------
# Default epsilon
# --------------------------------------------------------------------------------


from symforce import numeric_epsilon


def epsilon() -> T.Any:
"""
The default epsilon for SymForce
Library functions that require an epsilon argument should use a function signature like::
def foo(x: Scalar, epsilon: Scalar = sf.epsilon()) -> Scalar:
...
This makes it easy to configure entire expressions that make extensive use of epsilon to either
use no epsilon (i.e. 0), or a symbol, or a numerical value. It also means that by setting the
default to a symbol, you can confidently generate code without worrying about having forgotten
to pass an epsilon argument to one of these functions.
For more information on how we use epsilon to prevent singularities, see the Epsilon Tutorial
in the SymForce docs here: https://symforce.org/tutorials/epsilon_tutorial.html
For purely numerical code that just needs a good default numerical epsilon, see
:data:`symforce.symbolic.numeric_epsilon`.
Returns:
The current default epsilon. This is typically some kind of "Scalar", like a float or a
:class:`Symbol <symforce.symbolic.Symbol>`.
"""
# pylint: disable=protected-access

if isinstance(symforce._epsilon, symforce.SymbolicEpsilon):
symforce._epsilon = sympy.Symbol(symforce._epsilon.name)

symforce._have_used_epsilon = True

return symforce._epsilon


# --------------------------------------------------------------------------------
# Typing
# --------------------------------------------------------------------------------
Expand Down Expand Up @@ -683,43 +747,3 @@ def _get_subs_dict(*args: T.Any, dont_flatten_args: bool = False, **kwargs: T.An
)
else:
raise symforce.InvalidSymbolicApiError(sympy.__package__)

# --------------------------------------------------------------------------------
# Create scopes
# --------------------------------------------------------------------------------


def create_named_scope(scopes_list: T.List[str]) -> T.Callable:
"""
Return a context manager that adds to the given list of name scopes. This is used to
add scopes to symbol names for namespacing.
"""

@contextlib.contextmanager
def named_scope(scope: str) -> T.Iterator[None]:
scopes_list.append(scope)

# The body of the with block is executed inside the yield, this ensures we release the
# scope if something in the block throws
try:
yield None
finally:
scopes_list.pop()

return named_scope


# Nested scopes created with `sf.scope`, initialized to empty (symbols created with no added scope)
__scopes__ = []


def set_scope(scope: str) -> None:
global __scopes__ # pylint: disable=global-statement
__scopes__ = scope.split(".") if scope else []


def get_scope() -> str:
return ".".join(__scopes__)


scope = create_named_scope(__scopes__)

0 comments on commit 2650c4b

Please sign in to comment.