From fdd15f6c800888a8b702fff0fa44e11884afcac8 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 9 Jul 2018 11:03:09 -0600 Subject: [PATCH 1/2] MSELoss prec test --- test/test_nn.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/test_nn.py b/test/test_nn.py index d3c673f37391c..8e794265db463 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -5933,6 +5933,15 @@ def forward(self, *args): check_sum_reduction=True, desc='scalar' ), + dict( + module_name='MSELoss', + input_fn=lambda: torch.ones(5, 68, 64, 64, dtype=torch.float) / 10, + target_fn=lambda: torch.zeros(5, 68, 64, 64, dtype=torch.float), + reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / + (i.numel() if get_reduction(m) == 'elementwise_mean' else 1)), + check_forward_only=True, + desc='prec', + ), dict( module_name='BCELoss', constructor_args_fn=lambda: (torch.rand(()),), From 2d13c927cc22c015f392c040b3aa9578ef6f9382 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 9 Jul 2018 15:09:53 -0600 Subject: [PATCH 2/2] MSELoss precision checks --- aten/src/THNN/generic/MSECriterion.c | 6 +++--- test/common_nn.py | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/aten/src/THNN/generic/MSECriterion.c b/aten/src/THNN/generic/MSECriterion.c index e236c8ea61c6d..b7c6e07d0d039 100644 --- a/aten/src/THNN/generic/MSECriterion.c +++ b/aten/src/THNN/generic/MSECriterion.c @@ -14,17 +14,17 @@ void THNN_(MSECriterion_updateOutput)( if (reduction != Reduction::None) { THTensor_(resize1d)(output, 1); - real sum = 0; + accreal sum = 0; TH_TENSOR_APPLY2(real, input, real, target, - real z = (*input_data - *target_data); + accreal z = (*input_data - *target_data); sum += z*z; ); if (reduction == Reduction::ElementwiseMean) sum /= THTensor_(nElement)(input); - THTensor_(set1d)(output, 0, sum); + THTensor_(set1d)(output, 0, (real)sum); return; } diff --git a/test/common_nn.py b/test/common_nn.py index ba161b39f0b2e..6172f4b15adc3 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -1087,6 +1087,7 @@ class CriterionTest(TestBase): def __init__(self, *args, **kwargs): super(CriterionTest, self).__init__(*args, **kwargs) self.should_test_cuda = kwargs.get('test_cuda', True) + self.check_forward_only = kwargs.get('check_forward_only', True) def _get_target(self): return self._get_arg('target', True) @@ -1109,6 +1110,9 @@ def __call__(self, test_case): expected_out = expected_out.item() test_case.assertEqual(out, expected_out) + if self.check_forward_only: + return + test_case.check_criterion_jacobian(module, input, target) self._do_extra_tests(test_case, module, input, target)