Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,14 @@ inputs = torch.randn(3, 3, 28, 28, device='jax')
m = MyModel().to('jax')
res = m(inputs)
print(type(res)) # outputs torchax.tensor.Tensor
print(res.jax()) # print the underlying Jax Array
```

`torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds
a `jax.Array`. You can inspect that JAX array with `res.jax()`.

In other words, despite that the code above looks like PyTorch, it is actually running JAX!

## What is happening behind the scene

We took the approach detailed in the
Expand Down
4 changes: 2 additions & 2 deletions torchax/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-f https://download.pytorch.org/whl/torch
torch==2.7.1 ; sys_platform == 'darwin' # macOS
torch==2.7.1+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU
torch==2.8.0 ; sys_platform == 'darwin' # macOS
torch==2.8.0+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU
yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml`
flax==0.10.6
106 changes: 0 additions & 106 deletions torchax/examples/_diffusion.py

This file was deleted.

76 changes: 0 additions & 76 deletions torchax/examples/_grad_of_attention.py

This file was deleted.

52 changes: 0 additions & 52 deletions torchax/examples/torchbench_models/BERT_pytorch.py

This file was deleted.

4 changes: 0 additions & 4 deletions torchax/examples/train_gpt/requirements.txt

This file was deleted.

3 changes: 0 additions & 3 deletions torchax/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,3 @@ odml = ["jax[cpu]>=0.6.2", "jax[cpu]"]

[tool.hatch.build.targets.wheel]
packages = ["torchax"]

[tool.pytest.ini_options]
addopts="-n auto"
7 changes: 3 additions & 4 deletions torchax/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
absl-py==2.2.2
immutabledict==4.2.1
pytest==8.3.5
pytest-xdist==3.6.1
pytest-forked==1.6.0
sentencepiece==0.2.0
sentencepiece
expecttest==0.3.0
optax==0.2.4
tensorflow==2.19.0
pytest
pytest-xdist
12 changes: 12 additions & 0 deletions torchax/test/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ def forward(self, a, b):
jnp.sin(jnp.array([1, 2, 3])) + jnp.cos(jnp.array([3, 4, 5]))))

def test_to_device(self):
env = torchax.default_env()
with env:
step1 = torch.ones(
100,
100,
)
step2 = torch.triu(step1, diagonal=1)
step3 = step2.to(dtype=torch.bool, device='jax')
self.assertEqual(step3.device.type, 'jax')

def test_to_device_twice(self):
env = torchax.default_env()
env.config.debug_print_each_op = True
with env:
Expand All @@ -40,6 +51,7 @@ def test_to_device(self):
)
step2 = torch.triu(step1, diagonal=1)
step3 = step2.to(dtype=torch.bool, device='jax')
step3.to('jax')
self.assertEqual(step3.device.type, 'jax')


Expand Down
51 changes: 0 additions & 51 deletions torchax/test/test_tf_integration.py

This file was deleted.

27 changes: 27 additions & 0 deletions torchax/test_dist/test_to_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import jax
import torch
import torchax
import unittest

from jax.sharding import NamedSharding, PartitionSpec as P


class ToDeviceTest(unittest.TestCase):

def test_to_device_twice(self):
env = torchax.default_env()
mesh = jax.make_mesh((jax.device_count(),), ('axis',))
with env:
step1 = torch.ones(
100,
100,
)
step2 = torch.triu(step1, diagonal=1)
step3 = step2.to(dtype=torch.bool, device='jax')
step3.apply_jax_(jax.device_put, NamedSharding(mesh, P()))
print(step3.to('jax'))
self.assertEqual(step3.device.type, 'jax')


if __name__ == '__main__':
unittest.main()
Loading