# Fully Sharded Data Parallel(FSDP)
---

The goal of this notebook is to walk you through the implementation of the training strategy called Fully Sharded Data Parallel(FSDP) and how it improves data parallelism. This content is adapted from [Pytorch webpage](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) for learning purposes.

In distributed data parallelism (DDP) training, each worker or device has a replica of the model parameters and optimizer to train a batch of data. After the training, the backend (`NCCL or Gloo`) applies an all-reduce operation, to sum up gradients over different workers or devices and update them.  [Fully sharded Data Parallel (FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) is a type of data parallelism that shards model parameters, optimizer states, and gradients across DDP ranks. During training, the FSDP GPU memory footprint is smaller than DDP, making training very large models possible but at the cost of communication overheads. However, through overlapping communication techniques, the overhead can be significantly minimized.

#### FSDP Workflow

FSDP includes three major parts, namely `In constructor,` `In forward path,` and `In backward path.` 

- **In constructor**: Shard model parameters and each rank or device keeps only its shard
- **In forward path**: This unit recovers the full parameters by executing the all-gather operation to collect shards from ranks or devices. It runs the forward pass and discards the parameters of the shards it collected.
- **In backward path**: Performs a task similar to the `In forward path` except for `forward pass` but executes the backward computation and `reduce-scatter` operation to sync gradients.

<center><img src="images/fsdp-arc.png" width="550px" height="550px" alt-text="fsdp workflow"/></center>
<center> FSDP Workflow <a href="https://pytorch.org/tutorials/_images/fsdp_workflow.png" >[view image source]</a> </center>

#### FSDP Sharding Process

To fully understand how the FSDP sharding process works, it is crucial to decompose the `all-reduce` operation in DDP into `reduce-scatter` and `all-gather` operations as shown in the screenshot below. During the backward pass, each rank or device possesses a shard of the gradients through the reduce-scatter operation. It proceeds to update the corresponding shard of the parameters in the optimizer step. FSDP executes an `all-gather` operation in the next forward pass to gather and combine the updated parameter shards.

<center><img src="images/fsdp-allreduce.png" width="550px" height="550px" alt-text="fsdp workflow"/></center>
<center> FSDP All Reduce <a href="https://pytorch.org/tutorials/_images/fsdp_sharding.png" >[view image source]</a> </center>


#### FSDP Implementation 

Let's use a toy model to demonstrate the process using the MNIST dataset. You can also apply the steps highlighted below to train larger models.

**Steps**:

- Initialize the group process distributed training and clean-up functions.

```python

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()
```
- Model definition. For example, we define a toy model for handwritten digit classification.

```python
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        ...

    def forward(self, x):

        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        ...
        return output
```
- Define train and test functions, calculate the loss, and perform all-reduce opration within them

```python
def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    model.train()
    ddp_loss = torch.zeros(2).to(rank)
    ...
    for batch_idx, (data, target) in enumerate(train_loader):
        ...
        loss.backward()
        optimizer.step()
        ddp_loss[0] += loss.item()
        ddp_loss[1] += len(data)

    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
    ...

def test(model, rank, world_size, test_loader):
   ...
        for data, target in test_loader:
            ...
            ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
            ddp_loss[2] += len(data)

    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)

    if rank == 0:
        test_loss = ddp_loss[0] / ddp_loss[2]
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, int(ddp_loss[1]), int(ddp_loss[2]),
            100. * ddp_loss[1] / ddp_loss[2]))
```
- The last function would be a train function that wraps the model in FSDP.

```python

def fsdp_main(rank, world_size, args):
    setup(rank, world_size)
    ...
    sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
    sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)
    ...
    
    my_auto_wrap_policy = functools.partial(
        size_based_auto_wrap_policy, min_num_params=100
    )
    torch.cuda.set_device(rank)
    ...
    model = Net().to(rank)
    model = FSDP(model)

    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    init_start_event.record()
    for epoch in range(1, args.epochs + 1):
        train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        test(model, rank, world_size, test_loader)
        scheduler.step()

    ...

    cleanup()
```

You can open the complete FSDP code [from here](../source_code/test_fsdp.py). Please run the cell below to see the output.

In [None]:
!cd ../source_code && srun -p gpu -N 1 --gres=gpu:4 python test_fsdp.py  

**Likely Output:**

```python
...

Train Epoch: 10 	Loss: 0.026925
Test set: Average loss: 0.0272, Accuracy: 9916/10000 (99.16%)

CUDA event elapsed time: 80.868640625sec
FullyShardedDataParallel(
  (_fsdp_wrapped_module): Net(
    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (dropout1): Dropout(p=0.25, inplace=False)
    (dropout2): Dropout(p=0.5, inplace=False)
    (fc1): Linear(in_features=9216, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=10, bias=True)
  )
)

```
The output clearly shows that the model is wrapped in one FSDP unit. Consequently, it reduces computation efficiency and memory efficiency because there is only one blocking `all-gather` call for all 100 linear layers; hence, communication and computation overlap between layers is lacking. To remedy the situation, we increase `min_num_params` for the based size of auto wrap policy, define the `auto_wrap_policy,` and pass it to the FSDP wrapper as shown below.

```python
my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=20000 )
torch.cuda.set_device(rank)
model = Net().to(rank)
model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy)
```

Todo: In the [FSDP code](../source_code/test_fsdp.py), comment link lines `#134 and #144`. Next, uncomment lines `#136 and #145`. Please run the cell below to see the output

In [None]:
!cd ../source_code && srun -p gpu -N 1 --gres=gpu:4 python test_fsdp.py

**Likely output**:
```python
...
Train Epoch: 10 	Loss: 0.023942
Test set: Average loss: 0.0255, Accuracy: 9922/10000 (99.22%)

CUDA event elapsed time: 79.628578125sec
FullyShardedDataParallel(
  (_fsdp_wrapped_module): Net(
    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (dropout1): Dropout(p=0.25, inplace=False)
    (dropout2): Dropout(p=0.5, inplace=False)
    (fc1): FullyShardedDataParallel(
      (_fsdp_wrapped_module): Linear(in_features=9216, out_features=128, bias=True)
    )
    (fc2): Linear(in_features=128, out_features=10, bias=True)
  )
)
```
Now, we have the model wrapped in two FSDP units. Profiling this code, you will find out that the FSDP Peak Memory usage has been reduced using the Auto_wrap policy. `CPU offloading` can be used if you have a very large model that won't fit into GPUs with FSDP. Currently, PyTorch only supports parameter and gradient CPU offload. It can be enabled by specifying this line within the FSDP wrapper: `cpu_offload= CPUOffload(offload_params=True).`

```python
model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy, cpu_offload=CPUOffload(offload_params=True))

```

Through the content of this notebook, we have been able to understand the concept of sharding and its implementation in FSDP. Let's move on to the next notebook to learn `Model parallelism`. To proceed, please click the [Next Link](model-parallelism.ipynb).

---
## References

- https://github.com/pytorch/tutorials/blob/main/intermediate_source/FSDP_tutorial.rst
- https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
- https://github.com/pytorch/tutorials/blob/main/intermediate_source/FSDP_advanced_tutorial.rst


## Licensing 

Copyright Â© 2025 OpenACC-Standard.org. This material is released by OpenACC-Standard.org, in collaboration with NVIDIA Corporation, under the Creative Commons Attribution 4.0 International (CC BY 4.0). These materials include references to hardware and software developed by other entities; all applicable licensing and copyrights apply.