Skip to content

Commit

Permalink
Add Complex and dtype_int
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Dec 14, 2021
1 parent 18fa9e0 commit c13718e
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 2 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ True
```
Int # Integers
Float # Floating-point numbers
Complex # Complex numbers
Bool # Booleans
Number # Numbers
Numeric # Numerical objects, including booleans
Expand Down Expand Up @@ -251,6 +252,7 @@ log_2_pi
```
dtype(a)
dtype_float(a)
dtype_int(a)
promote_dtypes(*dtype)
issubdtype(dtype1, dtype2)
Expand Down
8 changes: 7 additions & 1 deletion lab/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from . import B, dispatch

__all__ = ["Shape", "Dimension", "dispatch_unwrap_dimensions"]
__all__ = ["Shape", "Dimension", "unwrap_dimension", "dispatch_unwrap_dimensions"]

_dispatch = Dispatcher()

Expand Down Expand Up @@ -131,6 +131,12 @@ def __truediv__(self, other):
def __rtruediv__(self, other):
return other / self.dim

def __floordiv__(self, other):
return self.dim // other

def __rfloordiv__(self, other):
return other // self.dim

def __neg__(self):
return -self.dim

Expand Down
22 changes: 21 additions & 1 deletion lab/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
__all__ = [
"Int",
"Float",
"Complex",
"Bool",
"Number",
"NPNumeric",
Expand Down Expand Up @@ -46,6 +47,7 @@
"issubdtype",
"promote_dtypes",
"dtype_float",
"dtype_int",
"NP",
"AG",
"TF",
Expand Down Expand Up @@ -143,8 +145,9 @@ def _module_attr(module, attr):
# Numeric types:
Int = Union(*([int, Dimension] + np.sctypes["int"] + np.sctypes["uint"]), alias="Int")
Float = Union(*([float] + np.sctypes["float"]), alias="Float")
Complex = Union(*([complex] + np.sctypes["complex"]), alias="Complex")
Bool = Union(bool, np.bool_, alias="Bool")
Number = Union(Int, Bool, Float, alias="Number")
Number = Union(Int, Bool, Float, Complex, alias="Number")
NPNumeric = Union(np.ndarray, alias="NPNumeric")
AGNumeric = Union(_ag_tensor, alias="AGNumeric")
TFNumeric = Union(_tf_tensor, _tf_variable, _tf_indexedslices, alias="TFNumeric")
Expand Down Expand Up @@ -345,6 +348,23 @@ def dtype_float(x):
return promote_dtypes(dtype(x), np.float16)


@dispatch
def dtype_int(x):
"""Get the data type of an object and get the integer equivalent.
Args:
x (object): Object to get data type of.
Returns:
dtype: Data type of `x`, but ensured to be integer.
"""
x_dtype = dtype(x)
name = list(convert(x_dtype, NPDType).__name__)
while name and name[0] not in set([str(i) for i in range(10)]):
name.pop(0)
return _convert_back(getattr(np, "int" + "".join(name)), x_dtype)


# Random state types:
NPRandomState = Union(np.random.RandomState, alias="NPRandomState")
AGRandomState = Union(NPRandomState, alias="AGRandomState")
Expand Down
2 changes: 2 additions & 0 deletions tests/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def test_dimension():
assert isinstance(d / 5, float)
assert d / 5 == 1
assert 5 / d == 1
assert d // 2 == 2
assert 11 // d == 2
assert -d is -5
assert d ** 2 is 25

Expand Down
17 changes: 17 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@ def test_numeric(check_lazy_shapes):
assert isinstance(1, B.Int)
assert isinstance(np.int32(1), B.Int)
assert isinstance(np.uint64(1), B.Int)

assert isinstance(1.0, B.Float)
assert isinstance(np.float32(1), B.Float)

assert isinstance(1 + 0j, B.Complex)
assert isinstance(np.complex64(1), B.Complex)

assert isinstance(True, B.Bool)
assert isinstance(np.bool_(True), B.Bool)

assert isinstance(np.uint(1), B.Number)
assert isinstance(np.float64(1), B.Number)
assert isinstance(np.complex64(1), B.Number)

# Test NumPy.
assert isinstance(np.array(1), B.NPNumeric)
Expand Down Expand Up @@ -180,6 +187,16 @@ def test_dtype_float(check_lazy_shapes):
assert B.dtype_float(1) is np.float64


def test_dtype_int(check_lazy_shapes):
assert B.dtype_int(np.float32(1)) is np.int32
assert B.dtype_int(np.float64(1)) is np.int64
assert B.dtype_int(1) is int
# Test conversion back to right framework type. This conversion is thoroughly
# tested for `B.promote_dtypes`.
assert B.dtype_float(tf.constant(1.0, dtype=tf.float32)) is tf.int32
assert B.dtype_float(tf.constant(1.0, dtype=tf.float64)) is tf.int64


@pytest.mark.parametrize(
"t, FWRandomState",
[
Expand Down

0 comments on commit c13718e

Please sign in to comment.