From b308fb78d1be446bb20bf9b73ee3f6d21f080c31 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Wed, 27 Jan 2021 17:41:50 -0800 Subject: [PATCH] [ONNX] Add binary_cross_entropy_with_logits op to ONNX opset version 12 (#49675) (#50908) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50908 Fixes #{#47997} Exporting the operator binary_cross_entropy_with_logits to ONNX opset version 12. Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D26050885 Pulled By: SplitInfinity fbshipit-source-id: e4167895eed804739aa50481679500a4d564b360 --- test/onnx/test_pytorch_onnx_onnxruntime.py | 46 ++++++++++++++++++++++ torch/onnx/symbolic_opset12.py | 28 +++++++++++++ 2 files changed, 74 insertions(+) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 1c9f97488f27..42e3e91b5ce4 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -5159,6 +5159,52 @@ def forward(self, input, target): target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) self.run_test(NLLModel(), (input, target)) + + @skipIfUnsupportedMinOpsetVersion(12) + def test_binary_cross_entropy_with_logits(self): + x = torch.randn(5) + y = torch.empty(5).random_(2) + self._bce_logits_loss(x, y) + + x = torch.randn(2, 3, 5, 7) + y = torch.empty(2, 3, 5, 7).random_(2) + weight = torch.tensor([2]) + self._bce_logits_loss(x, y, weight) + + x = torch.FloatTensor([[-0.4089, -1.2471, 0.5907], [-0.4897, -0.8267, -0.7349], [0.5241, -0.1246, -0.4751]]) + y = torch.FloatTensor([[0, 1, 1], [0, 0, 1], [1, 0, 1]]) + pos_weight = torch.empty([3]).random_(2) + self._bce_logits_loss(x, y, pos_weight) + + x = torch.randn(3, 3, 4) + y = torch.empty(3, 3, 4).random_(2) + weight = torch.tensor([3]) + pos_weight = torch.empty([3, 4]).random_(2) + self._bce_logits_loss(x, y, weight, pos_weight) + + def _bce_logits_loss(self, x, y, weight=None, pos_weight=None): + class BCEWithLogitsLossNoneWeights(torch.nn.Module): + def forward(self, input, target, weight, pos_weight): + return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, + pos_weight=pos_weight, reduction='none') + + self.run_test(BCEWithLogitsLossNoneWeights(), input=(x, y, weight, pos_weight)) + + class BCEWithLogitsLossMeanWeights(torch.nn.Module): + def forward(self, input, target, weight, pos_weight): + return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, + pos_weight=pos_weight, reduction='mean') + + self.run_test(BCEWithLogitsLossMeanWeights(), input=(x, y, weight, pos_weight)) + + class BCEWithLogitsLossSumWeights(torch.nn.Module): + def forward(self, input, target, weight, pos_weight): + return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight, + pos_weight=pos_weight, reduction='sum') + + self.run_test(BCEWithLogitsLossSumWeights(), input=(x, y, weight, pos_weight)) + + def test_torch_mm(self): class M(torch.nn.Module): def forward(self, mat1, mat2): diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index 63a40b555c8e..5a926eef5e1d 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -52,6 +52,34 @@ def nll_loss2d(g, self, target, weight, reduction, ignore_index): return nll_loss(g, self, target, weight, reduction, ignore_index) +@parse_args('v', 'v', 'v', 'v', 'i') +def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduction): + from torch.onnx.symbolic_opset9 import sigmoid, log, sub, neg, mul, add + p = g.op("Constant", value_t=torch.tensor([1])) + sig_x = sigmoid(g, input) + log_sig_x = log(g, sig_x) + sub_1_x = sub(g, p, sig_x) + sub_1_y = sub(g, p, target) + log_1_x = log(g, sub_1_x) + if pos_weight is None or sym_help._is_none(pos_weight): + output = neg(g, add(g, mul(g, target, log_sig_x), mul(g, sub_1_y, log_1_x))) + else: + output = neg(g, add(g, mul(g, mul(g, target, log_sig_x), pos_weight), mul(g, sub_1_y, log_1_x))) + + if weight is not None and not sym_help._is_none(weight): + output = mul(g, weight, output) + + reduction = sym_help._maybe_get_const(reduction, 'i') + if reduction == 0: + return output + elif reduction == 1: + return g.op("ReduceMean", output) + elif reduction == 2: + return g.op("ReduceSum", output) + else: + return sym_help._onnx_unsupported("binary_cross_entropy_with_logits with reduction other than none, mean, or sum") + + def celu(g, self, alpha): alpha = sym_help._maybe_get_const(alpha, 'f') # if the input is of type double cast it to float