Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_layer_norm_fwd_1pass_kernel error #84

Open
chenwuchen opened this issue Dec 28, 2023 · 8 comments
Open

_layer_norm_fwd_1pass_kernel error #84

chenwuchen opened this issue Dec 28, 2023 · 8 comments

Comments

@chenwuchen
Copy link

chenwuchen commented Dec 28, 2023

Title: Error when running multi-GPU training with Mamba

Description:
I am experiencing an issue when running multi-GPU training with Mamba. Specifically, I am getting a TypeError: 'NoneType' object is not a mapping error when running the forward pass of the model. The error occurs when I try to run the model on multiple GPUs using the DataParallel module. However, when I run the model on a single GPU, everything works fine.

I have tried to reproduce the issue with a minimal example, but I was unable to do so. I have also checked the documentation and searched online for similar issues, but I couldn't find anything useful.

Here is the full traceback of the error:

Traceback (most recent call last):
File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 152, in forward
hidden_states, residual = layer(
File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py", line 341, in forward
hidden_states, residual = fused_add_norm_fn(
File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
y, mean, rstd, residual_out = _layer_norm_fwd(
File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in run
timings = {config: self._bench(*args, config=config, **kwargs)
File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in
timings = {config: self._bench(*args, config=config, **kwargs)
File "/mnt/miniconda3/envs/mamba_env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 75, in _bench
full_nargs = {**self.nargs, **current}
TypeError: 'NoneType' object is not a mapping

I am using Python 3.10, PyTorch 1.12.1, and causal_conv1d-1.1.1 mamba-ssm-1.1.1 triton-2.1.0

@akkikiki
Copy link

akkikiki commented Jan 5, 2024

@chenwuchen Bumped into the same error.
Solution: Use DDP or torchrun.

@mhamzaerol
Copy link

mhamzaerol commented Feb 8, 2024

I was facing with the same issue. Upon investigating further, I realized that the line 75 of the autotuner.py file in the triton package receives self.nargs=None, which casues problems with generating a dictionary from it:

(inside the _bench method)

full_nargs = {**self.nargs, **current}

My take is that this may:

  • be a bug inside the triton package (which should've handled the case self.nargs=None)
  • be an issue in this repo which sets self.nargs=None unintentionally
  • be resulting from an issue/mismatch regarding the dependencies/environment (for instance in my case, I have Cuda 11.4 not Cuda 11.6)
  • be a more complex conflict between the dataparallel and the triton package or this repository itself

But, I was able to run the dataparallel by replacing the line 75 of autotuner.py with these modifications:

full_nargs = {}
if self.nargs:
    full_nargs.update(self.nargs)
if current:
    full_nargs.update(current)

Though, not sure if the training would overall behave as expected.

@PheelaV
Copy link

PheelaV commented Feb 29, 2024

Needed to do the same, I am executing the trainer script from mamba_chat and got a triton error same as above. Patching the package and installing:

Get the package

git clone https://github.com/openai/triton.git;
git checkout release/2.1.x;
pip install cmake;

Patch it by editing python/triton/runtime/autotuner.py at line 75

replacing

full_nargs = {**self.nargs, **current}

with

full_nargs = {}
if self.nargs:
    full_nargs.update(self.nargs)
if current:
    full_nargs.update(current)

proceed to install the patched version:

cd triton/python;
pip install -e .

install the rest of mamba dependencies as per normal

(I am using Pyton 3.11 in a conda environment, currently have training running on a pair of RTX 3090s)

@s22chan
Copy link

s22chan commented Mar 27, 2024

Hacking it that way could cause silent errors (especially if different args are passed concurrently to the jit)

Looks like DataParallel is multi-threaded and Triton doesn't appear to be thread-safe.

If you're confident the rest of the forward pass is thread-safe and you must run in threaded mode, you could try to run a single pass first to boostrap the jit before running it in parallel. I don't think mamba has a config.pre_hook, which is the only thing that would be written per run after benching is complete.

@AndssY
Copy link

AndssY commented Apr 8, 2024

@PheelaV Can you provide more details about how to conduct triton from source? Thanks!
I can't install triton with pyton 3.8.5 in a conda environment.

.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/mlir/IR/Value.h:95:56: error: ‘void* __builtin_memset(void*, int, long unsigned int)’ specified size between 18446744039349813224 and 18446744073709551608 exceeds maximum object size 9223372036854775807 [-Werror=stringop-overflow=]
         95 |   constexpr Value(detail::ValueImpl *impl = nullptr) : impl(impl) {}
...
        270 |   iterator begin() { return (iterator)this->BeginX; }
            |                                       ~~~~~~^~~~~~
      At global scope:
      cc1plus: note: unrecognized command-line option ‘-Wno-covered-switch-default’ may have been intended to silence earlier diagnostics
      cc1plus: all warnings being treated as errors
      gmake[2]: *** [lib/Conversion/TritonGPUToLLVM/CMakeFiles/obj.TritonGPUToLLVM.dir/build.make:163: lib/Conversion/TritonGPUToLLVM/CMakeFiles/obj.TritonGPUToLLVM.dir/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp.o] Error 1
      gmake[2]: Leaving directory 'triton/python/build/cmake.linux-x86_64-cpython-3.8'
      gmake[1]: *** [CMakeFiles/Makefile2:2299: lib/Conversion/TritonGPUToLLVM/CMakeFiles/obj.TritonGPUToLLVM.dir/all] Error 2
      gmake[1]: Leaving directory 'triton/python/build/cmake.linux-x86_64-cpython-3.8'
      gmake: *** [Makefile:149: all] Error 2
...
ERROR: Could not build wheels for triton, which is required to install pyproject.toml-based projects

@PheelaV
Copy link

PheelaV commented Apr 8, 2024

@AndssY Sorry I can't really, I was using python 3.10 or .11 on Ubuntu 22.04 LTS and following their readme instructions everything worked out. I think the only thing I had to do out of standard was to follow the specific release branch as requested per mamba dependencies (important).

But this whole thing became redundant. I think Mamba was patched and everything suddenly started to work with just mamba-ssm and conv1d install. I still kept the environment with a triton built from source to be sure, but it was no longer necessary for me.

Hope that helps at least a little bit. Good luck with getting it up and running. Feel free to DM me if you still struggle, I think I went through all the jumps and hoops I could have met.

@AndssY
Copy link

AndssY commented Apr 8, 2024

...and following their readme instructions everything worked out. I think the only thing I had to do out of standard was to follow the specific release branch as requested per mamba dependencies (important).

@PheelaV Did you install according to the readme of maba-chat? So mamba-ssm==1.0.1 and triton==release/2.1.x?

I will try python==3.11 and install it again following the readme of mamba-chat. Thanks very much!

@cgz6498
Copy link

cgz6498 commented Oct 23, 2024

Has it been resolved?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants