Skip to content

Commit 1097fe0

Browse files
gchananfacebook-github-bot
authored andcommitted
Remove CriterionTest.test_cuda code for dtype None. (#45316)
Summary: Pull Request resolved: #45316 It's never used. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D23919449 Pulled By: gchanan fbshipit-source-id: f9aaeeabf3940389156bfc01bc3118d348ca4cf6
1 parent a4486fe commit 1097fe0

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

torch/testing/_internal/common_nn.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5089,7 +5089,7 @@ def apply_fn(input1, input2, target, *params):
50895089
if self.check_gradgrad:
50905090
gradgradcheck(apply_fn, inputs)
50915091

5092-
def test_cuda(self, test_case, dtype=None, extra_args=None):
5092+
def test_cuda(self, test_case, dtype, extra_args=None):
50935093
def convert_dtype(obj, dtype, requires_grad=False):
50945094
if isinstance(obj, torch.Tensor):
50955095
return obj.detach().to(dtype=dtype).requires_grad_(requires_grad)
@@ -5107,12 +5107,11 @@ def convert_dtype(obj, dtype, requires_grad=False):
51075107
gpu_module = self.constructor(*self.constructor_args)
51085108

51095109
# Convert input, target and module parameters to dtype
5110-
if dtype is not None:
5111-
cpu_input = convert_dtype(cpu_input, dtype, True)
5112-
if cpu_target.is_floating_point() or cpu_target.is_complex():
5113-
cpu_target = convert_dtype(cpu_target, dtype)
5114-
cpu_module.type(dtype)
5115-
gpu_module.type(dtype)
5110+
cpu_input = convert_dtype(cpu_input, dtype, True)
5111+
if cpu_target.is_floating_point() or cpu_target.is_complex():
5112+
cpu_target = convert_dtype(cpu_target, dtype)
5113+
cpu_module.type(dtype)
5114+
gpu_module.type(dtype)
51165115

51175116
# GPU setup
51185117
gpu_input = to_gpu(cpu_input)
@@ -5128,13 +5127,14 @@ def convert_dtype(obj, dtype, requires_grad=False):
51285127

51295128
cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target, extra_args=extra_args)
51305129
gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target, extra_args=extra_args)
5131-
# dtype can be None, so set precision in this way instead of a precision map
5130+
# dtype used to be able to be None, so set precision in this way instead of a precision map
51325131
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
51335132
test_case.assertEqualIgnoreType(cpu_output, gpu_output,
51345133
atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0)
51355134

51365135
cpu_gradInput = test_case._backward_criterion(cpu_module, cpu_input, cpu_target, extra_args=extra_args)
51375136
gpu_gradInput = test_case._backward_criterion(gpu_module, gpu_input, gpu_target, extra_args=extra_args)
5137+
# dtype used to be able to be None, so set precision in this way instead of a precision map
51385138
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
51395139
test_case.assertEqualIgnoreType(cpu_gradInput, gpu_gradInput,
51405140
atol=1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4, rtol=0)

0 commit comments

Comments
 (0)