diff --git a/lab/autograd/__init__.py b/lab/autograd/__init__.py index bcf6936..ddec35b 100644 --- a/lab/autograd/__init__.py +++ b/lab/autograd/__init__.py @@ -12,6 +12,8 @@ from plum import clear_all_cache as _clear_all_cache +import autograd # Load `autograd` to load all new types. + # noinspection PyUnresolvedReferences from .generic import * from .linear_algebra import * diff --git a/lab/jax/__init__.py b/lab/jax/__init__.py index 81fdbdd..9f5474f 100644 --- a/lab/jax/__init__.py +++ b/lab/jax/__init__.py @@ -12,6 +12,8 @@ from plum import clear_all_cache as _clear_all_cache +import jax # Load `jax` to load all new types. + # noinspection PyUnresolvedReferences from .generic import * from .linear_algebra import * diff --git a/lab/tensorflow/__init__.py b/lab/tensorflow/__init__.py index 153f639..717533c 100644 --- a/lab/tensorflow/__init__.py +++ b/lab/tensorflow/__init__.py @@ -12,7 +12,7 @@ from plum import clear_all_cache as _clear_all_cache -import tensorflow as tf +import tensorflow as tf # Load `tensorflow` to load all new types. # noinspection PyUnresolvedReferences from .generic import * diff --git a/lab/torch/__init__.py b/lab/torch/__init__.py index 9fd9029..ea99be9 100644 --- a/lab/torch/__init__.py +++ b/lab/torch/__init__.py @@ -12,6 +12,8 @@ from plum import clear_all_cache as _clear_all_cache +import torch # Load `torch` to load all new types. + # noinspection PyUnresolvedReferences from .generic import * from .linear_algebra import * diff --git a/lab/types.py b/lab/types.py index 927ceb6..f0784d8 100644 --- a/lab/types.py +++ b/lab/types.py @@ -83,12 +83,11 @@ def _module_attr(module, attr): _ag_tensor = ModuleType("autograd.tracer", "Box") # Define JAX module types. -if ( - sys.version_info.minor > 7 -): # jax>0.4 deprecated python-3.7 support, rely on older jax versions - _jax_tensor = ModuleType("jaxlib.xla_extension", "ArrayImpl") -else: +if sys.version_info.minor <= 7: + # `jax` 0.4 deprecated Python 3.7 support. Rely on older JAX versions. _jax_tensor = ModuleType("jax.interpreters.xla", "DeviceArray") +else: + _jax_tensor = ModuleType("jaxlib.xla_extension", "ArrayImpl") _jax_tracer = ModuleType("jax.core", "Tracer") _jax_dtype = ModuleType("jax._src.numpy.lax_numpy", "_ScalarMeta") _jax_device = ModuleType("jaxlib.xla_extension", "Device") @@ -156,31 +155,34 @@ def _module_attr(module, attr): _torch_lookup_cache = {} +def _name_to_numpy_dtype(name): + # We will want to get types from `np`, but the built-in types should be just + # those. + if name in {"int", "long"}: + return int + elif name == "bool": + return bool + elif name == "unicode": + return str + else: + return getattr(np, name) + + def _torch_lookup(dtype): if not _torch_lookup_cache: # Cache is empty. Fill it. - def _from_np(name): - # We will want to get types from `np`, but the built-in types should be just - # those. - if name in {"int", "long"}: - return int - elif name == "bool": - return bool - elif name == "unicode": - return str - else: - return getattr(np, name) - # `bool` can occur but isn't in `__all__`. for name in np.core.numerictypes.__all__ + ["bool"]: + _from_np = _name_to_numpy_dtype(name) + # Check that it is a type. - if not isinstance(_from_np(name), type): + if not isinstance(_from_np, type): continue # Attempt to get the PyTorch equivalent. try: - _torch_lookup_cache[_module_attr("torch", name)] = _from_np(name) + _torch_lookup_cache[_module_attr("torch", name)] = _from_np except AttributeError: # Could not find the PyTorch equivalent. That's okay. pass @@ -355,12 +357,7 @@ def dtype_int(dtype: DType): name = list(convert(dtype, NPDType).__name__) while name and name[0] not in set([str(i) for i in range(10)]): name.pop(0) - name_int = "int" + "".join(name) - if name_int == "int": - dtype_int = int - else: - dtype_int = getattr(np, name_int) - return _convert_back(dtype_int, dtype) + return _convert_back(_name_to_numpy_dtype("int" + "".join(name)), dtype) @dispatch