Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor] atomic_add does not support bf16 #97016

Open
JonasGeiping opened this issue Mar 17, 2023 · 11 comments
Open

[Inductor] atomic_add does not support bf16 #97016

JonasGeiping opened this issue Mar 17, 2023 · 11 comments
Labels
feature A request for a proper, new feature. module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@JonasGeiping
Copy link

JonasGeiping commented Mar 17, 2023

🐛 Describe the bug

This is may be known already, but triton does not support atomic_add with bf16, see https://github.com/openai/triton/blob/c9740f0870f6ae2480acd2a76a5fb4c920bc5ce5/python/triton/language/semantic.py#L904.

This is not a problem in eager mode, only with torch.compile as it works right now, ideally this op should not be currently selected?

I made a minified repro below, but there is probably an even easier way to replicate this, I'm just unsure how to exactly trigger atomic_add.

Error logs

    raise ValueError("atomic_" + op + " does not support " + str(element_ty))
ValueError: atomic_add does not support bf16

The above exception was the direct cause of the following exception:
triton.compiler.CompilationError: at 11:85:
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 61440
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 768)
    x2 = xindex
    x0 = xindex % 768
    tmp0 = tl.load(in_ptr0 + (x1), None)
    tmp1 = tl.load(in_ptr1 + (x2), None).to(tl.float32)
    tl.atomic_add(out_ptr0 + (x0 + (768*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, None)
                                                                                     ^

Minified repro

import torch._inductor.overrides

import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
torch._dynamo.config.load_config(b'\x80\x02}q\x00(X\x0b\x00\x00\x00output_codeq\x01\x89X\r\x00\x00\x00log_file_nameq\x02NX\x07\x00\x00\x00verboseq\x03\x89X\x11\x00\x00\x00output_graph_codeq\x04\x89X\x12\x00\x00\x00verify_correctnessq\x05\x89X\x12\x00\x00\x00minimum_call_countq\x06K\x01X\x15\x00\x00\x00dead_code_eliminationq\x07\x88X\x10\x00\x00\x00cache_size_limitq\x08K@X\x0e\x00\x00\x00specialize_intq\t\x88X\x0e\x00\x00\x00dynamic_shapesq\n\x89X\x18\x00\x00\x00assume_static_by_defaultq\x0b\x89X\x10\x00\x00\x00guard_nn_modulesq\x0c\x89X\x1b\x00\x00\x00traceable_tensor_subclassesq\rc__builtin__\nset\nq\x0e]q\x0f\x85q\x10Rq\x11X\x0f\x00\x00\x00suppress_errorsq\x12\x89X\x15\x00\x00\x00replay_record_enabledq\x13\x89X \x00\x00\x00rewrite_assert_with_torch_assertq\x14\x88X\x12\x00\x00\x00print_graph_breaksq\x15\x89X\x07\x00\x00\x00disableq\x16\x89X*\x00\x00\x00allowed_functions_module_string_ignorelistq\x17h\x0e]q\x18(X\x0b\x00\x00\x00torch._refsq\x19X\x0c\x00\x00\x00torch._primsq\x1aX\x13\x00\x00\x00torch.distributionsq\x1bX\r\x00\x00\x00torch._decompq\x1cX\r\x00\x00\x00torch.testingq\x1de\x85q\x1eRq\x1fX\x12\x00\x00\x00repro_forward_onlyq \x89X\x0f\x00\x00\x00repro_toleranceq!G?PbM\xd2\xf1\xa9\xfcX\x16\x00\x00\x00capture_scalar_outputsq"\x89X \x00\x00\x00capture_dynamic_output_shape_opsq#\x89X\x19\x00\x00\x00enforce_cond_guards_matchq$\x88X\x0c\x00\x00\x00optimize_ddpq%\x88X\x1a\x00\x00\x00raise_on_ctx_manager_usageq&\x88X\x1c\x00\x00\x00raise_on_unsafe_aot_autogradq\'\x89X\x17\x00\x00\x00raise_on_backend_changeq(\x89X\x18\x00\x00\x00error_on_nested_fx_traceq)\x88X\t\x00\x00\x00allow_rnnq*\x89X\x08\x00\x00\x00base_dirq+X;\x00\x00\x00/home/jonas/miniconda3/envs/dl/lib/python3.10/site-packagesq,X\x0e\x00\x00\x00debug_dir_rootq-XN\x00\x00\x00/home/jonas/Dropbox/Documents_Hyperion/Python/cramming-dev/torch_compile_debugq.X)\x00\x00\x00DO_NOT_USE_legacy_non_fake_example_inputsq/\x89X\x13\x00\x00\x00_save_config_ignoreq0h\x0e]q1(X!\x00\x00\x00skipfiles_inline_module_allowlistq2X\x12\x00\x00\x00constant_functionsq3X\x0b\x00\x00\x00repro_levelq4X\x0b\x00\x00\x00repro_afterq5e\x85q6Rq7u.')
torch._inductor.config.load_config(b'\x80\x02}q\x00(X\x05\x00\x00\x00debugq\x01\x89X\x10\x00\x00\x00disable_progressq\x02\x88X\x10\x00\x00\x00verbose_progressq\x03\x89X\x0b\x00\x00\x00cpp_wrapperq\x04\x89X\x03\x00\x00\x00dceq\x05\x89X\x14\x00\x00\x00static_weight_shapesq\x06\x88X\x0c\x00\x00\x00size_assertsq\x07\x88X\x10\x00\x00\x00pick_loop_ordersq\x08\x88X\x0f\x00\x00\x00inplace_buffersq\t\x88X\x11\x00\x00\x00benchmark_harnessq\n\x88X\x0f\x00\x00\x00epilogue_fusionq\x0b\x89X\x15\x00\x00\x00epilogue_fusion_firstq\x0c\x89X\x0f\x00\x00\x00pattern_matcherq\r\x88X\n\x00\x00\x00reorderingq\x0e\x89X\x0c\x00\x00\x00max_autotuneq\x0f\x89X\x15\x00\x00\x00search_autotune_cacheq\x10\x89X\x17\x00\x00\x00realize_reads_thresholdq\x11K\x04X\x17\x00\x00\x00realize_bytes_thresholdq\x12M\xd0\x07X\x1b\x00\x00\x00realize_acc_reads_thresholdq\x13K\x08X\x0f\x00\x00\x00fallback_randomq\x14\x89X\x12\x00\x00\x00implicit_fallbacksq\x15\x88X\x0b\x00\x00\x00tune_layoutq\x16\x89X\x11\x00\x00\x00aggressive_fusionq\x17\x89X\x0f\x00\x00\x00max_fusion_sizeq\x18K@X\x1b\x00\x00\x00unroll_reductions_thresholdq\x19K\x08X\x0e\x00\x00\x00comment_originq\x1a\x89X\x10\x00\x00\x00benchmark_kernelq\x1b\x89X\x12\x00\x00\x00developer_warningsq\x1c\x89X\x0f\x00\x00\x00compile_threadsq\x1dK\x10X\x11\x00\x00\x00global_cache_pathq\x1eNX\x13\x00\x00\x00kernel_name_max_opsq\x1fK\nX\r\x00\x00\x00shape_paddingq \x89X\x0e\x00\x00\x00permute_fusionq!\x89X\x1a\x00\x00\x00profiler_mark_wrapper_callq"\x89X\x18\x00\x00\x00_raise_error_for_testingq#\x89X\x0c\x00\x00\x00_profile_varq$X\x00\x00\x00\x00q%X\x11\x00\x00\x00profile_bandwidthq&\x89X\x17\x00\x00\x00profile_bandwidth_regexq\'h%X\x0b\x00\x00\x00cpp.threadsq(J\xff\xff\xff\xffX\x13\x00\x00\x00cpp.dynamic_threadsq)\x89X\x0b\x00\x00\x00cpp.simdlenq*NX\x12\x00\x00\x00cpp.min_chunk_sizeq+M\x00\x10X\x07\x00\x00\x00cpp.cxxq,NX\x03\x00\x00\x00g++q-\x86q.X\x19\x00\x00\x00cpp.enable_kernel_profileq/\x89X\x12\x00\x00\x00cpp.weight_prepackq0\x88X\x11\x00\x00\x00triton.cudagraphsq1\x89X\x17\x00\x00\x00triton.debug_sync_graphq2\x89X\x18\x00\x00\x00triton.debug_sync_kernelq3\x89X\x15\x00\x00\x00triton.dense_indexingq4\x89X\x10\x00\x00\x00triton.max_tilesq5K\x02X\x19\x00\x00\x00triton.autotune_pointwiseq6\x88X\'\x00\x00\x00triton.tiling_prevents_pointwise_fusionq7\x88X\'\x00\x00\x00triton.tiling_prevents_reduction_fusionq8\x88X\x1b\x00\x00\x00triton.ordered_kernel_namesq9\x89X\x1f\x00\x00\x00triton.descriptive_kernel_namesq:\x89X\x1c\x00\x00\x00triton.persistent_reductionsq;\x88X\x10\x00\x00\x00triton.max_blockq<}q=(X\x01\x00\x00\x00Xq>M\x00\x08X\x01\x00\x00\x00Yq?M\x00\x04X\x01\x00\x00\x00Zq@M\x00\x04uX\r\x00\x00\x00trace.enabledqA\x89X\x0f\x00\x00\x00trace.debug_logqB\x88X\x0e\x00\x00\x00trace.info_logqC\x89X\x0e\x00\x00\x00trace.fx_graphqD\x88X\x1a\x00\x00\x00trace.fx_graph_transformedqE\x88X\x13\x00\x00\x00trace.ir_pre_fusionqF\x88X\x14\x00\x00\x00trace.ir_post_fusionqG\x88X\x11\x00\x00\x00trace.output_codeqH\x88X\x13\x00\x00\x00trace.graph_diagramqI\x89X\x15\x00\x00\x00trace.compile_profileqJ\x89X\x10\x00\x00\x00trace.upload_tarqKNu.')
torch._functorch.config.load_config(b'\x80\x02}q\x00(X\x11\x00\x00\x00use_functionalizeq\x01\x88X\x0f\x00\x00\x00use_fake_tensorq\x02\x88X\x16\x00\x00\x00fake_tensor_allow_metaq\x03\x88X\x0c\x00\x00\x00debug_assertq\x04\x88X\x14\x00\x00\x00debug_fake_cross_refq\x05\x89X\x11\x00\x00\x00debug_partitionerq\x06\x89X\x0c\x00\x00\x00debug_graphsq\x07\x89X\x0b\x00\x00\x00debug_jointq\x08\x89X\x14\x00\x00\x00static_weight_shapesq\t\x88X\x03\x00\x00\x00cseq\n\x88X\x10\x00\x00\x00max_dist_from_bwq\x0bK\x03X\t\x00\x00\x00log_levelq\x0cK\x14u.')

from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    
    
    def forward(self, arg27_1, mm_1, full):
        index_put = torch.ops.aten.index_put.default(full, [arg27_1], mm_1, True);  full = arg27_1 = mm_1 = None
        return (index_put,)
        
args = [((80,), (1,), torch.int64, 'cuda'), ((80, 768), (768, 1), torch.bfloat16, 'cuda'), ((512, 768), (768, 1), torch.bfloat16, 'cuda')]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
mod = make_fx(Repro(), tracing_mode='real')(*args)

from torch._inductor.compile_fx import compile_fx_inner
from torch._dynamo.debug_utils import same_two_models

compiled = compile_fx_inner(mod, args)
ref = compiled(args)
torch.cuda.synchronize() # Ensures that segfaults are surfaced

Versions

PyTorch version: 2.1.0.dev20230316
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Pop!_OS 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.9 (main, Mar  1 2023, 18:23:06) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.2.0-76060200-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.5.119
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3050 Ti Laptop GPU
Nvidia driver version: 525.85.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.1.0.dev20230316
[pip3] torchaudio==2.0.0.dev20230317
[pip3] torchvision==0.16.0.dev20230317
[conda] blas                      1.0                         mkl  
[conda] lion-pytorch              0.0.7                    pypi_0    pypi
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0           py310h7f8727e_0  
[conda] mkl_fft                   1.3.1           py310hd6ae3a3_0  
[conda] mkl_random                1.2.2           py310h00e6091_0  
[conda] numpy                     1.23.5          py310hd5efca6_0  
[conda] numpy-base                1.23.5          py310h8e6c178_0  
[conda] pytorch                   2.1.0.dev20230316 py3.10_cuda11.8_cudnn8.7.0_0    pytorch-nightly
[conda] pytorch-cuda              11.8                 h7e8668a_3    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] torchaudio                2.0.0.dev20230317     py310_cu118    pytorch-nightly
[conda] torchtriton               2.1.0+2c32f43999           py310    pytorch-nightly
[conda] torchvision               0.16.0.dev20230317     py310_cu118    pytorch-nightly

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @soumith @ngimel @Xia-Weiwen @desertfire

@Chillee
Copy link
Contributor

Chillee commented Mar 21, 2023

I remember that bfloat16 atomics on older versions of CUDA were horrendously slow, so perhaps that's why Triton doesn't support them. @ngimel any thoughts on this now?

@Chillee Chillee added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 21, 2023
@peterbell10
Copy link
Collaborator

According to the CUDA programming guide there is an atomicAdd for bfloat16, however

The 16-bit __nv_bfloat16 floating-point version of atomicAdd() is only supported by devices of compute capability 8.x and higher.

Whereas triton targets 7.0+. PyTorch works around this limitation with an atomicCAS loop.

@ngimel
Copy link
Collaborator

ngimel commented Mar 22, 2023

Triton has functions that work on 8.0+ only, atomic bf16 can be one of those.

@peterbell10
Copy link
Collaborator

Ah okay, I could have a go at adding it to triton if you like?

@ngimel
Copy link
Collaborator

ngimel commented Mar 22, 2023

Yeah that would be great!

@peterbell10
Copy link
Collaborator

I tried it out but it looks like the PTX atom.add.bf16 requires 9.0+ despite atomicAdd in CUDA being supported for 8.0+. In those cases nvcc actually generates a CAS loop: https://gcc.godbolt.org/z/TG7j4Kjbs

@daadaada
Copy link

FYI: I'm adding tl.atomic_add for bf16: triton-lang/triton#1689

@zjjott
Copy link
Contributor

zjjott commented Aug 3, 2023

+1

@quancs
Copy link

quancs commented Sep 27, 2023

similar error:

triton.compiler.CompilationError: at 11:85:
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 5120
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 40
    x2 = xindex
    x1 = (xindex // 40)
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x2), xmask).to(tl.float32)
    tl.atomic_add(out_ptr0 + (x1 + (128*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
                                                                                     ^

@plotfi
Copy link

plotfi commented Nov 11, 2023

I have made some progress here. Working on it at: https://github.com/plotfi/triton/commits/plotfi-atomic-add-bf16

I want to make it so that if you are working on hopper it will use the native atomic_add bf16, and use the fallback form ampere.

@eellison eellison added the feature A request for a proper, new feature. label Nov 30, 2023
@bhack
Copy link
Contributor

bhack commented Apr 2, 2024

What is the status of this? It seems that some triton upstream PRs were rejected?
The last one was triton-lang/triton#2708.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests