Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tp] improve documentation #115880

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/distributed.tensor.parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Tensor Parallelism supports the following parallel styles:
To simply configure the nn.Module's inputs and outputs with DTensor layouts
and perform necessary layout redistributions, without distribute the module
parameters to DTensors, the following classes can be used in
the ``parallelize_plan``of ``parallelize_module``:
the ``parallelize_plan`` of ``parallelize_module``:

.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleInput
:members:
Expand Down
18 changes: 9 additions & 9 deletions torch/distributed/tensor/parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def parallelize_module( # type: ignore[return]
to be parallelized.

User can also specify different parallel style per module fully qualified name (FQN).
The API supports 2D parallelism natively by accepting an n-dimension device_mesh
and users just need to specify the dimension where we perform tensor parallelism on.

Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,
slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)

Args:
module (:class:`nn.Module`):
Expand All @@ -50,9 +51,10 @@ def parallelize_module( # type: ignore[return]
:class:`ParallelStyle` object which contains how
we prepare input/output for Tensor Parallelism or it can be a
dict of module FQN and its corresponding :class:`ParallelStyle` object.
tp_mesh_dim (int):
tp_mesh_dim (int, deprecated):
The dimension of ``device_mesh`` where we perform
Tensor Parallelism on.
Tensor Parallelism on, this field is deprecated and will be removed in future.
If you have a 2-D or N-D :class:`DeviceMesh`, consider passing in device_mesh[\"tp\"]

Return:
A :class:`nn.Module` object parallelized.
Expand All @@ -66,11 +68,9 @@ def parallelize_module( # type: ignore[return]
>>> m = parallelize_module(m, ColwiseParallel())
>>>

.. warning::
Currently, there are some constraints which makes it hard for complicated modules
like ``MultiheadAttention`` to work out of box for Tensor or Sequence Parallelism.
We recommend users to try ``ColwiseParallel`` and ``RowwiseParallel`` for each parameter
or submodule and there might be some code changes needed now.
.. note:: For complex module architecture like Attention, MLP layers, we recommend composing
different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass
as a parallelize_plan, to achieves the desired sharding computation.
"""
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")

Expand Down
34 changes: 19 additions & 15 deletions torch/distributed/tensor/parallel/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ class ColwiseParallel(ParallelStyle):
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
with the user desired layout. If not specified, the output tensor is sharded on the last dimension.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of `DTensor` for the module output, default: True.
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
Returns:
A ParallelStyle object that represents Colwise sharding of the nn.Module.
A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module.

Example::
>>> # xdoctest: +SKIP(failing)
Expand All @@ -60,6 +60,10 @@ class ColwiseParallel(ParallelStyle):
>>> parallelize_plan={"w1": ColwiseParallel()},
>>> )
>>> ...

... note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not
specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``),
keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size.
"""

def __init__(
Expand All @@ -68,7 +72,7 @@ def __init__(
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True
) -> None:
):
super().__init__()
self.input_layouts = (input_layouts or Replicate(), )
self.output_layouts = (output_layouts or Shard(-1), )
Expand Down Expand Up @@ -146,9 +150,9 @@ class RowwiseParallel(ParallelStyle):
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
with the user desired layout. If not specified, the output tensor is replicated.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of `DTensor` for the module output, default: True.
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
Returns:
A ParallelStyle object that represents Rowwise sharding of the nn.Module.
A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module.

Example::
>>> # xdoctest: +SKIP(failing)
Expand All @@ -171,7 +175,7 @@ def __init__(
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True
) -> None:
):
super().__init__()
self.input_layouts = (input_layouts or Shard(-1), )
self.output_layouts = (output_layouts or Replicate(), )
Expand Down Expand Up @@ -225,20 +229,20 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
class PrepareModuleInput(ParallelStyle):
"""
Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to
input_layouts, and perform layout redistribution according to the desired_input_layouts.
``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``.

Keyword Args:
input_layouts (Union[Placement, Tuple[Placement]]):
The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to
DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, `None` need to be specified
DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified
as a placeholder.
desired_input_layouts (Union[Placement, Tuple[Placement]]):
The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module
have the desired DTensor layouts. This argument needs to have the same length with `input_layouts`.
have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of `DTensor` for the module inputs, default: False.
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False.
Returns:
A ParallelStyle object that prepares the sharding layouts of the nn.Module's inputs.
A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs.

Example::
>>> # xdoctest: +SKIP(failing)
Expand Down Expand Up @@ -298,18 +302,18 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
class PrepareModuleOutput(ParallelStyle):
"""
Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to
output_layouts, and perform layout redistribution according to the desired_output_layouts.
``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``.

Keyword Args:
output_layouts (Union[Placement, Tuple[Placement]]):
The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to
DTensors if they are `torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors,
`None` need to be specified as a placeholder.
DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors,
``None`` need to be specified as a placeholder.
desired_output_layouts (Union[Placement, Tuple[Placement]]):
The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module
have the desired DTensor layouts.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of `DTensor` for the module outputs, default: False.
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: False.
Returns:
A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs.

Expand Down
Loading