Skip to content

Commit

Permalink
Fix deprecation warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed May 20, 2022
1 parent 5d40300 commit 3d736a2
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 5 deletions.
2 changes: 1 addition & 1 deletion lab/numpy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,4 +331,4 @@ def argsort(a: Numeric, axis: Int = -1, descending: bool = False):

@dispatch
def quantile(a: Numeric, q: Numeric, axis: Union[Int, None] = None):
return np.quantile(a, q, axis=axis, interpolation="linear")
return np.quantile(a, q, axis=axis, method="linear")
2 changes: 1 addition & 1 deletion lab/torch/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def cholesky_solve(a: Numeric, b: Numeric):

@dispatch
def triangular_solve(a: Numeric, b: Numeric, lower_a: bool = True):
return torch.triangular_solve(b, a, upper=not lower_a)[0]
return torch.linalg.solve_triangular(a, b, upper=not lower_a)


_toeplitz_solve = torch_register(toeplitz_solve, s_toeplitz_solve)
Expand Down
16 changes: 14 additions & 2 deletions lab/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,27 @@ def _torch_lookup(dtype):
if not _torch_lookup_cache:
# Cache is empty. Fill it.

def _from_np(name):
# We will want to get types from `np`, but the built-in types should be just
# those.
if name in {"int", "long"}:
return int
elif name == "bool":
return bool
elif name == "unicode":
return str
else:
return getattr(np, name)

# `bool` can occur but isn't in `__all__`.
for name in np.core.numerictypes.__all__ + ["bool"]:
# Check that it is a type.
if not isinstance(getattr(np, name), type):
if not isinstance(_from_np(name), type):
continue

# Attempt to get the PyTorch equivalent.
try:
_torch_lookup_cache[_module_attr("torch", name)] = getattr(np, name)
_torch_lookup_cache[_module_attr("torch", name)] = _from_np(name)
except AttributeError:
# Could not find the PyTorch equivalent. That's okay.
pass
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,16 @@ def test_take_consistency(check_lazy_shapes):
{"axis": Value(0, 1, -1)},
)

# Test PyTorch separately, because it has a separate implementation for framework
# masks or indices.
for indices_or_mask in [
torch.tensor([True, True, False], dtype=torch.bool),
torch.tensor([0, 1], dtype=torch.int32),
torch.tensor([0, 1], dtype=torch.int64),
]:
a = B.randn(torch.float32, 3, 3)
approx(B.take(a, indices_or_mask), a[[0, 1]])


def test_take_consistency_order(check_lazy_shapes):
# Check order of indices.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_data_type(check_lazy_shapes):
assert convert(jnp.float32, B.TorchDType) is torch.float32

# `torch.bool` has a manual addition, so test it separately.
assert convert(torch.bool, B.NPDType) is np.bool
assert convert(torch.bool, B.NPDType) is bool


def test_dtype(check_lazy_shapes):
Expand Down

0 comments on commit 3d736a2

Please sign in to comment.