Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions torchax/test/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""If you don't know which file a test should go, and don't want to make a new file
for a small test. PUt it here
"""
import torch
import unittest
import torchax
import jax
import jax.numpy as jnp


class MiscTest(unittest.TestCase):

def test_extract_jax_kwargs(self):

class M(torch.nn.Module):

def forward(self, a, b):
return torch.sin(a) + torch.cos(b)

weights, func = torchax.extract_jax(M())
res = func(
weights,
args=(),
kwargs={
'a': jnp.array([1, 2, 3]),
'b': jnp.array([3, 4, 5])
})
self.assertTrue(
jnp.allclose(
res,
jnp.sin(jnp.array([1, 2, 3])) + jnp.cos(jnp.array([3, 4, 5]))))

def test_to_device(self):
env = torchax.default_env()
env.config.debug_print_each_op = True
with env:
step1 = torch.ones(
100,
100,
)
step2 = torch.triu(step1, diagonal=1)
step3 = step2.to(dtype=torch.bool, device='jax')
self.assertEqual(step3.device.type, 'jax')


if __name__ == '__main__':
unittest.main()
6 changes: 3 additions & 3 deletions torchax/torchax/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,12 +469,12 @@ def _to_copy(self, the_tensor, new_dtype, new_device):
arr = self.t2j_copy(the_tensor)
res = Tensor(arr, self, the_tensor.requires_grad)

if new_dtype is not None and new_dtype != the_tensor.dtype:
if isinstance(the_tensor, Tensor):
if new_dtype is not None and new_dtype != res.dtype:
if isinstance(res, Tensor):
res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype))
else:
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
return the_tensor.to(device=new_device, dtype=new_dtype)
return res.to(device=new_device, dtype=new_dtype)
return res

def get_and_rotate_prng_key(self,
Expand Down