Skip to content

Commit

Permalink
[ONNX] Add binary_cross_entropy_with_logits op to ONNX opset version …
Browse files Browse the repository at this point in the history
…12 (#49675) (#50908)

Summary:
Pull Request resolved: #50908

Fixes #{#47997}
Exporting the operator binary_cross_entropy_with_logits to ONNX opset version 12.

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D26050885

Pulled By: SplitInfinity

fbshipit-source-id: e4167895eed804739aa50481679500a4d564b360
  • Loading branch information
BowenBao authored and facebook-github-bot committed Jan 28, 2021
1 parent 1723ab5 commit b308fb7
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
46 changes: 46 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -5159,6 +5159,52 @@ def forward(self, input, target):
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
self.run_test(NLLModel(), (input, target))


@skipIfUnsupportedMinOpsetVersion(12)
def test_binary_cross_entropy_with_logits(self):
x = torch.randn(5)
y = torch.empty(5).random_(2)
self._bce_logits_loss(x, y)

x = torch.randn(2, 3, 5, 7)
y = torch.empty(2, 3, 5, 7).random_(2)
weight = torch.tensor([2])
self._bce_logits_loss(x, y, weight)

x = torch.FloatTensor([[-0.4089, -1.2471, 0.5907], [-0.4897, -0.8267, -0.7349], [0.5241, -0.1246, -0.4751]])
y = torch.FloatTensor([[0, 1, 1], [0, 0, 1], [1, 0, 1]])
pos_weight = torch.empty([3]).random_(2)
self._bce_logits_loss(x, y, pos_weight)

x = torch.randn(3, 3, 4)
y = torch.empty(3, 3, 4).random_(2)
weight = torch.tensor([3])
pos_weight = torch.empty([3, 4]).random_(2)
self._bce_logits_loss(x, y, weight, pos_weight)

def _bce_logits_loss(self, x, y, weight=None, pos_weight=None):
class BCEWithLogitsLossNoneWeights(torch.nn.Module):
def forward(self, input, target, weight, pos_weight):
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight,
pos_weight=pos_weight, reduction='none')

self.run_test(BCEWithLogitsLossNoneWeights(), input=(x, y, weight, pos_weight))

class BCEWithLogitsLossMeanWeights(torch.nn.Module):
def forward(self, input, target, weight, pos_weight):
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight,
pos_weight=pos_weight, reduction='mean')

self.run_test(BCEWithLogitsLossMeanWeights(), input=(x, y, weight, pos_weight))

class BCEWithLogitsLossSumWeights(torch.nn.Module):
def forward(self, input, target, weight, pos_weight):
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight,
pos_weight=pos_weight, reduction='sum')

self.run_test(BCEWithLogitsLossSumWeights(), input=(x, y, weight, pos_weight))


def test_torch_mm(self):
class M(torch.nn.Module):
def forward(self, mat1, mat2):
Expand Down
28 changes: 28 additions & 0 deletions torch/onnx/symbolic_opset12.py
Expand Up @@ -52,6 +52,34 @@ def nll_loss2d(g, self, target, weight, reduction, ignore_index):
return nll_loss(g, self, target, weight, reduction, ignore_index)


@parse_args('v', 'v', 'v', 'v', 'i')
def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduction):
from torch.onnx.symbolic_opset9 import sigmoid, log, sub, neg, mul, add
p = g.op("Constant", value_t=torch.tensor([1]))
sig_x = sigmoid(g, input)
log_sig_x = log(g, sig_x)
sub_1_x = sub(g, p, sig_x)
sub_1_y = sub(g, p, target)
log_1_x = log(g, sub_1_x)
if pos_weight is None or sym_help._is_none(pos_weight):
output = neg(g, add(g, mul(g, target, log_sig_x), mul(g, sub_1_y, log_1_x)))
else:
output = neg(g, add(g, mul(g, mul(g, target, log_sig_x), pos_weight), mul(g, sub_1_y, log_1_x)))

if weight is not None and not sym_help._is_none(weight):
output = mul(g, weight, output)

reduction = sym_help._maybe_get_const(reduction, 'i')
if reduction == 0:
return output
elif reduction == 1:
return g.op("ReduceMean", output)
elif reduction == 2:
return g.op("ReduceSum", output)
else:
return sym_help._onnx_unsupported("binary_cross_entropy_with_logits with reduction other than none, mean, or sum")


def celu(g, self, alpha):
alpha = sym_help._maybe_get_const(alpha, 'f')
# if the input is of type double cast it to float
Expand Down

0 comments on commit b308fb7

Please sign in to comment.