Skip to content

PyTorch/XLA 2.1 Release

Compare
Choose a tag to compare
@ManfeiBai ManfeiBai released this 07 Sep 16:14
· 30 commits to r2.1 since this release

Cloud TPUs now support the PyTorch 2.1 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.1 release, this release introduces several features, and PyTorch/XLA specific bug fixes.

PJRT is now PyTorch/XLA's officially supported runtime! PJRT brings improved performance, superior usability, and broader device support. PyTorch/XLA r2.1 will be the last release with XRT available as a legacy runtime. Our main release build will not include XRT, but it will be available in a separate package. In most cases, we expect the migration to PJRT to require minimal changes. For more information, see our PJRT documentation.

GSPMD support has been added as an experimental feature to the PyTorch/XLA 2.1 release. GSPMD will transform the single device program into a partitioned one with proper collectives, based on the user provided sharding hints. This feature allows developers to write PyTorch programs as if they are on a single large device without any custom sharded computation ops and/or collective communications to scale. We published a blog post explaining the technical details and expected usage, you can also find more detail in this user guide.

PyTorch/XLA has transitioned from depending on TensorFlow to depending on the new OpenXLA repo. This allows us to reduce our binary size and simplify our build system. Starting from 2.1, PyTorch/XLA will release our TPU whl on the pypi.

To install PyTorch/XLA 2.1.0 wheels, please find the installation instructions below.

Installing PyTorch and PyTorch/XLA 2.1.0 wheel:

pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html

Please note that you might have to re-install the libtpu on your TPUVM depending on your previous installation:

pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

Stable Features

OpenXLA

  • Migrate to pull XLA from TensorFlow to OpenXLA, TF pin dependency sunset (#5202)
  • Instructions to build PyTorch/XLA with OpenXLA can be found in this doc.

PjRt Runtime

  • Move PJRT APIs from experimental to torch_xla.runtime (#5011)
  • Enable PJRT C API Client and other changes for Neuron (#5428)
  • Enable PJRT C API Client for Intel XPU (#4891)
  • Change pjrt:// init method to xla:// (#5560)
  • Make TPU detection more robust (#5271)
  • Add runtime.host_index (#5283)

Functionalization

  • Functionalization integration (#4158)
  • Add support for XLA_DISABLE_FUNCTIONALIZATION flag (#4792)

Improvements and additions

  • Op Lowering
    • squeeze_copy.dims (#5286)
    • native_dropout (#5643)
    • native_dropout_backward (#5642)
    • count_nonzero (#5137)
  • Build System
    • Migrate the build system to Bazel (#4528)

Beta Features

AMP (Automatic MIxed Precision)

  • Added bfloat16 support on TPUs. (#5161)
  • Documentation can be found in amp.md

TorchDynamo

  • Support CPU egaer fallback in Dynamo bridge (#5000)
  • Support torch.compile with SPMD for inference (#5002)
  • Update the dynamo backend name to openxla and openxla_eval (#5402)
  • Inference optimization for SPMD inference + torch.compile (#5447, #5446)

Traceable Collectives

  • Adopts traceable all_reduce (#4915)
  • Make xm.all_gather a single graph in Dynamo (#4922)

Experimental Features

GSPMD

  • Add SPMD user guide
  • Enable Input-output aliasing (#5320)
  • Introduce global_runtime_device_count to query the runtime device count (#5129)
  • Support partial replication (#5411 )
  • Support tuple partition spec (#5488)
  • Support mark_sharding on IRs (#5301)
  • Make IR sharding custom sharding op (#5433)
  • Introduce Hybrid Device mesh creation (#5147)
  • Introduce SPMD-friendly patched nn.Linear (#5491)
  • Allow dumping post optimizations HLO (#5302)
  • Allow sharding n-d tensor on (n+1)-d Mesh (#5268)
  • Support synchronous distributed checkpointing (#5130, #5170)

Serving Support

StableHLO

  • Add StableHLO user guide (#5523)
  • Add save_as_stablehlo and save_torch_model_as_stablehlo APIs (#5493)
  • Make StableHLO executable (#5476)

Ongoing Development

TorchDynamo

  • Enable single step graph for training
  • Avoid inter-graph reshapes from aot_autograd
  • Support GSPMD for activation checkpointing

GSPMD

  • Support auto-sharding
  • Benchmark and improving GSPMD for XLA:GPU
  • Integrating to PyTorch’s Distributed Tensor API

GPU

  • Support Multi-host GPU for PJRT runtime
  • Improve performance on torchbench models

Quantization

  • Support PyTorch PT2E quantization workflow

Bug Fixes and Improvements

  • Fix unexpected Dynamo crash due to clear_pending_ir call(#5582)
  • Fix FSDP for Models with Frozen Weights (#5484)
  • Fix data type in Pow with Scalar base and Tensor exponent (#5467)
  • Fix the inplace op crash when applied on self tensors in dynamo (#5309)