# Tensor Parallelism
---

This notebook focuses on the need for tensor parallel and explains key concepts that include `Column-wise parallel,` `Row-wise parallel,` and `Combined Column-wise and Row-wise Parallel.` It further gave an overview of `Device Mesh` and demonstrated how to implement Tensor/Sequence parallel with Fully Sharded Data Parallel.

Tensor parallelism (TP) is a training technique used for large models. TP distributes layers across multiple devices and reduces inter-device communication to improve memory management and efficiency. Tensor Parallelism splits tensor into N chunks along a particular dimension such that each device only holds 1/N chunk of the tensor. Computation is performed using this partial chunk to get partial output. These partial outputs are collected from all devices ensuring the correctness of the computation is maintained. The technique was initially proposed in the [Megatron-LM paper](https://arxiv.org/abs/1909.08053) as an efficient model parallelism approach to train large-scale Transformer models. However, when the model becomes larger, the activation memory becomes the bottleneck. Therefore, TP applies a Sequence Parallel (SP) strategy (a parallel strategy that partitions along the sequence dimension) to the LayerNorm or RMSNorm layers.


#### Why Use Tensor Parallel

With Pytorch, Fully Sharded Data Parallel (FSDP) has a scaling limit on the number of GPUs used in training a model. Further scaling attempts result in challenges that require solving by combining Tensor Parallel with FSDP. For example:

- As the number of GPUs becomes large (exceeding 128/256 GPUs), the FSDP all-gather operation gets dominated by ring latency. Combining FSDP and TP in a way that only the FSDP is inter-host could reduce the world size by 8 and decrease the latency costs.
- When data parallelism is faced with both convergence and GPU memory limitations, applying TP/SP could ballpark the global batch size and enable scaling with more GPUs
- When the local batch size becomes smaller for some models, TP/SP can yield matrix multiplication shapes that are more optimized for floating point operations (FLOPS)


In tensor parallelism, the computation of a linear layer can be split up across GPUs. This saves memory because each GPU only needs to hold a portion of the weight matrix. Pytorch has a set of Parallel styles (ParallelStyle) to configure sharding for Model layers. This includes `ColwiseParallel,` `RowwiseParallel,` `SequenceParallel,` `PrepareModuleInput,` and `PrepareModuleOutput.`

### Column-wise Parallel

In a column-wise parallel layer, the weight matrix is split evenly along the column dimension. Each GPU is sent the same input and computes a regular matrix multiplication with its portion of the weight matrix. At the end, the outputs from each GPU can be concatenated to form the final output. For example, let `A = XB`. We can split B along the column dimension into `(B0 B1 B2 … Bn)`. Each device holds a column for matrix multiplication and possesses partial results( e.g., device rank 0 holds XB0). To ensure a correct result, the `all-gather` operation is performed on the partial result to concatenate the tensor along the column dimension.

<center><img src="images/columnwise.png" width="550px" height="550px" alt-text="fsdp workflow"/></center>
<center> <a href="https://lightning.ai/docs/pytorch/stable/_images/tp-colwise.jpeg">image source</a></center>

### Row-wise Parallel

This form of parallelism evenly divides weight matric across available devices. Because the weight matrix now has fewer rows, it uses the same approach to split the input along its dimension(column), as shown in the screenshot below. Each GPU then performs matrix multiplication using its portion of the weight matrix and inputs. All-reduce operation is performed on the outputs from each GPU to form the final output. For example, let `A = XB`. We can split B along its row into `(B0 B1 B2 … Bn)`. Each device holds a row for matrix multiplication with the input, resulting in a partial result (e.g., device rank 0 holds X0B0). To ensure a correct result, an `all-reduce` sum is performed on the partial results to produce the final output.

<center><img src="images/rowwise.png" width="550px" height="550px" alt-text="row"/></center>
<center><a href="https://lightning.ai/docs/pytorch/stable/_images/tp-rowwise.jpeg">image source</a></center>

### Combined Column-wise and Row-wise Parallel

The column-wise and row-wise parallel styles can be combined for maximum effect on multiple linear layers in sequence such as we have in an MLP or a Transformer. The output of the column-wise parallel layer is kept separate and is fed directly to the row-wise parallel layer to avoid costly data transfers between GPUs. A typical illustration of using Columnwise Parallel and Row-wise Parallel linear layers can be found in the Megatron-LM paper, titled, *[Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)*.

<center><img src="images/rowcolumnwise.png" width="550px" height="550px" alt-text="row"/></center>
<center><a href="https://lightning.ai/docs/pytorch/stable/_images/tp-combined.jpeg">image source</a></center>
<i>Note that activation functions between the layers can still be applied without additional communication because they are element-wise, but are not shown in the screenshot for simplicity. </i> <br/>


Tensor Parallel shard individual tensors over a set of devices in a distributed environment (such as NCCL communicators or Gloo). Tensor Parallelism is a `Single-Program Multiple-Data (SPMD)` sharding algorithm that leverages the PyTorch `DTensor` to perform `sharding.` It also utilizes the `DeviceMesh` abstraction (manages ProcessGroups under the hood) for device management and sharding. Without `DeviceMesh,` users would need to manually set up NCCL communicators and CUDA devices on each process before applying any parallelism.

### Brief Overview of Device Mesh

[DeviceMesh](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html) is a higher-level abstraction that manages `ProcessGroup.` It allows users to effortlessly create inter-node and intra-node process groups without worrying about how to set up ranks correctly for different sub-process groups. Users can also easily manage the underlying process_groups/devices for multi-dimensional parallelism via `DeviceMesh.` The screenshot below shows that a 2D mesh can be created to connect devices within each host and connect each device with its counterpart on the other hosts in a homogenous setup.

<center><img src="images/device_mesh.png" width="550px" height="550px" alt-text="device-mesh"/></center>
<center><a href="https://pytorch.org/tutorials/_images/device_mesh.png">image source</a></center>

Tensor Parallel usually works within each host. We can initialize a DeviceMesh that connects 8 GPUs within a host using `init_device_mesh()` as follows:

```python
from torch.distributed.device_mesh import init_device_mesh

tp_mesh = init_device_mesh("cuda", (8,))

```

A 2D setup and access to the underlying `ProcessGroup` can be created using a sample code is given below.

```python

from torch.distributed.device_mesh import init_device_mesh
mesh_2d = init_device_mesh("cuda", (2, 4), mesh_dim_names=("replicate", "shard"))

# Users can access the underlying process group through the `get_group` API.
replicate_group = mesh_2d.get_group(mesh_dim="replicate")
shard_group = mesh_2d.get_group(mesh_dim="shard")

```
Checkout for detailed documentation on DeviceMesh and its use for custom parallel solutions [here](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html#getting-started-with-devicemesh).



### Tensor Parallel Implementation with An MLP Model

This section tests Tensor Parallel(TP) implementation with a toy MLP model in a Megetron-LM Single-Program Multiple-Data (SPMD) style. We show an end-to-end working flow from forward, backward, and optimization. The sample code includes two `nn.Linear` layers with an element-wise `nn.RELU.` The basic idea is that the first linear layer is parallelized column-wise while the second linear layer is parallelized row-wise, so only one `all-reduce` is done at the end of the second linear layer. This way, communications between two layers are avoided, and the model training is sped up.

**Steps**
- Define the Model layer class

```python
class ToyModel(nn.Module):
    """MLP based model"""

    def __init__(self):
        super(ToyModel, self).__init__()
        self.in_proj = nn.Linear(10, 32)
        self.relu = nn.ReLU()
        self.out_proj = nn.Linear(32, 5)

    def forward(self, x):
        return self.out_proj(self.relu(self.in_proj(x)))
```

- Create a device mesh based on the given world_size
  
```python
_world_size = int(os.environ["WORLD_SIZE"])

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()
```
-  Create a custom parallelization plan for the model

```python
...
tp_model = parallelize_module(
    module=tp_model,
    device_mesh=device_mesh,
    parallelize_plan={
        "in_proj": ColwiseParallel(),
        "out_proj": RowwiseParallel(),
    },
)
...
```
Please find the complete code [here](../source_code/tensor_parallel_example.py). You can run the Tensor parallelism sample code using 4 GPUs (within a node) in the cell below. 

In [None]:
!cd ../source_code && srun -p gpu -N 1 --gres=gpu:4 torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=101 --rdzv_endpoint="localhost:5972" tensor_parallel_example.py

**Likely Output:**

```python
...
Starting PyTorch TP example on rank 0.
02/22/2025 07:49:00 AM  Device Mesh created: device_mesh=DeviceMesh('cuda', [0, 1, 2, 3])
Starting PyTorch TP example on rank 2.
Starting PyTorch TP example on rank 3.
Starting PyTorch TP example on rank 1.
02/22/2025 07:49:09 AM  Tensor Parallel training starting...
02/22/2025 07:49:09 AM  Tensor Parallel iter 0 completed
02/22/2025 07:49:10 AM  Tensor Parallel iter 1 completed
02/22/2025 07:49:10 AM  Tensor Parallel iter 2 completed
02/22/2025 07:49:10 AM  Tensor Parallel iter 3 completed
02/22/2025 07:49:10 AM  Tensor Parallel iter 4 completed
02/22/2025 07:49:10 AM  Tensor Parallel iter 5 completed
02/22/2025 07:49:10 AM  Tensor Parallel iter 6 completed
02/22/2025 07:49:10 AM  Tensor Parallel iter 7 completed
02/22/2025 07:49:10 AM  Tensor Parallel iter 8 completed
02/22/2025 07:49:10 AM  Tensor Parallel iter 9 completed
02/22/2025 07:49:10 AM  Tensor Parallel training completed!

```

### Tensor/Sequence parallel with Fully Sharded Data Parallel (TP/SP + FSDP)

This section describes the implementation of 2D Parallel, which combines Tensor/Sequence parallel with Fully Sharded Data Parallel (TP/SP + FSDP) using a [Llama2 model]. It further describes the working flow of the forward and backward passes. In the implementation, Fully Sharded Data Parallel + Tensor Parallel are enabled in separate parallel dimensions: `Data Parallel ("dp") across hosts` and `Tensor Parallel ("tp") within each host.`

**Illustration steps:**

- Initialize the world size and rank topology
```python
...
_rank = int(os.environ["RANK"])
_world_size = int(os.environ["WORLD_SIZE"])
...
```
- Create a sharding plan based on the given world_size
```python
dp_size = _world_size // tp_size
```
- Create a device mesh with 2 dimensions (data parallel dimension, tensor parallel dimension).

```python
...
device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp"))
rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")
tp_mesh = device_mesh["tp"]
dp_mesh = device_mesh["dp"]

...
```
- For the Tensors Parallel, input must be the same across all TP ranks, while for SP, input can be different across all ranks. Let's use dp_rank to set the random seed and mimic the behavior of the dataloader.

```python
dp_rank = dp_mesh.get_local_rank()

```
- Instantiate the llama model and move it to GPU.

```python 
simple_llama2_config = ModelArgs(dim=256, n_layers=2, n_heads=16, vocab_size=32000)

model = Transformer.from_model_args(simple_llama2_config).to("cuda")
...
```

- For each transformer block in the llama model, apply sequence parallel (SequenceParallel) to the `normalization_norm` and `ffn_norm` layers. To the feed_forward (w1,w3) and attention (wq,wk,wv), apply column-wise parallel, while to both (same for w2 and w3).

```python

for layer_id, transformer_block in enumerate(model.layers):
    layer_tp_plan = {
        "attention_norm": SequenceParallel(),
        "attention": PrepareModuleInput(
            input_layouts=(Shard(1), None),
            desired_input_layouts=(Replicate(), None),
        ),
        "attention.wq": ColwiseParallel(),
         ...

        "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
        "ffn_norm": SequenceParallel(),
        "feed_forward": PrepareModuleInput(
            input_layouts=(Shard(1),),
            desired_input_layouts=(Replicate(),),
        ),
        "feed_forward.w1": ColwiseParallel(),
        ...
    }
```

- Initialize a custom parallelization plan for the model
  
```python
    parallelize_module(
        module=transformer_block,
        device_mesh=tp_mesh,
        parallelize_plan=layer_tp_plan
    )
```
- Initialize FSDP using the dp device mesh

```python
sharded_model = FSDP(model, device_mesh=dp_mesh, use_orig_params=True)
```
- Run a training loop to perform iterations for the forward and backward passes

```python

for i in range(num_iterations):
    ...
    output = sharded_model(inp)
    output.sum().backward()
    optimizer.step()
    ...
```

Please find the complete code [here](../source_code/fsdp_tp_example.py). Let's execute the TP/SP + FSDP sample code with 4 GPUs (in a node) by running the cell below. 

In [None]:
!cd ../source_code && srun -p gpu -N 1 --gres=gpu:4 torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=101 --rdzv_endpoint="localhost:5972" fsdp_tp_example.py

**Likely Output using 4 GPUs:**

```python
...
Starting PyTorch 2D (FSDP + TP) example on rank 0.
Starting PyTorch 2D (FSDP + TP) example on rank 3.Starting PyTorch 2D (FSDP + TP) example on rank 2.

Starting PyTorch 2D (FSDP + TP) example on rank 1.
02/22/2025 09:45:20 AM  Device Mesh created: device_mesh=DeviceMesh('cuda', [[0, 1], [2, 3]], mesh_dim_names=('dp', 'tp'))
02/22/2025 09:45:26 AM  Model after parallelization sharded_model=FullyShardedDataParallel(
  (_fsdp_wrapped_module): Transformer(
    (tok_embeddings): Embedding(32000, 256)
    (layers): ModuleList(
      (0-1): 2 x TransformerBlock(
        (attention): Attention(
          (wq): Linear(in_features=256, out_features=256, bias=False)
          (wk): Linear(in_features=256, out_features=256, bias=False)
          (wv): Linear(in_features=256, out_features=256, bias=False)
          (wo): Linear(in_features=256, out_features=256, bias=False)
        )
        (feed_forward): FeedForward(
          (w1): Linear(in_features=256, out_features=768, bias=False)
          (w2): Linear(in_features=768, out_features=256, bias=False)
          (w3): Linear(in_features=256, out_features=768, bias=False)
        )
 ...     

02/22/2025 09:45:28 AM  2D iter 7 complete
02/22/2025 09:45:28 AM  2D iter 8 complete
02/22/2025 09:45:28 AM  2D iter 9 complete
02/22/2025 09:45:28 AM  2D training successfully completed!

```
The key concepts of this notebook were adopted from the [Pytorch Lightning](https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/tp.html) documentation page and the codes from [Pytorch GitHub](https://github.com/pytorch/examples/tree/main/distributed/tensor_parallelism). Let's proceed to the next notebook and learn about `Message Passing and Mixed Precision.` Please click the [Next Link](other-topics.ipynb).

---

## References

- https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/tensor_parallelism_overview.html
- https://pytorch.org/tutorials/intermediate/TP_tutorial.html
- https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/tp.html
- https://github.com/pytorch/examples/tree/main/distributed/tensor_parallelism


## Licensing 

Copyright © 2025 OpenACC-Standard.org. This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). These materials include references to hardware and software developed by other entities; all applicable licensing and copyrights apply.
