@@ -5089,7 +5089,7 @@ def apply_fn(input1, input2, target, *params):
5089
5089
if self .check_gradgrad :
5090
5090
gradgradcheck (apply_fn , inputs )
5091
5091
5092
- def test_cuda (self , test_case , dtype = None , extra_args = None ):
5092
+ def test_cuda (self , test_case , dtype , extra_args = None ):
5093
5093
def convert_dtype (obj , dtype , requires_grad = False ):
5094
5094
if isinstance (obj , torch .Tensor ):
5095
5095
return obj .detach ().to (dtype = dtype ).requires_grad_ (requires_grad )
@@ -5107,12 +5107,11 @@ def convert_dtype(obj, dtype, requires_grad=False):
5107
5107
gpu_module = self .constructor (* self .constructor_args )
5108
5108
5109
5109
# Convert input, target and module parameters to dtype
5110
- if dtype is not None :
5111
- cpu_input = convert_dtype (cpu_input , dtype , True )
5112
- if cpu_target .is_floating_point () or cpu_target .is_complex ():
5113
- cpu_target = convert_dtype (cpu_target , dtype )
5114
- cpu_module .type (dtype )
5115
- gpu_module .type (dtype )
5110
+ cpu_input = convert_dtype (cpu_input , dtype , True )
5111
+ if cpu_target .is_floating_point () or cpu_target .is_complex ():
5112
+ cpu_target = convert_dtype (cpu_target , dtype )
5113
+ cpu_module .type (dtype )
5114
+ gpu_module .type (dtype )
5116
5115
5117
5116
# GPU setup
5118
5117
gpu_input = to_gpu (cpu_input )
@@ -5128,13 +5127,14 @@ def convert_dtype(obj, dtype, requires_grad=False):
5128
5127
5129
5128
cpu_output = test_case ._forward_criterion (cpu_module , cpu_input , cpu_target , extra_args = extra_args )
5130
5129
gpu_output = test_case ._forward_criterion (gpu_module , gpu_input , gpu_target , extra_args = extra_args )
5131
- # dtype can be None, so set precision in this way instead of a precision map
5130
+ # dtype used to be able to be None, so set precision in this way instead of a precision map
5132
5131
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
5133
5132
test_case .assertEqualIgnoreType (cpu_output , gpu_output ,
5134
5133
atol = 1e-1 if dtype in {torch .half , torch .bfloat16 } else 4e-4 , rtol = 0 )
5135
5134
5136
5135
cpu_gradInput = test_case ._backward_criterion (cpu_module , cpu_input , cpu_target , extra_args = extra_args )
5137
5136
gpu_gradInput = test_case ._backward_criterion (gpu_module , gpu_input , gpu_target , extra_args = extra_args )
5137
+ # dtype used to be able to be None, so set precision in this way instead of a precision map
5138
5138
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
5139
5139
test_case .assertEqualIgnoreType (cpu_gradInput , gpu_gradInput ,
5140
5140
atol = 1e-1 if dtype in {torch .half , torch .bfloat16 } else 4e-4 , rtol = 0 )
0 commit comments