Skip to content

Commit

Permalink
Fix bug and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Jun 22, 2021
1 parent 7288f58 commit 370bc76
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
1 change: 1 addition & 0 deletions lab/generic.py
Expand Up @@ -295,6 +295,7 @@ def to_active_device(a: Numeric): # pragma: no cover
"""


@dispatch
def to_active_device(a: Number):
return a

Expand Down
10 changes: 8 additions & 2 deletions tests/test_generic.py
@@ -1,3 +1,4 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest
Expand Down Expand Up @@ -130,13 +131,18 @@ def test_to_active_device_jax(check_lazy_shapes):
# No device specified: should do nothing.
assert B.to_active_device(a) is a

# Move to JAX device.
with B.on_device(jax.devices("cpu")[0]):
assert B.to_active_device(a) is not a
approx(B.to_active_device(a), a)

# Move to CPU without identifier.
with B.device("cpu"):
with B.on_device("cpu"):
assert B.to_active_device(a) is not a
approx(B.to_active_device(a), a)

# Move to CPU with identifier. Also check that capitalisation does not matter.
with B.device("CPU:0"):
with B.on_device("CPU:0"):
assert B.to_active_device(a) is not a
approx(B.to_active_device(a), a)

Expand Down

0 comments on commit 370bc76

Please sign in to comment.