From 4fb0dfea21dffa3a48e426f6f3380bb973971bdf Mon Sep 17 00:00:00 2001 From: Roy Li Date: Mon, 4 Jun 2018 19:57:44 -0700 Subject: [PATCH 1/5] Stop BCELoss from returning negative results --- aten/src/THCUNN/BCECriterion.cu | 6 +++--- aten/src/THNN/generic/BCECriterion.c | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/aten/src/THCUNN/BCECriterion.cu b/aten/src/THCUNN/BCECriterion.cu index dfef9082de949..0bfd7eab4d73a 100644 --- a/aten/src/THCUNN/BCECriterion.cu +++ b/aten/src/THCUNN/BCECriterion.cu @@ -31,7 +31,7 @@ struct bce_functor Dtype input = thrust::get<0>(x); Dtype t = thrust::get<1>(x); assert(input >= 0. && input <= 1.); - return - (t * THCNumerics::log(input + eps()) + (Acctype(1)- t) * THCNumerics::log(Acctype(1) - input + eps())); + return - (t * THCNumerics::log(input + eps()) + (Acctype(1)- t) * THCNumerics::log(Acctype(1) - input)); } }; @@ -47,7 +47,7 @@ struct bce_updateOutput_no_reduce_functor assert(*input >= 0. && *input <= 1.); *output = ScalarConvert::to( -(*target * THCNumerics::log(*input + eps()) + - (Acctype(1) - *target) * THCNumerics::log(Acctype(1) - *input + eps()))); + (Acctype(1) - *target) * THCNumerics::log(Acctype(1) - *input))); } }; @@ -63,7 +63,7 @@ struct bce_functor_weights Dtype w = thrust::get<2>(x); assert(input >= 0. && input <= 1.); return - w * (t * THCNumerics::log(input + eps()) + - (Acctype(1) - t) * THCNumerics::log(Acctype(1) - input + eps())); + (Acctype(1) - t) * THCNumerics::log(Acctype(1) - input)); } }; diff --git a/aten/src/THNN/generic/BCECriterion.c b/aten/src/THNN/generic/BCECriterion.c index 29bdabe7c633c..5faaa83a350c5 100644 --- a/aten/src/THNN/generic/BCECriterion.c +++ b/aten/src/THNN/generic/BCECriterion.c @@ -24,7 +24,7 @@ void THNN_(BCECriterion_updateOutput)( THAssertMsg(x >= 0. && x <= 1., "input value should be between 0~1, but got %f", (double) x); - *output_data = -(log(x + EPS) * y + log(1. - x + EPS) * (1. - y)); + *output_data = -(log(x + EPS) * y + log(1. - x) * (1. - y)); ); if (weights) { THTensor_(cmul)(output, output, weights); @@ -43,7 +43,7 @@ void THNN_(BCECriterion_updateOutput)( THAssertMsg(x >= 0. && x <= 1., "input value should be between 0~1, but got %f", (double) x); - sum -= (log(x + EPS) * y + log(1. - x + EPS) * (1. - y)) * w; + sum -= (log(x + EPS) * y + log(1. - x) * (1. - y)) * w; ); } else { TH_TENSOR_APPLY2(real, input, real, target, @@ -52,7 +52,7 @@ void THNN_(BCECriterion_updateOutput)( THAssertMsg(x >= 0. && x <= 1., "input value should be between 0~1, but got %f", (double) x); - sum -= log(x + EPS) * y + log(1. - x + EPS) * (1. - y); + sum -= log(x + EPS) * y + log(1. - x) * (1. - y); ); } From 22df92422b7002dc82df9e5f5b19038b579a8269 Mon Sep 17 00:00:00 2001 From: Roy Li Date: Tue, 5 Jun 2018 15:42:50 -0700 Subject: [PATCH 2/5] check explicitly for 0 before taking log --- aten/src/THCUNN/BCECriterion.cu | 20 +++++++++++++++----- aten/src/THNN/generic/BCECriterion.c | 13 ++++++++++--- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/aten/src/THCUNN/BCECriterion.cu b/aten/src/THCUNN/BCECriterion.cu index 0bfd7eab4d73a..b40d264b5bb65 100644 --- a/aten/src/THCUNN/BCECriterion.cu +++ b/aten/src/THCUNN/BCECriterion.cu @@ -21,6 +21,15 @@ inline __host__ __device__ float eps() { return 1e-12f; } template <> inline __host__ __device__ double eps() { return 1e-12; } +template +inline __host__ __device__ T safe_log(T a) { + if (a == 0.) + { + return THCNumerics::log(eps()); + } + return THCNumerics::log(a); +} + template struct bce_functor { @@ -31,7 +40,8 @@ struct bce_functor Dtype input = thrust::get<0>(x); Dtype t = thrust::get<1>(x); assert(input >= 0. && input <= 1.); - return - (t * THCNumerics::log(input + eps()) + (Acctype(1)- t) * THCNumerics::log(Acctype(1) - input)); + return - (t * safe_log(ScalarConvert::to(input)) + + (Acctype(1)- t) * safe_log(Acctype(1) - input)); } }; @@ -46,8 +56,8 @@ struct bce_updateOutput_no_reduce_functor { assert(*input >= 0. && *input <= 1.); *output = ScalarConvert::to( - -(*target * THCNumerics::log(*input + eps()) + - (Acctype(1) - *target) * THCNumerics::log(Acctype(1) - *input))); + -(*target * safe_log(ScalarConvert::to(*input)) + + (Acctype(1) - *target) * safe_log(Acctype(1) - *input))); } }; @@ -62,8 +72,8 @@ struct bce_functor_weights Dtype t = thrust::get<1>(x); Dtype w = thrust::get<2>(x); assert(input >= 0. && input <= 1.); - return - w * (t * THCNumerics::log(input + eps()) + - (Acctype(1) - t) * THCNumerics::log(Acctype(1) - input)); + return - w * (t * safe_log(ScalarConvert::to(input)) + + (Acctype(1) - t) * safe_log(Acctype(1) - input)); } }; diff --git a/aten/src/THNN/generic/BCECriterion.c b/aten/src/THNN/generic/BCECriterion.c index 5faaa83a350c5..13b415ed62187 100644 --- a/aten/src/THNN/generic/BCECriterion.c +++ b/aten/src/THNN/generic/BCECriterion.c @@ -4,6 +4,13 @@ #define EPS 1e-12 +static inline real safe_log(real a) { + if (a == 0.) { + return log(EPS); + } + return log(a); +} + void THNN_(BCECriterion_updateOutput)( THNNState *state, THTensor *input, @@ -24,7 +31,7 @@ void THNN_(BCECriterion_updateOutput)( THAssertMsg(x >= 0. && x <= 1., "input value should be between 0~1, but got %f", (double) x); - *output_data = -(log(x + EPS) * y + log(1. - x) * (1. - y)); + *output_data = -(safe_log(x) * y + safe_log(1. - x) * (1. - y)); ); if (weights) { THTensor_(cmul)(output, output, weights); @@ -43,7 +50,7 @@ void THNN_(BCECriterion_updateOutput)( THAssertMsg(x >= 0. && x <= 1., "input value should be between 0~1, but got %f", (double) x); - sum -= (log(x + EPS) * y + log(1. - x) * (1. - y)) * w; + sum -= (safe_log(x) * y + safe_log(1. - x) * (1. - y)) * w; ); } else { TH_TENSOR_APPLY2(real, input, real, target, @@ -52,7 +59,7 @@ void THNN_(BCECriterion_updateOutput)( THAssertMsg(x >= 0. && x <= 1., "input value should be between 0~1, but got %f", (double) x); - sum -= log(x + EPS) * y + log(1. - x) * (1. - y); + sum -= safe_log(x) * y + safe_log(1. - x) * (1. - y); ); } From 229472b1081d9227debbe4cf76814b6bd60108f6 Mon Sep 17 00:00:00 2001 From: Roy Li Date: Wed, 6 Jun 2018 19:00:58 -0700 Subject: [PATCH 3/5] add tests --- test/test_nn.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/test_nn.py b/test/test_nn.py index 13b7013200b80..28a127b3a54fe 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4116,6 +4116,16 @@ def func(root): gradcheck(func, [v]) gradgradcheck(func, [v]) + def test_bce_loss_always_nonnegative(self): + target = torch.ones(5) + input = torch.ones(5) + assert ((nn.BCELoss()(input,target) < 0).sum() == 0) + + target = torch.zeros(5) + input = torch.zeros(5) + print((nn.BCELoss()(input,target) < 0).sum) + assert ((nn.BCELoss()(input,target) < 0).sum() == 0) + def test_bce_with_logits_raises_if_target_and_input_are_different_size(self): target = torch.rand(5) input = torch.rand(5, 1) From 654ac994535c90fab49cf6ce3abb47ffa90a7840 Mon Sep 17 00:00:00 2001 From: Roy Li Date: Wed, 6 Jun 2018 19:19:00 -0700 Subject: [PATCH 4/5] fix lint --- test/test_nn.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 28a127b3a54fe..19c60b6876c4a 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4119,12 +4119,11 @@ def func(root): def test_bce_loss_always_nonnegative(self): target = torch.ones(5) input = torch.ones(5) - assert ((nn.BCELoss()(input,target) < 0).sum() == 0) + assert ((nn.BCELoss()(input, target) < 0).sum() == 0) target = torch.zeros(5) input = torch.zeros(5) - print((nn.BCELoss()(input,target) < 0).sum) - assert ((nn.BCELoss()(input,target) < 0).sum() == 0) + assert ((nn.BCELoss()(input, target) < 0).sum() == 0) def test_bce_with_logits_raises_if_target_and_input_are_different_size(self): target = torch.rand(5) From b5c1fd84c15a75d2f16f9cd3418ca73d61045ac3 Mon Sep 17 00:00:00 2001 From: Roy Li Date: Thu, 7 Jun 2018 13:36:19 -0700 Subject: [PATCH 5/5] address comments --- aten/src/THCUNN/BCECriterion.cu | 2 +- test/test_nn.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/THCUNN/BCECriterion.cu b/aten/src/THCUNN/BCECriterion.cu index b40d264b5bb65..3624588015c8a 100644 --- a/aten/src/THCUNN/BCECriterion.cu +++ b/aten/src/THCUNN/BCECriterion.cu @@ -41,7 +41,7 @@ struct bce_functor Dtype t = thrust::get<1>(x); assert(input >= 0. && input <= 1.); return - (t * safe_log(ScalarConvert::to(input)) - + (Acctype(1)- t) * safe_log(Acctype(1) - input)); + + (Acctype(1) - t) * safe_log(Acctype(1) - input)); } }; diff --git a/test/test_nn.py b/test/test_nn.py index 19c60b6876c4a..7cbecbcc27d4b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4119,11 +4119,11 @@ def func(root): def test_bce_loss_always_nonnegative(self): target = torch.ones(5) input = torch.ones(5) - assert ((nn.BCELoss()(input, target) < 0).sum() == 0) + self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0) target = torch.zeros(5) input = torch.zeros(5) - assert ((nn.BCELoss()(input, target) < 0).sum() == 0) + self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0) def test_bce_with_logits_raises_if_target_and_input_are_different_size(self): target = torch.rand(5)