Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP crashes when submodule calls method that isn't forward() #109385

Open
siddk opened this issue Sep 15, 2023 · 3 comments
Open

FSDP crashes when submodule calls method that isn't forward() #109385

siddk opened this issue Sep 15, 2023 · 3 comments
Labels
module: fsdp triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@siddk
Copy link

siddk commented Sep 15, 2023

🐛 Describe the bug

I am getting various runtime errors given an FSDP module that wraps multiple children modules, where in the forward pass, we invoke a submodule's non-forward method. The autowrap policy wraps each submodule separately. The minimal example below should make this more clear:

Run with (at least 2 GPUs): `torchrun --standalone --nnodes 1 --nproc-per-node 2 <script.py>

from functools import partial

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
from torch.distributed.fsdp.wrap import _module_wrap_policy

# This code is a "dummy" VisionTransformer that just implements the offending partial logic from a default pretrained ViT
# Our actual code uses a model loaded from `timm` (PyTorch Image Models); full Gist is linked below.

class VisionTransformer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.embed_dim = 1024

        # Mimics the "Patch Embedding" for a ViT-Large w/ Patch Size = 14
        self.patch_proj = nn.Conv2d(in_channels=3, out_channels=self.embed_dim, kernel_size=14, stride=14, bias=True)

    def forward_features(self, imgs: torch.Tensor) -> torch.Tensor:
        patches = self.patch_proj(imgs)                                               # [bsz, embed_dim, 16 = (224 / 14), 16 = (224 / 14)]
        patch_embeddings = patches.flatten(2).transpose(1, 2)       # [bsz, 256 = 16 * 16, 1024]
        return patch_embeddings

    def forward(self, imgs: torch.Tensor) -> torch.Tensor:
        return self.forward_features(imgs).sum(dim=1)

class LinearProjector(nn.Module):
    def __init__(self, in_dim: int, out_dim: int) -> None:
        super().__init__()
        self.projector = nn.Linear(in_dim, out_dim)

    def forward(self, patch_embeddings: torch.Tensor) -> torch.Tensor:
        return self.projector(patch_embeddings)

# === Actual Network ===
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.vit, self.projector = VisionTransformer(), LinearProjector(in_dim=1024, out_dim=256)

    def forward(self, imgs: torch.Tensor) -> torch.Tensor:
        patch_embeddings = self.vit.forward_features(imgs)  # [ERRORS HERE]
        return self.projector(patch_embeddings)

# === Main ===
def bug_fsdp() -> None:
    dist.init_process_group(backend="nccl", init_method="env://")
    torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())

    # Initialize Network
    net = Net()
    
    # FSDP w/ custom module-based autowrap policy
    auto_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer, LinearProjector})
    net = FSDP(
        net,
        auto_wrap_policy=auto_wrap_policy,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        device_id=torch.cuda.current_device(),
        limit_all_gathers=True,
    )

   # Run a forward pass w/ dummy input (bsz = 4)
   dummy_input = torch.randn(4, 3, 224, 224)
   net(dummy_input)    # CRASH!

if __name__ == "__main__":
     bug_fsdp()

This results in the following error message:

Traceback (most recent call last):                                                                                                                                                                                    
  File "/mnt/fsx/skaramcheti/code/prismatic-vlms/x-reference/bugs-fixes/bug_fsdp.py", line 74, in <module>                                                                                                            
    bug_fsdp()                                                                                                                                                                                                        
  File "/mnt/fsx/skaramcheti/code/prismatic-vlms/x-reference/bugs-fixes/bug_fsdp.py", line 70, in bug_fsdp                                                                                                            
    net(dummy_input)  # CRASH!                                                                                                                                                                                        
  File "/home/ubuntu/mambaforge/envs/fsdp-debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                                    
    return forward_call(*args, **kwargs)                                                                                                                                                                              
  File "/home/ubuntu/mambaforge/envs/fsdp-debug/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 748, in forward                                                             
    output = self._fsdp_wrapped_module(*args, **kwargs)                                                                                                                                                              
  File "/home/ubuntu/mambaforge/envs/fsdp-debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                                    
    return forward_call(*args, **kwargs)                                                                                                                                                                              
  File "/mnt/fsx/skaramcheti/code/prismatic-vlms/x-reference/bugs-fixes/bug_fsdp.py", line 46, in forward                                                                                                             
    patch_embeddings = self.vit.forward_features(imgs)  # [ERRORS HERE]                                                                                                                                               
  File "/mnt/fsx/skaramcheti/code/prismatic-vlms/x-reference/bugs-fixes/bug_fsdp.py", line 22, in forward_features                                                                                                   
    patches = self.patch_proj(imgs)  # [bsz, embed_dim, 16 = (224 / 14), 16 = (224 / 14)]                                                                                                                            
  File "/home/ubuntu/mambaforge/envs/fsdp-debug/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                                   
    return forward_call(*args, **kwargs)                                                                                                                                                                             
  File "/home/ubuntu/mambaforge/envs/fsdp-debug/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 463, in forward                                                                                          
    return self._conv_forward(input, self.weight, self.bias)                                                                                                                                                         
  File "/home/ubuntu/mambaforge/envs/fsdp-debug/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward                                                                                    
    return F.conv2d(input, weight, bias, self.stride,                                                                                                                                                                 
RuntimeError: Output 0 of ViewBackward0 is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.

Further context: I'm working on a project where we take the patch features from a (frozen) Vision Transformer backbone and transform them into a different latent space where they're used to decode other modalities (e.g., depth).

This gist provides an annotated example that reflects our setup a bit better: https://gist.github.com/siddk/db3e8808bed2a9cb90ae62b5338de68d


Some other things I tried (to help speed along debugging) -- all of this is in the linked Gist:
- **Setting use_orig_params=True results in a different error at the same Conv2D call (RuntimeError: weight should have at least three dimensions)
- Freezing the ViT (as required in our original setup) results in yet another error at the Conv2D call (RuntimeError: GET was unable to find an engine to execute this computation)

Interestingly, if we monkey patch the vit instance such that vit.forward = vit.forward_features and call self.vit(imgs) in Net.forward() -- all of these bugs disappear!

Versions

PyTorch version: 2.0.1
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1033-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 525.85.12
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   46 bits physical, 48 bits virtual
CPU(s):                          96
On-line CPU(s) list:             0-95
Thread(s) per core:              2
Core(s) per socket:              24
Socket(s):                       2
NUMA node(s):                    2
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           85
Model name:                      Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping:                        7
CPU MHz:                         2999.998
BogoMIPS:                        5999.99
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       1.5 MiB
L1i cache:                       1.5 MiB
L2 cache:                        48 MiB
L3 cache:                        71.5 MiB
NUMA node0 CPU(s):               0-23,48-71
NUMA node1 CPU(s):               24-47,72-95
Vulnerability Itlb multihit:     KVM: Mitigation: VMX unsupported
Vulnerability L1tf:              Mitigation; PTE Inversion
Vulnerability Mds:               Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:          Mitigation; PTI
Vulnerability Mmio stale data:   Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:          Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries:
[pip3] numpy==1.25.2
[pip3] torch==2.0.1
[pip3] torchaudio==2.0.2
[pip3] torchvision==0.15.2
[pip3] triton==2.0.0
[conda] blas                      2.116                       mkl    conda-forge
[conda] blas-devel                3.9.0            16_linux64_mkl    conda-forge
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libblas                   3.9.0            16_linux64_mkl    conda-forge
[conda] libcblas                  3.9.0            16_linux64_mkl    conda-forge
[conda] liblapack                 3.9.0            16_linux64_mkl    conda-forge
[conda] liblapacke                3.9.0            16_linux64_mkl    conda-forge
[conda] mkl                       2022.1.0           h84fe81f_915    conda-forge
[conda] mkl-devel                 2022.1.0           ha770c72_916    conda-forge
[conda] mkl-include               2022.1.0           h84fe81f_915    conda-forge
[conda] numpy                     1.25.2          py310ha4c1d20_0    conda-forge
[conda] pytorch                   2.0.1           py3.10_cuda11.8_cudnn8.7.0_0    pytorch
[conda] pytorch-cuda              11.8                 h7e8668a_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                2.0.2               py310_cu118    pytorch
[conda] torchtriton               2.0.0                     py310    pytorch
[conda] torchvision               0.15.2              py310_cu118    pytorch

cc @zhaojuanmao @mrshenli @rohan-varma @awgu @fegin @penguinwu

@awgu
Copy link
Contributor

awgu commented Sep 15, 2023

Hi @siddk. This is a known limitation of FSDP. Our design relies on nn.Module.forward() to designate the compute for which FSDP should all-gather parameters. This issue has shown up previously for HuggingFace's generate() methods.

If you have a way to workaround this for now (e.g. monkey patching), then that would be the shortest path for now.

@awgu awgu added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: fsdp labels Sep 15, 2023
@siddk
Copy link
Author

siddk commented Sep 15, 2023

Awesome - thanks @awgu; I'm lucky in that I found the monkey patching thing to work just as I was writing up the minimal example for the bug report... would be super great to add this to the docs somewhere so others don't fall into the same trap!

@lukaemon
Copy link

Facing the same problem today. Thanks for the open issue and discussion.

awgu added a commit that referenced this issue May 2, 2024
… fwd methods"


FSDP only runs its pre/post-forward hooks on `nn.Module.forward`. This means that if the user runs a custom method meant as a forward pass, then FSDP will not all-gather the parameters. Examples include HuggingFace models' `generate()` (#123962, #100069) or others (#109385).

This PR adds a monkey patching API to allow FSDP pre/post-forward hooks to run on the method.

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue May 3, 2024
#125394)

FSDP only runs its pre/post-forward hooks on `nn.Module.forward`. This means that if the user runs a custom method meant as a forward pass, then FSDP will not all-gather the parameters. Examples include HuggingFace models' `generate()` (#123962, #100069) or others (#109385).

This PR adds a monkey patching API `register_fsdp_forward_method(module: nn.Module, method_name: str)` to allow FSDP pre/post-forward hooks to run on the method. The function is a no-op if the passed-in `module` is not an FSDP module so that the register function can be called even if the FSDP wrapping changes.

Pull Request resolved: #125394
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: fsdp triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants