Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: 'MultiheadAttention' object has no attribute 'requires_grad' #111279

Closed
quancs opened this issue Oct 14, 2023 · 9 comments
Closed
Assignees
Labels
high priority module: regression It used to work, and now it doesn't oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@quancs
Copy link

quancs commented Oct 14, 2023

馃悰 Describe the bug

Cannot train compiled model using ddp, but single gpu training is OK.

CMD:

python code.py

code (code.py):

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

torch.set_float32_matmul_precision('high')


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):

    def __init__(self,):
        super().__init__()
        layers = []
        for l in range(2):
            layer = nn.ModuleList([nn.LayerNorm(96), nn.MultiheadAttention(embed_dim=96, num_heads=4, batch_first=True)])
            layers.append(layer)
        self.layers = nn.ModuleList(layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [Batch, Freq, Time, Feature]
        B, F, T, H = x.shape

        for m in self.layers:
            x = x.reshape(B * F, T, H)
            x = m[0](x)
            x, attn = m[1].forward(x, x, x)
            x = x.reshape(B, F, T, H)

        return x


def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model.compile()  # if comment this line, the training process will be OK
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(2, 129, 100, 96))
    labels = torch.randn(2, 129, 100, 96).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()

    print("success")


if __name__ == '__main__':
    world_size = 2
    mp.spawn(demo_basic, args=(world_size,), nprocs=world_size, join=True)

The error reported:

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 74, in _wrap
    fn(i, *args)
  File "/data/home/quancs/projects/NBSS_pmt/code2.py", line 61, in demo_basic
    outputs = ddp_model(torch.randn(2, 129, 100, 96))
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1516, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 487, in catch_errors
    return hijacked_callback(frame, cache_entry, hooks, frame_state)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2069, in run
    super().run()
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 719, in run
    and self.step()
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 683, in step
    getattr(self, inst.opname)(inst)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2157, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 833, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 957, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1024, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1009, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/_dynamo/backends/distributed.py", line 268, in compile_fn
    if maybe_param.requires_grad and not self._ignore_parameter(
  File "/data/home/quancs/miniconda3/envs/torch21/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
AttributeError: 'MultiheadAttention' object has no attribute 'requires_grad'

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

Collecting environment information...
PyTorch version: 2.1.1
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.27.9
Libc version: glibc-2.35

Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-87-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 535.129.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz
CPU family: 6
Model: 106
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
Stepping: 6
CPU max MHz: 3400.0000
CPU min MHz: 800.0000
BogoMIPS: 5200.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid fsrm md_clear pconfig flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 3 MiB (64 instances)
L1i cache: 2 MiB (64 instances)
L2 cache: 80 MiB (64 instances)
L3 cache: 96 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-31,64-95
NUMA node1 CPU(s): 32-63,96-127
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] mypy==1.6.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.0
[pip3] pytorch-lightning==2.1.2
[pip3] pytorch-ranger==0.1.1
[pip3] torch==2.1.1
[pip3] torch-complex==0.4.3
[pip3] torch-optimizer==0.1.0
[pip3] torch-stoi==0.1.2
[pip3] torch-tb-profiler==0.4.3
[pip3] torchaudio==2.1.1
[pip3] torchmetrics==1.2.1
[pip3] torchtnt==0.2.1
[pip3] torchvision==0.16.1
[pip3] triton==2.1.0
[conda] blas 1.0 mkl
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
[conda] mkl 2023.1.0 h213fc3f_46343
[conda] mkl-service 2.4.0 py310h5eee18b_1
[conda] mkl_fft 1.3.8 py310h5eee18b_0
[conda] mkl_random 1.2.4 py310hdb19cb5_0
[conda] numpy 1.23.0 pypi_0 pypi
[conda] numpy-base 1.26.2 py310hb5e798b_0
[conda] pytorch 2.1.1 py3.10_cuda12.1_cudnn8.9.2_0 pytorch
[conda] pytorch-cuda 12.1 ha16c6d3_5 pytorch
[conda] pytorch-lightning 2.1.2 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] pytorch-ranger 0.1.1 pypi_0 pypi
[conda] torch-complex 0.4.3 pypi_0 pypi
[conda] torch-optimizer 0.1.0 pypi_0 pypi
[conda] torch-stoi 0.1.2 pypi_0 pypi
[conda] torch-tb-profiler 0.4.3 pypi_0 pypi
[conda] torchaudio 2.1.1 py310_cu121 pytorch
[conda] torchmetrics 1.2.1 pypi_0 pypi
[conda] torchtnt 0.2.1 pypi_0 pypi
[conda] torchtriton 2.1.0 py310 pytorch
[conda] torchvision 0.16.1 py310_cu121 pytorch

cc @ezyang @gchanan @zou3519 @kadeng @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @msaroufim @bdhirsh @anijain2305

@quancs
Copy link
Author

quancs commented Oct 14, 2023

issue related in pytorch-lightning: Lightning-AI/pytorch-lightning#18798

@janeyx99 janeyx99 added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 16, 2023
@addtt
Copy link

addtt commented Nov 22, 2023

If that helps, I don't use Lightning but huggingface accelerate, and I have the same issue with multi GPU training. It works fine with torch 2.0.1, but I get this error with 2.1.0 and 2.1.1.

You mentioned it works with strategy="auto" but I'm not sure what that means in Lightning. No information on the docs. Does that still allow parallel training? Could you find a workaround so far?

@quancs
Copy link
Author

quancs commented Nov 23, 2023

If that helps, I don't use Lightning but huggingface accelerate, and I have the same issue with multi GPU training. It works fine with torch 2.0.1, but I get this error with 2.1.0 and 2.1.1.

You mentioned it works with strategy="auto" but I'm not sure what that means in Lightning. No information on the docs. Does that still allow parallel training? Could you find a workaround so far?

I guess "auto" means automatic choose one strategy (parallel training or not) according to the number of GPU specified. So it allows parallel training when the number of GPUs is greater than 1. I don't a workaround yet.

@addtt
Copy link

addtt commented Nov 23, 2023

Have you tried the same exact setup but downgrading to torch==2.0.1? (and downgrade accordingly torchvision, triton, or whatever else you need)

@quancs
Copy link
Author

quancs commented Nov 23, 2023

Yes. Torch==2.0.1 works fine.

@quancs quancs changed the title Cannot use compiled model together with the ddp strategy AttributeError: 'MultiheadAttention' object has no attribute 'requires_grad' Dec 24, 2023
@quancs
Copy link
Author

quancs commented Dec 24, 2023

@janeyx99 @addtt the code is updated to use pytorch only. Hope this can accelerate the fix of this bug

@janeyx99 janeyx99 added high priority module: regression It used to work, and now it doesn't labels Dec 27, 2023
@janeyx99
Copy link
Contributor

Marking as hi-pri due to it being a regression, cc @wconstab

@bdhirsh
Copy link
Contributor

bdhirsh commented Dec 27, 2023

As a workaround/unblock, you can workaround by setting:

torch._dynamo.config.optimize_ddp = False

(tested locally)

@bdhirsh
Copy link
Contributor

bdhirsh commented Dec 27, 2023

This patch also fixes the error for me. I'm not sure if it handles the DDPOptimizer's bucketing logic properly though:

diff --git a/torch/_dynamo/backends/distributed.py b/torch/_dynamo/backends/distributed.py
index 024384a97c3..3928c81fba6 100644
--- a/torch/_dynamo/backends/distributed.py
+++ b/torch/_dynamo/backends/distributed.py
@@ -263,12 +263,13 @@ class DDPOptimizer:
                         buckets[0].param_ids.append(id(param))
             elif node.op == "get_attr":
                 maybe_param = getattr(gm, node.target)
-                if maybe_param.requires_grad and not self._ignore_parameter(
+                if isinstance(maybe_param, torch.nn.Parameter) and maybe_param.requires_grad and not self._ignore_parameter(
                     maybe_param
                 ):
                     buckets[0].size += maybe_param.untyped_storage().nbytes()

@yf225 yf225 self-assigned this Jan 9, 2024
@yf225 yf225 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jan 9, 2024
@yf225 yf225 removed their assignment Feb 27, 2024
pytorchmergebot pushed a commit that referenced this issue Mar 13, 2024
This PR fixes Issue #111279.

While #111279 reported the issue with `MultiheadAttention`, a minimal reproduction would be:
```python
class ToyModel(nn.Module):
    def __init__(self,):
        super().__init__()
        self.linear = nn.Linear(128, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear.forward(x) # Error
        # return self.linear(x) # OK
```

Dynamo treats `self.linear(x)` as `call_module` while treating `self.linear.forward(x)` as a [`get_attr` and a `call_method`](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/nn_module.py#L358-L378). However, existing DDPOptimizer assumes, for a `get_attr` node, `getattr(gm, node.target)` gives a tensor with the `requires_grad` attribute. Existing DDPOptimizer also does not support `call_method` nodes.

This PR adds support for `call_method` and check on `get_attr`. It also checks if a module's parameters have been added to a bucket to support multiple method calls from the same module.

Pull Request resolved: #121771
Approved by: https://github.com/yf225
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: regression It used to work, and now it doesn't oncall: distributed Add this issue/PR to distributed oncall triage queue oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants