Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jit Error with CUDA and FP16 -- identifier "aten_add_flat__1" is undefined #47138

Closed
erikwijmans opened this issue Oct 30, 2020 · 9 comments
Closed
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects

Comments

@erikwijmans
Copy link

erikwijmans commented Oct 30, 2020

🐛 Bug

When running a scripted module with a cuda device and with fp16, I get the following error when computing the backwards pass:

RuntimeError: default_program(59): error: identifier "aten_add_flat__1" is undefined

default_program(60): error: no operator "=" matches these operands
            operand types are: half = float

To Reproduce

Steps to reproduce the behavior:

python repro.py where repro has the following contents:

import torch
from torch import nn
from torch.nn import functional as F

dtype = torch.float16
device = torch.device("cuda", 0)


class MockSEFixupBasicBlock(nn.Module):
    def __init__(self, inplanes, planes):
        super().__init__()

        self.fixup_bias2a = nn.Parameter(torch.zeros(1))
        self.fixup_scale = nn.Parameter(torch.ones(1))
        self.fixup_bias2b = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        identity = x
        out = x

        out = out + self.fixup_bias2a
        out = out * self.fixup_scale + self.fixup_bias2b

        return out * out + identity


net = torch.jit.script(MockSEFixupBasicBlock(64, 64)).to(dtype=dtype, device=device)

inp = torch.randn(16, 64, 16, 16, dtype=dtype, device=device)

for i in range(10):
    for param in net.parameters():
        param.grad = None

    print(i)
    net(inp).mean().backward()

I have minified this as much as I can (started with the actual module from my network and removed stuff that didn't cause the error)

Running this produces the following:

0
1
Traceback (most recent call last):
  File "repro.py", line 36, in <module>
    net(inp).mean().backward()
  File "/private/home/erikwijmans/miniconda3/envs/v4r/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
RuntimeError: default_program(59): error: identifier "aten_add_flat__1" is undefined

default_program(60): error: no operator "=" matches these operands
            operand types are: half = float

2 errors detected in the compilation of "default_program".

nvrtc compilation failed: 

#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)


template<typename T>
__device__ T maximum(T a, T b) {
  return isnan(a) ? a : (a > b ? a : b);
}

template<typename T>
__device__ T minimum(T a, T b) {
  return isnan(a) ? a : (a < b ? a : b);
}


#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
#if defined(__cplusplus)
  struct __align__(2) __half {
    __host__ __device__ __half() { }

  protected:
    unsigned short __x;
  };

  /* All intrinsic functions are only available to nvcc compilers */
  #if defined(__CUDACC__)
    /* Definitions of intrinsics */
    __device__ __half __float2half(const float f) {
      __half val;
      asm("{  cvt.rn.f16.f32 %0, %1;}\n" : "=h"(__HALF_TO_US(val)) : "f"(f));
      return val;
    }

    __device__ float __half2float(const __half h) {
      float val;
      asm("{  cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__HALF_TO_CUS(h)));
      return val;
    }

  #endif /* defined(__CUDACC__) */
#endif /* defined(__cplusplus) */
#undef __HALF_TO_US
#undef __HALF_TO_CUS

typedef __half half;

extern "C" __global__
void func_1(half* t0, half* t1, half* t2, half* t3, half* aten_add_flat, half* aten_add_flat_1, half* aten_add_flat_2) {
{
  float t3_ = __half2float(t3[0]);
  float v = __half2float(t0[16 * ((512 * blockIdx.x + threadIdx.x) / 16) + (512 * blockIdx.x + threadIdx.x) % 16]);
  aten_add_flat_2[512 * blockIdx.x + threadIdx.x] = __float2half(v + t3_);
  float t2_ = __half2float(t2[0]);
  float t1_ = __half2float(t1[0]);
  float aten_add_flat_ = __half2float(aten_add_flat_1[512 * blockIdx.x + threadIdx.x]);
  aten_add_flat__1 = __float2half((__half2float(t0[16 * ((512 * blockIdx.x + threadIdx.x) / 16) + (512 * blockIdx.x + threadIdx.x) % 16]) + t3_) * t2_ + t1_);
  aten_add_flat_1[512 * blockIdx.x + threadIdx.x] = aten_add_flat_;
  float v_1 = __half2float(t0[16 * ((512 * blockIdx.x + threadIdx.x) / 16) + (512 * blockIdx.x + threadIdx.x) % 16]);
  float v_2 = __half2float(t0[16 * ((512 * blockIdx.x + threadIdx.x) / 16) + (512 * blockIdx.x + threadIdx.x) % 16]);
  float v_3 = __half2float(t0[16 * ((512 * blockIdx.x + threadIdx.x) / 16) + (512 * blockIdx.x + threadIdx.x) % 16]);
  aten_add_flat[512 * blockIdx.x + threadIdx.x] = __float2half(((v_1 + t3_) * t2_ + t1_) * ((v_2 + t3_) * t2_ + t1_) + v_3);
}
}

The exact number of iterations it runs for before erroring seems to be somewhat stochastic, but I have never seen it error on the first iteration and have only seen it error on the 2nd or 3rd.

I have also seen with slightly different variants of this that aten_mul_flat__1 is undefined. I assume the root cause is the same, but thought I would point this out.

Expected behavior

Does not crash

Environment

  • PyTorch Version (e.g., 1.0): 1.7.0-py3.6_cuda10.1.243_cudnn7.6.3_0
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): N/A
  • Python version: 3.6
  • CUDA/cuDNN version: 10.1/7603
  • GPU models and configuration: Quadro GP100 and Tesla V100
  • Any other relevant information: Works on the same system(s) with 1.6.0-py3.6_cuda10.1.243_cudnn7.6.3_0

cc @gmagogsfm

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Oct 30, 2020
@github-actions github-actions bot added this to Need triage in JIT Triage Oct 30, 2020
@erikwijmans erikwijmans changed the title Jit Error with CUDA and FP16 -- identifier "aten_mul_flat__1" is undefined Jit Error with CUDA and FP16 -- identifier "aten_add_flat__1" is undefined Oct 31, 2020
@nickgg
Copy link
Contributor

nickgg commented Nov 2, 2020

This looks like an issue with the variable uniquing rules in the TE compiler, I'll take a look.

@nickgg
Copy link
Contributor

nickgg commented Nov 2, 2020

@erikwijmans I think I should have fixed this in #47229, at least I used your repro as the test case for it (thanks for the minimal repro btw!). Please let me know if the issue persists and I'll jump back on it.

@SplitInfinity SplitInfinity moved this from Need triage to In progress in JIT Triage Nov 2, 2020
@erikwijmans
Copy link
Author

@nickgg Glad the repro was helpful! I built from scratch on the branch from the PR and this fixes it (both the repro and the full model). Thank you!

JIT Triage automation moved this from In progress to Done Nov 4, 2020
facebook-github-bot pushed a commit that referenced this issue Nov 6, 2020
…47448)

