Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

result of 2 ** s differs between eager mode and inductor + triton + cuda when in float32 denormal range #125557

Open
vkuzo opened this issue May 5, 2024 · 1 comment
Labels
oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module upstream triton Upstream Triton Issue

Comments

@vkuzo
Copy link
Contributor

vkuzo commented May 5, 2024

馃悰 Describe the bug

When I calculate 2.0 ** s on cuda for very small s so that the result is in the float32 denormal range, the result from PT eager mode is the correct denormalized floating point number, and the result from torch.compile + inductor + triton is 0.0. Is this expected?

Repro:

def test_triton_debug():                                                                                                     
    for s_val in (
        # boundary between normal and denormal                                                                               
        -126, -127,                                                                                                          
        # boundary between denormal and not representable                                                                    
        -149, -150,                                                                                                          
    ):                                                                                                                       
        s = torch.tensor([s_val], dtype=torch.float, device='cuda')                                                          
                                                                                                                             
        def calculate(s):                                                                                                    
            return 2.0 ** s                                                                                                  
                                                                                                                             
        calculate_c = torch.compile(calculate)                                                                               
                                                                                                                             
        y_ref = calculate(s)                                                                                                 
        y_c = calculate_c(s)                                                                                                 
                                                                                                                             
        print(s_val)                                                                                                         
        print(y_ref)                                                                                                         
        print(y_c)                                                                                                           
        print()                                                                                                              
                                                                                                                             
        # s: -126: eager 1.1755e-38, compile 1.1755e-38 
        # s: -127 (first denormal), eager 5.8775e-39, compile 0.0 
        # s: -149 (last denormal), eager 1.4013e-45, compile 0.0
        # s: -150 (not representable), eager 0.0, compile 0.0                                                        

looking at generated triton code (https://gist.github.com/vkuzo/63a35ea68a58721f40806a32af04435d), it looks like the relevant line of triton code is

tmp2 = libdevice.exp2(tmp1)

https://forums.developer.nvidia.com/t/more-accurate-version-of-exp2f-with-no-change-in-performance/243209 offers some hints that this cuda function does not support results in the subnormal range. What's the right behavior on PT side - can we do better than the silently diverging results today?

Versions

https://gist.github.com/vkuzo/b5f49e49e7bb30e4cc59d95a0fd4249f

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang

@lezcano
Copy link
Collaborator

lezcano commented May 6, 2024

Related: triton-lang/triton#3792

@yanboliang yanboliang added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module upstream triton Upstream Triton Issue labels May 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module upstream triton Upstream Triton Issue
Projects
None yet
Development

No branches or pull requests

4 participants