From bb541807d732f53db875878bcd5132493953f248 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Mon, 25 Aug 2025 16:41:39 +0000 Subject: [PATCH] Fix case when both device & dtype are given in .to --- torchax/test/test_misc.py | 47 +++++++++++++++++++++++++++++++++++++++ torchax/torchax/tensor.py | 6 ++--- 2 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 torchax/test/test_misc.py diff --git a/torchax/test/test_misc.py b/torchax/test/test_misc.py new file mode 100644 index 00000000000..b93877a7fd6 --- /dev/null +++ b/torchax/test/test_misc.py @@ -0,0 +1,47 @@ +"""If you don't know which file a test should go, and don't want to make a new file +for a small test. PUt it here +""" +import torch +import unittest +import torchax +import jax +import jax.numpy as jnp + + +class MiscTest(unittest.TestCase): + + def test_extract_jax_kwargs(self): + + class M(torch.nn.Module): + + def forward(self, a, b): + return torch.sin(a) + torch.cos(b) + + weights, func = torchax.extract_jax(M()) + res = func( + weights, + args=(), + kwargs={ + 'a': jnp.array([1, 2, 3]), + 'b': jnp.array([3, 4, 5]) + }) + self.assertTrue( + jnp.allclose( + res, + jnp.sin(jnp.array([1, 2, 3])) + jnp.cos(jnp.array([3, 4, 5])))) + + def test_to_device(self): + env = torchax.default_env() + env.config.debug_print_each_op = True + with env: + step1 = torch.ones( + 100, + 100, + ) + step2 = torch.triu(step1, diagonal=1) + step3 = step2.to(dtype=torch.bool, device='jax') + self.assertEqual(step3.device.type, 'jax') + + +if __name__ == '__main__': + unittest.main() diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 67bc074177e..a325c51dfc1 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -469,12 +469,12 @@ def _to_copy(self, the_tensor, new_dtype, new_device): arr = self.t2j_copy(the_tensor) res = Tensor(arr, self, the_tensor.requires_grad) - if new_dtype is not None and new_dtype != the_tensor.dtype: - if isinstance(the_tensor, Tensor): + if new_dtype is not None and new_dtype != res.dtype: + if isinstance(res, Tensor): res = res.apply_jax(jnp.astype, mappings.t2j_dtype(new_dtype)) else: with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return the_tensor.to(device=new_device, dtype=new_dtype) + return res.to(device=new_device, dtype=new_dtype) return res def get_and_rotate_prng_key(self,