You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/gpu.md
+19-7Lines changed: 19 additions & 7 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -11,14 +11,14 @@ You can either use a local machine with GPU attached or a GPU VM on the cloud. F
11
11
### Docker
12
12
Pytorch/XLA currently publish prebuilt docker images and wheels with cuda11.7/8 and python 3.8. We recommend users to create a docker container with corresponding config. For a full list of docker images and wheels, please refer to [this doc](https://github.com/pytorch/xla#available-docker-images-and-wheels).
Make sure `PATH` and `LD_LIBRARY_PATH` environment variables account for cuda. Please do a `echo $PATH` and `echo $LD_LIBRARY_PATH` to verify. If not, please follow [link](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#mandatory-actions) to do so. Example:
Make sure `PATH` and `LD_LIBRARY_PATH` environment variables account for cuda. See the [above](https://github.com/pytorch/xla/blob/master/docs/gpu.md#check-environment-variable) for more info.
Copy file name to clipboardExpand all lines: docs/spmd.md
+7-7Lines changed: 7 additions & 7 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -33,7 +33,7 @@ Also, this version of the SPMD is currently only tested.optimized on Google Clou
33
33
34
34
### Simple Example & Sharding Aannotation API
35
35
36
-
Users can annotate native PyTorch tensors using the `mark_sharding` API ([src](https://github.com/pytorch/xla/blob/9a5fdf3920c18275cf7dba785193636f1b39ced9/torch_xla/experimental/xla_sharding.py#L388)). This takes `torch.Tensor` as input and returns a `XLAShardedTensor` as output.
36
+
Users can annotate native PyTorch tensors using the `mark_sharding` API ([src](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharding.py#L452)). This takes `torch.Tensor` as input and returns a `XLAShardedTensor` as output.
We abstract logical mesh with [Mesh API](https://github.com/pytorch/xla/blob/028df4da388468fa9a41b1f98ea08bfce13b4c63/torch_xla/experimental/xla_sharding.py#L16). The axes of the logical Mesh can be named. Here is an example:
103
+
We abstract logical mesh with [Mesh API](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharding.py#L17). The axes of the logical Mesh can be named. Here is an example:
104
104
105
105
```python
106
106
import torch_xla.runtime as xr
107
-
from torch_xla.experimental.xla_shardingimport Mesh
107
+
from torch_xla.distributed.spmdimport Mesh
108
108
109
109
# Assuming you are running on a TPU host that has 8 devices attached
110
110
num_devices= xr.global_runtime_device_count()
@@ -130,7 +130,7 @@ In general, SPMD programs should create a single mesh and reuse it for all shard
130
130
Mesh nicely abstracts how the physical device mesh is constructed. Users can arrange devices inany shape and order using the logical mesh. However, one can define a more performant mesh based on the physical topology, especially when it involves Data Center Network (DCN) cross slice connections. HybridMesh creates a mesh which gives good performance out of the box for such multislice environments. It accepts ici\_mesh\_shape and dcn\_mesh\_shape which denote logical mesh shapes of inner and outer network.
131
131
132
132
```python
133
-
from torch_xla.experimental.xla_shardingimport HybridMesh
133
+
from torch_xla.distributed.spmdimport HybridMesh
134
134
135
135
# This example is assuming 2 slices of v4-8.
136
136
# - ici_mesh_shape: shape of the logical mesh for inner connected devices.
@@ -198,7 +198,7 @@ The main use case for `XLAShardedTensor` [[RFC](https://github.com/pytorch/xla/i
198
198
*`XLAShardedTensor`is a `torch.Tensor` subclass and works directly with native torch ops and`module.layers`. We use `__torch_dispatch__` to send `XLAShardedTensor` to the XLA backend. PyTorch/XLA retrieves attached sharding annotations to trace the graph and invokes XLA SPMDPartitioner.
199
199
* Internally, `XLAShardedTensor` (and its global\_tensor input) is backed by `XLATensor` with a special data structure holding references to the sharded device data.
200
200
* The sharded tensor after lazy execution may be gathered and materialized back to the host asglobal\_tensor when requested on the host (e.g., printing the value of the global tensor.
201
-
* The handles to the local shards are materialized strictly after the lazy execution. `XLAShardedTensor` exposes [local\_shards](https://github.com/pytorch/xla/blob/909f28fa4c1a44efcd21051557b3bcf2d399620d/torch_xla/experimental/xla_sharded_tensor.py#L111) to return the local shards on addressable devices as <code>List[[XLAShard](https://github.com/pytorch/xla/blob/909f28fa4c1a44efcd21051557b3bcf2d399620d/torch_xla/experimental/xla_sharded_tensor.py#L12)]</code>.
201
+
* The handles to the local shards are materialized strictly after the lazy execution. `XLAShardedTensor` exposes [local\_shards](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L117) to return the local shards on addressable devices as <code>List[[XLAShard](https://github.com/pytorch/xla/blob/4e8e5511555073ce8b6d1a436bf808c9333dcac6/torch_xla/distributed/spmd/xla_sharded_tensor.py#L12)]</code>.
202
202
203
203
There is also an ongoing effort to integrate <code>XLAShardedTensor</code> into <code>DistributedTensor</code>API to support XLA backend [[RFC](https://github.com/pytorch/pytorch/issues/92909)].
0 commit comments