Skip to content

Commit

Permalink
Fix dtype_int and make loading more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Mar 7, 2023
1 parent b4599ba commit 474b625
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 26 deletions.
2 changes: 2 additions & 0 deletions lab/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from plum import clear_all_cache as _clear_all_cache

import autograd # Load `autograd` to load all new types.

# noinspection PyUnresolvedReferences
from .generic import *
from .linear_algebra import *
Expand Down
2 changes: 2 additions & 0 deletions lab/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from plum import clear_all_cache as _clear_all_cache

import jax # Load `jax` to load all new types.

# noinspection PyUnresolvedReferences
from .generic import *
from .linear_algebra import *
Expand Down
2 changes: 1 addition & 1 deletion lab/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from plum import clear_all_cache as _clear_all_cache

import tensorflow as tf
import tensorflow as tf # Load `tensorflow` to load all new types.

# noinspection PyUnresolvedReferences
from .generic import *
Expand Down
2 changes: 2 additions & 0 deletions lab/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from plum import clear_all_cache as _clear_all_cache

import torch # Load `torch` to load all new types.

# noinspection PyUnresolvedReferences
from .generic import *
from .linear_algebra import *
Expand Down
47 changes: 22 additions & 25 deletions lab/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,11 @@ def _module_attr(module, attr):
_ag_tensor = ModuleType("autograd.tracer", "Box")

# Define JAX module types.
if (
sys.version_info.minor > 7
): # jax>0.4 deprecated python-3.7 support, rely on older jax versions
_jax_tensor = ModuleType("jaxlib.xla_extension", "ArrayImpl")
else:
if sys.version_info.minor <= 7:
# `jax` 0.4 deprecated Python 3.7 support. Rely on older JAX versions.
_jax_tensor = ModuleType("jax.interpreters.xla", "DeviceArray")
else:
_jax_tensor = ModuleType("jaxlib.xla_extension", "ArrayImpl")
_jax_tracer = ModuleType("jax.core", "Tracer")
_jax_dtype = ModuleType("jax._src.numpy.lax_numpy", "_ScalarMeta")
_jax_device = ModuleType("jaxlib.xla_extension", "Device")
Expand Down Expand Up @@ -156,31 +155,34 @@ def _module_attr(module, attr):
_torch_lookup_cache = {}


def _name_to_numpy_dtype(name):
# We will want to get types from `np`, but the built-in types should be just
# those.
if name in {"int", "long"}:
return int
elif name == "bool":
return bool
elif name == "unicode":
return str
else:
return getattr(np, name)


def _torch_lookup(dtype):
if not _torch_lookup_cache:
# Cache is empty. Fill it.

def _from_np(name):
# We will want to get types from `np`, but the built-in types should be just
# those.
if name in {"int", "long"}:
return int
elif name == "bool":
return bool
elif name == "unicode":
return str
else:
return getattr(np, name)

# `bool` can occur but isn't in `__all__`.
for name in np.core.numerictypes.__all__ + ["bool"]:
_from_np = _name_to_numpy_dtype(name)

# Check that it is a type.
if not isinstance(_from_np(name), type):
if not isinstance(_from_np, type):
continue

# Attempt to get the PyTorch equivalent.
try:
_torch_lookup_cache[_module_attr("torch", name)] = _from_np(name)
_torch_lookup_cache[_module_attr("torch", name)] = _from_np
except AttributeError:
# Could not find the PyTorch equivalent. That's okay.
pass
Expand Down Expand Up @@ -355,12 +357,7 @@ def dtype_int(dtype: DType):
name = list(convert(dtype, NPDType).__name__)
while name and name[0] not in set([str(i) for i in range(10)]):
name.pop(0)
name_int = "int" + "".join(name)
if name_int == "int":
dtype_int = int
else:
dtype_int = getattr(np, name_int)
return _convert_back(dtype_int, dtype)
return _convert_back(_name_to_numpy_dtype("int" + "".join(name)), dtype)


@dispatch
Expand Down

0 comments on commit 474b625

Please sign in to comment.