Skip to content

Commit

Permalink
[ONNX] Reimplement _var_mean to ensure non-negative (#47240)
Browse files Browse the repository at this point in the history
Summary:
The current `_var_mean` implementation cannot ensure non-negative for variance, because it is actually `E(X^2)-(E(X))^2`: numerically when the dimension number is large and X is close to 0, it can have negative numbers (like our UT shows). The new implementation is `(E(X-E(X))^2)`, it ensures non-negative because the expectation of square is non-negative for sure.

The UT passes for the new implementation (but fails for the existing one). So it is good to go.

Pull Request resolved: #47240

Reviewed By: ejguan

Differential Revision: D24735729

Pulled By: bzinodev

fbshipit-source-id: 136f448dd16622b2b46f40cdf6cb2fccf357c48d
  • Loading branch information
jiafatom authored and facebook-github-bot committed Nov 7, 2020
1 parent f23a2a1 commit 6e69a24
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 69 deletions.
121 changes: 57 additions & 64 deletions test/onnx/expect/TestOperators.test_std.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,9 @@ producer_name: "pytorch"
producer_version: "CURRENT_VERSION"
graph {
node {
input: "0"
input: "0"
output: "1"
name: "Mul_0"
op_type: "Mul"
}
node {
input: "1"
output: "2"
name: "ReduceMean_1"
op_type: "ReduceMean"
attribute {
name: "axes"
ints: 0
ints: 1
type: INTS
}
attribute {
name: "keepdims"
i: 1
type: INT
}
}
node {
input: "0"
output: "3"
name: "ReduceMean_2"
name: "ReduceMean_0"
op_type: "ReduceMean"
attribute {
name: "axes"
Expand All @@ -45,13 +21,13 @@ graph {
}
node {
input: "0"
output: "4"
name: "Shape_3"
output: "2"
name: "Shape_1"
op_type: "Shape"
}
node {
output: "5"
name: "Constant_4"
output: "3"
name: "Constant_2"
op_type: "Constant"
attribute {
name: "value"
Expand All @@ -64,10 +40,10 @@ graph {
}
}
node {
input: "4"
input: "5"
output: "6"
name: "Gather_5"
input: "2"
input: "3"
output: "4"
name: "Gather_3"
op_type: "Gather"
attribute {
name: "axis"
Expand All @@ -76,9 +52,9 @@ graph {
}
}
node {
input: "6"
output: "7"
name: "ReduceProd_6"
input: "4"
output: "5"
name: "ReduceProd_4"
op_type: "ReduceProd"
attribute {
name: "keepdims"
Expand All @@ -87,23 +63,40 @@ graph {
}
}
node {
input: "3"
input: "3"
output: "8"
name: "Mul_7"
op_type: "Mul"
input: "0"
input: "1"
output: "6"
name: "Sub_5"
op_type: "Sub"
}
node {
input: "2"
input: "8"
output: "9"
name: "Sub_8"
op_type: "Sub"
input: "6"
input: "6"
output: "7"
name: "Mul_6"
op_type: "Mul"
}
node {
input: "7"
output: "10"
name: "Cast_9"
output: "8"
name: "ReduceMean_7"
op_type: "ReduceMean"
attribute {
name: "axes"
ints: 0
ints: 1
type: INTS
}
attribute {
name: "keepdims"
i: 1
type: INT
}
}
node {
input: "5"
output: "9"
name: "Cast_8"
op_type: "Cast"
attribute {
name: "to"
Expand All @@ -112,15 +105,15 @@ graph {
}
}
node {
input: "8"
input: "9"
input: "10"
output: "11"
name: "Mul_10"
output: "10"
name: "Mul_9"
op_type: "Mul"
}
node {
output: "12"
name: "Constant_11"
output: "11"
name: "Constant_10"
op_type: "Constant"
attribute {
name: "value"
Expand All @@ -131,24 +124,24 @@ graph {
type: TENSOR
}
}
node {
input: "9"
input: "11"
output: "12"
name: "Sub_11"
op_type: "Sub"
}
node {
input: "10"
input: "12"
output: "13"
name: "Sub_12"
op_type: "Sub"
name: "Div_12"
op_type: "Div"
}
node {
input: "11"
input: "13"
output: "14"
name: "Div_13"
op_type: "Div"
}
node {
input: "14"
output: "15"
name: "Sqrt_14"
name: "Sqrt_13"
op_type: "Sqrt"
}
name: "torch-jit-export"
Expand All @@ -172,7 +165,7 @@ graph {
}
}
output {
name: "15"
name: "14"
type {
tensor_type {
elem_type: 1
Expand Down
9 changes: 9 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1866,6 +1866,15 @@ def forward(self, input):
model = VarianceUnbiased()
self.run_test(model, x)

class VarianceSqrt(torch.nn.Module):
def forward(self, input):
y = torch.var(input, 1)
return torch.sqrt(y + 1e-8)

x = torch.randn(1, 2, 3, 300, 300)
model = VarianceSqrt()
self.run_test(model, x)

def test_var_along_dims(self):
class Variance(torch.nn.Module):
def forward(self, input):
Expand Down
9 changes: 4 additions & 5 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -2219,20 +2219,19 @@ def gather(g, self, dim, index, sparse_grad=False):

@parse_args('v', 'is', 'b', 'i')
def _var_mean(g, input, dim, unbiased, keepdim):
sqrd = g.op("Mul", input, input)
if dim is None:
sqrdmean = g.op("ReduceMean", sqrd, keepdims_i=0)
mean = g.op("ReduceMean", input, keepdims_i=0)
num_elements = numel(g, input)
else:
sqrdmean = g.op("ReduceMean", sqrd, axes_i=dim, keepdims_i=keepdim)
mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim)
redudced_dims = g.op("Shape", input)
# dim could contain one or multiple dimensions
redudced_dims = g.op("Gather", redudced_dims, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0)
num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0)
meansqrd = g.op("Mul", mean, mean)
var = g.op("Sub", sqrdmean, meansqrd)
sub_v = g.op("Sub", input, mean)
sqr_sub = g.op("Mul", sub_v, sub_v)
keepdim_mean = 0 if dim is None else keepdim
var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean)
# Correct bias in calculating variance, by dividing it over (N - 1) instead on N
if unbiased:
num_elements = g.op("Cast", num_elements, to_i=sym_help.cast_pytorch_to_onnx['Float'])
Expand Down

0 comments on commit 6e69a24

Please sign in to comment.