From 07832d618c84908d7d15a1fc4d5157cbf7fafb05 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Wed, 3 Sep 2025 20:02:09 +0000 Subject: [PATCH 1/3] Miscelanous cleanup --- pyproject.toml | 5 + torchax/README.md | 3 + torchax/dev-requirements.txt | 4 +- torchax/examples/_diffusion.py | 106 ---------------- torchax/examples/_grad_of_attention.py | 76 ----------- .../torchbench_models/BERT_pytorch.py | 52 -------- torchax/examples/train_gpt/requirements.txt | 4 - torchax/pyproject.toml | 3 - torchax/test-requirements.txt | 5 +- torchax/test/test_misc.py | 14 ++- torchax/test/test_tf_integration.py | 51 -------- torchax/test_dist/test_to_device.py | 26 ++++ torchax/torchax/CONTRIBUTING.md | 15 ++- torchax/torchax/tf_integration.py | 119 ------------------ 14 files changed, 60 insertions(+), 423 deletions(-) delete mode 100644 torchax/examples/_diffusion.py delete mode 100644 torchax/examples/_grad_of_attention.py delete mode 100644 torchax/examples/torchbench_models/BERT_pytorch.py delete mode 100644 torchax/examples/train_gpt/requirements.txt delete mode 100644 torchax/test/test_tf_integration.py create mode 100644 torchax/test_dist/test_to_device.py delete mode 100644 torchax/torchax/tf_integration.py diff --git a/pyproject.toml b/pyproject.toml index 6d76b2a22e99..88b68d55cde8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,3 +46,8 @@ dynamic = [ Homepage = "https://github.com/pytorch/xla" Repository = "https://github.com/pytorch/xla" "Bug Tracker" = "https://github.com/pytorch/xla/issues" + +[tool.uv.workspace] +members = [ + "torchax2", +] diff --git a/torchax/README.md b/torchax/README.md index 06d9e26d7dcd..2b1fa8d58f33 100644 --- a/torchax/README.md +++ b/torchax/README.md @@ -97,11 +97,14 @@ inputs = torch.randn(3, 3, 28, 28, device='jax') m = MyModel().to('jax') res = m(inputs) print(type(res)) # outputs torchax.tensor.Tensor +print(res.jax()) # print the underlying Jax Array ``` `torchax.tensor.Tensor` is a `torch.Tensor` subclass that holds a `jax.Array`. You can inspect that JAX array with `res.jax()`. +In other words, despite that the code above looks like PyTorch, it is actually running JAX! + ## What is happening behind the scene We took the approach detailed in the diff --git a/torchax/dev-requirements.txt b/torchax/dev-requirements.txt index 7c0020e5156e..2da02ae8599b 100644 --- a/torchax/dev-requirements.txt +++ b/torchax/dev-requirements.txt @@ -1,5 +1,5 @@ -f https://download.pytorch.org/whl/torch -torch==2.7.1 ; sys_platform == 'darwin' # macOS -torch==2.7.1+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU +torch==2.8.0 ; sys_platform == 'darwin' # macOS +torch==2.8.0+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml` flax==0.10.6 diff --git a/torchax/examples/_diffusion.py b/torchax/examples/_diffusion.py deleted file mode 100644 index 9f7578056b06..000000000000 --- a/torchax/examples/_diffusion.py +++ /dev/null @@ -1,106 +0,0 @@ -import functools - -import torch -from time import time -from diffusers import DiffusionPipeline -from torch.utils import _pytree as pytree - -import torchax -import torchax.functions -from torchax.extra import torch_view, jax_view - -import jax -import torch.func - - -class CompiledModule: - - def __init__(self, model): - weights = model.state_dict() - weights.update(model.named_parameters()) - self._weights = pytree.tree_map_only(torch.Tensor, - torchax.tensor.move_to_device, weights) - self._model = model - - self._func_jitted_torch = None #torch_view(func_mod_jitted) - - def _maybe_move_tensor(self, tensor): - if isinstance( - tensor, torch.Tensor) and not isinstance(tensor, torchax.tensor.Tensor): - return torchax.tensor.move_to_device(tensor) - return tensor - - def _make_jitted(self, args, kwargs): - static = [] - for i, a in enumerate(args): - if not isinstance(a, torch.Tensor): - static.append(i + 1) # weight is 0 - static_argnames = [] - for k, v in kwargs.items(): - if not isinstance(v, torch.Tensor): - static_argnames.append(k) - - def f(weights, *args, **kwargs): - weights, args, kwargs = torchax.tensor.wrap((weights, args, kwargs)) - with torchax.functions.XLAFunctionMode(), torchax.tensor.XLADispatchMode( - ): - res = torch.func.functional_call(self._model, weights, args, kwargs) - if isinstance(res, tuple) and len(res) == 1: - res = res[0] - return torchax.tensor.unwrap(res) - - fjit = jax.jit(f, static_argnames=tuple(static_argnames)) - return torch_view(fjit) - - def forward(self, *args, **kwargs): - (args, kwargs) = pytree.tree_map(self._maybe_move_tensor, (args, kwargs)) - if self._func_jitted_torch is None: - self._func_jitted_torch = self._make_jitted(args, kwargs) - return self._func_jitted_torch(self._weights, *args, **kwargs) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def __getattr__(self, key): - return getattr(self._model, key) - - -def compile_pipe(pipe): - pipe.text_encoder = CompiledModule(pipe.text_encoder) - pipe.text_encoder_2 = CompiledModule(pipe.text_encoder_2) - pipe.unet = CompiledModule(pipe.unet) - pipe.vae = CompiledModule(pipe.vae) - - -def main(): - pipe = DiffusionPipeline.from_pretrained( - # "stabilityai/stable-diffusion-xl-base-0.9", - "stabilityai/stable-diffusion-xl-base-1.0", - use_safetensors=True, - ) - compile_pipe(pipe) - - global_bs = 10 - inference_steps = 20 - resol = 1024 - prompts = ["a photo of an astronaut riding a horse on mars"] * global_bs - print( - f'global batch size {global_bs}', - f'inference steps {inference_steps}', - f'Image resolution {resol}', - flush=True) - - iters = 5 - for i in range(iters): - prompt = prompts - # print('per device prompts len',len(prompt)) - # prompt = prompts[rank] - start = time() - image = pipe( - prompt, num_inference_steps=inference_steps, height=resol, - width=resol).images[0] - print(f'Step {i} inference time {time()-start} sec', flush=True) - - -if __name__ == '__main__': - main() diff --git a/torchax/examples/_grad_of_attention.py b/torchax/examples/_grad_of_attention.py deleted file mode 100644 index 8a8882720837..000000000000 --- a/torchax/examples/_grad_of_attention.py +++ /dev/null @@ -1,76 +0,0 @@ -import jax.numpy as jnp -import jax -from jax.experimental.pallas.ops.tpu import flash_attention - -import torchax -from jax.experimental import mesh_utils -from torchax.ops.jtorch import _tpu_flash_attention - -env = torchax.default_env() -jax.config.update('jax_enable_x64', False) -env._mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh((4,)), - axis_names=("fsdp",), -) -env.use_flash_attention = True - -from torch.nn import functional as F - - -def attn(q, k, v): - q, k, v = env.j2t_iso((q, k, v)) - with env: - x = F.scaled_dot_product_attention(q, k, v, is_causal=True) - x = env.t2j_iso(x) - return jnp.sum(x) - - -import torch - - -class M(torch.nn.Module): - - def __init__(self): - super().__init__() - self.a = torch.nn.Linear(10, 10) - - def forward(self, x): - return self.a(x) - - -m = M() -from torchax.interop import JittableModule - -mjit = JittableModule(m) - -from torch.nn.utils import stateless - - -def f(weights, x): - res = mjit.functional_call('forward', weights, {}, (x,)) - return torch.sum(res) - - -def crossent(x, y): - x, y = env.j2t_iso((x, y)) - res = torch.func.functional_call(m, x, (y,)) - return env.t2j_iso(res) - - -graded = jax.value_and_grad(attn) - -shape = (4, 32, 128, 32) -q = jnp.ones(shape, dtype='bfloat16') -v = jnp.ones(shape, dtype='bfloat16') -k = jnp.ones(shape, dtype='bfloat16') - -env = torchax.default_env() -weights = env.t2j_iso(env.to_xla(mjit.params)) - -from torchax.interop import jax_view - -#print(jax.jit(graded).lower(q, v, k).as_text()) -print( - jax.jit(jax.grad(jax_view(f))).lower(weights, - jax.ShapeDtypeStruct( - (10,), 'float32')).as_text()) diff --git a/torchax/examples/torchbench_models/BERT_pytorch.py b/torchax/examples/torchbench_models/BERT_pytorch.py deleted file mode 100644 index 79ba47b5eaa7..000000000000 --- a/torchax/examples/torchbench_models/BERT_pytorch.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -import time -import torchax -import torchax.interop -import os -import importlib -import sys -import logging -import sys - -root = logging.getLogger() -root.setLevel(logging.DEBUG) - -handler = logging.StreamHandler(sys.stdout) -handler.setLevel(logging.DEBUG) -formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') -handler.setFormatter(formatter) -root.addHandler(handler) - -# NOTE: replace this patch below with your installation -TORCH_BENCH_PATH = os.path.expanduser('~/git/qihqi/benchmark') -# If your directory looks like this_file.py, benchmark/ -sys.path.append(TORCH_BENCH_PATH) -model_name = "torchbenchmark.models.BERT_pytorch" # replace this by the name of the model you're working on -module = importlib.import_module(model_name) -benchmark_cls = getattr(module, "Model", None) -benchmark = benchmark_cls( - test="eval", device="cpu") # test = train or eval device = cuda or cpu - -model, example = benchmark.get_module() - -env = torchax.default_env() -env.config.debug_print_each_op = False -model = env.to_xla(model) -example = env.to_xla(example) -with env: - start = time.perf_counter() - print(model(*example)) - end = time.perf_counter() - print('Eager mode time', end - start) - - -def func_call(state, example): - return torch.func.functional_call(model, state, example, tie_weights=False) - - -jitted = torchax.interop.jax_jit(func_call) -start = time.perf_counter() -print(func_call(model.state_dict(), example)) -end = time.perf_counter() -print('Jitted mode time', end - start) diff --git a/torchax/examples/train_gpt/requirements.txt b/torchax/examples/train_gpt/requirements.txt deleted file mode 100644 index d302f474acfa..000000000000 --- a/torchax/examples/train_gpt/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -tqdm -git+https://github.com/karpathy/minGPT.git@master -datasets -tiktoken diff --git a/torchax/pyproject.toml b/torchax/pyproject.toml index 9407829b76ed..2f30f30e7c68 100644 --- a/torchax/pyproject.toml +++ b/torchax/pyproject.toml @@ -48,6 +48,3 @@ odml = ["jax[cpu]>=0.6.2", "jax[cpu]"] [tool.hatch.build.targets.wheel] packages = ["torchax"] - -[tool.pytest.ini_options] -addopts="-n auto" diff --git a/torchax/test-requirements.txt b/torchax/test-requirements.txt index c64af1807b7b..854a216a074c 100644 --- a/torchax/test-requirements.txt +++ b/torchax/test-requirements.txt @@ -2,9 +2,6 @@ absl-py==2.2.2 immutabledict==4.2.1 pytest==8.3.5 -pytest-xdist==3.6.1 -pytest-forked==1.6.0 -sentencepiece==0.2.0 +sentencepiece expecttest==0.3.0 optax==0.2.4 -tensorflow==2.19.0 diff --git a/torchax/test/test_misc.py b/torchax/test/test_misc.py index b93877a7fd64..c6ce5022aaae 100644 --- a/torchax/test/test_misc.py +++ b/torchax/test/test_misc.py @@ -32,7 +32,6 @@ def forward(self, a, b): def test_to_device(self): env = torchax.default_env() - env.config.debug_print_each_op = True with env: step1 = torch.ones( 100, @@ -42,6 +41,19 @@ def test_to_device(self): step3 = step2.to(dtype=torch.bool, device='jax') self.assertEqual(step3.device.type, 'jax') + def test_to_device_twice(self): + env = torchax.default_env() + env.config.debug_print_each_op = True + with env: + step1 = torch.ones( + 100, + 100, + ) + step2 = torch.triu(step1, diagonal=1) + step3 = step2.to(dtype=torch.bool, device='jax') + step3.to('jax') + self.assertEqual(step3.device.type, 'jax') + if __name__ == '__main__': unittest.main() diff --git a/torchax/test/test_tf_integration.py b/torchax/test/test_tf_integration.py deleted file mode 100644 index 35e58a6c5b0f..000000000000 --- a/torchax/test/test_tf_integration.py +++ /dev/null @@ -1,51 +0,0 @@ -import os -import tempfile -import numpy as np -import tensorflow as tf -import torch -import torch.nn.functional as F -import torchax - -from torchax import tf_integration -from . import base_test_util - - -class Interpolate(torch.nn.Module): - - def forward(self, masks: torch.Tensor) -> torch.Tensor: - masks = F.interpolate( - masks, - size=(500, 500), - mode="bilinear", - align_corners=False, - ) - return masks - - -class TfIntegrationTest(base_test_util.TestCase): - - def setUp(self): - torch.manual_seed(0) - torchax.enable_accuracy_mode() - - def test_interpolate(self): - """Simple model roundtripped through TF savedmodel""" - - # Create model - arg = (torch.randn(3, 3, 200, 200),) - pt_model = Interpolate() - - # Export to SavedModel - with tempfile.TemporaryDirectory() as tempdir: - sm_path = os.path.join(tempdir, "interpolate.savedmodel") - tf_integration.save_torch_module_as_tf_saved_model(pt_model, arg, sm_path) - - # Reload SM and compare results with PT results - loaded_model = tf.saved_model.load(sm_path) - pt_res = pt_model(*arg) - tf_res = torch.tensor(loaded_model.f(*arg)[0].numpy()) - self.assertTrue(torch.allclose(pt_res, tf_res, atol=1e-4)) - - -if __name__ == "__main__": - base_test_util.main() diff --git a/torchax/test_dist/test_to_device.py b/torchax/test_dist/test_to_device.py new file mode 100644 index 000000000000..12c0b94bd2c7 --- /dev/null +++ b/torchax/test_dist/test_to_device.py @@ -0,0 +1,26 @@ +import jax +import torch +import torchax +import unittest + +from jax.sharding import NamedSharding, PartitionSpec as P + +class ToDeviceTest(unittest.TestCase): + + def test_to_device_twice(self): + env = torchax.default_env() + mesh = jax.make_mesh((jax.device_count(), ), ('axis', )) + with env: + step1 = torch.ones( + 100, + 100, + ) + step2 = torch.triu(step1, diagonal=1) + step3 = step2.to(dtype=torch.bool, device='jax') + step3.apply_jax_(jax.device_put, NamedSharding(mesh, P())) + print(step3.to('jax')) + self.assertEqual(step3.device.type, 'jax') + + +if __name__ == '__main__': + unittest.main() diff --git a/torchax/torchax/CONTRIBUTING.md b/torchax/torchax/CONTRIBUTING.md index c61462850652..f908cd2e59bb 100644 --- a/torchax/torchax/CONTRIBUTING.md +++ b/torchax/torchax/CONTRIBUTING.md @@ -1,9 +1,7 @@ -# Contributing to TorchXLA2 +# Contributing to torchax We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels. -If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of. - # Developer setup @@ -19,9 +17,17 @@ conda activate pip install --upgrade "jax[cpu]" torch pip install -r test_requirements.txt pip install -e . -pytest test +pip install pytest-xdist # recommended for running test faster +pytest -n auto test ``` +## Setup on GPU or TPU + +Same as Mac setup, except, if you run test using pytest, please also +add `JAX_PLATFORMS=cpu`. The reason is because pytest usually runs +test in multiple threads. CPU device can be accessed concurrently where +TPU devices usually only allow one accesor per process; so it could deadlock. + ### VSCode I use vscode on my Mac. I loosely followed instruction in @@ -35,4 +41,3 @@ The plugins I installed (a subset of the ones listed above) are: I also changed Python interpreter to point at the one in my conda env. That is all the changes I have. - diff --git a/torchax/torchax/tf_integration.py b/torchax/torchax/tf_integration.py deleted file mode 100644 index c9842089bfcf..000000000000 --- a/torchax/torchax/tf_integration.py +++ /dev/null @@ -1,119 +0,0 @@ -# pylint: disable -import os -from typing import Any, Tuple - -from jax.experimental import jax2tf -import tensorflow as tf -import torch -from torchax import export - - -def exported_program_to_tf_function(ep, enable_xla=True): - weights, jax_program = export.exported_program_to_jax(ep) - wrapped = lambda *args: jax_program(weights, (args,)) - avals = export.extract_avals(ep) - input_signature = [ - tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=f"args_{i}") - for i, t in enumerate(avals) - ] - tf_f = tf.function( - jax2tf.convert( - wrapped, - with_gradient=False, - enable_xla=enable_xla, - ), - autograph=False, - input_signature=input_signature, - ) - return tf_f - - -def exported_program_to_tf_module(ep: torch.export.ExportedProgram, - enable_xla=True) -> tf.Module: - tfm = tf.Module() - tfm.f = exported_program_to_tf_function(ep, enable_xla) - return tfm - - -def save_exported_program_as_tf_saved_model( - ep: torch.export.ExportedProgram, - saved_model_dir: os.PathLike, - serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - function_alias: str = "", - enable_xla=True, -): - """This function will export and save a pytorch ExportedProgram to tf.saved_model format. - - The resulting tf.saved_model can be used inference using tf.serving model - server - or further convert to tflite flatbuffer for on-device serving. - - Args: - torch_model: torch.nn.Module - model to export and save - args: Tuple[Any] - a set of args to trace the model with, i.e. - torch_model(*args) must run - saved_model_dir: os.PathLike - location to an empty directory to store the - saved_model - serving_key: str - serving key tag, this is used by tf.serving to know - which function to run. - function_alias: str - passed through saved_model.save, used to tag a - function for inference converter or other tools. - """ - tfm = exported_program_to_tf_module(ep, enable_xla=enable_xla) - signatures = { - serving_key: tfm.f.get_concrete_function(*tfm.f.input_signature) - } - save_options = tf.saved_model.SaveOptions(function_aliases={ - function_alias: tfm.f, - }) - tf.saved_model.save( - tfm, - saved_model_dir, - signatures=signatures, - options=save_options, - ) - - -def save_torch_module_as_tf_saved_model( - torch_model: torch.nn.Module, - args: Tuple[Any], - saved_model_dir: os.PathLike, - serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - function_alias: str = "", - enable_xla=True, -): - """This function will export and save a pytorch nn.Module to tf.saved_model format. - - The resulting tf.saved_model can be used inference using tf.serving model - server - or further convert to tflite flatbuffer for on-device serving. - - Args: - torch_model: torch.nn.Module - model to export and save - args: Tuple[Any] - a set of args to trace the model with, i.e. - torch_model(*args) must run - saved_model_dir: os.PathLike - location to an empty directory to store the - saved_model - serving_key: str - serving key tag, this is used by tf.serving to know - which function to run. - function_alias: str - passed through saved_model.save, used to tag a - function for inference converter or other tools. - """ - ep = torch.export.export(torch_model, args) - save_exported_program_as_tf_saved_model(ep, saved_model_dir, serving_key, - function_alias, enable_xla) - - -def exported_program_to_tflite_flatbuffer(ep: torch.export.ExportedProgram): - tfm = exported_program_to_tf_module(ep) - tf_concrete_func = tfm.f.get_concrete_function(*tfm.f.input_signature) - converter = tf.lite.TFLiteConverter.from_concrete_functions( - [tf_concrete_func], tfm) - tflite_model = converter.convert() - return tflite_model - - -def torch_module_to_tflite_flatbuffer(torch_model: torch.nn.Module, - args: Tuple[Any]): - ep = torch.export.export(torch_model, args) - return exported_program_to_tflite_flatbuffer(ep) From 776cfe08e9791b4721eee8d7927836889f49d46a Mon Sep 17 00:00:00 2001 From: Han Qi Date: Wed, 3 Sep 2025 20:06:12 +0000 Subject: [PATCH 2/3] restore pyproject --- pyproject.toml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 88b68d55cde8..6d76b2a22e99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,8 +46,3 @@ dynamic = [ Homepage = "https://github.com/pytorch/xla" Repository = "https://github.com/pytorch/xla" "Bug Tracker" = "https://github.com/pytorch/xla/issues" - -[tool.uv.workspace] -members = [ - "torchax2", -] From c406bd7cc658bf8d679022eff9030f4c62ca0f83 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Wed, 3 Sep 2025 21:02:21 +0000 Subject: [PATCH 3/3] lint --- torchax/test-requirements.txt | 2 ++ torchax/test/test_misc.py | 22 +++++++++++----------- torchax/test_dist/test_to_device.py | 27 ++++++++++++++------------- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/torchax/test-requirements.txt b/torchax/test-requirements.txt index 854a216a074c..677912bbd04d 100644 --- a/torchax/test-requirements.txt +++ b/torchax/test-requirements.txt @@ -5,3 +5,5 @@ pytest==8.3.5 sentencepiece expecttest==0.3.0 optax==0.2.4 +pytest +pytest-xdist diff --git a/torchax/test/test_misc.py b/torchax/test/test_misc.py index c6ce5022aaae..9214c5b1eac6 100644 --- a/torchax/test/test_misc.py +++ b/torchax/test/test_misc.py @@ -42,17 +42,17 @@ def test_to_device(self): self.assertEqual(step3.device.type, 'jax') def test_to_device_twice(self): - env = torchax.default_env() - env.config.debug_print_each_op = True - with env: - step1 = torch.ones( - 100, - 100, - ) - step2 = torch.triu(step1, diagonal=1) - step3 = step2.to(dtype=torch.bool, device='jax') - step3.to('jax') - self.assertEqual(step3.device.type, 'jax') + env = torchax.default_env() + env.config.debug_print_each_op = True + with env: + step1 = torch.ones( + 100, + 100, + ) + step2 = torch.triu(step1, diagonal=1) + step3 = step2.to(dtype=torch.bool, device='jax') + step3.to('jax') + self.assertEqual(step3.device.type, 'jax') if __name__ == '__main__': diff --git a/torchax/test_dist/test_to_device.py b/torchax/test_dist/test_to_device.py index 12c0b94bd2c7..78794fad704e 100644 --- a/torchax/test_dist/test_to_device.py +++ b/torchax/test_dist/test_to_device.py @@ -5,22 +5,23 @@ from jax.sharding import NamedSharding, PartitionSpec as P + class ToDeviceTest(unittest.TestCase): def test_to_device_twice(self): - env = torchax.default_env() - mesh = jax.make_mesh((jax.device_count(), ), ('axis', )) - with env: - step1 = torch.ones( - 100, - 100, - ) - step2 = torch.triu(step1, diagonal=1) - step3 = step2.to(dtype=torch.bool, device='jax') - step3.apply_jax_(jax.device_put, NamedSharding(mesh, P())) - print(step3.to('jax')) - self.assertEqual(step3.device.type, 'jax') + env = torchax.default_env() + mesh = jax.make_mesh((jax.device_count(),), ('axis',)) + with env: + step1 = torch.ones( + 100, + 100, + ) + step2 = torch.triu(step1, diagonal=1) + step3 = step2.to(dtype=torch.bool, device='jax') + step3.apply_jax_(jax.device_put, NamedSharding(mesh, P())) + print(step3.to('jax')) + self.assertEqual(step3.device.type, 'jax') if __name__ == '__main__': - unittest.main() + unittest.main()