Skip to content

Commit b9d567a

Browse files
wbmcmochen.bmc
andauthored
fix: missing import numpy (#5533)
* missing * fix typo --------- Co-authored-by: mochen.bmc <mochen.bmc@antgroup.com>
1 parent e51d28b commit b9d567a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

docs/spmd.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Also, this version of the SPMD is currently only tested.optimized on Google Clou
3131
## PyTorch/XLA SPMD Design Overview
3232

3333

34-
### Simple Eexample & Sharding Aannotation API
34+
### Simple Example & Sharding Aannotation API
3535

3636
Users can annotate native PyTorch tensors using the `mark_sharding` API ([src](https://github.com/pytorch/xla/blob/9a5fdf3920c18275cf7dba785193636f1b39ced9/torch_xla/experimental/xla_sharding.py#L388)). This takes `torch.Tensor` as input and returns a `XLAShardedTensor` as output.
3737

@@ -42,6 +42,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, partitio
4242
Invoking `mark_sharding` API takes a user defined logical [mesh](#mesh) and [partition\_spec](#partition-spec) and generates a sharding annotation for the XLA compiler. The sharding spec is attached to the XLATensor. Here is a simple usage example from the [[RFC](https://github.com/pytorch/xla/issues/3871), to illustrate how the sharding annotation API works:
4343

4444
```python
45+
import numpy as np
4546
import torch
4647
import torch_xla.core.xla_model as xm
4748
import torch_xla.runtime as xr

0 commit comments

Comments
 (0)