Summary:
Take 2 of this fix, I removed the repro from the issue which is a bit flaky due to parallelism. It broke on Windows but isn't specific to Windows or this fix, I think. I'll make sure all the tests pass this time (cc zou3519).

Fixes an issue where fp16 scalars created by the registerizer could be referenced as floats - causing invalid conversions which would crash in the NVRTX compile. I also noticed that we were inserting patterns like float(half(float(X))) and added a pass to collapse those down inside the CudaHalfScalarRewriter.

Fixes #47138

Pull Request resolved: #47448

Reviewed By: glaringlee

Differential Revision: D24765070

Pulled By: nickgg

fbshipit-source-id: 5297e647534d53657bef81f4798e8aa6a93d1fbd
@nikhilmishra000
Copy link

I am still having this issue:

  • I have torch 1.7.1 on ubuntu 18.04 / python 3.7 / cuda 11.0 (installed via pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html)
  • It fails consistently using the repro.py script from above

@erikwijmans
Copy link
Author

The fix for this wasn't cherry picked into 1.7.1. It should be in nightly tho.

@nikhilmishra000
Copy link

oh got it -- will it be included in 1.8? do you have a ballpark estimate of when that will be released?

@erikwijmans
Copy link
Author

I am not Pytroch team, so I don't know the exact details, but since the fix is in master I assume it will be. Don't have a guess for the 1.8 release timeline tho.

@ArijRB
Copy link

ArijRB commented Jan 8, 2021

I have the same issue with, I have : pytorch 1.7.1 py3.8_cuda10.2.89_cudnn7.6.5_0

It fails consistently using the repro.py script from above
Thank you in advance

@ArijRB
Copy link

ArijRB commented Jan 13, 2021

Hello , any updates for this issue? Thank you in advance
@nickgg

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
JIT Triage
  
Done
5 participants