torch.jit.script: RuntimeError: default_program(61): error: no suitable conversion function from "__half" to "float" exists #45953
Labels
oncall: jit
Add this issue/PR to JIT oncall triage queue
Projects
When I ran deepspeed bert example, I received the following error:
`Traceback (most recent call last):
File "/home/ec2-user/DeepSpeedExamples/bing_bert/deepspeed_train.py", line 532, in
main()
File "/home/ec2-user/DeepSpeedExamples/bing_bert/deepspeed_train.py", line 525, in main
run(args, model, optimizer, start_epoch)
File "/home/ec2-user/DeepSpeedExamples/bing_bert/deepspeed_train.py", line 491, in run
train(args, index, model, optimizer, pretrain_dataset_provider)
File "/home/ec2-user/DeepSpeedExamples/bing_bert/deepspeed_train.py", line 166, in train
loss = model.network(batch)
File "/home/ec2-user/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/ec2-user/.local/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 743, in forward
loss = self.module(*inputs, **kwargs)
File "/home/ec2-user/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/ec2-user/DeepSpeedExamples/bing_bert/nvidia/modelingpreln.py", line 1144, in forward
sequence_output, pooled_output, masked_token_indexes)
File "/home/ec2-user/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/ec2-user/DeepSpeedExamples/bing_bert/nvidia/modelingpreln.py", line 769, in forward
masked_token_indexes)
File "/home/ec2-user/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/ec2-user/DeepSpeedExamples/bing_bert/nvidia/modelingpreln.py", line 721, in forward
hidden_states = self.transform(hidden_states)
File "/home/ec2-user/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/ec2-user/DeepSpeedExamples/bing_bert/nvidia/modelingpreln.py", line 701, in forward
hidden_states = self.dense_act(hidden_states)
File "/home/ec2-user/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/ec2-user/DeepSpeedExamples/bing_bert/nvidia/modelingpreln.py", line 252, in forward
return bias_gelu(self.bias, F.linear(input, self.weight, None))
RuntimeError: default_program(61): error: no suitable conversion function from "__half" to "float" exists
default_program(61): error: no suitable constructor exists to convert from "float" to "__half"
default_program(66): error: no suitable conversion function from "__half" to "float" exists
default_program(66): error: no suitable constructor exists to convert from "float" to "__half"
default_program(66): error: no operator "*" matches these operands
operand types are: __half * __half
5 errors detected in the compilation of "default_program".
nvrtc compilation failed:
#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)
template
device T maximum(T a, T b) {
return isnan(a) ? a : (a > b ? a : b);
}
template
device T minimum(T a, T b) {
return isnan(a) ? a : (a < b ? a : b);
}
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
#if defined(__cplusplus)
struct align(2) __half {
host device __half() { }
protected:
unsigned short __x;
};
/* All intrinsic functions are only available to nvcc compilers /
#if defined(CUDACC)
/ Definitions of intrinsics */
device __half __float2half(const float f) {
__half val;
asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(__HALF_TO_US(val)) : "f"(f));
return val;
}
#endif /* defined(CUDACC) /
#endif / defined(__cplusplus) */
#undef __HALF_TO_US
#undef __HALF_TO_CUS
typedef __half half;
extern "C" global
void func_2(half* t0, half* t1, half* aten_mul_flat, half* aten_add_flat, half* aten_div_flat, half* aten_mul_flat_1) {
{
float v = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 1024]);
float v_1 = __half2float(t1[512 * blockIdx.x + threadIdx.x]);
aten_div_flat[512 * blockIdx.x + threadIdx.x] = float2half((v + v_1) / 1.414209961891174f);
half aten_mul_flat = aten_mul_flat_1[512 * blockIdx.x + threadIdx.x];
aten_mul_flat = __float2half((__half2float(t0[(512 * blockIdx.x + threadIdx.x) % 1024]) + _half2float(t1[512 * blockIdx.x + threadIdx.x])) * 0.5f);
aten_mul_flat_1[512 * blockIdx.x + threadIdx.x] = aten_mul_flat;
float v_2 = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 1024]);
float v_3 = __half2float(t1[512 * blockIdx.x + threadIdx.x]);
aten_add_flat[512 * blockIdx.x + threadIdx.x] = __float2half(__half2float(erff(__float2half((v_2 + v_3) / 1.414209961891174f))) + 1.f);
float v_4 = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 1024]);
float v_5 = __half2float(t1[512 * blockIdx.x + threadIdx.x]);
float v_6 = __half2float(t0[(512 * blockIdx.x + threadIdx.x) % 1024]);
float v_7 = __half2float(t1[512 * blockIdx.x + threadIdx.x]);
aten_mul_flat[512 * blockIdx.x + threadIdx.x] = __float2half((v_4 + v_5) * 0.5f) * __float2half(__half2float(erff(__float2half((v_6 + v_7) / 1.414209961891174f))) + 1.f);
}
}
`
I found that it is due to the torch jit script decorator . If I remove this decorator, everything goes well.
I use the pytorch master version on Oct 4th, 2020.
cc @gmagogsfm
The text was updated successfully, but these errors were encountered: