Skip to content

Commit

Permalink
fix comments:fix reduction message, delete duplicate test
Browse files Browse the repository at this point in the history
  • Loading branch information
hwangdeyu committed Jan 6, 2021
1 parent 0992510 commit cdc08ce
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 35 deletions.
39 changes: 10 additions & 29 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5092,56 +5092,40 @@ def forward(self, input, target):
@skipIfUnsupportedMinOpsetVersion(12)
def test_binary_cross_entropy_with_logits(self):
x = torch.randn(5)
y = torch.randn(5)
y = torch.empty(5).random_(2)
self._bce_logits_loss(x, y)

x = torch.randn(2, 3, 5, 7)
y = torch.randn(2, 3, 5, 7)
self._bce_logits_loss(x, y)
y = torch.empty(2, 3, 5, 7).random_(2)
weight = torch.tensor([2])
self._bce_logits_loss(x, y, weight)

x = torch.FloatTensor([[-0.4089, -1.2471, 0.5907], [-0.4897, -0.8267, -0.7349], [0.5241, -0.1246, -0.4751]])
y = torch.FloatTensor([[0, 1, 1], [0, 0, 1], [1, 0, 1]])
self._bce_logits_loss(x, y)
pos_weight = torch.empty([3]).random_(2)
self._bce_logits_loss(x, y, pos_weight)

x = torch.randn(3, 3, requires_grad=True)
y = torch.empty(3, 3).random_(2)
weight = torch.tensor([3], dtype=torch.float)
pos_weight = torch.ones([1])
x = torch.randn(3, 3, 4)
y = torch.empty(3, 3, 4).random_(2)
weight = torch.tensor([3])
pos_weight = torch.empty([3, 4]).random_(2)
self._bce_logits_loss(x, y, weight, pos_weight)

def _bce_logits_loss(self, x, y, weight=None, pos_weight=None):
class BCEWithLogitsLossNone(torch.nn.Module):
def forward(self, input, target):
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, reduction='none')

self.run_test(BCEWithLogitsLossNone(), input=(x, y))

class BCEWithLogitsLossNoneWeights(torch.nn.Module):
def forward(self, input, target, weight, pos_weight):
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight,
pos_weight=pos_weight, reduction='none')

self.run_test(BCEWithLogitsLossNoneWeights(), input=(x, y, weight, pos_weight))

class BCEWithLogitsLossMean(torch.nn.Module):
def forward(self, input, target):
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, reduction='mean')

self.run_test(BCEWithLogitsLossMean(), input=(x, y))

class BCEWithLogitsLossMeanWeights(torch.nn.Module):
def forward(self, input, target, weight, pos_weight):
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight,
pos_weight=pos_weight, reduction='mean')

self.run_test(BCEWithLogitsLossMeanWeights(), input=(x, y, weight, pos_weight))

class BCEWithLogitsLossSum(torch.nn.Module):
def forward(self, input, target):
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, reduction='sum')

self.run_test(BCEWithLogitsLossSum(), input=(x, y))

class BCEWithLogitsLossSumWeights(torch.nn.Module):
def forward(self, input, target, weight, pos_weight):
return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=weight,
Expand Down Expand Up @@ -5371,10 +5355,7 @@ def forward(self, x):

@skipIfONNXShapeInference(False)
@skipIfUnsupportedMinOpsetVersion(13)
<<<<<<< HEAD
=======
@skipIfUnsupportedOpsetVersion([13])
>>>>>>> 616da7c185f710f612b63701d8adb19e544913a9
def test_if_list(self):
class IfModel(torch.nn.Module):
def forward(self, x, y, cond):
Expand Down
5 changes: 0 additions & 5 deletions torch/csrc/autograd/VariableTypeManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,6 @@ Tensor& resize_as_(
at::resize_as_(self_, the_template_, optional_memory_format);
}

// Handle fw grad
if (self.fw_grad(/* level */ 0).defined()) {
AT_ERROR("cannot resize variables that has a forward grad");
}

// Handle fw grad
if (self.fw_grad(/* level */ 0).defined()) {
AT_ERROR("cannot resize variables that has a forward grad");
Expand Down
2 changes: 1 addition & 1 deletion torch/onnx/symbolic_opset12.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduc
elif reduction == 2:
return g.op("ReduceSum", output)
else:
return sym_help._onnx_unsupported("binary_cross_entropy_with_logits")
return sym_help._onnx_unsupported("binary_cross_entropy_with_logits with reduction other than none, mean, or sum")


def celu(g, self, alpha):
Expand Down

0 comments on commit cdc08ce

Please sign in to comment.